Part 22, ChemBERTa: Pre-training a BERT-like model for masked language modelling of SMILES and molecular property prediction.

By Seyone Chithrananda (Twitter)

Deep learning for chemistry and materials science remains a novel field with lots of potiential. However, the popularity of transfer learning based methods in areas such as NLP and computer vision have not yet been effectively developed in computational chemistry + machine learning. Using HuggingFace's suite of models and the ByteLevel tokenizer, we are able to train a large-transformer model, RoBERTa, on a large corpus of 100k SMILES strings from a commonly known benchmark chemistry dataset, ZINC.

Training RoBERTa over 5 epochs, the model achieves a pretty good loss of 0.398, and may likely continue to decrease if trained for a larger number of epochs. The model can predict tokens within a SMILES sequence/molecule, allowing for variants of a molecule within discoverable chemical space to be predicted.

By applying the representations of functional groups and atoms learned by the model, we can try to tackle problems of toxicity, solubility, drug-likeness, and synthesis accessibility on smaller datasets using the learned representations as features for graph convolution and attention models on the graph structure of molecules, as well as fine-tuning of BERT. Finally, we propose the use of attention visualization as a helpful tool for chemistry practitioners and students to quickly identify important substructures in various chemical properties.

Additionally, visualization of the attention mechanism have been seen through previous research as incredibly valuable towards chemical reaction classification. The applications of open-sourcing large-scale transformer models such as RoBERTa with HuggingFace may allow for the acceleration of these individual research directions.

A link to a repository which includes the training, uploading and evaluation notebook (with sample predictions on compounds such as Remdesivir) can be found here. All of the notebooks can be copied into a new Colab runtime for easy execution.

For the sake of this tutorial, we'll be fine-tuning RoBERTa on a small-scale molecule dataset, to show the potiential and effectiveness of HuggingFace's NLP-based transfer learning applied to computational chemistry. Output for some cells are purposely cleared for readability, so do not worry if some output messages for your cells differ!

Installing DeepChem from source, alongside RDKit for molecule visualizations


In [1]:
!pip install transformers


Collecting transformers
  Downloading https://files.pythonhosted.org/packages/48/35/ad2c5b1b8f99feaaf9d7cdadaeef261f098c6e1a6a2935d4d07662a6b780/transformers-2.11.0-py3-none-any.whl (674kB)
     |████████████████████████████████| 675kB 4.6MB/s 
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.6/dist-packages (from transformers) (2019.12.20)
Collecting sentencepiece
  Downloading https://files.pythonhosted.org/packages/d4/a4/d0a884c4300004a78cca907a6ff9a5e9fe4f090f5d95ab341c53d28cbc58/sentencepiece-0.1.91-cp36-cp36m-manylinux1_x86_64.whl (1.1MB)
     |████████████████████████████████| 1.1MB 23.9MB/s 
Requirement already satisfied: packaging in /usr/local/lib/python3.6/dist-packages (from transformers) (20.4)
Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.6/dist-packages (from transformers) (4.41.1)
Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from transformers) (1.18.5)
Collecting tokenizers==0.7.0
  Downloading https://files.pythonhosted.org/packages/14/e5/a26eb4716523808bb0a799fcfdceb6ebf77a18169d9591b2f46a9adb87d9/tokenizers-0.7.0-cp36-cp36m-manylinux1_x86_64.whl (3.8MB)
     |████████████████████████████████| 3.8MB 40.2MB/s 
Requirement already satisfied: dataclasses; python_version < "3.7" in /usr/local/lib/python3.6/dist-packages (from transformers) (0.7)
Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from transformers) (2.23.0)
Collecting sacremoses
  Downloading https://files.pythonhosted.org/packages/7d/34/09d19aff26edcc8eb2a01bed8e98f13a1537005d31e95233fd48216eed10/sacremoses-0.0.43.tar.gz (883kB)
     |████████████████████████████████| 890kB 57.9MB/s 
Requirement already satisfied: filelock in /usr/local/lib/python3.6/dist-packages (from transformers) (3.0.12)
Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from packaging->transformers) (1.12.0)
Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.6/dist-packages (from packaging->transformers) (2.4.7)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (1.24.3)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (2020.4.5.2)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (2.9)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (3.0.4)
Requirement already satisfied: click in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (7.1.2)
Requirement already satisfied: joblib in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (0.15.1)
Building wheels for collected packages: sacremoses
  Building wheel for sacremoses (setup.py) ... done
  Created wheel for sacremoses: filename=sacremoses-0.0.43-cp36-none-any.whl size=893260 sha256=5b83ab4c2e1f1420040b2a1c7b2a43e2f0eb4c3ae1c251ab5ff24cc5baf3bff9
  Stored in directory: /root/.cache/pip/wheels/29/3c/fd/7ce5c3f0666dab31a50123635e6fb5e19ceb42ce38d4e58f45
Successfully built sacremoses
Installing collected packages: sentencepiece, tokenizers, sacremoses, transformers
Successfully installed sacremoses-0.0.43 sentencepiece-0.1.91 tokenizers-0.7.0 transformers-2.11.0

In [2]:
import sys
!test -d bertviz_repo && echo "FYI: bertviz_repo directory already exists, to pull latest version uncomment this line: !rm -r bertviz_repo"
# !rm -r bertviz_repo # Uncomment if you need a clean pull from repo
!test -d bertviz_repo || git clone https://github.com/jessevig/bertviz bertviz_repo
if not 'bertviz_repo' in sys.path:
  sys.path += ['bertviz_repo']
!pip install regex


Cloning into 'bertviz_repo'...
remote: Enumerating objects: 1074, done.
remote: Total 1074 (delta 0), reused 0 (delta 0), pack-reused 1074
Receiving objects: 100% (1074/1074), 99.41 MiB | 27.70 MiB/s, done.
Resolving deltas: 100% (687/687), done.
Requirement already satisfied: regex in /usr/local/lib/python3.6/dist-packages (2019.12.20)

We want to install NVIDIA's Apex tool, for the training pipeline used by simple-transformers and Weights and Biases.


In [ ]:
!git clone https://github.com/NVIDIA/apex
!cd /content/apex
!pip install -v --no-cache-dir /content/apex
!cd ..

Now, to ensure our model demonstrates an understanding of chemical syntax and molecular structure, we'll be testing it on predicting a masked token/character within the SMILES molecule for Remdesivir.


In [4]:
# Test if NVIDIA apex training tool works
from apex import amp

In [5]:
from transformers import AutoModelWithLMHead, AutoTokenizer, pipeline, RobertaModel, RobertaTokenizer
from bertviz import head_view

model = AutoModelWithLMHead.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")

fill_mask = pipeline('fill-mask', model=model, tokenizer=tokenizer)








/usr/local/lib/python3.6/dist-packages/transformers/tokenization_utils.py:831: FutureWarning: Parameter max_len is deprecated and will be removed in a future release. Use model_max_length instead.
  category=FutureWarning,

In [6]:
remdesivir_mask = "CCC(CC)COC(=O)[C@H](C)N[P@](=O)(OC[C@H]1O[C@](C#N)([C@H](O)[C@@H]1O)C1=CC=C2N1N=CN=C2N)OC1=CC=CC=<mask>1"
remdesivir = "CCC(CC)COC(=O)[C@H](C)N[P@](=O)(OC[C@H]1O[C@](C#N)([C@H](O)[C@@H]1O)C1=CC=C2N1N=CN=C2N)OC1=CC=CC=C1"

"CCC(CC)COC(=O)[C@H](C)N[P@](=O)(OC[C@H]1O[C@](C#N)([C@H](O)[C@@H]1O)C1=CC=C2N1N=CN=C2N)OC1=CC=CC=O1"

masked_smi = fill_mask(remdesivir_mask)

for smi in masked_smi:
  print(smi)


