By Seyone Chithrananda (Twitter)
Deep learning for chemistry and materials science remains a novel field with lots of potiential. However, the popularity of transfer learning based methods in areas such as NLP and computer vision have not yet been effectively developed in computational chemistry + machine learning. Using HuggingFace's suite of models and the ByteLevel tokenizer, we are able to train a large-transformer model, RoBERTa, on a large corpus of 100k SMILES strings from a commonly known benchmark chemistry dataset, ZINC.
Training RoBERTa over 5 epochs, the model achieves a pretty good loss of 0.398, and may likely continue to decrease if trained for a larger number of epochs. The model can predict tokens within a SMILES sequence/molecule, allowing for variants of a molecule within discoverable chemical space to be predicted.
By applying the representations of functional groups and atoms learned by the model, we can try to tackle problems of toxicity, solubility, drug-likeness, and synthesis accessibility on smaller datasets using the learned representations as features for graph convolution and attention models on the graph structure of molecules, as well as fine-tuning of BERT. Finally, we propose the use of attention visualization as a helpful tool for chemistry practitioners and students to quickly identify important substructures in various chemical properties.
Additionally, visualization of the attention mechanism have been seen through previous research as incredibly valuable towards chemical reaction classification. The applications of open-sourcing large-scale transformer models such as RoBERTa with HuggingFace may allow for the acceleration of these individual research directions.
A link to a repository which includes the training, uploading and evaluation notebook (with sample predictions on compounds such as Remdesivir) can be found here. All of the notebooks can be copied into a new Colab runtime for easy execution.
For the sake of this tutorial, we'll be fine-tuning RoBERTa on a small-scale molecule dataset, to show the potiential and effectiveness of HuggingFace's NLP-based transfer learning applied to computational chemistry. Output for some cells are purposely cleared for readability, so do not worry if some output messages for your cells differ!
Installing DeepChem from source, alongside RDKit for molecule visualizations
In [1]:
!pip install transformers
In [2]:
import sys
!test -d bertviz_repo && echo "FYI: bertviz_repo directory already exists, to pull latest version uncomment this line: !rm -r bertviz_repo"
# !rm -r bertviz_repo # Uncomment if you need a clean pull from repo
!test -d bertviz_repo || git clone https://github.com/jessevig/bertviz bertviz_repo
if not 'bertviz_repo' in sys.path:
sys.path += ['bertviz_repo']
!pip install regex
We want to install NVIDIA's Apex tool, for the training pipeline used by simple-transformers
and Weights and Biases.
In [ ]:
!git clone https://github.com/NVIDIA/apex
!cd /content/apex
!pip install -v --no-cache-dir /content/apex
!cd ..
Now, to ensure our model demonstrates an understanding of chemical syntax and molecular structure, we'll be testing it on predicting a masked token/character within the SMILES molecule for Remdesivir.
In [4]:
# Test if NVIDIA apex training tool works
from apex import amp
In [5]:
from transformers import AutoModelWithLMHead, AutoTokenizer, pipeline, RobertaModel, RobertaTokenizer
from bertviz import head_view
model = AutoModelWithLMHead.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
fill_mask = pipeline('fill-mask', model=model, tokenizer=tokenizer)
In [6]:
remdesivir_mask = "CCC(CC)COC(=O)[C@H](C)N[P@](=O)(OC[C@H]1O[C@](C#N)([C@H](O)[C@@H]1O)C1=CC=C2N1N=CN=C2N)OC1=CC=CC=<mask>1"
remdesivir = "CCC(CC)COC(=O)[C@H](C)N[P@](=O)(OC[C@H]1O[C@](C#N)([C@H](O)[C@@H]1O)C1=CC=C2N1N=CN=C2N)OC1=CC=CC=C1"
"CCC(CC)COC(=O)[C@H](C)N[P@](=O)(OC[C@H]1O[C@](C#N)([C@H](O)[C@@H]1O)C1=CC=C2N1N=CN=C2N)OC1=CC=CC=O1"
masked_smi = fill_mask(remdesivir_mask)
for smi in masked_smi:
print(smi)
Here, we get some interesting results. The final branch, C1=CC=CC=C1
, is a benzene ring. Since its a pretty common molecule, the model is easily able to predict the final double carbon bond with a score of 0.60. Let's get a list of the top 5 predictions (including the target, Remdesivir), and visualize them (with a highlighted focus on the beginning of the final benzene-like pattern). Lets import some various RDKit packages to do so.
In [ ]:
!wget -c https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
!chmod +x Miniconda3-latest-Linux-x86_64.sh
!bash ./Miniconda3-latest-Linux-x86_64.sh -b -f -p /usr/local
!time conda install -q -y -c conda-forge rdkit
import sys
sys.path.append('/usr/local/lib/python3.7/site-packages/')
In [8]:
import torch
import rdkit
import rdkit.Chem as Chem
from rdkit.Chem import rdFMCS
from matplotlib import colors
from rdkit.Chem import Draw
from rdkit.Chem.Draw import MolToImage
from PIL import Image
def get_mol(smiles):
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return None
Chem.Kekulize(mol)
return mol
def find_matches_one(mol,submol):
#find all matching atoms for each submol in submol_list in mol.
match_dict = {}
mols = [mol,submol] #pairwise search
res=rdFMCS.FindMCS(mols) #,ringMatchesRingOnly=True)
mcsp = Chem.MolFromSmarts(res.smartsString)
matches = mol.GetSubstructMatches(mcsp)
return matches
#Draw the molecule
def get_image(mol,atomset):
hcolor = colors.to_rgb('green')
if atomset is not None:
#highlight the atoms set while drawing the whole molecule.
img = MolToImage(mol, size=(600, 600),fitImage=True, highlightAtoms=atomset,highlightColor=hcolor)
else:
img = MolToImage(mol, size=(400, 400),fitImage=True)
return img
In [9]:
sequence = f"CCC(CC)COC(=O)[C@H](C)N[P@](=O)(OC[C@H]1O[C@](C#N)([C@H](O)[C@@H]1O)C1=CC=C2N1N=CN=C2N)OC1=CC=CC={tokenizer.mask_token}1"
substructure = "CC=CC"
image_list = []
input = tokenizer.encode(sequence, return_tensors="pt")
mask_token_index = torch.where(input == tokenizer.mask_token_id)[1]
token_logits = model(input)[0]
mask_token_logits = token_logits[0, mask_token_index, :]
top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()
for token in top_5_tokens:
smi = (sequence.replace(tokenizer.mask_token, tokenizer.decode([token])))
print (smi)
smi_mol = get_mol(smi)
substructure_mol = get_mol(substructure)
if smi_mol is None: # if the model's token prediction isn't chemically feasible
continue
Draw.MolToFile(smi_mol, smi+".png")
matches = find_matches_one(smi_mol, substructure_mol)
atomset = list(matches[0])
img = get_image(smi_mol, atomset)
img.format="PNG"
image_list.append(img)
In [10]:
from IPython.display import Image
for img in image_list:
display(img)
As we can see above, 2 of 4 of the model's MLM predictions are chemically valid. The one the model would've chosen (with a score of 0.6), is the first image, in which the top left molecular structure resembles the benzene found in the therapy Remdesivir. Overall, the model seems to understand syntax with a pretty decent degree of certainity.
However, further training on a more specific dataset (say leads for a specific target) may generate a stronger MLM model. Let's now fine-tune our model on a dataset of our choice, Tox21.
BertViz is a tool for visualizing attention in the Transformer model, supporting all models from the transformers library (BERT, GPT-2, XLNet, RoBERTa, XLM, CTRL, etc.). It extends the Tensor2Tensor visualization tool by Llion Jones and the transformers library from HuggingFace.
Using this tool, we can easily plug in CHemBERTa from the HuggingFace model hub and visualize the attention patterns produced by one or more attention heads in a given transformer layer. This is known as the attention-head view.
Lets start by obtaining a Javascript object for d3.js and jquery to create interactive visualizations:
In [11]:
%%javascript
require.config({
paths: {
d3: '//cdnjs.cloudflare.com/ajax/libs/d3/3.4.8/d3.min',
jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',
}
});
In [12]:
def call_html():
import IPython
display(IPython.core.display.HTML('''
<script src="/static/components/requirejs/require.js"></script>
<script>
requirejs.config({
paths: {
base: '/static/base',
"d3": "https://cdnjs.cloudflare.com/ajax/libs/d3/3.5.8/d3.min",
jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',
},
});
</script>
'''))
Now, we create an instance of ChemBERTa, tokenize a set of SMILES strings, and compute the attention for each head in the transformer. There are two available models hosted by DeepChem on HuggingFace's model hub, one being seyonec/ChemBERTa-zinc-base-v1
which is the ChemBERTa model trained via masked lagnuage modelling (MLM) on the ZINC100k dataset, and the other being seyonec/ChemBERTa-zinc250k-v1
, which is trained via MLM on the larger ZINC250k dataset.
In the following example, we take two SMILES molecules from the ZINC database with nearly identical chemical structure, the only difference being rooted in chiral specification (hence the additional ‘@‘
symbol). This is a feature of molecules which indicates that there exists tetrahedral centres. ‘@'
tells us whether the neighbours of a molecule appear in a counter-clockwise order, whereas ‘@@‘
indicates that the neighbours are ordered in a clockwise direction. The model should ideally refer to similar substructures in each SMILES string with a higher attention weightage.
Lets look at the first SMILES string: CCCCC[C@@H](Br)CC
:
In [13]:
m = Chem.MolFromSmiles('CCCCC[C@@H](Br)CC')
fig = Draw.MolToMPL(m, size=(200, 200))
And the second SMILES string, CCCCC[C@H](Br)CC
:
In [14]:
m = Chem.MolFromSmiles('CCCCC[C@H](Br)CC')
fig = Draw.MolToMPL(m, size=(200,200))
The visualization below shows the attention induced by a sample input SMILES. This view visualizes attention as lines connecting the tokens being updated (left) with the tokens being attended to (right), following the design of the figures above. Color intensity reflects the attention weight; weights close to one show as very dark lines, while weights close to zero appear as faint lines or are not visible at all. The user may highlight a particular SMILES character to see the attention from that token only. This visualization is called the attention-head view. It is based on the excellent Tensor2Tensor visualization tool, and are all generated by the Bertviz library.
In [15]:
from transformers import RobertaModel, RobertaTokenizer
from bertviz import head_view
model_version = 'seyonec/ChemBERTa-zinc250k-v1'
model = RobertaModel.from_pretrained(model_version, output_attentions=True)
tokenizer = RobertaTokenizer.from_pretrained(model_version)
sentence_a = "CCCCC[C@@H](Br)CC"
sentence_b = "CCCCC[C@H](Br)CC"
inputs = tokenizer.encode_plus(sentence_a, sentence_b, return_tensors='pt', add_special_tokens=True)
input_ids = inputs['input_ids']
attention = model(input_ids)[-1]
input_id_list = input_ids[0].tolist() # Batch index 0
tokens = tokenizer.convert_ids_to_tokens(input_id_list)
call_html()
head_view(attention, tokens)
The visualization shows that attention is highest between words that don’t cross a boundary between the two SMILES strings; the model seems to understand that it should relate tokens to other tokens in the same molecule in order to best understand their context.
There are many other fascinating visualizations we can do, such as a neuron-by neuron analysis of attention or a model overview that visualizes all of the heads at once:
Tumor suppressor protein (SR.p53), typically the p53 pathway is “off” and is activated when cells are under stress or damaged, hence being a good indicator of DNA damage and other cellular stresses. Tumor suppressor protein p53 is activated by inducing DNA repair, cell cycle arrest and apoptosis.
The Tox21 challenge was introduced in 2014 in an attempt to build models that are successful in predicting compounds' interference in biochemical pathways using only chemical structure data. The computational models produced from the challenge could become decision-making tools for government agencies in determining which environmental chemicals and drugs are of the greatest potential concern to human health. Additionally, these models can act as drug screening tools in the drug discovery pipelines for toxicity.
Lets start by loading the dataset from s3, before importing apex and transformers, the tool which will allow us to import the pre-trained masked-language modelling architecture trained on ZINC15.
In [16]:
!wget https://t.co/zrC7F8DcRs?amp=1
If you're only running the toxicity prediction portion of this tutorial, make sure you install transformers here. If you've ran all the cells before, you can ignore this install as we've already done pip install transformers
before.
In [ ]:
!pip install transformers
In [ ]:
!pip install simpletransformers
!pip install wandb
From here, we want to load the dataset from tox21 for training the model. We're going to use a filtered dataset of 2100 compounds, as there are only 400 positive leads and we want to avoid having a large data imbalance. We'll also use simple-transformer's auto_weights
argument in defining our ChemBERTa model to do automatic weight balancing later on, to counteract this problem.
In [18]:
import pandas as pd
!cd ..
dataset_path = "/content/zrC7F8DcRs?amp=1"
df = pd.read_csv(dataset_path, sep = ',', warn_bad_lines=True, header=None)
df.rename(columns={0:'smiles',1:'labels'}, inplace=True)
df.head()
Out[18]:
From here, lets set up a logger to record if any issues occur, and notify us if there are any problems with the arguments we've set for the model.
In [19]:
from simpletransformers.classification import ClassificationModel
import logging
logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.WARNING)
Now, using simple-transformer
, let's load the pre-trained model from HuggingFace's useful model-hub. We'll set the number of epochs to 3 in the arguments, but you can train for longer. Also make sure that auto_weights
is set to True as we are dealing with imbalanced toxicity datasets.
In [20]:
model = ClassificationModel('roberta', 'seyonec/ChemBERTa-zinc-base-v1', args={'num_train_epochs': 3, 'auto_weights': True}) # You can set class weights by using the optional weight argument
In [21]:
# Split the train and test dataset 80-20
train_size = 0.8
train_dataset=df.sample(frac=train_size,random_state=200).reset_index(drop=True)
test_dataset=df.drop(train_dataset.index).reset_index(drop=True)
In [22]:
# check if our train and evaluation dataframes are setup properly. There should only be two columns for the SMILES string and its corresponding label.
print("FULL Dataset: {}".format(df.shape))
print("TRAIN Dataset: {}".format(train_dataset.shape))
print("TEST Dataset: {}".format(test_dataset.shape))
Now that we've set everything up, lets get to the fun part: training the model! We use Weights and Biases, which is optional (simply remove wandb_project
from the list of args). Its a really useful tool for monitering the model's training results (such as accuracy, learning rate and loss), alongside with custom visualizations you can create as well as the gradients.
When you run this cell, Weights and Biases will ask for an account, which you can setup when you get a key through a Github account. Again, this is completely optional and it can be removed from the list of arguments.
In [23]:
!wandb login
In [24]:
# Create directory to store model weights (change path accordingly to where you want!)
!cd /content
!mkdir chemberta_tox21
# Train the model
model.train_model(train_dataset, output_dir='/content/chemberta_tox21', num_labels=2, use_cuda=True, args={'wandb_project': 'project-name'})
Let's install scikit-learn now, to evaluate the model we've trained.
In [25]:
!pip install -U scikit-learn
The following cell can be ignored unless you are starting a new run-time and just want to load the model from your local directory.
In [ ]:
# Loading a saved model for evaluation
model = ClassificationModel('roberta', '/content/chemberta_tox21', num_labels=2, use_cuda=True, args={'wandb_project': 'project-name','num_train_epochs': 3})
In [26]:
import sklearn
result, model_outputs, wrong_predictions = model.eval_model(test_dataset, acc=sklearn.metrics.accuracy_score)
The model performs pretty well, averaging above 91% after training on only ~2000 data samples and 400 positive leads! We can clearly see the predictive power of transfer learning, and approaches like these are becoming increasing popular in the pharmaceutical industry where larger datasets are scarce. By training on more epochs and tasks, we can probably boost the accuracy as well!
Lets train the model on one last string outside of the filtered dataset for toxicity. The model should predict 0, meaning no interference in biochemical pathways for p53.
In [27]:
# Lets input a molecule with a SR-p53 value of 0
predictions, raw_outputs = model.predict(['CCCCOc1cc(C(=O)OCCN(CC)CC)ccc1N'])
In [28]:
print(predictions)
print(raw_outputs)
The model predicts the sample correctly! Some future tasks may include using the same model on multiple tasks (Tox21 provides multiple for toxicity), through multi-task classification, as well as training on a wider dataset. This will be expanded on in a future tutorial!
Congratulations on completing this tutorial notebook! If you enjoyed working through the tutorial, and want to continue working with DeepChem, we encourage you to finish the rest of the tutorials in this series. You can also help the DeepChem community in the following ways:
This helps build awareness of the DeepChem project and the tools for open source drug discovery that we're trying to build.
The DeepChem Gitter hosts a number of scientists, developers, and enthusiasts interested in deep learning for the life sciences. Join the conversation!