epochs = 50

Part 7 - Federated Learning with FederatedDataset

Here we introduce a new tool for using federated datasets. We have created a FederatedDataset class which is intended to be used like the PyTorch Dataset class, and is given to a federated data loader FederatedDataLoader which will iterate on it in a federated fashion.


We use the sandbox that we discovered last lesson

import torch as th
import syft as sy
sy.create_sandbox(globals(), verbose=False)

Then search for a dataset

boston_data ="#boston", "#data")
boston_target ="#boston", "#target")

We load a model and an optimizer

n_features = boston_data['alice'][0].shape[1]
n_targets = 1

model = th.nn.Linear(n_features, n_targets)

Here we cast the data fetched in a FederatedDataset. See the workers which hold part of the data.

# Cast the result in BaseDatasets
datasets = []
for worker in boston_data.keys():
    dataset = sy.BaseDataset(boston_data[worker][0], boston_target[worker][0])

# Build the FederatedDataset object
dataset = sy.FederatedDataset(datasets)
optimizers = {}
for worker in dataset.workers:
    optimizers[worker] = th.optim.Adam(params=model.parameters(),lr=1e-2)

We put it in a FederatedDataLoader and specify options

train_loader = sy.FederatedDataLoader(dataset, batch_size=32, shuffle=False, drop_last=False)

And finally we iterate over epochs. You can see how similar this is compared to pure and local PyTorch training!

for epoch in range(1, epochs + 1):
    loss_accum = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer = optimizers[]
        pred = model(data)
        loss = ((pred.view(-1) - target)**2).mean()
        loss = loss.get()
        loss_accum += float(loss)
        if batch_idx % 8 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tBatch loss: {:.6f}'.format(
                epoch, batch_idx, len(train_loader),
                       100. * batch_idx / len(train_loader), loss.item()))            
    print('Total loss', loss_accum)

