In [1]:
import pandas as pd
import matplotlib
import numpy as np
import matplotlib.pyplot as plt

In [24]:
import deepchem as dc
import deepchem
from rdkit import Chem
import numpy as np
import os
import pickle
from deepchem.contrib.torch.pytorch_graphconv import symmetric_normalize_adj
np.random.seed(2017)


# Load Tox21 dataset
save_file = "./tox21_featurized_nxn.pkl"
if not os.path.exists(save_file):
#if 1==1:
    tox21_tasks, tox21_datasets, transformers = dc.molnet.load_tox21(featurizer="AdjacencyConv")
    with open(save_file, "wb") as f:
        pickle.dump((tox21_tasks, tox21_datasets, transformers), f, protocol=2)
else:
    with open(save_file, "rb") as f:
        tox21_tasks, tox21_datasets, transformers = pickle.load(f)
train_dataset, valid_dataset, test_dataset = tox21_datasets

train_dataset_x = train_dataset.X
valid_dataset_x = valid_dataset.X
test_dataset_x = test_dataset.X

for idx in range(0, len(train_dataset_x)):
    if idx % 100 == 0: print(idx)
    train_dataset_x[idx] = {"g": symmetric_normalize_adj(train_dataset_x[idx][0]), "x": train_dataset_x[idx][1]}

for idx in range(0, len(valid_dataset_x)):
    valid_dataset_x[idx] = {"g": symmetric_normalize_adj(valid_dataset_x[idx][0]), "x": valid_dataset_x[idx][1]}

for idx in range(0, len(test_dataset_x)):
    test_dataset_x[idx] = {"g": symmetric_normalize_adj(test_dataset_x[idx][0]), "x": test_dataset_x[idx][1]}


0
100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
1700
1800
1900
2000
2100
2200
2300
2400
2500
2600
2700
2800
2900
3000
3100
3200
3300
3400
3500
3600
3700
3800
3900
4000
4100
4200
4300
4400
4500
4600
4700
4800
4900
5000
5100
5200
5300
5400
5500
5600
5700
5800
5900
6000
6100
6200
6300
6400

In [25]:
train_dataset_x[0]["g"].shape


Out[25]:
(150, 150)

In [102]:
import deepchem.contrib.torch.pytorch_graphconv
reload(deepchem.contrib.torch.pytorch_graphconv)
from deepchem.contrib.torch.pytorch_graphconvpytorch_graphconv import GraphConvolution, SingleTaskGraphConvolution, MultiTaskGraphConvolution

In [98]:
gc = GraphConvolution(n_conv_layers=1,
                     max_n_atoms=train_dataset_x[0]["g"].shape[0],
                     n_atom_types=train_dataset_x[0]["x"].shape[1],
                     max_valence=4,
                     conv_layer_dims=[128,128,256],
                     n_fc_layers=2,
                     fc_layer_dims=[64, 12],
                     dropout=0.,
                     return_sigmoid=True)

print(gc)


GraphConvolution (
  (conv_ops): ModuleList (
    (0): Sequential (
      (0): Linear (75 -> 128)
      (1): Dropout (p = 0.0)
      (2): LeakyReLU (0.01, inplace)
      (3): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True)
    )
  )
  (fc_ops): ModuleList (
    (0): Sequential (
      (0): Linear (128 -> 64)
      (1): Dropout (p = 0.0)
      (2): ReLU (inplace)
      (3): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True)
    )
    (1): Sequential (
      (0): Linear (64 -> 12)
      (1): Dropout (p = 0.0)
      (2): ReLU (inplace)
      (3): BatchNorm1d(12, eps=1e-05, momentum=0.1, affine=True)
    )
  )
)

In [99]:
mtgc = MultiTaskGraphConvolution(net=gc, lr=1e-3, weight_decay=0., n_tasks=train_dataset.y.shape[1])


<pytorch_graphconv.MultiTaskGraphConvolution object at 0x7fe4d472fa50>

In [101]:
for epoch in range(0, 100):
    print("epoch: %d" %epoch)
    mtgc.train_epoch(train_dataset_x, train_dataset.y)
    mtgc.evaluate(train_dataset_x,
                       valid_dataset_x,
                       train_dataset.y,
                       valid_dataset.y, 
                       transformer=None,
                       batch_size=32)


epoch: 0
TRAIN:
RMSE:
0.53186988363
ROC AUC:
0.723114679447
TEST:
RMSE:
0.528469609187
ROC AUC:
0.675233267089
epoch: 1
TRAIN:
RMSE:
0.598940708157
ROC AUC:
0.755874104233
TEST:
RMSE:
0.590225975884
ROC AUC:
0.711647401725
epoch: 2
TRAIN:
RMSE:
0.562084877244
ROC AUC:
0.78211365085
TEST:
RMSE:
0.556865888178
ROC AUC:
0.726447343016
epoch: 3
TRAIN:
RMSE:
0.642413079526
ROC AUC:
0.78587978098
TEST:
RMSE:
0.634735195488
ROC AUC:
0.737737387723
epoch: 4
TRAIN:
RMSE:
0.638659640579
ROC AUC:
0.803120459635
TEST:
RMSE:
0.629150083298
ROC AUC:
0.769429276589
epoch: 5
TRAIN:
RMSE:
0.582474446838
ROC AUC:
0.81245649018
TEST:
RMSE:
0.572140395229
ROC AUC:
0.772381808464
epoch: 6
TRAIN:
RMSE:
0.704796795508
ROC AUC:
0.801768784871
TEST:
RMSE:
0.695747050214
ROC AUC:
0.773810687118
epoch: 7
TRAIN:
RMSE:
0.685908526704
ROC AUC:
0.823920923992
TEST:
RMSE:
0.677127303167
ROC AUC:
0.772956146258
epoch: 8
TRAIN:
RMSE:
0.659291154141
ROC AUC:
0.818992835694
TEST:
RMSE:
0.65217092742
ROC AUC:
0.781351347957
epoch: 9
TRAIN:
RMSE:
0.671502161027
ROC AUC:
0.810644545312
TEST:
RMSE:
0.660665267729
ROC AUC:
0.763459361922
epoch: 10
TRAIN:
RMSE:
0.700032647626
ROC AUC:
0.810779775378
TEST:
RMSE:
0.69290967366
ROC AUC:
0.761431912225
epoch: 11
TRAIN:
RMSE:
0.767156844417
ROC AUC:
0.811739557383
TEST:
RMSE:
0.758572523042
ROC AUC:
0.764784832302
epoch: 12
TRAIN:
RMSE:
0.753884446232
ROC AUC:
0.819315866735
TEST:
RMSE:
0.743418487105
ROC AUC:
0.764869501505
epoch: 13
TRAIN:
RMSE:
0.718770698779
ROC AUC:
0.817167256792
TEST:
RMSE:
0.711835244379
ROC AUC:
0.788994632188
epoch: 14
TRAIN:
RMSE:
0.724336389134
ROC AUC:
0.832878055325
TEST:
RMSE:
0.715701222555
ROC AUC:
0.787385426772
epoch: 15
TRAIN:
RMSE:
0.747999631188
ROC AUC:
0.829874422947
TEST:
RMSE:
0.741302864203
ROC AUC:
0.783744651026
epoch: 16
TRAIN:
RMSE:
0.749683401153
ROC AUC:
0.837899146847
TEST:
RMSE:
0.743330795796
ROC AUC:
0.799209093492
epoch: 17
TRAIN:
RMSE:
0.789491601515
ROC AUC:
0.838638054584
TEST:
RMSE:
0.780108520322
ROC AUC:
0.785109169312
epoch: 18
TRAIN:
RMSE:
0.794947192656
ROC AUC:
0.829175652813
TEST:
RMSE:
0.789303090961
ROC AUC:
0.784858791783
epoch: 19
TRAIN:
RMSE:
0.791379053443
ROC AUC:
0.834376097899
TEST:
RMSE:
0.785039878347
ROC AUC:
0.782490310625
epoch: 20
TRAIN:
RMSE:
0.790213421906
ROC AUC:
0.84486576613
TEST:
RMSE:
0.783333710333
ROC AUC:
0.792433496903
epoch: 21
TRAIN:
RMSE:
0.776783335309
ROC AUC:
0.840139980925
TEST:
RMSE:
0.77025888597
ROC AUC:
0.795943982149
epoch: 22
TRAIN:
RMSE:
0.782282614857
ROC AUC:
0.841350465749
TEST:
RMSE:
0.777188343188
ROC AUC:
0.792490400887
epoch: 23
TRAIN:
RMSE:
0.819948831147
ROC AUC:
0.844551760251
TEST:
RMSE:
0.815227094932
ROC AUC:
0.781483110113
epoch: 24
TRAIN:
RMSE:
0.805488880979
ROC AUC:
0.823387590349
TEST:
RMSE:
0.797630158931
ROC AUC:
0.779858501377
epoch: 25
TRAIN:
RMSE:
0.854889571883
ROC AUC:
0.818578959239
TEST:
RMSE:
0.852101354073
ROC AUC:
0.753321720996
epoch: 26
TRAIN:
RMSE:
0.817167701228
ROC AUC:
0.84299811066
TEST:
RMSE:
0.813433980716
ROC AUC:
0.779987025892
epoch: 27
TRAIN:
RMSE:
0.84590004396
ROC AUC:
0.844080048939
TEST:
RMSE:
0.842470858171
ROC AUC:
0.789743802568
epoch: 28
TRAIN:
RMSE:
0.83433587212
ROC AUC:
0.84903604871
TEST:
RMSE:
0.830326107974
ROC AUC:
0.795776998389
epoch: 29
TRAIN:
RMSE:
0.833986482866
ROC AUC:
0.852971315268
TEST:
RMSE:
0.83034489568
ROC AUC:
0.798362597679
epoch: 30
TRAIN:
RMSE:
0.860026995721
ROC AUC:
0.839949677824
TEST:
RMSE:
0.857066908998
ROC AUC:
0.775130172768
epoch: 31
TRAIN:
RMSE:
0.847605611651
ROC AUC:
0.851982281002
TEST:
RMSE:
0.844974399196
ROC AUC:
0.794970041034
epoch: 32
TRAIN:
RMSE:
0.854665903194
ROC AUC:
0.851502578936
TEST:
RMSE:
0.852252192472
ROC AUC:
0.791160221211
epoch: 33
TRAIN:
RMSE:
0.862152987779
ROC AUC:
0.84607879539
TEST:
RMSE:
0.86036796003
ROC AUC:
0.779487644379
epoch: 34
TRAIN:
RMSE:
0.859375767801
ROC AUC:
0.853598858699
TEST:
RMSE:
0.854968626042
ROC AUC:
0.7845451331
epoch: 35
TRAIN:
RMSE:
0.869298698531
ROC AUC:
0.854877828829
TEST:
RMSE:
0.864919059602
ROC AUC:
0.7917297516
epoch: 36
TRAIN:
RMSE:
0.869651776755
ROC AUC:
0.855105576283
TEST:
RMSE:
0.865064118839
ROC AUC:
0.798612582766
epoch: 37
TRAIN:
RMSE:
0.886879476241
ROC AUC:
0.855505081926
TEST:
RMSE:
0.885691688422
ROC AUC:
0.800056766629
epoch: 38
TRAIN:
RMSE:
0.871914537234
ROC AUC:
0.849655303757
TEST:
RMSE:
0.868528931363
ROC AUC:
0.793178154207
epoch: 39
TRAIN:
RMSE:
0.867837743555
ROC AUC:
0.849937208016
TEST:
RMSE:
0.864016387228
ROC AUC:
0.795752863251
epoch: 40
TRAIN:
RMSE:
0.857926416644
ROC AUC:
0.837417886605
TEST:
RMSE:
0.855494448824
ROC AUC:
0.796943137615
epoch: 41
TRAIN:
RMSE:
0.894374857055
ROC AUC:
0.843325575498
TEST:
RMSE:
0.892262398245
ROC AUC:
0.789975931577
epoch: 42
TRAIN:
RMSE:
0.890075166468
ROC AUC:
0.852815544425
TEST:
RMSE:
0.889214264798
ROC AUC:
0.78515518305
epoch: 43
TRAIN:
RMSE:
0.892523775966
ROC AUC:
0.859971686262
TEST:
RMSE:
0.889855235803
ROC AUC:
0.796565609116
epoch: 44
TRAIN:
RMSE:
0.894561573761
ROC AUC:
0.861110406025
TEST:
RMSE:
0.892693982821
ROC AUC:
0.801407451361
epoch: 45
TRAIN:
RMSE:
0.895778291671
ROC AUC:
0.864573838719
TEST:
RMSE:
0.893402096689
ROC AUC:
0.807178103975
epoch: 46
TRAIN:
RMSE:
0.901926233982
ROC AUC:
0.858844375611
TEST:
RMSE:
0.900232845846
ROC AUC:
0.8054827577
epoch: 47
TRAIN:
RMSE:
0.90112203259
ROC AUC:
0.860794426165
TEST:
RMSE:
0.899489209332
ROC AUC:
0.799393050336
epoch: 48
TRAIN:
RMSE:
0.906220887096
ROC AUC:
0.865673990726
TEST:
RMSE:
0.903847771065
ROC AUC:
0.808401245295
epoch: 49
TRAIN:
RMSE:
0.908549273683
ROC AUC:
0.86483991466
TEST:
RMSE:
0.906795882571
ROC AUC:
0.7979263011
epoch: 50
TRAIN:
RMSE:
0.913077018159
ROC AUC:
0.867931245462
TEST:
RMSE:
0.911690965412
ROC AUC:
0.80672120423
epoch: 51
TRAIN:
RMSE:
0.910657974465
ROC AUC:
0.853133264831
TEST:
RMSE:
0.907812835612
ROC AUC:
0.792085990161
epoch: 52
TRAIN:
RMSE:
0.914595355296
ROC AUC:
0.855434915892
TEST:
RMSE:
0.912994891771
ROC AUC:
0.779533658118
epoch: 53
TRAIN:
RMSE:
0.904838862712
ROC AUC:
0.860277297254
TEST:
RMSE:
0.902881918376
ROC AUC:
0.80415100826
epoch: 54
TRAIN:
RMSE:
0.916614034918
ROC AUC:
0.854245365766
TEST:
RMSE:
0.915204816861
ROC AUC:
0.794301713556
epoch: 55
TRAIN:
RMSE:
0.920304433292
ROC AUC:
0.843420751121
TEST:
RMSE:
0.918212983712
ROC AUC:
0.785868739384
epoch: 56
TRAIN:
RMSE:
0.92549607619
ROC AUC:
0.843905393245
TEST:
RMSE:
0.923702627051
ROC AUC:
0.79529105799
epoch: 57
TRAIN:
RMSE:
0.920326548937
ROC AUC:
0.855349356311
TEST:
RMSE:
0.917578781078
ROC AUC:
0.80014379048
epoch: 58
TRAIN:
RMSE:
0.918235023484
ROC AUC:
0.858402225861
TEST:
RMSE:
0.915700849948
ROC AUC:
0.798753567291
epoch: 59
TRAIN:
RMSE:
0.923372252821
ROC AUC:
0.852344962353
TEST:
RMSE:
0.922896909112
ROC AUC:
0.797422897064
epoch: 60
TRAIN:
RMSE:
0.920716157336
ROC AUC:
0.846140102585
TEST:
RMSE:
0.920963904548
ROC AUC:
0.788893578562
epoch: 61
TRAIN:
RMSE:
0.920366277151
ROC AUC:
0.854903023707
TEST:
RMSE:
0.920194833465
ROC AUC:
0.808726971546
epoch: 62
TRAIN:
RMSE:
0.926801732535
ROC AUC:
0.859927874983
TEST:
RMSE:
0.926253023518
ROC AUC:
0.801274511882
epoch: 63
TRAIN:
RMSE:
0.926582790221
ROC AUC:
0.864419491826
TEST:
RMSE:
0.925949254929
ROC AUC:
0.807576529972
epoch: 64
TRAIN:
RMSE:
0.931422508418
ROC AUC:
0.855404230926
TEST:
RMSE:
0.931795863437
ROC AUC:
0.790034307216
epoch: 65
TRAIN:
RMSE:
0.928686837991
ROC AUC:
0.857457358862
TEST:
RMSE:
0.929081629356
ROC AUC:
0.793033932042
epoch: 66
TRAIN:
RMSE:
0.931913446688
ROC AUC:
0.860368834217
TEST:
RMSE:
0.932202917079
ROC AUC:
0.802801893294
epoch: 67
TRAIN:
RMSE:
0.93074883291
ROC AUC:
0.859710499316
TEST:
RMSE:
0.930041598504
ROC AUC:
0.797978103347
epoch: 68
TRAIN:
RMSE:
0.930061635495
ROC AUC:
0.864354573691
TEST:
RMSE:
0.929510783324
ROC AUC:
0.800886681799
epoch: 69
TRAIN:
RMSE:
0.926299685943
ROC AUC:
0.865296090905
TEST:
RMSE:
0.926035663576
ROC AUC:
0.797711733837
epoch: 70
TRAIN:
RMSE:
0.936108132936
ROC AUC:
0.844312896937
TEST:
RMSE:
0.93649247942
ROC AUC:
0.77931153636
epoch: 71
TRAIN:
RMSE:
0.9309734069
ROC AUC:
0.855362592047
TEST:
RMSE:
0.930042538452
ROC AUC:
0.794882722852
epoch: 72
TRAIN:
RMSE:
0.928986652152
ROC AUC:
0.819924243734
TEST:
RMSE:
0.927448826211
ROC AUC:
0.767221892572
epoch: 73
TRAIN:
RMSE:
0.935343863728
ROC AUC:
0.85891439429
TEST:
RMSE:
0.93542656286
ROC AUC:
0.793173052471
epoch: 74
TRAIN:
RMSE:
0.939686779958
ROC AUC:
0.864057691692
TEST:
RMSE:
0.939869228688
ROC AUC:
0.796912527196
epoch: 75
TRAIN:
RMSE:
0.935859721835
ROC AUC:
0.865996652648
TEST:
RMSE:
0.935587767454
ROC AUC:
0.795789850841
epoch: 76
TRAIN:
RMSE:
0.927668072835
ROC AUC:
0.837019199442
TEST:
RMSE:
0.926876897414
ROC AUC:
0.789698573711
epoch: 77
TRAIN:
RMSE:
0.937896748472
ROC AUC:
0.841826440158
TEST:
RMSE:
0.938453850481
ROC AUC:
0.7903299136
epoch: 78
TRAIN:
RMSE:
0.940171536989
ROC AUC:
0.840430777579
TEST:
RMSE:
0.9405977132
ROC AUC:
0.781582005312
epoch: 79
TRAIN:
RMSE:
0.94313998238
ROC AUC:
0.820048123048
TEST:
RMSE:
0.943101489585
ROC AUC:
0.758384605785
epoch: 80
TRAIN:
RMSE:
0.943944238094
ROC AUC:
0.84294497659
TEST:
RMSE:
0.944225736878
ROC AUC:
0.791365075553
epoch: 81
TRAIN:
RMSE:
0.941252839501
ROC AUC:
0.850100454599
TEST:
RMSE:
0.941178139709
ROC AUC:
0.805032823788
epoch: 82
TRAIN:
RMSE:
0.94334845548
ROC AUC:
0.85880782998
TEST:
RMSE:
0.942989117748
ROC AUC:
0.805144178997
epoch: 83
TRAIN:
RMSE:
0.942757197439
ROC AUC:
0.858633969422
TEST:
RMSE:
0.942967791298
ROC AUC:
0.802188998145
epoch: 84
TRAIN:
RMSE:
0.941984003547
ROC AUC:
0.863459588728
TEST:
RMSE:
0.94202779962
ROC AUC:
0.805272801623
epoch: 85
TRAIN:
RMSE:
0.943861135122
ROC AUC:
0.869559424854
TEST:
RMSE:
0.944059712049
ROC AUC:
0.80357490448
epoch: 86
TRAIN:
RMSE:
0.943892160782
ROC AUC:
0.870380012782
TEST:
RMSE:
0.944058229597
ROC AUC:
0.809208300761
epoch: 87
TRAIN:
RMSE:
0.940230995333
ROC AUC:
0.87246376003
TEST:
RMSE:
0.939949200727
ROC AUC:
0.808967047492
epoch: 88
TRAIN:
RMSE:
0.944266292794
ROC AUC:
0.866078344511
TEST:
RMSE:
0.944297956638
ROC AUC:
0.801329159328
epoch: 89
TRAIN:
RMSE:
0.946623793073
ROC AUC:
0.847425492164
TEST:
RMSE:
0.947071148307
ROC AUC:
0.784504515429
epoch: 90
TRAIN:
RMSE:
0.937446099607
ROC AUC:
0.83987666243
TEST:
RMSE:
0.937755393829
ROC AUC:
0.795186079951
epoch: 91
TRAIN:
RMSE:
0.945965778444
ROC AUC:
0.864741295877
TEST:
RMSE:
0.946522999419
ROC AUC:
0.80281003645
epoch: 92
TRAIN:
RMSE:
0.945757927527
ROC AUC:
0.799365484756
TEST:
RMSE:
0.945521591078
ROC AUC:
0.749642387895
epoch: 93
TRAIN:
RMSE:
0.948769988161
ROC AUC:
0.797458654599
TEST:
RMSE:
0.949179327489
ROC AUC:
0.755016478609
epoch: 94
TRAIN:
RMSE:
0.95036789814
ROC AUC:
0.810493345408
TEST:
RMSE:
0.951081825452
ROC AUC:
0.748300042541
epoch: 95
TRAIN:
RMSE:
0.946148022472
ROC AUC:
0.806541112539
TEST:
RMSE:
0.946281742483
ROC AUC:
0.758772828309
epoch: 96
TRAIN:
RMSE:
0.942774428703
ROC AUC:
0.785112083474
TEST:
RMSE:
0.942970572363
ROC AUC:
0.739490128532
epoch: 97
TRAIN:
RMSE:
0.948110874706
ROC AUC:
0.801159481428
TEST:
RMSE:
0.948190755217
ROC AUC:
0.748907443512
epoch: 98
TRAIN:
RMSE:
0.950300825258
ROC AUC:
0.813960766914
TEST:
RMSE:
0.950183085411
ROC AUC:
0.776723484313
epoch: 99
TRAIN:
RMSE:
0.953419139379
ROC AUC:
0.822994361177
TEST:
RMSE:
0.953429165286
ROC AUC:
0.778934596523