{'sequence': '<s> CCC(CC)COC(=O)[C@H](C)N[P@](=O)(OC[C@H]1O[C@](C#N)([C@H](O)[C@@H]1O)C1=CC=C2N1N=CN=C2N)OC1=CC=CC=C1</s>', 'score': 0.5986589789390564, 'token': 39}
{'sequence': '<s> CCC(CC)COC(=O)[C@H](C)N[P@](=O)(OC[C@H]1O[C@](C#N)([C@H](O)[C@@H]1O)C1=CC=C2N1N=CN=C2N)OC1=CC=CC=O1</s>', 'score': 0.09766950458288193, 'token': 51}
{'sequence': '<s> CCC(CC)COC(=O)[C@H](C)N[P@](=O)(OC[C@H]1O[C@](C#N)([C@H](O)[C@@H]1O)C1=CC=C2N1N=CN=C2N)OC1=CC=CC=N1</s>', 'score': 0.0769445151090622, 'token': 50}
{'sequence': '<s> CCC(CC)COC(=O)[C@H](C)N[P@](=O)(OC[C@H]1O[C@](C#N)([C@H](O)[C@@H]1O)C1=CC=C2N1N=CN=C2N)OC1=CC=CC=21</s>', 'score': 0.024126358330249786, 'token': 22}
{'sequence': '<s> CCC(CC)COC(=O)[C@H](C)N[P@](=O)(OC[C@H]1O[C@](C#N)([C@H](O)[C@@H]1O)C1=CC=C2N1N=CN=C2N)OC1=CC=CC=H1</s>', 'score': 0.018853096291422844, 'token': 44}

Here, we get some interesting results. The final branch, C1=CC=CC=C1, is a benzene ring. Since its a pretty common molecule, the model is easily able to predict the final double carbon bond with a score of 0.60. Let's get a list of the top 5 predictions (including the target, Remdesivir), and visualize them (with a highlighted focus on the beginning of the final benzene-like pattern). Lets import some various RDKit packages to do so.


In [ ]:
!wget -c https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
!chmod +x Miniconda3-latest-Linux-x86_64.sh
!bash ./Miniconda3-latest-Linux-x86_64.sh -b -f -p /usr/local
!time conda install -q -y -c conda-forge rdkit
import sys
sys.path.append('/usr/local/lib/python3.7/site-packages/')

In [8]:
import torch
import rdkit
import rdkit.Chem as Chem
from rdkit.Chem import rdFMCS
from matplotlib import colors
from rdkit.Chem import Draw
from rdkit.Chem.Draw import MolToImage
from PIL import Image


