Federated Learning Training Plan: Create Plan

Let's try to make protobuf-serializable Training Plan and Model that work after deserializing :)

Current list of problems:

  • tensor.shape is not traceable inside the Plan (issue #3554).
  • Autograd/Plan tracing doesn't work with native torch's loss functions and optimizers.
  • others?

In [1]:
%load_ext autoreload
%autoreload 2

import syft as sy
from syft.serde import protobuf
from syft_proto.execution.v1.plan_pb2 import Plan as PlanPB
from syft_proto.execution.v1.state_pb2 import State as StatePB
from syft.grid.clients.static_fl_client import StaticFLClient
from syft.execution.state import State
from syft.execution.placeholder import PlaceHolder
from syft.execution.translation import TranslationTarget

import torch as th
from torch import nn

import os
import websockets
import json
import requests

sy.make_hook(globals())
# force protobuf serialization for tensors
hook.local_worker.framework = None
th.random.manual_seed(1)


Setting up Sandbox...
Done!
Out[1]:
<torch._C.Generator at 0x27f0c28cbb0>

This utility function will set tensors as model parameters.


In [2]:
def set_model_params(module, params_list, start_param_idx=0):
    """ Set params list into model recursively
    """
    param_idx = start_param_idx

    for name, param in module._parameters.items():
        module._parameters[name] = params_list[param_idx]
        param_idx += 1

    for name, child in module._modules.items():
        if child is not None:
            param_idx = set_model_params(child, params_list, param_idx)

    return param_idx

Step 1: Define the model

This model will train on MNIST data, it's very simple yet can demonstrate learning process. There're 2 linear layers:

  • Linear 784x392
  • ReLU
  • Linear 392x10

In [3]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(784, 392)
        self.fc2 = nn.Linear(392, 10)

    def forward(self, x):
        x = self.fc1(x)
        x = nn.functional.relu(x)
        x = self.fc2(x)
        return x

model = Net()

Step 2: Define Training Plan

Loss function

Batch size needs to be passed because otherwise target.shape[0] is not traced inside Plan yet (Issue #3554).


In [4]:
def softmax_cross_entropy_with_logits(logits, targets, batch_size):
    """ Calculates softmax entropy
        Args:
            * logits: (NxC) outputs of dense layer
            * targets: (NxC) one-hot encoded labels
            * batch_size: value of N, temporarily required because Plan cannot trace .shape
    """
    # numstable logsoftmax
    norm_logits = logits - logits.max()
    log_probs = norm_logits - norm_logits.exp().sum(dim=1, keepdim=True).log()
    # NLL, reduction = mean
    return -(targets * log_probs).sum() / batch_size

Optimization function

Just updates weights with grad*lr.

Note: can't do inplace update because of Autograd/Plan tracing specifics.


In [5]:
def naive_sgd(param, **kwargs):
    return param - kwargs['lr'] * param.grad

Training Plan procedure

We define a routine that will take one batch of training data, and model parameters, and will update model parameters to optimize them for given loss function using SGD.


In [6]:
@sy.func2plan()
def training_plan(X, y, batch_size, lr, model_params):
    # inject params into model
    set_model_params(model, model_params)

    # forward pass
    logits = model.forward(X)
    
    # loss
    loss = softmax_cross_entropy_with_logits(logits, y, batch_size)

    # backprop
    loss.backward()

    # step
    updated_params = [
        naive_sgd(param, lr=lr)
        for param in model_params
    ]
    
    # accuracy
    pred = th.argmax(logits, dim=1)
    target = th.argmax(y, dim=1)
    acc = pred.eq(target).sum().float() / batch_size

    return (
        loss,
        acc,
        *updated_params
    )

Let's build this procedure into the Plan that we can serialize.


In [7]:
# Dummy input parameters to make the trace
model_params = [param.data for param in model.parameters()]  # raw tensors instead of nn.Parameter
X = th.randn(3, 28 * 28)
y = nn.functional.one_hot(th.tensor([1, 2, 3]), 10)
lr = th.tensor([0.01])
batch_size = th.tensor([3.0])

_ = training_plan.build(X, y, batch_size, lr, model_params, trace_autograd=True)

Let's look inside the Syft Plan and print out the list of operations recorded.


In [8]:
print(training_plan.code)


def training_plan(arg_1, arg_2, arg_3, arg_4, arg_5, arg_6, arg_7, arg_8):
    var_0 = arg_5.t()
    var_1 = arg_1.matmul(var_0)
    var_2 = arg_6.add(var_1)
    var_3 = var_2.relu()
    var_4 = arg_7.t()
    var_5 = var_3.matmul(var_4)
    var_6 = arg_8.add(var_5)
    var_7 = var_6.max()
    var_8 = var_6.sub(var_7)
    var_9 = var_8.exp()
    var_10 = var_9.sum(dim=1, keepdim=True)
    var_11 = var_10.log()
    var_12 = var_8.sub(var_11)
    var_13 = arg_2.mul(var_12)
    var_14 = var_13.sum()
    var_15 = var_14.neg()
    out_1 = var_15.div(arg_3)
    var_16 = out_1.mul(0)
    var_17 = var_16.add(1)
    var_18 = var_17.div(arg_3)
    var_19 = var_18.mul(-1)
    var_20 = var_19.reshape([-1, 1])
    var_21 = var_13.mul(0)
    var_22 = var_21.add(1)
    var_23 = var_22.mul(var_20)
    var_24 = var_23.mul(arg_2)
    var_25 = var_24.add(0)
    var_26 = var_24.mul(-1)
    var_27 = var_26.sum(dim=[1], keepdim=True)
    var_28 = var_25.add(0)
    var_29 = var_28.add(0)
    var_30 = var_28.add(0)
    var_31 = var_29.sum(dim=[0])
    var_32 = var_31.copy()
    var_33 = var_4.t()
    var_34 = var_30.matmul(var_33)
    var_35 = var_3.t()
    var_36 = var_35.matmul(var_30)
    var_37 = var_2.mul(0)
    var_38 = var_2.__gt__(var_37)
    var_39 = var_38.mul(var_34)
    var_40 = var_39.add(0)
    var_41 = var_39.add(0)
    var_42 = var_40.sum(dim=[0])
    var_43 = var_42.copy()
    var_44 = arg_1.t()
    var_45 = var_44.matmul(var_41)
    var_46 = var_45.t()
    var_47 = var_46.copy()
    var_48 = var_36.t()
    var_49 = var_48.copy()
    var_50 = var_10.__rtruediv__(1)
    var_51 = var_27.mul(var_50)
    var_52 = var_51.reshape([-1, 1])
    var_53 = var_9.mul(0)
    var_54 = var_53.add(1)
    var_55 = var_54.mul(var_52)
    var_56 = var_8.exp()
    var_57 = var_55.mul(var_56)
    var_58 = var_57.add(0)
    var_59 = var_58.add(0)
    var_60 = var_58.add(0)
    var_61 = var_59.sum(dim=[0])
    var_32 = var_32.add_(var_61)
    var_62 = var_4.t()
    var_63 = var_60.matmul(var_62)
    var_64 = var_3.t()
    var_65 = var_64.matmul(var_60)
    var_66 = var_2.mul(0)
    var_67 = var_2.__gt__(var_66)
    var_68 = var_67.mul(var_63)
    var_69 = var_68.add(0)
    var_70 = var_68.add(0)
    var_71 = var_69.sum(dim=[0])
    var_43 = var_43.add_(var_71)
    var_72 = arg_1.t()
    var_73 = var_72.matmul(var_70)
    var_74 = var_73.t()
    var_47 = var_47.add_(var_74)
    var_75 = var_65.t()
    var_49 = var_49.add_(var_75)
    var_76 = arg_4.mul(var_47)
    out_3 = arg_5.sub(var_76)
    var_77 = arg_4.mul(var_43)
    out_4 = arg_6.sub(var_77)
    var_78 = arg_4.mul(var_49)
    out_5 = arg_7.sub(var_78)
    var_79 = arg_4.mul(var_32)
    out_6 = arg_8.sub(var_79)
    var_80 = torch.argmax(var_6, dim=1)
    var_81 = torch.argmax(arg_2, dim=1)
    var_82 = var_80.eq(var_81)
    var_83 = var_82.sum()
    var_84 = var_83.float()
    out_2 = var_84.div(arg_3)
    return out_1, out_2, out_3, out_4, out_5, out_6

Plan should be automatically translated to torchscript and tensorflow.js, too. Let's examine torchscript code:


In [9]:
print(training_plan.torchscript.code)


def <Plan training_plan id:49443480607 owner:me built>
(argument_0: Tensor,
    argument_1: Tensor,
    argument_2: Tensor,
    argument_3: Tensor,
    argument_4: List[Tensor]) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
  _0, _1, _2, _3, = argument_4
  _4 = torch.add(_1, torch.matmul(argument_0, torch.t(_0)), alpha=1)
  _5 = torch.relu(_4)
  _6 = torch.t(_2)
  _7 = torch.add(_3, torch.matmul(_5, _6), alpha=1)
  _8 = torch.sub(_7, torch.max(_7), alpha=1)
  _9 = torch.exp(_8)
  _10 = torch.sum(_9, [1], True, dtype=None)
  _11 = torch.sub(_8, torch.log(_10), alpha=1)
  _12 = torch.mul(argument_1, _11)
  _13 = torch.div(torch.neg(torch.sum(_12, dtype=None)), argument_2)
  _14 = torch.add(torch.mul(_13, CONSTANTS.c0), CONSTANTS.c1, alpha=1)
  _15 = torch.mul(torch.div(_14, argument_2), CONSTANTS.c2)
  _16 = torch.reshape(_15, [-1, 1])
  _17 = torch.add(torch.mul(_12, CONSTANTS.c0), CONSTANTS.c1, alpha=1)
  _18 = torch.mul(torch.mul(_17, _16), argument_1)
  _19 = torch.add(_18, CONSTANTS.c0, alpha=1)
  _20 = torch.sum(torch.mul(_18, CONSTANTS.c2), [1], True, dtype=None)
  _21 = torch.add(_19, CONSTANTS.c0, alpha=1)
  _22 = torch.add(_21, CONSTANTS.c0, alpha=1)
  _23 = torch.add(_21, CONSTANTS.c0, alpha=1)
  _24 = torch.sum(_22, [0], False, dtype=None)
  _25 = torch.matmul(_23, torch.t(_6))
  _26 = torch.matmul(torch.t(_5), _23)
  _27 = torch.gt(_4, torch.mul(_4, CONSTANTS.c0))
  _28 = torch.mul(_27, _25)
  _29 = torch.add(_28, CONSTANTS.c0, alpha=1)
  _30 = torch.add(_28, CONSTANTS.c0, alpha=1)
  _31 = torch.sum(_29, [0], False, dtype=None)
  _32 = torch.matmul(torch.t(argument_0), _30)
  _33 = torch.t(_32)
  _34 = torch.t(_26)
  _35 = torch.mul(torch.reciprocal(_10), CONSTANTS.c1)
  _36 = torch.reshape(torch.mul(_20, _35), [-1, 1])
  _37 = torch.add(torch.mul(_9, CONSTANTS.c0), CONSTANTS.c1, alpha=1)
  _38 = torch.mul(torch.mul(_37, _36), torch.exp(_8))
  _39 = torch.add(_38, CONSTANTS.c0, alpha=1)
  _40 = torch.add(_39, CONSTANTS.c0, alpha=1)
  _41 = torch.add(_39, CONSTANTS.c0, alpha=1)
  _42 = torch.sum(_40, [0], False, dtype=None)
  _43 = torch.add_(_24, _42, alpha=1)
  _44 = torch.matmul(_41, torch.t(_6))
  _45 = torch.matmul(torch.t(_5), _41)
  _46 = torch.gt(_4, torch.mul(_4, CONSTANTS.c0))
  _47 = torch.mul(_46, _44)
  _48 = torch.add(_47, CONSTANTS.c0, alpha=1)
  _49 = torch.add(_47, CONSTANTS.c0, alpha=1)
  _50 = torch.sum(_48, [0], False, dtype=None)
  _51 = torch.add_(_31, _50, alpha=1)
  _52 = torch.matmul(torch.t(argument_0), _49)
  _53 = torch.add_(_33, torch.t(_52), alpha=1)
  _54 = torch.add_(_34, torch.t(_45), alpha=1)
  _55 = torch.sub(_0, torch.mul(argument_3, _53), alpha=1)
  _56 = torch.sub(_1, torch.mul(argument_3, _51), alpha=1)
  _57 = torch.sub(_2, torch.mul(argument_3, _54), alpha=1)
  _58 = torch.sub(_3, torch.mul(argument_3, _43), alpha=1)
  _59 = torch.eq(torch.argmax(_7, 1, False), torch.argmax(argument_1, 1, False))
  _60 = torch.to(torch.sum(_59, dtype=None), 6, False, False, None)
  _61 = (_13, torch.div(_60, argument_2), _55, _56, _57, _58)
  return _61

Tensorflow.js code:


In [10]:
training_plan.base_framework = TranslationTarget.TENSORFLOW_JS.value
print(training_plan.code)
training_plan.base_framework = TranslationTarget.PYTORCH.value


def training_plan(arg_1, arg_2, arg_3, arg_4, arg_5, arg_6, arg_7, arg_8):
    var_0 = tf.transpose(arg_5)
    var_1 = tf.matMul(arg_1, var_0)
    var_2 = tf.add(arg_6, var_1)
    var_3 = tf.relu(var_2)
    var_4 = tf.transpose(arg_7)
    var_5 = tf.matMul(var_3, var_4)
    var_6 = tf.add(arg_8, var_5)
    var_7 = tf.max(var_6)
    var_8 = tf.sub(var_6, var_7)
    var_9 = tf.exp(var_8)
    var_10 = tf.sum(var_9, 1, keepdim=True)
    var_11 = tf.log(var_10)
    var_12 = tf.sub(var_8, var_11)
    var_13 = tf.mul(arg_2, var_12)
    var_14 = tf.sum(var_13)
    var_15 = tf.neg(var_14)
    out_1 = tf.div(var_15, arg_3)
    var_16 = tf.mul(out_1, 0)
    var_17 = tf.add(var_16, 1)
    var_18 = tf.div(var_17, arg_3)
    var_19 = tf.mul(var_18, -1)
    var_20 = tf.layers.reshape(var_19, [-1, 1])
    var_21 = tf.mul(var_13, 0)
    var_22 = tf.add(var_21, 1)
    var_23 = tf.mul(var_22, var_20)
    var_24 = tf.mul(var_23, arg_2)
    var_25 = tf.add(var_24, 0)
    var_26 = tf.mul(var_24, -1)
    var_27 = tf.sum(var_26, [1], keepdim=True)
    var_28 = tf.add(var_25, 0)
    var_29 = tf.add(var_28, 0)
    var_30 = tf.add(var_28, 0)
    var_31 = tf.sum(var_29, [0])
    var_32 = clone(var_31)
    var_33 = tf.transpose(var_4)
    var_34 = tf.matMul(var_30, var_33)
    var_35 = tf.transpose(var_3)
    var_36 = tf.matMul(var_35, var_30)
    var_37 = tf.mul(var_2, 0)
    var_38 = tf.greater(var_2, var_37)
    var_39 = tf.mul(var_38, var_34)
    var_40 = tf.add(var_39, 0)
    var_41 = tf.add(var_39, 0)
    var_42 = tf.sum(var_40, [0])
    var_43 = clone(var_42)
    var_44 = tf.transpose(arg_1)
    var_45 = tf.matMul(var_44, var_41)
    var_46 = tf.transpose(var_45)
    var_47 = clone(var_46)
    var_48 = tf.transpose(var_36)
    var_49 = clone(var_48)
    var_50 = tf.div(1, var_10)
    var_51 = tf.mul(var_27, var_50)
    var_52 = tf.layers.reshape(var_51, [-1, 1])
    var_53 = tf.mul(var_9, 0)
    var_54 = tf.add(var_53, 1)
    var_55 = tf.mul(var_54, var_52)
    var_56 = tf.exp(var_8)
    var_57 = tf.mul(var_55, var_56)
    var_58 = tf.add(var_57, 0)
    var_59 = tf.add(var_58, 0)
    var_60 = tf.add(var_58, 0)
    var_61 = tf.sum(var_59, [0])
    var_32 = tf.add(var_32, var_61)
    var_62 = tf.transpose(var_4)
    var_63 = tf.matMul(var_60, var_62)
    var_64 = tf.transpose(var_3)
    var_65 = tf.matMul(var_64, var_60)
    var_66 = tf.mul(var_2, 0)
    var_67 = tf.greater(var_2, var_66)
    var_68 = tf.mul(var_67, var_63)
    var_69 = tf.add(var_68, 0)
    var_70 = tf.add(var_68, 0)
    var_71 = tf.sum(var_69, [0])
    var_43 = tf.add(var_43, var_71)
    var_72 = tf.transpose(arg_1)
    var_73 = tf.matMul(var_72, var_70)
    var_74 = tf.transpose(var_73)
    var_47 = tf.add(var_47, var_74)
    var_75 = tf.transpose(var_65)
    var_49 = tf.add(var_49, var_75)
    var_76 = tf.mul(arg_4, var_47)
    out_3 = tf.sub(arg_5, var_76)
    var_77 = tf.mul(arg_4, var_43)
    out_4 = tf.sub(arg_6, var_77)
    var_78 = tf.mul(arg_4, var_49)
    out_5 = tf.sub(arg_7, var_78)
    var_79 = tf.mul(arg_4, var_32)
    out_6 = tf.sub(arg_8, var_79)
    var_80 = tf.argMax(var_6, 1)
    var_81 = tf.argMax(arg_2, 1)
    var_82 = tf.equal(var_80, var_81)
    var_83 = tf.sum(var_82)
    var_84 = tf.cast(var_83, float32)
    out_2 = tf.div(var_84, arg_3)
    return out_1, out_2, out_3, out_4, out_5, out_6

Step 3: Define Averaging Plan

Averaging Plan is executed by PyGrid at the end of the cycle, to average diffs submitted by workers and update the model and create new checkpoint for the next cycle.

Diff is the difference between client-trained model params and original model params, so it has same number of tensors and tensor's shapes as the model parameters.

We define Plan that processes one diff at a time. Such Plans require iterative_plan flag set to True in server_config when hosting FL model to PyGrid.

Plan below will calculate simple mean of each parameter.


In [11]:
@sy.func2plan()
def avg_plan(avg, item, num):
    new_avg = []
    for i, param in enumerate(avg):
        new_avg.append((avg[i] * num + item[i]) / (num + 1))
    return new_avg

# Build the Plan
_ = avg_plan.build(model_params, model_params, th.tensor([1.0]))

In [12]:
# Let's check Plan contents
print(avg_plan.code)


def avg_plan(arg_1, arg_2, arg_3, arg_4, arg_5, arg_6, arg_7, arg_8, arg_9):
    var_0 = arg_1.__mul__(arg_9)
    var_1 = var_0.__add__(arg_5)
    var_2 = arg_9.__add__(1)
    out_1 = var_1.__truediv__(var_2)
    var_3 = arg_2.__mul__(arg_9)
    var_4 = var_3.__add__(arg_6)
    var_5 = arg_9.__add__(1)
    out_2 = var_4.__truediv__(var_5)
    var_6 = arg_3.__mul__(arg_9)
    var_7 = var_6.__add__(arg_7)
    var_8 = arg_9.__add__(1)
    out_3 = var_7.__truediv__(var_8)
    var_9 = arg_4.__mul__(arg_9)
    var_10 = var_9.__add__(arg_8)
    var_11 = arg_9.__add__(1)
    out_4 = var_10.__truediv__(var_11)
    return out_1, out_2, out_3, out_4

In [13]:
# Test averaging plan
# Pretend there're diffs, all params of which are ones * dummy_coeffs
dummy_coeffs = [1, 5.5, 7, 55]
dummy_diffs = [[th.ones_like(param) * i for param in model_params] for i in dummy_coeffs]
mean_coeff = th.tensor(dummy_coeffs).mean().item()

# Remove original function to make sure we execute traced Plan
avg_plan.forward = None

# Calculate avg value using our plan
avg = dummy_diffs[0]
for i, diff in enumerate(dummy_diffs[1:]):
    avg = avg_plan(list(avg), diff, th.tensor([i + 1]))

# Avg should be ones*mean_coeff for each param
for i, param in enumerate(model_params):
    expected = th.ones_like(param) * mean_coeff
    assert avg[i].eq(expected).all(), f"param #{i}"

Step 4: Host in PyGrid

Let's now host everything in PyGrid so that it can be accessed by worker libraries (syft.js, KotlinSyft, SwiftSyft, or even PySyft itself).

First, we need a function to send websocket messages to PyGrid.


In [14]:
async def sendWsMessage(data):
    async with websockets.connect('ws://' + gatewayWsUrl) as websocket:
        await websocket.send(json.dumps(data))
        message = await websocket.recv()
        return json.loads(message)

Follow PyGrid README.md to build openmined/grid-gateway image from the latest dev branch and spin up PyGrid using docker-compose up --build.

Define name, version, configs.


In [15]:
# Default gateway address when running locally 
gatewayWsUrl = "127.0.0.1:5000"
grid = StaticFLClient(id="test", address=gatewayWsUrl, secure=False)
grid.connect()# These name/version you use in worker
name = "mnist"
version = "1.0.0"

client_config = {
    "name": name,
    "version": version,
    "batch_size": 64,
    "lr": 0.005,
    "max_updates": 100  # custom syft.js option that limits number of training loops per worker
}

server_config = {
    "min_workers": 5,
    "max_workers": 5,
    "pool_selection": "random",
    "do_not_reuse_workers_until_cycle": 6,
    "cycle_length": 28800,  # max cycle length in seconds
    "num_cycles": 5,  # max number of cycles
    "max_diffs": 1,  # number of diffs to collect before avg
    "minimum_upload_speed": 0,
    "minimum_download_speed": 0,
    "iterative_plan": True  # tells PyGrid that avg plan is executed per diff
}

Authentication (optional)

Let's additionally protect the model with simple authentication for workers.

PyGrid supports authentication via JWT token (HMAC, RSA) or opaque token via remote API.

We'll try JWT/RSA. Suppose we generate RSA keys:

openssl genrsa -out private.pem
openssl rsa -in private.pem -pubout -out public.pem

In [16]:
private_key = """
-----BEGIN RSA PRIVATE KEY-----
MIIEowIBAAKCAQEAzQMcI09qonB9OZT20X3Z/oigSmybR2xfBQ1YJ1oSjQ3YgV+G
FUuhEsGDgqt0rok9BreT4toHqniFixddncTHg7EJzU79KZelk2m9I2sEsKUqEsEF
lMpkk9qkPHhJB5AQoClOijee7UNOF4yu3HYvGFphwwh4TNJXxkCg69/RsvPBIPi2
9vXFQzFE7cbN6jSxiCtVrpt/w06jJUsEYgNVQhUFABDyWN4h/67M1eArGA540vyd
kYdSIEQdknKHjPW62n4dvqDWxtnK0HyChsB+LzmjEnjTJqUzr7kM9Rzq3BY01DNi
TVcB2G8t/jICL+TegMGU08ANMKiDfSMGtpz3ZQIDAQABAoIBAD+xbKeHv+BxxGYE
Yt5ZFEYhGnOk5GU/RRIjwDSRplvOZmpjTBwHoCZcmsgZDqo/FwekNzzuch1DTnIV
M0+V2EqQ0TPJC5xFcfqnikybrhxXZAfpkhtU+gR5lDb5Q+8mkhPAYZdNioG6PGPS
oGz8BsuxINhgJEfxvbVpVNWTdun6hLOAMZaH3DHgi0uyTBg8ofARoZP5RIbHwW+D
p+5vd9x/x7tByu76nd2UbMp3yqomlB5jQktqyilexCIknEnfb3i/9jqFv8qVE5P6
e3jdYoJY+FoomWhqEvtfPpmUFTY5lx4EERCb1qhWG3a7sVBqTwO6jJJBsxy3RLIS
Ic0qZcECgYEA6GsBP11a2T4InZ7cixd5qwSeznOFCzfDVvVNI8KUw+n4DOPndpao
TUskWOpoV8MyiEGdQHgmTOgGaCXN7bC0ERembK0J64FI3TdKKg0v5nKa7xHb7Qcv
t9ccrDZVn4y/Yk5PCqjNWTR3/wDR88XouzIGaWkGlili5IJqdLEvPvUCgYEA4dA+
5MNEQmNFezyWs//FS6G3lTRWgjlWg2E6BXXvkEag6G5SBD31v3q9JIjs+sYdOmwj
kfkQrxEtbs173xgYWzcDG1FI796LTlJ/YzuoKZml8vEF3T8C4Bkbl6qj9DZljb2j
ehjTv5jA256sSUEqOa/mtNFUbFlBjgOZh3TCsLECgYAc701tdRLdXuK1tNRiIJ8O
Enou26Thm6SfC9T5sbzRkyxFdo4XbnQvgz5YL36kBnIhEoIgR5UFGBHMH4C+qbQR
OK+IchZ9ElBe8gYyrAedmgD96GxH2xAuxAIW0oDgZyZgd71RZ2iBRY322kRJJAdw
Xq77qo6eXTKpni7grjpijQKBgDHWRAs5DVeZkTwhoyEW0fRfPKUxZ+ZVwUI9sxCB
dt3guKKTtoY5JoOcEyJ9FdBC6TB7rV4KGiSJJf3OXAhgyP9YpNbimbZW52fhzTuZ
bwO/ZWC40RKDVZ8f63cNsiGz37XopKvNzu36SJYv7tY8C5WvvLsrd/ZxvIYbRUcf
/dgBAoGBAMdR5DXBcOWk3+KyEHXw2qwWcGXyzxtca5SRNLPR2uXvrBYXbhFB/PVj
h3rGBsiZbnIvSnSIE+8fFe6MshTl2Qxzw+F2WV3OhhZLLtBnN5qqeSe9PdHLHm49
XDce6NV2D1mQLBe8648OI5CScQENuRGxF2/h9igeR4oRRsM1gzJN
-----END RSA PRIVATE KEY-----
""".strip()

public_key = """
-----BEGIN PUBLIC KEY-----
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAzQMcI09qonB9OZT20X3Z
/oigSmybR2xfBQ1YJ1oSjQ3YgV+GFUuhEsGDgqt0rok9BreT4toHqniFixddncTH
g7EJzU79KZelk2m9I2sEsKUqEsEFlMpkk9qkPHhJB5AQoClOijee7UNOF4yu3HYv
GFphwwh4TNJXxkCg69/RsvPBIPi29vXFQzFE7cbN6jSxiCtVrpt/w06jJUsEYgNV
QhUFABDyWN4h/67M1eArGA540vydkYdSIEQdknKHjPW62n4dvqDWxtnK0HyChsB+
LzmjEnjTJqUzr7kM9Rzq3BY01DNiTVcB2G8t/jICL+TegMGU08ANMKiDfSMGtpz3
ZQIDAQAB
-----END PUBLIC KEY-----
""".strip()

If we set public key into model authentication config, then PyGrid will validate that submitted JWT auth token is signed with private key.


In [17]:
server_config["authentication"] = {
    "type": "jwt",
    "pub_key": public_key,
}

Now we're ready to host our federated training plan!


In [18]:
model_params_state = State(
    state_placeholders=[
        PlaceHolder().instantiate(param)
        for param in model_params
    ]
)

response = grid.host_federated_training(
    model=model_params_state,
    client_plans={'training_plan': training_plan},
    client_protocols={},
    server_averaging_plan=avg_plan,
    client_config=client_config,
    server_config=server_config
)

print("Host response:", response)


Host response: {'type': 'model_centric/host-training', 'data': {'status': 'success'}}

Let's double-check that data is loaded by requesting a cycle.

First, create authentication token.


In [19]:
!pip install pyjwt
import jwt
auth_token = jwt.encode({}, private_key, algorithm='RS256').decode('ascii')

print(auth_token)


Requirement already satisfied: pyjwt in d:\anaconda3\envs\syft\lib\site-packages (1.7.1)
eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.e30.Cn_0cSjCw1QKtcYDx_mYN_q9jO2KkpcUoiVbILmKVB4LUCQvZ7YeuyQ51r9h3562KQoSas_ehbjpz2dw1Dk24hQEoN6ObGxfJDOlemF5flvLO_sqAHJDGGE24JRE4lIAXRK6aGyy4f4kmlICL6wG8sGSpSrkZlrFLOVRJckTptgaiOTIm5Udfmi45NljPBQKVpqXFSmmb3dRy_e8g3l5eBVFLgrBhKPQ1VbNfRK712KlQWs7jJ31fGpW2NxMloO1qcd6rux48quivzQBCvyK8PV5Sqrfw_OMOoNLcSvzePDcZXa2nPHSu3qQIikUdZIeCnkJX-w0t8uEFG3DfH1fVA

Make authentication request:


In [20]:
auth_request = {
    "type": "model_centric/authenticate",
    "data": {
        "model_name": name,
        "model_version": version,
        "auth_token": auth_token,
    }
}
auth_response = await sendWsMessage(auth_request)
print('Auth response: ', json.dumps(auth_response, indent=2))


Auth response:  {
  "type": "model_centric/authenticate",
  "data": {
    "status": "success",
    "worker_id": "ab9639f1-6746-4be8-947b-c5bef9d16cb0"
  }
}

Make the cycle request:


In [21]:
cycle_request = {
    "type": "model_centric/cycle-request",
    "data": {
        "worker_id": auth_response['data']['worker_id'],
        "model": name,
        "version": version,
        "ping": 1,
        "download": 10000,
        "upload": 10000,
    }
}
cycle_response = await sendWsMessage(cycle_request)
print('Cycle response:', json.dumps(cycle_response, indent=2))

worker_id = auth_response['data']['worker_id']
request_key = cycle_response['data']['request_key']
model_id = cycle_response['data']['model_id'] 
training_plan_id = cycle_response['data']['plans']['training_plan']


Cycle response: {
  "type": "model_centric/cycle-request",
  "data": {
    "status": "accepted",
    "request_key": "756f3ea81b7d68d94c8a989ed2d9c7b727b1c85fb5f4c4599ef8c2b1b70e196d",
    "version": "1.0.0",
    "model": "mnist",
    "plans": {
      "training_plan": 2
    },
    "protocols": {},
    "client_config": {
      "name": "mnist",
      "version": "1.0.0",
      "batch_size": 64,
      "lr": 0.005,
      "max_updates": 100
    },
    "model_id": 1
  }
}

Let's download model and plan (both versions) and check they are actually workable.


In [22]:
# Model
req = requests.get(f"http://{gatewayWsUrl}/model_centric/get-model?worker_id={worker_id}&request_key={request_key}&model_id={model_id}")
model_data = req.content
pb = StatePB()
pb.ParseFromString(req.content)
model_params_downloaded = protobuf.serde._unbufferize(hook.local_worker, pb)
print("Params shapes:", [p.shape for p in model_params_downloaded.tensors()])


Params shapes: [torch.Size([392, 784]), torch.Size([392]), torch.Size([10, 392]), torch.Size([10])]

In [23]:
# Plan "list of ops"
req = requests.get(f"http://{gatewayWsUrl}/model_centric/get-plan?worker_id={worker_id}&request_key={request_key}&plan_id={training_plan_id}&receive_operations_as=list")
pb = PlanPB()
pb.ParseFromString(req.content)
plan_ops = protobuf.serde._unbufferize(hook.local_worker, pb)
print(plan_ops.code)
print(plan_ops.torchscript)


def training_plan(arg_1, arg_2, arg_3, arg_4, arg_5, arg_6, arg_7, arg_8):
    var_0 = arg_5.t()
    var_1 = arg_1.matmul(var_0)
    var_2 = arg_6.add(var_1)
    var_3 = var_2.relu()
    var_4 = arg_7.t()
    var_5 = var_3.matmul(var_4)
    var_6 = arg_8.add(var_5)
    var_7 = var_6.max()
    var_8 = var_6.sub(var_7)
    var_9 = var_8.exp()
    var_10 = var_9.sum(dim=1, keepdim=True)
    var_11 = var_10.log()
    var_12 = var_8.sub(var_11)
    var_13 = arg_2.mul(var_12)
    var_14 = var_13.sum()
    var_15 = var_14.neg()
    out_1 = var_15.div(arg_3)
    var_16 = out_1.mul(0)
    var_17 = var_16.add(1)
    var_18 = var_17.div(arg_3)
    var_19 = var_18.mul(-1)
    var_20 = var_19.reshape([-1, 1])
    var_21 = var_13.mul(0)
    var_22 = var_21.add(1)
    var_23 = var_22.mul(var_20)
    var_24 = var_23.mul(arg_2)
    var_25 = var_24.add(0)
    var_26 = var_24.mul(-1)
    var_27 = var_26.sum(dim=[1], keepdim=True)
    var_28 = var_25.add(0)
    var_29 = var_28.add(0)
    var_30 = var_28.add(0)
    var_31 = var_29.sum(dim=[0])
    var_32 = var_31.copy()
    var_33 = var_4.t()
    var_34 = var_30.matmul(var_33)
    var_35 = var_3.t()
    var_36 = var_35.matmul(var_30)
    var_37 = var_2.mul(0)
    var_38 = var_2.__gt__(var_37)
    var_39 = var_38.mul(var_34)
    var_40 = var_39.add(0)
    var_41 = var_39.add(0)
    var_42 = var_40.sum(dim=[0])
    var_43 = var_42.copy()
    var_44 = arg_1.t()
    var_45 = var_44.matmul(var_41)
    var_46 = var_45.t()
    var_47 = var_46.copy()
    var_48 = var_36.t()
    var_49 = var_48.copy()
    var_50 = var_10.__rtruediv__(1)
    var_51 = var_27.mul(var_50)
    var_52 = var_51.reshape([-1, 1])
    var_53 = var_9.mul(0)
    var_54 = var_53.add(1)
    var_55 = var_54.mul(var_52)
    var_56 = var_8.exp()
    var_57 = var_55.mul(var_56)
    var_58 = var_57.add(0)
    var_59 = var_58.add(0)
    var_60 = var_58.add(0)
    var_61 = var_59.sum(dim=[0])
    var_32 = var_32.add_(var_61)
    var_62 = var_4.t()
    var_63 = var_60.matmul(var_62)
    var_64 = var_3.t()
    var_65 = var_64.matmul(var_60)
    var_66 = var_2.mul(0)
    var_67 = var_2.__gt__(var_66)
    var_68 = var_67.mul(var_63)
    var_69 = var_68.add(0)
    var_70 = var_68.add(0)
    var_71 = var_69.sum(dim=[0])
    var_43 = var_43.add_(var_71)
    var_72 = arg_1.t()
    var_73 = var_72.matmul(var_70)
    var_74 = var_73.t()
    var_47 = var_47.add_(var_74)
    var_75 = var_65.t()
    var_49 = var_49.add_(var_75)
    var_76 = arg_4.mul(var_47)
    out_3 = arg_5.sub(var_76)
    var_77 = arg_4.mul(var_43)
    out_4 = arg_6.sub(var_77)
    var_78 = arg_4.mul(var_49)
    out_5 = arg_7.sub(var_78)
    var_79 = arg_4.mul(var_32)
    out_6 = arg_8.sub(var_79)
    var_80 = torch.argmax(var_6, dim=1)
    var_81 = torch.argmax(arg_2, dim=1)
    var_82 = var_80.eq(var_81)
    var_83 = var_82.sum()
    var_84 = var_83.float()
    out_2 = var_84.div(arg_3)
    return out_1, out_2, out_3, out_4, out_5, out_6
None

In [24]:
# Plan "torchscript"
req = requests.get(f"http://{gatewayWsUrl}/model_centric/get-plan?worker_id={worker_id}&request_key={request_key}&plan_id={training_plan_id}&receive_operations_as=torchscript")
pb = PlanPB()
pb.ParseFromString(req.content)
plan_ts = protobuf.serde._unbufferize(hook.local_worker, pb)
print(plan_ts.code)
print(plan_ts.torchscript.code)


def training_plan(arg_1, arg_2, arg_3, arg_4, arg_5, arg_6, arg_7, arg_8):
    return out_1, out_2, out_3, out_4, out_5, out_6
def forward(self,
    argument_1: Tensor,
    argument_2: Tensor,
    argument_3: Tensor,
    argument_4: Tensor,
    argument_5: List[Tensor]) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
  _0, _1, _2, _3, = argument_5
  _4 = torch.add(_1, torch.matmul(argument_1, torch.t(_0)), alpha=1)
  _5 = torch.relu(_4)
  _6 = torch.t(_2)
  _7 = torch.add(_3, torch.matmul(_5, _6), alpha=1)
  _8 = torch.sub(_7, torch.max(_7), alpha=1)
  _9 = torch.exp(_8)
  _10 = torch.sum(_9, [1], True, dtype=None)
  _11 = torch.sub(_8, torch.log(_10), alpha=1)
  _12 = torch.mul(argument_2, _11)
  _13 = torch.div(torch.neg(torch.sum(_12, dtype=None)), argument_3)
  _14 = torch.add(torch.mul(_13, CONSTANTS.c0), CONSTANTS.c1, alpha=1)
  _15 = torch.mul(torch.div(_14, argument_3), CONSTANTS.c2)
  _16 = torch.reshape(_15, [-1, 1])
  _17 = torch.add(torch.mul(_12, CONSTANTS.c0), CONSTANTS.c1, alpha=1)
  _18 = torch.mul(torch.mul(_17, _16), argument_2)
  _19 = torch.add(_18, CONSTANTS.c0, alpha=1)
  _20 = torch.sum(torch.mul(_18, CONSTANTS.c2), [1], True, dtype=None)
  _21 = torch.add(_19, CONSTANTS.c0, alpha=1)
  _22 = torch.add(_21, CONSTANTS.c0, alpha=1)
  _23 = torch.add(_21, CONSTANTS.c0, alpha=1)
  _24 = torch.sum(_22, [0], False, dtype=None)
  _25 = torch.matmul(_23, torch.t(_6))
  _26 = torch.matmul(torch.t(_5), _23)
  _27 = torch.gt(_4, torch.mul(_4, CONSTANTS.c0))
  _28 = torch.mul(_27, _25)
  _29 = torch.add(_28, CONSTANTS.c0, alpha=1)
  _30 = torch.add(_28, CONSTANTS.c0, alpha=1)
  _31 = torch.sum(_29, [0], False, dtype=None)
  _32 = torch.matmul(torch.t(argument_1), _30)
  _33 = torch.t(_32)
  _34 = torch.t(_26)
  _35 = torch.mul(torch.reciprocal(_10), CONSTANTS.c1)
  _36 = torch.reshape(torch.mul(_20, _35), [-1, 1])
  _37 = torch.add(torch.mul(_9, CONSTANTS.c0), CONSTANTS.c1, alpha=1)
  _38 = torch.mul(torch.mul(_37, _36), torch.exp(_8))
  _39 = torch.add(_38, CONSTANTS.c0, alpha=1)
  _40 = torch.add(_39, CONSTANTS.c0, alpha=1)
  _41 = torch.add(_39, CONSTANTS.c0, alpha=1)
  _42 = torch.sum(_40, [0], False, dtype=None)
  _43 = torch.add_(_24, _42, alpha=1)
  _44 = torch.matmul(_41, torch.t(_6))
  _45 = torch.matmul(torch.t(_5), _41)
  _46 = torch.gt(_4, torch.mul(_4, CONSTANTS.c0))
  _47 = torch.mul(_46, _44)
  _48 = torch.add(_47, CONSTANTS.c0, alpha=1)
  _49 = torch.add(_47, CONSTANTS.c0, alpha=1)
  _50 = torch.sum(_48, [0], False, dtype=None)
  _51 = torch.add_(_31, _50, alpha=1)
  _52 = torch.matmul(torch.t(argument_1), _49)
  _53 = torch.add_(_33, torch.t(_52), alpha=1)
  _54 = torch.add_(_34, torch.t(_45), alpha=1)
  _55 = torch.sub(_0, torch.mul(argument_4, _53), alpha=1)
  _56 = torch.sub(_1, torch.mul(argument_4, _51), alpha=1)
  _57 = torch.sub(_2, torch.mul(argument_4, _54), alpha=1)
  _58 = torch.sub(_3, torch.mul(argument_4, _43), alpha=1)
  _59 = torch.eq(torch.argmax(_7, 1, False), torch.argmax(argument_2, 1, False))
  _60 = torch.to(torch.sum(_59, dtype=None), 6, False, False, None)
  _61 = (_13, torch.div(_60, argument_3), _55, _56, _57, _58)
  return _61


In [25]:
# Plan "tfjs"
req = requests.get(f"http://{gatewayWsUrl}/model_centric/get-plan?worker_id={worker_id}&request_key={request_key}&plan_id={training_plan_id}&receive_operations_as=tfjs")
pb = PlanPB()
pb.ParseFromString(req.content)
plan_tfjs = protobuf.serde._unbufferize(hook.local_worker, pb)
print(plan_tfjs.code)


def training_plan(arg_1, arg_2, arg_3, arg_4, arg_5, arg_6, arg_7, arg_8):
    var_0 = tf.transpose(arg_5)
    var_1 = tf.matMul(arg_1, var_0)
    var_2 = tf.add(arg_6, var_1)
    var_3 = tf.relu(var_2)
    var_4 = tf.transpose(arg_7)
    var_5 = tf.matMul(var_3, var_4)
    var_6 = tf.add(arg_8, var_5)
    var_7 = tf.max(var_6)
    var_8 = tf.sub(var_6, var_7)
    var_9 = tf.exp(var_8)
    var_10 = tf.sum(var_9, 1, keepdim=True)
    var_11 = tf.log(var_10)
    var_12 = tf.sub(var_8, var_11)
    var_13 = tf.mul(arg_2, var_12)
    var_14 = tf.sum(var_13)
    var_15 = tf.neg(var_14)
    out_1 = tf.div(var_15, arg_3)
    var_16 = tf.mul(out_1, 0)
    var_17 = tf.add(var_16, 1)
    var_18 = tf.div(var_17, arg_3)
    var_19 = tf.mul(var_18, -1)
    var_20 = tf.layers.reshape(var_19, [-1, 1])
    var_21 = tf.mul(var_13, 0)
    var_22 = tf.add(var_21, 1)
    var_23 = tf.mul(var_22, var_20)
    var_24 = tf.mul(var_23, arg_2)
    var_25 = tf.add(var_24, 0)
    var_26 = tf.mul(var_24, -1)
    var_27 = tf.sum(var_26, [1], keepdim=True)
    var_28 = tf.add(var_25, 0)
    var_29 = tf.add(var_28, 0)
    var_30 = tf.add(var_28, 0)
    var_31 = tf.sum(var_29, [0])
    var_32 = clone(var_31)
    var_33 = tf.transpose(var_4)
    var_34 = tf.matMul(var_30, var_33)
    var_35 = tf.transpose(var_3)
    var_36 = tf.matMul(var_35, var_30)
    var_37 = tf.mul(var_2, 0)
    var_38 = tf.greater(var_2, var_37)
    var_39 = tf.mul(var_38, var_34)
    var_40 = tf.add(var_39, 0)
    var_41 = tf.add(var_39, 0)
    var_42 = tf.sum(var_40, [0])
    var_43 = clone(var_42)
    var_44 = tf.transpose(arg_1)
    var_45 = tf.matMul(var_44, var_41)
    var_46 = tf.transpose(var_45)
    var_47 = clone(var_46)
    var_48 = tf.transpose(var_36)
    var_49 = clone(var_48)
    var_50 = tf.div(1, var_10)
    var_51 = tf.mul(var_27, var_50)
    var_52 = tf.layers.reshape(var_51, [-1, 1])
    var_53 = tf.mul(var_9, 0)
    var_54 = tf.add(var_53, 1)
    var_55 = tf.mul(var_54, var_52)
    var_56 = tf.exp(var_8)
    var_57 = tf.mul(var_55, var_56)
    var_58 = tf.add(var_57, 0)
    var_59 = tf.add(var_58, 0)
    var_60 = tf.add(var_58, 0)
    var_61 = tf.sum(var_59, [0])
    var_32 = tf.add(var_32, var_61)
    var_62 = tf.transpose(var_4)
    var_63 = tf.matMul(var_60, var_62)
    var_64 = tf.transpose(var_3)
    var_65 = tf.matMul(var_64, var_60)
    var_66 = tf.mul(var_2, 0)
    var_67 = tf.greater(var_2, var_66)
    var_68 = tf.mul(var_67, var_63)
    var_69 = tf.add(var_68, 0)
    var_70 = tf.add(var_68, 0)
    var_71 = tf.sum(var_69, [0])
    var_43 = tf.add(var_43, var_71)
    var_72 = tf.transpose(arg_1)
    var_73 = tf.matMul(var_72, var_70)
    var_74 = tf.transpose(var_73)
    var_47 = tf.add(var_47, var_74)
    var_75 = tf.transpose(var_65)
    var_49 = tf.add(var_49, var_75)
    var_76 = tf.mul(arg_4, var_47)
    out_3 = tf.sub(arg_5, var_76)
    var_77 = tf.mul(arg_4, var_43)
    out_4 = tf.sub(arg_6, var_77)
    var_78 = tf.mul(arg_4, var_49)
    out_5 = tf.sub(arg_7, var_78)
    var_79 = tf.mul(arg_4, var_32)
    out_6 = tf.sub(arg_8, var_79)
    var_80 = tf.argMax(var_6, 1)
    var_81 = tf.argMax(arg_2, 1)
    var_82 = tf.equal(var_80, var_81)
    var_83 = tf.sum(var_82)
    var_84 = tf.cast(var_83, float32)
    out_2 = tf.div(var_84, arg_3)
    return out_1, out_2, out_3, out_4, out_5, out_6

Step 5: Train

To train hosted model, use one of the existing FL workers: