In [1]:
# エポック数が10だとトレーニングに時間が掛かりすぎるため、ここでは2にしていますが、
# より高い精度が出ることを体験したい方はより大きなエポック数を試してみてください。
# epochs = 10
epochs = 2

Part 6 - Federated Learningを使ってMNIST

MNISTは手書き文字のデータセットです。CNNを使った分類モデルをトレーニングします。

10行でPytorchのチュートリアルをPyTorch + PySyftへアップグレード

背景

Federated Learningは分散配置された学習データで学習を行える、とてもエキサイティングで、今まさに盛り上がりつつある機械学習のテクニックです。学習データはデータ所有者(ここではワーカー)の元を離れず、モデルの方がワーカー間で共有されながら学習されていくという考え方です。この手法の利点は、データのプライバシーを守れる事です。アプリケーションの応用例としては、キーボードの予測入力があります。キーボードの予測入力ではあなたが入力するテキストを学習データとする必要がありますが、個人的なメッセージですから、サーバーに送りたくはないですよね!

ところで、Federated Learningが注目を集め初めているのは、個人情報の保護に関する意識の高まりと関係があります。2018年の5月にEUで施工されたGDPRをきっかけに一躍注目を集めるようになりました。法規制を見越して、アップルやグーグルはこの技術に大きな投資をしています。特にモバイルユーザーのプライバシーの保護を意識しています。しかしながら、彼らはソースコードをオープンにしていません。

私たち(OpenMined)は、機械学習に携わる者なら、誰でも簡単にプライバシーに配慮した学習手法にアクセスできるべきだと考えています。そこで私たちはたったの一行のコードでデータを暗号化できるツールを開発しました。また、PyTorch 1.0の新機能を使って、直感的に、セキュアに、かつ大規模にFederated Learningを実装できるフレームワークもリリースしました。 詳細はブログを参照してください

このチュートリアルでは、Pytorchの(公式)チュートリアルをベースに、PySyftのライブラリを使う事で簡単にFederated Learningを実装できる事を紹介します。チュートリアルのコードサンプルを元に、Federated Learning化するために必要な変更を、一行一行確認していきましょう。

このコンテンツは私たちのブログからも見つけることが可能です。

Authors:

Ok, let's get started!

必要なモデルやライブラリをインポート

まず、PyTorch関連のライブラリをインポートします。


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

次にPySyft関連の設定を行います。ここではリモートワーカーとしてalicebobを定義しています。


In [3]:
import syft as sy  # <-- NEW: Pysyftライブラリのインポート
hook = sy.TorchHook(torch)  # <-- NEW: PyTorchをホック(Federated Learningに必要な機能を追加)
bob = sy.VirtualWorker(hook, id="bob")  # <-- NEW: リモートワーカー、Bobを追加
alice = sy.VirtualWorker(hook, id="alice")  # <-- NEW: 同じくAliceを追加


/home/ext-share/anaconda3/envs/pytorch/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:516: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
/home/ext-share/anaconda3/envs/pytorch/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:517: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
/home/ext-share/anaconda3/envs/pytorch/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:518: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
/home/ext-share/anaconda3/envs/pytorch/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:519: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
/home/ext-share/anaconda3/envs/pytorch/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:520: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
/home/ext-share/anaconda3/envs/pytorch/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:525: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  np_resource = np.dtype([("resource", np.ubyte, 1)])
WARNING:tensorflow:From /home/ext-share/anaconda3/envs/pytorch/lib/python3.6/site-packages/tf_encrypted/session.py:24: The name tf.Session is deprecated. Please use tf.compat.v1.Session instead.

/home/ext-share/anaconda3/envs/pytorch/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:541: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
/home/ext-share/anaconda3/envs/pytorch/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:542: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
/home/ext-share/anaconda3/envs/pytorch/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:543: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
/home/ext-share/anaconda3/envs/pytorch/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:544: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
/home/ext-share/anaconda3/envs/pytorch/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:545: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
/home/ext-share/anaconda3/envs/pytorch/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:550: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  np_resource = np.dtype([("resource", np.ubyte, 1)])

学習処理のハイパーパラメータを定義します。


In [4]:
class Arguments():
    def __init__(self):
        self.batch_size = 64
        self.test_batch_size = 1000
        self.epochs = epochs
        self.lr = 0.01
        self.momentum = 0.5
        self.no_cuda = False
        self.seed = 1
        self.log_interval = 30
        self.save_model = False

args = Arguments()

use_cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)

device = torch.device("cuda" if use_cuda else "cpu")

kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

データをロードして、ワーカーへ送る

まず、データをロードし、.federateコマンドを使って、データを分割しつつ、PytorchのDataset型からPySyftのFederated Dataset型へ変更し、複数のワーカー(このケースではAliceとBob)に割り当てます。この際に出来上がったfederated datasetはFederated DataLoaderへ渡されます。テスト用のデータセットに変更はありません。


In [5]:
federated_train_loader = sy.FederatedDataLoader( # <-- FederatedDataLoader を使います
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ]))
    .federate((bob, alice)), # <-- NEW: FederatedDatasetに変換し、分割してワーカーへ送ります。
    batch_size=args.batch_size, shuffle=True, **kwargs)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=args.test_batch_size, shuffle=True, **kwargs)


WARNING:root:The following options are not supported: num_workers: 1, pin_memory: True

CNNモデルの定義

ここではPytorchの公式チュートリアルの事例と全く同じ設定とします。


In [6]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

トレーニング関数とテスト関数の定義

トレーニング時は、データがalicebobに分散しているので、モデルを適宜各ワーカーへ送る必要があります。モデルを各ワーカーへ送った後は、ごく普通のPyTorchのトレーニングスクリプトと同様の構文で、リモートマシンでの学習を行うことができます。トレーニング完了後は、ロスと学習済みモデルを受け取ります。


In [7]:
def train(args, model, device, federated_train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(federated_train_loader): # <-- now FederatedDataLoaderです
        model.send(data.location) # <-- NEW: モデルをデータ所有者の元へ送ります
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        model.get() # <-- NEW: 学習済みモデルを受け取ります
        if batch_idx % args.log_interval == 0:
            loss = loss.get() # <-- NEW: ログ表示用にロスを受け取ります
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * args.batch_size, len(federated_train_loader) * args.batch_size,
                100. * batch_idx / len(federated_train_loader), loss.item()))

テスト用の関数は変更の必要はありません。


In [8]:
def test(args, model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item() # バッチロスを合計します
            pred = output.argmax(1, keepdim=True) # log-probabilityが最大のインデックスを取得します
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

実際にトレーニングしてみます


In [9]:
%%time
model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=args.lr) # TODO momentumは現在サポートされていません

for epoch in range(1, args.epochs + 1):
    train(args, model, device, federated_train_loader, optimizer, epoch)
    test(args, model, device, test_loader)

if (args.save_model):
    torch.save(model.state_dict(), "mnist_cnn.pt")


Train Epoch: 1 [0/60032 (0%)]	Loss: 2.305134
Train Epoch: 1 [1920/60032 (3%)]	Loss: 2.156802
Train Epoch: 1 [3840/60032 (6%)]	Loss: 1.896626
Train Epoch: 1 [5760/60032 (10%)]	Loss: 1.440404
Train Epoch: 1 [7680/60032 (13%)]	Loss: 0.866800
Train Epoch: 1 [9600/60032 (16%)]	Loss: 0.654367
Train Epoch: 1 [11520/60032 (19%)]	Loss: 0.593107
Train Epoch: 1 [13440/60032 (22%)]	Loss: 0.455813
Train Epoch: 1 [15360/60032 (26%)]	Loss: 0.370645
Train Epoch: 1 [17280/60032 (29%)]	Loss: 0.303963
Train Epoch: 1 [19200/60032 (32%)]	Loss: 0.313645
Train Epoch: 1 [21120/60032 (35%)]	Loss: 0.369348
Train Epoch: 1 [23040/60032 (38%)]	Loss: 0.237722
Train Epoch: 1 [24960/60032 (42%)]	Loss: 0.187720
Train Epoch: 1 [26880/60032 (45%)]	Loss: 0.524170
Train Epoch: 1 [28800/60032 (48%)]	Loss: 0.224550
Train Epoch: 1 [30720/60032 (51%)]	Loss: 0.143592
Train Epoch: 1 [32640/60032 (54%)]	Loss: 0.268505
Train Epoch: 1 [34560/60032 (58%)]	Loss: 0.187220
Train Epoch: 1 [36480/60032 (61%)]	Loss: 0.302562
Train Epoch: 1 [38400/60032 (64%)]	Loss: 0.239577
Train Epoch: 1 [40320/60032 (67%)]	Loss: 0.256233
Train Epoch: 1 [42240/60032 (70%)]	Loss: 0.192204
Train Epoch: 1 [44160/60032 (74%)]	Loss: 0.174780
Train Epoch: 1 [46080/60032 (77%)]	Loss: 0.221893
Train Epoch: 1 [48000/60032 (80%)]	Loss: 0.323514
Train Epoch: 1 [49920/60032 (83%)]	Loss: 0.274752
Train Epoch: 1 [51840/60032 (86%)]	Loss: 0.130486
Train Epoch: 1 [53760/60032 (90%)]	Loss: 0.184121
Train Epoch: 1 [55680/60032 (93%)]	Loss: 0.223131
Train Epoch: 1 [57600/60032 (96%)]	Loss: 0.080876
Train Epoch: 1 [59520/60032 (99%)]	Loss: 0.143369

Test set: Average loss: 0.1573, Accuracy: 9515/10000 (95%)

