In [1]:
import pandas as pd
import matplotlib
import numpy as np
import matplotlib.pyplot as plt
In [24]:
import deepchem as dc
import deepchem
from rdkit import Chem
import numpy as np
import os
import pickle
from deepchem.contrib.torch.pytorch_graphconv import symmetric_normalize_adj
np.random.seed(2017)
# Load Tox21 dataset
save_file = "./tox21_featurized_nxn.pkl"
if not os.path.exists(save_file):
#if 1==1:
tox21_tasks, tox21_datasets, transformers = dc.molnet.load_tox21(featurizer="AdjacencyConv")
with open(save_file, "wb") as f:
pickle.dump((tox21_tasks, tox21_datasets, transformers), f, protocol=2)
else:
with open(save_file, "rb") as f:
tox21_tasks, tox21_datasets, transformers = pickle.load(f)
train_dataset, valid_dataset, test_dataset = tox21_datasets
train_dataset_x = train_dataset.X
valid_dataset_x = valid_dataset.X
test_dataset_x = test_dataset.X
for idx in range(0, len(train_dataset_x)):
if idx % 100 == 0: print(idx)
train_dataset_x[idx] = {"g": symmetric_normalize_adj(train_dataset_x[idx][0]), "x": train_dataset_x[idx][1]}
for idx in range(0, len(valid_dataset_x)):
valid_dataset_x[idx] = {"g": symmetric_normalize_adj(valid_dataset_x[idx][0]), "x": valid_dataset_x[idx][1]}
for idx in range(0, len(test_dataset_x)):
test_dataset_x[idx] = {"g": symmetric_normalize_adj(test_dataset_x[idx][0]), "x": test_dataset_x[idx][1]}
In [25]:
train_dataset_x[0]["g"].shape
Out[25]:
In [102]:
import deepchem.contrib.torch.pytorch_graphconv
reload(deepchem.contrib.torch.pytorch_graphconv)
from deepchem.contrib.torch.pytorch_graphconvpytorch_graphconv import GraphConvolution, SingleTaskGraphConvolution, MultiTaskGraphConvolution
In [98]:
gc = GraphConvolution(n_conv_layers=1,
max_n_atoms=train_dataset_x[0]["g"].shape[0],
n_atom_types=train_dataset_x[0]["x"].shape[1],
max_valence=4,
conv_layer_dims=[128,128,256],
n_fc_layers=2,
fc_layer_dims=[64, 12],
dropout=0.,
return_sigmoid=True)
print(gc)
In [99]:
mtgc = MultiTaskGraphConvolution(net=gc, lr=1e-3, weight_decay=0., n_tasks=train_dataset.y.shape[1])
In [101]:
for epoch in range(0, 100):
print("epoch: %d" %epoch)
mtgc.train_epoch(train_dataset_x, train_dataset.y)
mtgc.evaluate(train_dataset_x,
valid_dataset_x,
train_dataset.y,
valid_dataset.y,
transformer=None,
batch_size=32)