# Epochs are set here so we can test these notebooks in our CI workflow.
epochs = 25
This example demonstrates how training a simple Image classifier written in PyTorch could be trained using federated learning with PySyft. We distribute the image data to two workers Bob and Alice to whom the model is sent and trained. Upon training the model the trained model is sent back to the owner of the model and used to make predictions.
Hrishikesh Kamath - GitHub: @kamathhrishi
#Import required libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import syft as sy # <-- NEW: import the Pysyft library
class Arguments():
def __init__(self):
self.batch_size = 64
self.test_batch_size = 1000
self.epochs = epochs = 0.01
self.momentum = 0.5
self.no_cuda = True
self.seed = 1
self.log_interval = 200
self.save_model = False
args = Arguments()
use_cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
hook = sy.TorchHook(torch) # <-- NEW: hook PyTorch ie add extra functionalities to support Federated Learning
bob = sy.VirtualWorker(hook, id="bob") # <-- NEW: define remote worker bob
alice = sy.VirtualWorker(hook, id="alice") # <-- NEW: and alice
def load_data():
'''<--Load CIFAR dataset from torch vision module distribute to workers using PySyft's Federated Data loader'''
federated_train_loader = sy.FederatedDataLoader( # <-- this is now a FederatedDataLoader
datasets.CIFAR10('../data', train=True, download=True,
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
.federate((bob, alice)), # <-- NEW: we distribute the dataset across all the workers, it's now a FederatedDataset
batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader =
datasets.CIFAR10('../data', train=False, transform=transforms.Compose([
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
batch_size=args.test_batch_size, shuffle=True, **kwargs)
return federated_train_loader,test_loader
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return F.log_softmax(x, dim=1)
def train(args, model, device, train_loader, optimizer, epoch):
for batch_idx, (data, target) in enumerate(federated_train_loader): # <-- now it is a distributed dataset
model.send(data.location) # <-- NEW: send the model to the right location
data, target =,
output = model(data)
loss = F.nll_loss(output, target)
model.get() # <-- NEW: get the model back
if batch_idx % args.log_interval == 0:
loss = loss.get() # <-- NEW: get the loss back
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * args.batch_size, len(train_loader) * args.batch_size, #batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
def test(args, model, device, test_loader):
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target =,
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
pred = output.argmax(1, keepdim=True) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
#<--Load federated training data and test data
#<--Create Neural Network model instance
model = Net().to(device)
optimizer = optim.SGD(model.parameters(), #<--TODO momentum is not supported at the moment
#<--Train Neural network and validate with test set after completion of training every epoch
for epoch in range(1, args.epochs + 1):
train(args, model, device, federated_train_loader, optimizer, epoch)
test(args, model, device, test_loader)
if (args.save_model):, "")
And voilà! We now are training a real world Learning model using Federated Learning! As you observed, we modified 10 lines of code to upgrade the official Pytorch example on CIFAR10 to a real Federated Learning setting!
Currently we do not support momentum argument in the optimizer due to which it took more number of epochs than required to train a CNN on CIFAR10 dataset.
Of course, there are dozen of improvements we could think of. We would like the computation to operate in parallel on the workers, to update the central model every n
batches only, to reduce the number of messages we use to communicate between workers, etc.
On the security side it still has some major shortcomings. Most notably, when we call model.get()
and receive the updated model from Bob or Alice, we can actually learn a lot about Bob and Alice's training data by looking at their gradients. We could average the gradient across multiple individuals before uploading it to the central server, like we did in Part 4 of tutorials section.
