By Seyone Chithrananda (Twitter)
Deep learning for chemistry and materials science remains a novel field with lots of potiential. However, the popularity of transfer learning based methods in areas such as NLP and computer vision have not yet been effectively developed in computational chemistry + machine learning. Using HuggingFace's suite of models and the ByteLevel tokenizer, we are able to train a large-transformer model, RoBERTa, on a large corpus of 100k SMILES strings from a commonly known benchmark chemistry dataset, ZINC.
Training RoBERTa over 5 epochs, the model achieves a pretty good loss of 0.398, and may likely continue to decrease if trained for a larger number of epochs. The model can predict tokens within a SMILES sequence/molecule, allowing for variants of a molecule within discoverable chemical space to be predicted.
By applying the representations of functional groups and atoms learned by the model, we can try to tackle problems of toxicity, solubility, drug-likeness, and synthesis accessibility on smaller datasets using the learned representations as features for graph convolution and attention models on the graph structure of molecules, as well as fine-tuning of BERT. Finally, we propose the use of attention visualization as a helpful tool for chemistry practitioners and students to quickly identify important substructures in various chemical properties.
Additionally, visualization of the attention mechanism have been seen through previous research as incredibly valuable towards chemical reaction classification. The applications of open-sourcing large-scale transformer models such as RoBERTa with HuggingFace may allow for the acceleration of these individual research directions.
A link to a repository which includes the training, uploading and evaluation notebook (with sample predictions on compounds such as Remdesivir) can be found here. All of the notebooks can be copied into a new Colab runtime for easy execution.
For the sake of this tutorial, we'll be fine-tuning RoBERTa on a small-scale molecule dataset, to show the potiential and effectiveness of HuggingFace's NLP-based transfer learning applied to computational chemistry. Output for some cells are purposely cleared for readability, so do not worry if some output messages for your cells differ!
Installing DeepChem from source, alongside RDKit for molecule visualizations
In [1]:
!pip install transformers
In [2]:
import sys
!test -d bertviz_repo && echo "FYI: bertviz_repo directory already exists, to pull latest version uncomment this line: !rm -r bertviz_repo"
# !rm -r bertviz_repo # Uncomment if you need a clean pull from repo
!test -d bertviz_repo || git clone https://github.com/jessevig/bertviz bertviz_repo
if not 'bertviz_repo' in sys.path:
sys.path += ['bertviz_repo']
!pip install regex
We want to install NVIDIA's Apex tool, for the training pipeline used by simple-transformers
and Weights and Biases.
In [ ]:
!git clone https://github.com/NVIDIA/apex
!cd /content/apex
!pip install -v --no-cache-dir /content/apex
!cd ..
Now, to ensure our model demonstrates an understanding of chemical syntax and molecular structure, we'll be testing it on predicting a masked token/character within the SMILES molecule for Remdesivir.
In [4]:
# Test if NVIDIA apex training tool works
from apex import amp
In [5]:
from transformers import AutoModelWithLMHead, AutoTokenizer, pipeline, RobertaModel, RobertaTokenizer
from bertviz import head_view
model = AutoModelWithLMHead.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
fill_mask = pipeline('fill-mask', model=model, tokenizer=tokenizer)
In [6]:
remdesivir_mask = "CCC(CC)COC(=O)[C@H](C)N[P@](=O)(OC[C@H]1O[C@](C#N)([C@H](O)[C@@H]1O)C1=CC=C2N1N=CN=C2N)OC1=CC=CC=<mask>1"
remdesivir = "CCC(CC)COC(=O)[C@H](C)N[P@](=O)(OC[C@H]1O[C@](C#N)([C@H](O)[C@@H]1O)C1=CC=C2N1N=CN=C2N)OC1=CC=CC=C1"
"CCC(CC)COC(=O)[C@H](C)N[P@](=O)(OC[C@H]1O[C@](C#N)([C@H](O)[C@@H]1O)C1=CC=C2N1N=CN=C2N)OC1=CC=CC=O1"
masked_smi = fill_mask(remdesivir_mask)
for smi in masked_smi:
print(smi)
Here, we get some interesting results. The final branch, C1=CC=CC=C1
, is a benzene ring. Since its a pretty common molecule, the model is easily able to predict the final double carbon bond with a score of 0.60. Let's get a list of the top 5 predictions (including the target, Remdesivir), and visualize them (with a highlighted focus on the beginning of the final benzene-like pattern). Lets import some various RDKit packages to do so.
In [ ]:
!wget -c https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
!chmod +x Miniconda3-latest-Linux-x86_64.sh
!bash ./Miniconda3-latest-Linux-x86_64.sh -b -f -p /usr/local
!time conda install -q -y -c conda-forge rdkit
import sys
sys.path.append('/usr/local/lib/python3.7/site-packages/')
In [8]:
import torch
import rdkit
import rdkit.Chem as Chem
from rdkit.Chem import rdFMCS
from matplotlib import colors
from rdkit.Chem import Draw
from rdkit.Chem.Draw import MolToImage
from PIL import Image
def get_mol(smiles):
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return None
Chem.Kekulize(mol)
return mol
def find_matches_one(mol,submol):
#find all matching atoms for each submol in submol_list in mol.
match_dict = {}
mols = [mol,submol] #pairwise search
res=rdFMCS.FindMCS(mols) #,ringMatchesRingOnly=True)
mcsp = Chem.MolFromSmarts(res.smartsString)
matches = mol.GetSubstructMatches(mcsp)
return matches
#Draw the molecule
def get_image(mol,atomset):
hcolor = colors.to_rgb('green')
if atomset is not None:
#highlight the atoms set while drawing the whole molecule.
img = MolToImage(mol, size=(600, 600),fitImage=True, highlightAtoms=atomset,highlightColor=hcolor)
else:
img = MolToImage(mol, size=(400, 400),fitImage=True)
return img
In [9]:
sequence = f"CCC(CC)COC(=O)[C@H](C)N[P@](=O)(OC[C@H]1O[C@](C#N)([C@H](O)[C@@H]1O)C1=CC=C2N1N=CN=C2N)OC1=CC=CC={tokenizer.mask_token}1"
substructure = "CC=CC"
image_list = []
input = tokenizer.encode(sequence, return_tensors="pt")
mask_token_index = torch.where(input == tokenizer.mask_token_id)[1]
token_logits = model(input)[0]
mask_token_logits = token_logits[0, mask_token_index, :]
top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()
for token in top_5_tokens:
smi = (sequence.replace(tokenizer.mask_token, tokenizer.decode([token])))
print (smi)
smi_mol = get_mol(smi)
substructure_mol = get_mol(substructure)
if smi_mol is None: # if the model's token prediction isn't chemically feasible
continue
Draw.MolToFile(smi_mol, smi+".png")
matches = find_matches_one(smi_mol, substructure_mol)
atomset = list(matches[0])
img = get_image(smi_mol, atomset)
img.format="PNG"
image_list.append(img)
In [10]:
from IPython.display import Image
for img in image_list:
display(img)
As we can see above, 2 of 4 of the model's MLM predictions are chemically valid. The one the model would've chosen (with a score of 0.6), is the first image, in which the top left molecular structure resembles the benzene found in the therapy Remdesivir. Overall, the model seems to understand syntax with a pretty decent degree of certainity.
However, further training on a more specific dataset (say leads for a specific target) may generate a stronger MLM model. Let's now fine-tune our model on a dataset of our choice, Tox21.
BertViz is a tool for visualizing attention in the Transformer model, supporting all models from the transformers library (BERT, GPT-2, XLNet, RoBERTa, XLM, CTRL, etc.). It extends the Tensor2Tensor visualization tool by Llion Jones and the transformers library from HuggingFace.
Using this tool, we can easily plug in CHemBERTa from the HuggingFace model hub and visualize the attention patterns produced by one or more attention heads in a given transformer layer. This is known as the attention-head view.
Lets start by obtaining a Javascript object for d3.js and jquery to create interactive visualizations:
In [11]:
%%javascript
require.config({
paths: {
d3: '//cdnjs.cloudflare.com/ajax/libs/d3/3.4.8/d3.min',
jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',
}
});
In [12]:
def call_html():
import IPython
display(IPython.core.display.HTML('''
<script src="/static/components/requirejs/require.js"></script>
<script>
requirejs.config({
paths: {
base: '/static/base',
"d3": "https://cdnjs.cloudflare.com/ajax/libs/d3/3.5.8/d3.min",
jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',
},
});
</script>
'''))
Now, we create an instance of ChemBERTa, tokenize a set of SMILES strings, and compute the attention for each head in the transformer. There are two available models hosted by DeepChem on HuggingFace's model hub, one being seyonec/ChemBERTa-zinc-base-v1
which is the ChemBERTa model trained via masked lagnuage modelling (MLM) on the ZINC100k dataset, and the other being seyonec/ChemBERTa-zinc250k-v1
, which is trained via MLM on the larger ZINC250k dataset.
In the following example, we take two SMILES molecules from the ZINC database with nearly identical chemical structure, the only difference being rooted in chiral specification (hence the additional ‘@‘
symbol). This is a feature of molecules which indicates that there exists tetrahedral centres. ‘@'
tells us whether the neighbours of a molecule appear in a counter-clockwise order, whereas ‘@@‘
indicates that the neighbours are ordered in a clockwise direction. The model should ideally refer to similar substructures in each SMILES string with a higher attention weightage.
Lets look at the first SMILES string: CCCCC[C@@H](Br)CC
:
In [13]:
m = Chem.MolFromSmiles('CCCCC[C@@H](Br)CC')
fig = Draw.MolToMPL(m, size=(200, 200))
And the second SMILES string, CCCCC[C@H](Br)CC
:
In [14]:
m = Chem.MolFromSmiles('CCCCC[C@H](Br)CC')
fig = Draw.MolToMPL(m, size=(200,200))
The visualization below shows the attention induced by a sample input SMILES. This view visualizes attention as lines connecting the tokens being updated (left) with the tokens being attended to (right), following the design of the figures above. Color intensity reflects the attention weight; weights close to one show as very dark lines, while weights close to zero appear as faint lines or are not visible at all. The user may highlight a particular SMILES character to see the attention from that token only. This visualization is called the attention-head view. It is based on the excellent Tensor2Tensor visualization tool, and are all generated by the Bertviz library.
In [15]:
from transformers import RobertaModel, RobertaTokenizer
from bertviz import head_view
model_version = 'seyonec/ChemBERTa-zinc250k-v1'
model = RobertaModel.from_pretrained(model_version, output_attentions=True)
tokenizer = RobertaTokenizer.from_pretrained(model_version)
sentence_a = "CCCCC[C@@H](Br)CC"
sentence_b = "CCCCC[C@H](Br)CC"
inputs = tokenizer.encode_plus(sentence_a, sentence_b, return_tensors='pt', add_special_tokens=True)
input_ids = inputs['input_ids']
attention = model(input_ids)[-1]
input_id_list = input_ids[0].tolist() # Batch index 0
tokens = tokenizer.convert_ids_to_tokens(input_id_list)
call_html()
head_view(attention, tokens)