def get_mol(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None
    Chem.Kekulize(mol)
    return mol


def find_matches_one(mol,submol):
    #find all matching atoms for each submol in submol_list in mol.
    match_dict = {}
    mols = [mol,submol] #pairwise search
    res=rdFMCS.FindMCS(mols) #,ringMatchesRingOnly=True)
    mcsp = Chem.MolFromSmarts(res.smartsString)
    matches = mol.GetSubstructMatches(mcsp)
    return matches

#Draw the molecule
def get_image(mol,atomset):    
    hcolor = colors.to_rgb('green')
    if atomset is not None:
        #highlight the atoms set while drawing the whole molecule.
        img = MolToImage(mol, size=(600, 600),fitImage=True, highlightAtoms=atomset,highlightColor=hcolor)
    else:
        img = MolToImage(mol, size=(400, 400),fitImage=True)
    return img

In [9]:
sequence = f"CCC(CC)COC(=O)[C@H](C)N[P@](=O)(OC[C@H]1O[C@](C#N)([C@H](O)[C@@H]1O)C1=CC=C2N1N=CN=C2N)OC1=CC=CC={tokenizer.mask_token}1"
substructure = "CC=CC"
image_list = []

input = tokenizer.encode(sequence, return_tensors="pt")
mask_token_index = torch.where(input == tokenizer.mask_token_id)[1]

token_logits = model(input)[0]
mask_token_logits = token_logits[0, mask_token_index, :]

top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()

for token in top_5_tokens:
  smi = (sequence.replace(tokenizer.mask_token, tokenizer.decode([token])))
  print (smi)
  smi_mol = get_mol(smi)
  substructure_mol = get_mol(substructure)
  if smi_mol is None: # if the model's token prediction isn't chemically feasible
    continue
  Draw.MolToFile(smi_mol, smi+".png")
  matches = find_matches_one(smi_mol, substructure_mol)
  atomset = list(matches[0])
  img = get_image(smi_mol, atomset)
  img.format="PNG" 
  image_list.append(img)


CCC(CC)COC(=O)[C@H](C)N[P@](=O)(OC[C@H]1O[C@](C#N)([C@H](O)[C@@H]1O)C1=CC=C2N1N=CN=C2N)OC1=CC=CC=C1
CCC(CC)COC(=O)[C@H](C)N[P@](=O)(OC[C@H]1O[C@](C#N)([C@H](O)[C@@H]1O)C1=CC=C2N1N=CN=C2N)OC1=CC=CC=O1
CCC(CC)COC(=O)[C@H](C)N[P@](=O)(OC[C@H]1O[C@](C#N)([C@H](O)[C@@H]1O)C1=CC=C2N1N=CN=C2N)OC1=CC=CC=N1
CCC(CC)COC(=O)[C@H](C)N[P@](=O)(OC[C@H]1O[C@](C#N)([C@H](O)[C@@H]1O)C1=CC=C2N1N=CN=C2N)OC1=CC=CC=21
CCC(CC)COC(=O)[C@H](C)N[P@](=O)(OC[C@H]1O[C@](C#N)([C@H](O)[C@@H]1O)C1=CC=C2N1N=CN=C2N)OC1=CC=CC=H1

In [10]:
from IPython.display import Image 

for img in image_list:
  display(img)


As we can see above, 2 of 4 of the model's MLM predictions are chemically valid. The one the model would've chosen (with a score of 0.6), is the first image, in which the top left molecular structure resembles the benzene found in the therapy Remdesivir. Overall, the model seems to understand syntax with a pretty decent degree of certainity.

However, further training on a more specific dataset (say leads for a specific target) may generate a stronger MLM model. Let's now fine-tune our model on a dataset of our choice, Tox21.

Visualizing the Attention Mechanism in ChemBERTa using BertViz

BertViz is a tool for visualizing attention in the Transformer model, supporting all models from the transformers library (BERT, GPT-2, XLNet, RoBERTa, XLM, CTRL, etc.). It extends the Tensor2Tensor visualization tool by Llion Jones and the transformers library from HuggingFace.

Using this tool, we can easily plug in CHemBERTa from the HuggingFace model hub and visualize the attention patterns produced by one or more attention heads in a given transformer layer. This is known as the attention-head view.

Lets start by obtaining a Javascript object for d3.js and jquery to create interactive visualizations:


In [11]:
%%javascript
require.config({
  paths: {
      d3: '//cdnjs.cloudflare.com/ajax/libs/d3/3.4.8/d3.min',
      jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',
  }
});



In [12]:
def call_html():
  import IPython
  display(IPython.core.display.HTML('''
        <script src="/static/components/requirejs/require.js"></script>
        <script>
          requirejs.config({
            paths: {
              base: '/static/base',
              "d3": "https://cdnjs.cloudflare.com/ajax/libs/d3/3.5.8/d3.min",
              jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',
            },
          });
        </script>
        '''))

Now, we create an instance of ChemBERTa, tokenize a set of SMILES strings, and compute the attention for each head in the transformer. There are two available models hosted by DeepChem on HuggingFace's model hub, one being seyonec/ChemBERTa-zinc-base-v1 which is the ChemBERTa model trained via masked lagnuage modelling (MLM) on the ZINC100k dataset, and the other being seyonec/ChemBERTa-zinc250k-v1, which is trained via MLM on the larger ZINC250k dataset.

In the following example, we take two SMILES molecules from the ZINC database with nearly identical chemical structure, the only difference being rooted in chiral specification (hence the additional ‘@‘ symbol). This is a feature of molecules which indicates that there exists tetrahedral centres. ‘@' tells us whether the neighbours of a molecule appear in a counter-clockwise order, whereas ‘@@‘ indicates that the neighbours are ordered in a clockwise direction. The model should ideally refer to similar substructures in each SMILES string with a higher attention weightage.

Lets look at the first SMILES string: CCCCC[C@@H](Br)CC:


In [13]:
m = Chem.MolFromSmiles('CCCCC[C@@H](Br)CC')
fig = Draw.MolToMPL(m, size=(200, 200))


And the second SMILES string, CCCCC[C@H](Br)CC:


In [14]:
m = Chem.MolFromSmiles('CCCCC[C@H](Br)CC')
fig = Draw.MolToMPL(m, size=(200,200))


The visualization below shows the attention induced by a sample input SMILES. This view visualizes attention as lines connecting the tokens being updated (left) with the tokens being attended to (right), following the design of the figures above. Color intensity reflects the attention weight; weights close to one show as very dark lines, while weights close to zero appear as faint lines or are not visible at all. The user may highlight a particular SMILES character to see the attention from that token only. This visualization is called the attention-head view. It is based on the excellent Tensor2Tensor visualization tool, and are all generated by the Bertviz library.


In [15]:
from transformers import RobertaModel, RobertaTokenizer
from bertviz import head_view

model_version = 'seyonec/ChemBERTa-zinc250k-v1'
model = RobertaModel.from_pretrained(model_version, output_attentions=True)
tokenizer = RobertaTokenizer.from_pretrained(model_version)

sentence_a = "CCCCC[C@@H](Br)CC"
sentence_b = "CCCCC[C@H](Br)CC"
inputs = tokenizer.encode_plus(sentence_a, sentence_b, return_tensors='pt', add_special_tokens=True)
input_ids = inputs['input_ids']
attention = model(input_ids)[-1]
input_id_list = input_ids[0].tolist() # Batch index 0
tokens = tokenizer.convert_ids_to_tokens(input_id_list)

call_html()

head_view(attention, tokens)








/usr/local/lib/python3.6/dist-packages/transformers/tokenization_utils.py:831: FutureWarning: Parameter max_len is deprecated and will be removed in a future release. Use model_max_length instead.
  category=FutureWarning,
Layer:

The visualization shows that attention is highest between words that don’t cross a boundary between the two SMILES strings; the model seems to understand that it should relate tokens to other tokens in the same molecule in order to best understand their context.

There are many other fascinating visualizations we can do, such as a neuron-by neuron analysis of attention or a model overview that visualizes all of the heads at once:

Attention by Head View:

Model View:

Neuron-by-neuron view:

Fine-tuning ChemBERTa on a Small Mollecular Dataset

Tumor suppressor protein (SR.p53), typically the p53 pathway is “off” and is activated when cells are under stress or damaged, hence being a good indicator of DNA damage and other cellular stresses. Tumor suppressor protein p53 is activated by inducing DNA repair, cell cycle arrest and apoptosis.

The Tox21 challenge was introduced in 2014 in an attempt to build models that are successful in predicting compounds' interference in biochemical pathways using only chemical structure data. The computational models produced from the challenge could become decision-making tools for government agencies in determining which environmental chemicals and drugs are of the greatest potential concern to human health. Additionally, these models can act as drug screening tools in the drug discovery pipelines for toxicity.

Lets start by loading the dataset from s3, before importing apex and transformers, the tool which will allow us to import the pre-trained masked-language modelling architecture trained on ZINC15.


In [16]:
!wget https://t.co/zrC7F8DcRs?amp=1


--2020-06-21 00:04:17--  https://t.co/zrC7F8DcRs?amp=1
Resolving t.co (t.co)... 104.244.42.197, 104.244.42.5, 104.244.42.133, ...
Connecting to t.co (t.co)|104.244.42.197|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/tox21_balanced_revised_no_id.csv [following]
--2020-06-21 00:04:18--  https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/tox21_balanced_revised_no_id.csv
Resolving deepchemdata.s3-us-west-1.amazonaws.com (deepchemdata.s3-us-west-1.amazonaws.com)... 52.219.120.233
Connecting to deepchemdata.s3-us-west-1.amazonaws.com (deepchemdata.s3-us-west-1.amazonaws.com)|52.219.120.233|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 85962 (84K) [text/csv]
Saving to: ‘zrC7F8DcRs?amp=1’

zrC7F8DcRs?amp=1    100%[===================>]  83.95K  --.-KB/s    in 0.05s   

2020-06-21 00:04:18 (1.73 MB/s) - ‘zrC7F8DcRs?amp=1’ saved [85962/85962]

If you're only running the toxicity prediction portion of this tutorial, make sure you install transformers here. If you've ran all the cells before, you can ignore this install as we've already done pip install transformers before.


In [ ]:
!pip install transformers

In [ ]:
!pip install simpletransformers
!pip install wandb

From here, we want to load the dataset from tox21 for training the model. We're going to use a filtered dataset of 2100 compounds, as there are only 400 positive leads and we want to avoid having a large data imbalance. We'll also use simple-transformer's auto_weights argument in defining our ChemBERTa model to do automatic weight balancing later on, to counteract this problem.


In [18]:
import pandas as pd

!cd ..
dataset_path = "/content/zrC7F8DcRs?amp=1"
df = pd.read_csv(dataset_path, sep = ',', warn_bad_lines=True, header=None)


df.rename(columns={0:'smiles',1:'labels'}, inplace=True)
df.head()


Out[18]:
smiles labels
0 CCCCCCCC/C=C\CCCCCCCC(N)=O 0
1 CCCCCCOC(=O)c1ccccc1 0
2 O=C(c1ccc(Cl)cc1)c1ccc(Cl)cc1 0
3 COc1cc(Cl)c(OC)cc1N 0
4 N[C@H](Cc1c[nH]c2ccccc12)C(=O)O 0

From here, lets set up a logger to record if any issues occur, and notify us if there are any problems with the arguments we've set for the model.


In [19]:
from simpletransformers.classification import ClassificationModel
import logging

logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.WARNING)

Now, using simple-transformer, let's load the pre-trained model from HuggingFace's useful model-hub. We'll set the number of epochs to 3 in the arguments, but you can train for longer. Also make sure that auto_weights is set to True as we are dealing with imbalanced toxicity datasets.


In [20]:
model = ClassificationModel('roberta', 'seyonec/ChemBERTa-zinc-base-v1', args={'num_train_epochs': 3, 'auto_weights': True}) # You can set class weights by using the optional weight argument


/usr/local/lib/python3.6/dist-packages/transformers/tokenization_utils.py:831: FutureWarning: Parameter max_len is deprecated and will be removed in a future release. Use model_max_length instead.
  category=FutureWarning,

In [21]:
# Split the train and test dataset 80-20

train_size = 0.8
train_dataset=df.sample(frac=train_size,random_state=200).reset_index(drop=True)
test_dataset=df.drop(train_dataset.index).reset_index(drop=True)

In [22]:
# check if our train and evaluation dataframes are setup properly. There should only be two columns for the SMILES string and its corresponding label.

print("FULL Dataset: {}".format(df.shape))
print("TRAIN Dataset: {}".format(train_dataset.shape))
print("TEST Dataset: {}".format(test_dataset.shape))


FULL Dataset: (2142, 2)
TRAIN Dataset: (1714, 2)
TEST Dataset: (428, 2)

Now that we've set everything up, lets get to the fun part: training the model! We use Weights and Biases, which is optional (simply remove wandb_project from the list of args). Its a really useful tool for monitering the model's training results (such as accuracy, learning rate and loss), alongside with custom visualizations you can create as well as the gradients.

When you run this cell, Weights and Biases will ask for an account, which you can setup when you get a key through a Github account. Again, this is completely optional and it can be removed from the list of arguments.


In [23]:
!wandb login


wandb: You can find your API key in your browser here: https://app.wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter: 3453d85d7ddabfc34500f3fa6ac9ec2ba5683c2f
wandb: Appending key for api.wandb.ai to your netrc file: /root/.netrc
Successfully logged in to Weights & Biases!

In [24]:
# Create directory to store model weights (change path accordingly to where you want!)
!cd /content
!mkdir chemberta_tox21

# Train the model
model.train_model(train_dataset, output_dir='/content/chemberta_tox21', num_labels=2, use_cuda=True, args={'wandb_project': 'project-name'})


/usr/local/lib/python3.7/site-packages/simpletransformers/classification/classification_model.py:267: UserWarning: Dataframe headers not specified. Falling back to using column 0 as text and column 1 as labels.
  "Dataframe headers not specified. Falling back to using column 0 as text and column 1 as labels."
INFO:simpletransformers.classification.classification_model: Converting to features started. Cache is not used.
Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Warning:  multi_tensor_applier fused unscale kernel is unavailable, possibly because apex was installed without --cuda_ext --cpp_ext. Using Python fallback.  Original ImportError was: ModuleNotFoundError("No module named 'amp_C'",)
INFO:wandb.run_manager:system metrics and metadata threads started
INFO:wandb.run_manager:checking resume status, waiting at most 10 seconds
INFO:wandb.run_manager:resuming run from id: UnVuOnYxOnc1cDM0eG1oOnByb2plY3QtbmFtZTpzZXlvbmVj
INFO:wandb.run_manager:upserting run before process can begin, waiting at most 10 seconds
INFO:wandb.run_manager:saving pip packages
INFO:wandb.run_manager:initializing streaming files api
INFO:wandb.run_manager:unblocking file change observer, beginning sync with W&B servers
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200621_000615-w5p34xmh/config.yaml
INFO:wandb.run_manager:file/dir created: /content/wandb/run-20200621_000615-w5p34xmh/wandb-summary.json
INFO:wandb.run_manager:file/dir created: /content/wandb/run-20200621_000615-w5p34xmh/wandb-history.jsonl
INFO:wandb.run_manager:file/dir created: /content/wandb/run-20200621_000615-w5p34xmh/media/graph/graph_0_summary_692f3881.graph.json
INFO:wandb.run_manager:file/dir created: /content/wandb/run-20200621_000615-w5p34xmh/wandb-events.jsonl
INFO:wandb.run_manager:file/dir created: /content/wandb/run-20200621_000615-w5p34xmh/wandb-metadata.json
INFO:wandb.run_manager:file/dir created: /content/wandb/run-20200621_000615-w5p34xmh/requirements.txt
INFO:wandb.run_manager:file/dir created: /content/wandb/run-20200621_000615-w5p34xmh/media/graph
INFO:wandb.run_manager:file/dir created: /content/wandb/run-20200621_000615-w5p34xmh/media
Running loss: 1.016106
/usr/local/lib/python3.6/dist-packages/torch/optim/lr_scheduler.py:114: UserWarning: Seems like `optimizer.step()` has been overridden after learning rate scheduler initialization. Please, make sure to call `optimizer.step()` before `lr_scheduler.step()`. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
  "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
Running loss: 0.766425
/usr/local/lib/python3.6/dist-packages/torch/optim/lr_scheduler.py:231: UserWarning: To get the last learning rate computed by the scheduler, please use `get_last_lr()`.
  warnings.warn("To get the last learning rate computed by the scheduler, "
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200621_000615-w5p34xmh/wandb-history.jsonl
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200621_000615-w5p34xmh/wandb-summary.json
Running loss: 0.866304
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200621_000615-w5p34xmh/wandb-history.jsonl
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200621_000615-w5p34xmh/wandb-summary.json
Running loss: 0.331168
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200621_000615-w5p34xmh/wandb-history.jsonl
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200621_000615-w5p34xmh/wandb-summary.json
Running loss: 0.096342
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200621_000615-w5p34xmh/wandb-metadata.json
Running loss: 0.467952
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200621_000615-w5p34xmh/wandb-history.jsonl
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200621_000615-w5p34xmh/wandb-summary.json
Running loss: 0.324419
/usr/local/lib/python3.6/dist-packages/torch/optim/lr_scheduler.py:200: UserWarning: Please also save or load the state of the optimzer when saving or loading the scheduler.
  warnings.warn(SAVE_STATE_WARNING, UserWarning)
Running loss: 0.078696
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200621_000615-w5p34xmh/wandb-history.jsonl
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200621_000615-w5p34xmh/wandb-summary.json
Running loss: 0.686080
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200621_000615-w5p34xmh/wandb-events.jsonl
Running loss: 0.121916
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200621_000615-w5p34xmh/wandb-history.jsonl
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200621_000615-w5p34xmh/wandb-summary.json
Running loss: 0.513443
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200621_000615-w5p34xmh/wandb-metadata.json
Running loss: 0.120766
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200621_000615-w5p34xmh/wandb-history.jsonl
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200621_000615-w5p34xmh/wandb-summary.json
Running loss: 0.446782
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200621_000615-w5p34xmh/wandb-history.jsonl
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200621_000615-w5p34xmh/wandb-summary.json
Running loss: 0.229184
Running loss: 0.671774
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200621_000615-w5p34xmh/wandb-history.jsonl
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200621_000615-w5p34xmh/wandb-summary.json
Running loss: 0.015629
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200621_000615-w5p34xmh/wandb-metadata.json
Running loss: 0.053129
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200621_000615-w5p34xmh/wandb-history.jsonl
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200621_000615-w5p34xmh/wandb-summary.json
Running loss: 0.201588
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200621_000615-w5p34xmh/wandb-history.jsonl
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200621_000615-w5p34xmh/wandb-summary.json
Running loss: 0.021707
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200621_000615-w5p34xmh/wandb-events.jsonl
Running loss: 0.024193
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200621_000615-w5p34xmh/wandb-history.jsonl
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200621_000615-w5p34xmh/wandb-summary.json
Running loss: 0.031435
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200621_000615-w5p34xmh/wandb-metadata.json
Running loss: 0.002347

INFO:simpletransformers.classification.classification_model: Training of roberta model complete. Saved to /content/chemberta_tox21.
INFO:wandb.run_manager:shutting down system stats and metadata service
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200621_000615-w5p34xmh/wandb-events.jsonl
INFO:wandb.run_manager:stopping streaming files and file change observer
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200621_000615-w5p34xmh/wandb-metadata.json

Let's install scikit-learn now, to evaluate the model we've trained.


In [25]:
!pip install -U scikit-learn


Requirement already up-to-date: scikit-learn in /usr/local/lib/python3.7/site-packages (0.23.1)
Requirement already satisfied, skipping upgrade: scipy>=0.19.1 in /usr/local/lib/python3.7/site-packages (from scikit-learn) (1.4.1)
Requirement already satisfied, skipping upgrade: numpy>=1.13.3 in /usr/local/lib/python3.7/site-packages (from scikit-learn) (1.18.5)
Requirement already satisfied, skipping upgrade: threadpoolctl>=2.0.0 in /usr/local/lib/python3.7/site-packages (from scikit-learn) (2.1.0)
Requirement already satisfied, skipping upgrade: joblib>=0.11 in /usr/local/lib/python3.7/site-packages (from scikit-learn) (0.15.1)

The following cell can be ignored unless you are starting a new run-time and just want to load the model from your local directory.


In [ ]:
# Loading a saved model for evaluation
model = ClassificationModel('roberta', '/content/chemberta_tox21', num_labels=2, use_cuda=True, args={'wandb_project': 'project-name','num_train_epochs': 3})

In [26]:
import sklearn
result, model_outputs, wrong_predictions = model.eval_model(test_dataset, acc=sklearn.metrics.accuracy_score)


/usr/local/lib/python3.7/site-packages/simpletransformers/classification/classification_model.py:690: UserWarning: Dataframe headers not specified. Falling back to using column 0 as text and column 1 as labels.
  "Dataframe headers not specified. Falling back to using column 0 as text and column 1 as labels."
INFO:simpletransformers.classification.classification_model: Converting to features started. Cache is not used.

INFO:simpletransformers.classification.classification_model:{'mcc': 0.7851764343873741, 'tp': 65, 'tn': 334, 'fp': 5, 'fn': 24, 'acc': 0.9322429906542056, 'eval_loss': 0.19206710794457682}

The model performs pretty well, averaging above 91% after training on only ~2000 data samples and 400 positive leads! We can clearly see the predictive power of transfer learning, and approaches like these are becoming increasing popular in the pharmaceutical industry where larger datasets are scarce. By training on more epochs and tasks, we can probably boost the accuracy as well!

Lets train the model on one last string outside of the filtered dataset for toxicity. The model should predict 0, meaning no interference in biochemical pathways for p53.


In [27]:
# Lets input a molecule with a SR-p53 value of 0
predictions, raw_outputs = model.predict(['CCCCOc1cc(C(=O)OCCN(CC)CC)ccc1N'])


INFO:simpletransformers.classification.classification_model: Converting to features started. Cache is not used.



In [28]:
print(predictions)
print(raw_outputs)


[0]
[[ 3.0878906 -2.9765625]]

The model predicts the sample correctly! Some future tasks may include using the same model on multiple tasks (Tox21 provides multiple for toxicity), through multi-task classification, as well as training on a wider dataset. This will be expanded on in a future tutorial!

Congratulations! Time to join the Community!

Congratulations on completing this tutorial notebook! If you enjoyed working through the tutorial, and want to continue working with DeepChem, we encourage you to finish the rest of the tutorials in this series. You can also help the DeepChem community in the following ways:

Star DeepChem on Github

This helps build awareness of the DeepChem project and the tools for open source drug discovery that we're trying to build.

Join the DeepChem Gitter

The DeepChem Gitter hosts a number of scientists, developers, and enthusiasts interested in deep learning for the life sciences. Join the conversation!