Train Epoch: 2 [0/60032 (0%)]	Loss: 0.102808
Train Epoch: 2 [1920/60032 (3%)]	Loss: 0.106100
Train Epoch: 2 [3840/60032 (6%)]	Loss: 0.146959
Train Epoch: 2 [5760/60032 (10%)]	Loss: 0.148886
Train Epoch: 2 [7680/60032 (13%)]	Loss: 0.109027
Train Epoch: 2 [9600/60032 (16%)]	Loss: 0.110443
Train Epoch: 2 [11520/60032 (19%)]	Loss: 0.118914
Train Epoch: 2 [13440/60032 (22%)]	Loss: 0.062979
Train Epoch: 2 [15360/60032 (26%)]	Loss: 0.089123
Train Epoch: 2 [17280/60032 (29%)]	Loss: 0.156774
Train Epoch: 2 [19200/60032 (32%)]	Loss: 0.161360
Train Epoch: 2 [21120/60032 (35%)]	Loss: 0.157510
Train Epoch: 2 [23040/60032 (38%)]	Loss: 0.229683
Train Epoch: 2 [24960/60032 (42%)]	Loss: 0.196785
Train Epoch: 2 [26880/60032 (45%)]	Loss: 0.206010
Train Epoch: 2 [28800/60032 (48%)]	Loss: 0.079425
Train Epoch: 2 [30720/60032 (51%)]	Loss: 0.062955
Train Epoch: 2 [32640/60032 (54%)]	Loss: 0.158972
Train Epoch: 2 [34560/60032 (58%)]	Loss: 0.156671
Train Epoch: 2 [36480/60032 (61%)]	Loss: 0.074501
Train Epoch: 2 [38400/60032 (64%)]	Loss: 0.161591
Train Epoch: 2 [40320/60032 (67%)]	Loss: 0.073496
Train Epoch: 2 [42240/60032 (70%)]	Loss: 0.152694
Train Epoch: 2 [44160/60032 (74%)]	Loss: 0.047764
Train Epoch: 2 [46080/60032 (77%)]	Loss: 0.085315
Train Epoch: 2 [48000/60032 (80%)]	Loss: 0.100825
Train Epoch: 2 [49920/60032 (83%)]	Loss: 0.154736
Train Epoch: 2 [51840/60032 (86%)]	Loss: 0.031952
Train Epoch: 2 [53760/60032 (90%)]	Loss: 0.073943
Train Epoch: 2 [55680/60032 (93%)]	Loss: 0.113156
Train Epoch: 2 [57600/60032 (96%)]	Loss: 0.112269
Train Epoch: 2 [59520/60032 (99%)]	Loss: 0.068695

Test set: Average loss: 0.0901, Accuracy: 9737/10000 (97%)

CPU times: user 1min 38s, sys: 1.14 s, total: 1min 39s
Wall time: 1min 43s

ジャジャーン! Federated Learningを使ってリモートデータでのモデル学習に成功しました!

最後に

気に掛かっている事はありませんか: 学習にかかる時間です Federated Learningでの学習って通常の学習より時間が掛かりそうな気がするけど、どの程度長く掛かっちゃうのか気になりませんか?

コンピューテーションにかかる時間は、もちろん通常の学習よりは時間がかかるけれど倍までは行かないというものです。だいたい1.9倍くらいの時間が掛かります。でも、得られるメリットを考えたら小さなマイナスですよね。

結論

見て頂いた通り、Pytorchの公式チュートリアルのソースコードを10行程度変更するだけで、Federated Learningを使ってMNISTを学習することができました。

もちろん、改善の余地はまだまだあります。各ワーカーのコンピューテーションを並列化するとか、バッチ毎に集計をするのではなく、数バッチに1回だけ集計をするようにするとか、ワーカーどうしのやりとりの頻度を減らすとか、色々あります。これらは、Federated Learningをプロダクション環境で使えるようにするために私たちが取り組んでいる機能です。それらの機能がリリースされしだい、チュートリアルにも反映させていきたいと思います。

もし、やろうと思えば、ご自身でもFederated Learningを実装できると思います。もし、PySyft、プライバシーに配慮したディープラーニング、非中央集権的なAIの学習データ、あるいは学習データのサプライチェーンに関する活動に参加したい、貢献したいって思われた方は以下を参考にしてみてください。

PySyftのGitHubレポジトリにスターをつける

一番簡単に貢献できる方法はこのGitHubのレポジトリにスターを付けていただくことです。スターが増えると露出が増え、より多くのデベロッパーにこのクールな技術の事を知って貰えます。

Slackに入る

最新の開発状況のトラッキングする一番良い方法はSlackに入ることです。 下記フォームから入る事ができます。 http://slack.openmined.org

コードプロジェクトに参加する

コミュニティに貢献する一番良い方法はソースコードのコントリビューターになることです。PySyftのGitHubへアクセスしてIssueのページを開き、"Projects"で検索してみてください。参加し得るプロジェクトの状況を把握することができます。また、"good first issue"とマークされているIssueを探す事でミニプロジェクトを探すこともできます。

寄付

もし、ソースコードで貢献できるほどの時間は取れないけど、是非何かサポートしたいという場合は、寄付をしていただくことも可能です。寄附金の全ては、ハッカソンやミートアップの開催といった、コミュニティ運営経費として利用されます。

OpenMined's Open Collective Page


In [ ]: