In [ ]:
epochs = 15

Tutorial 2 - MultiLayer Split Neural Network

Recap: The previous tutorial looked at building a basic SplitNN, where an NN was split into two segments on two seperate hosts. However, there is a lot more that we can do with this technique. An NN can be split any number of times without affecting the accuracy of the model.

Description: Here we define a class which can procees a SplitNN of any number of layers. All it needs is a list of distributed models and their optimizers.

In this tutorial, we demonstrate the SplitNN class with a 3 segment distribution [1]. This time;

  • Alice
    • Has Model Segment 1
    • Has the handwritten images
  • Bob
    • Has model Segment 2
  • Claire
    • Has Model Segment 3
    • Has the image labels

We use the exact same model as we used in the previous tutorial, only this time we are splitting over 3 hosts, not two. However, we see the same loss being reported as there is no reduction in accuracy when training in this way. While we only use 3 models this can be done for any arbitrary number of models.

Author:


In [ ]:
import torch
from torchvision import datasets, transforms
from torch import nn, optim
import syft as sy
hook = sy.TorchHook(torch)

In [ ]:
class SplitNN(torch.nn.Module):
    def __init__(self, models, optimizers):
        self.models = models
        self.optimizers = optimizers
        self.outputs = [None]*len(self.models)
        self.inputs = [None]*len(self.models)
        super().__init__()
        
    def forward(self, x):
        self.inputs[0] = x
        self.outputs[0] = self.models[0](self.inputs[0])
        
        for i in range(1, len(self.models)):
            self.inputs[i] = self.outputs[i-1].detach().requires_grad_()
            if self.outputs[i-1].location != self.models[i].location:
                self.inputs[i] = self.inputs[i].move(self.models[i].location).requires_grad_()               
            self.outputs[i] = self.models[i](self.inputs[i])
        
        return self.outputs[-1]
    
    def backward(self):
        for i in range(len(self.models)-2, -1, -1):
            grad_in = self.inputs[i+1].grad.copy()
            if self.outputs[i].location != self.inputs[i+1].location:
                grad_in = grad_in.move(self.outputs[i].location)
            self.outputs[i].backward(grad_in)
    
    def zero_grads(self):
        for opt in self.optimizers:
            opt.zero_grad()
        
    def step(self):
        for opt in self.optimizers:
            opt.step()
    
    def train(self):
        for model in self.models:
            model.train()
    
    def eval(self):
        for model in self.models:
            model.eval()
            
    @property
    def location(self):
        return self.models[0].location if self.models and len(self.models) else None

In [ ]:
# Data preprocessing
transform = transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize((0.5,), (0.5,)),
                              ])
trainset = datasets.MNIST('mnist', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

In [ ]:
torch.manual_seed(0)

# Define our model segments

input_size = 784
hidden_sizes = [128, 640]
output_size = 10

models = [
    nn.Sequential(
                nn.Linear(input_size, hidden_sizes[0]),
                nn.ReLU(),
    ),
    nn.Sequential(
                nn.Linear(hidden_sizes[0], hidden_sizes[1]),
                nn.ReLU(),
    ),
    nn.Sequential(
                nn.Linear(hidden_sizes[1], output_size),
                nn.LogSoftmax(dim=1)
    )
]

# Create optimisers for each segment and link to them
optimizers = [
    optim.SGD(model.parameters(), lr=0.03,)
    for model in models
]

# create some workers
alice = sy.VirtualWorker(hook, id="alice")
bob = sy.VirtualWorker(hook, id="bob")
claire = sy.VirtualWorker(hook, id="claire")

# Send Model Segments to model locations
model_locations = [alice, bob, claire]
for model, location in zip(models, model_locations):
    model.send(location)

#Instantiate a SpliNN class with our distributed segments and their respective optimizers
splitNN =  SplitNN(models, optimizers)

In [ ]:
def train(x, target, splitNN):
    
    #1) Zero our grads
    splitNN.zero_grads()
    
    #2) Make a prediction
    pred = splitNN.forward(x)
  
    #3) Figure out how much we missed by
    criterion = nn.NLLLoss()
    loss = criterion(pred, target)
    
    #4) Backprop the loss on the end layer
    loss.backward()
    
    #5) Feed Gradients backward through the nework
    splitNN.backward()
    
    #6) Change the weights
    splitNN.step()
    
    return loss

In [ ]:
for i in range(epochs):
    running_loss = 0
    splitNN.train()
    for images, labels in trainloader:
        images = images.send(models[0].location)
        images = images.view(images.shape[0], -1)
        labels = labels.send(models[-1].location)
        loss = train(images, labels, splitNN)
        running_loss += loss.get()

    else:
        print("Epoch {} - Training loss: {}".format(i, running_loss/len(trainloader)))

In [ ]:
def test(model, dataloader, dataset_name):
    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in dataloader:
            data = data.view(data.shape[0], -1).send(model.location)
            output = model(data).get()
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).sum()

    print("{}: Accuracy {}/{} ({:.0f}%)".format(dataset_name, 
                                                correct,
                                                len(dataloader.dataset), 
                                                100. * correct / len(dataloader.dataset)))

In [ ]:
testset = datasets.MNIST('mnist', download=True, train=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)
test(splitNN, testloader, "Test set")
test(splitNN, trainloader, "Train set")