In [1]:
epochs = 15

Introduction to Split Neural Network (SplitNN)

Traditionally, PySyft has been used to facilitate federated learning. However, we can also leverage the tools included in this framework to implement distributed neural networks.

What is a SplitNN?

The training of a neural network (NN) is 'split' accross one or more hosts. Each model segment is a self contained NN that feeds into the segment in front. In this example Alice has unlabeled training data and the bottom of the network whereas Bob has the corresponding labels and the top of the network. The image below shows this training process where Bob has all the labels and there are multiple alices with X data [1]. Once $Alice_1$ has trained she sends a copy of her trained bottom model to the next Alice. This continues until $Alice_n$ has trained.

In this case, both parties can train the model without knowing each others data or full details of the model. When Alice is finished training, she passes it to the next person with data.

Why use a SplitNN?

The SplitNN has been shown to provide a dramatic reduction to the computational burden of training while maintaining higher accuracies when training over large number of clients [2]. In the figure below, the Blue line denotes distributed deep learning using SplitNN, red line indicate federated learning (FL) and green line indicates Large Batch Stochastic Gradient Descent (LBSGD).

Table 1 shows computational resources consumed when training CIFAR 10 over VGG. Theses are a fraction of the resources of FL and LBSGD. Table 2 shows the bandwith usage when training CIFAR 100 over ResNet. Federated learning is less bandwidth intensive with fewer than 100 clients. However, the SplitNN outperforms other approaches as the number of clients grow[ 2].

Advantages

  • The accuracy should be identical to a non-split version of the same model, trained locally.
  • the model is distributed, meaning all segment holders must consent in order to aggregate the model at the end of training.
  • The scalability of this approach, in terms of both network and computational resources, could make this an a valid alternative to FL and LBSGD, particularly on low power devices.
  • This could be an effective mechanism for both horizontal and vertical data distributions.
  • As computational cost is already quite low, the cost of applying homomorphic encryption is also minimised.
  • Only activation signal gradients are sent/ recieved, meaning that malicious actors cannot use gradients of model parameters to reverse engineer the original values.

Constraints

  • A new technique with little surroundung literature, a large amount of comparison and evaluation is still to be performed.
  • This approach requires all hosts to remain online during the entire learning process (less fesible for hand-held devices).
  • Not as established in privacy-preserving toolkits as FL and LBSGD.
  • Activation signals and their corresponding gradients still have the capacity to leak information, however this is yet to be fully addressed in the literature.

Tutorial

This tutorial demonstrates a basic example of SplitNN which

  • Has two paticipants: Alice and Bob.
    • Bob has labels
    • Alice has X values
  • Has two model segments.
    • Alice has the bottom half
    • Bob has the top half
  • Trains on the MNIST dataset.

Authors:


In [2]:
import torch
from torchvision import datasets, transforms
from torch import nn, optim
import syft as sy
hook = sy.TorchHook(torch)

In [3]:
# Data preprocessing
transform = transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize((0.5,), (0.5,)),
                              ])
trainset = datasets.MNIST('mnist', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

In [4]:
torch.manual_seed(0)

# Define our model segments

input_size = 784
hidden_sizes = [128, 640]
output_size = 10

models = [
    nn.Sequential(
                nn.Linear(input_size, hidden_sizes[0]),
                nn.ReLU(),
                nn.Linear(hidden_sizes[0], hidden_sizes[1]),
                nn.ReLU(),
    ),
    nn.Sequential(
                nn.Linear(hidden_sizes[1], output_size),
                nn.LogSoftmax(dim=1)
    )
]

# Create optimisers for each segment and link to their segment
optimizers = [
    optim.SGD(model.parameters(), lr=0.03,)
    for model in models
]

# create some workers
alice = sy.VirtualWorker(hook, id="alice")
bob = sy.VirtualWorker(hook, id="bob")
workers = alice, bob

# Send Model Segments to starting locations
model_locations = [alice, bob]

for model, location in zip(models, model_locations):
    model.send(location)

In [5]:
def train(x, target, models, optimizers):
    # Training Logic

    #1) erase previous gradients (if they exist)
    for opt in optimizers:
        opt.zero_grad()

    #2) make a prediction
    a = models[0](x)

    #3) break the computation graph link, and send the activation signal to the next model
    remote_a = a.move(models[1].location, requires_grad=True)

    #4) make prediction on next model using recieved signal
    pred = models[1](remote_a)

    #5) calculate how much we missed
    criterion = nn.NLLLoss()
    loss = criterion(pred, target)

    #6) figure out which weights caused us to miss
    loss.backward()

    # 7) send gradient of the recieved activation signal to the model behind
    # grad_a = remote_a.grad.copy().move(models[0].location)

    # 8) backpropagate on bottom model given this gradient
    # a.backward(grad_a)

    #9) change the weights
    for opt in optimizers:
        opt.step()

    #10) print our progress
    return loss.detach().get()

In [6]:
for i in range(epochs):
    running_loss = 0
    for images, labels in trainloader:
        images = images.send(alice)
        images = images.view(images.shape[0], -1)
        labels = labels.send(bob)
        
        loss = train(images, labels, models, optimizers)
        running_loss += loss

    else:
        print("Epoch {} - Training loss: {}".format(i, running_loss/len(trainloader)))


Epoch 0 - Training loss: 0.5366032123565674
Epoch 1 - Training loss: 0.25951144099235535
Epoch 2 - Training loss: 0.19616961479187012
Epoch 3 - Training loss: 0.1603524535894394
Epoch 4 - Training loss: 0.13461314141750336
Epoch 5 - Training loss: 0.11615928262472153
Epoch 6 - Training loss: 0.10251112282276154
Epoch 7 - Training loss: 0.0917714536190033
Epoch 8 - Training loss: 0.081630177795887
Epoch 9 - Training loss: 0.07489361613988876
Epoch 10 - Training loss: 0.06842804700136185
Epoch 11 - Training loss: 0.06335976719856262
Epoch 12 - Training loss: 0.05796955153346062
Epoch 13 - Training loss: 0.05360636115074158
Epoch 14 - Training loss: 0.04987649992108345

In [ ]: