Tutorial Part 22: Finetuning HuggingFace's RoBERTa for Tox21 Toxicity Predictions on SMILES Strings

By Seyone Chithrananda

Deep learning for chemistry and materials science remains a novel field with lots of potiential. However, the popularity of transfer learning based methods in areas such as NLP and computer vision have not yet been effectively developed in computational chemistry + machine learning. Using HuggingFace's suite of models and the ByteLevel tokenizer, we are able to train a large-transformer model, RoBERTa, on a large corpus of 100k SMILES strings from a commonly known benchmark chemistry dataset, ZINC.

Training RoBERTa over 5 epochs, the model achieves a pretty good loss of 0.398, and may likely continue to decrease if trained for a larger number of epochs. The model can predict tokens within a SMILES sequence/molecule, allowing for variants of a molecule within discoverable chemical space to be predicted.

By applying the representations of functional groups and atoms learned by the model, we can try to tackle problems of toxicity, solubility, drug-likeness, and synthesis accessibility on smaller datasets using the learned representations as features for graph convolution and attention models on the graph structure of molecules, as well as fine-tuning of BERT. Finally, we propose the use of attention visualization as a helpful tool for chemistry practitioners and students to quickly identify important substructures in various chemical properties.

Additionally, visualization of the attention mechanism have been seen through previous research as incredibly valuable towards chemical reaction classification. The applications of open-sourcing large-scale transformer models such as RoBERTa with HuggingFace may allow for the acceleration of these individual research directions.

A link to a repository which includes the training, uploading and evaluation notebook (with sample predictions on compounds such as Remdesivir) can be found here. All of the notebooks can be copied into a new Colab runtime for easy execution.

For the sake of this tutorial, we'll be fine-tuning RoBERTa on a small-scale molecule dataset, to show the potiential and effectiveness of HuggingFace's NLP-based transfer learning applied to computational chemistry.

Installing DeepChem from source, alongside RDKit for molecule visualizations


In [1]:
import tensorflow as tf
print("tf.__version__: %s" % str(tf.__version__))
device_name = tf.test.gpu_device_name()
if not device_name:
  raise SystemError('GPU device not found')
print('Found GPU at: {}'.format(device_name))


tf.__version__: 2.2.0-rc3
Found GPU at: /device:GPU:0

In [2]:
!pip install transformers


Collecting transformers
  Downloading https://files.pythonhosted.org/packages/a3/78/92cedda05552398352ed9784908b834ee32a0bd071a9b32de287327370b7/transformers-2.8.0-py3-none-any.whl (563kB)
     |████████████████████████████████| 573kB 12.3MB/s 
Collecting tokenizers==0.5.2
  Downloading https://files.pythonhosted.org/packages/d1/3f/73c881ea4723e43c1e9acf317cf407fab3a278daab3a69c98dcac511c04f/tokenizers-0.5.2-cp36-cp36m-manylinux1_x86_64.whl (3.7MB)
     |████████████████████████████████| 3.7MB 47.4MB/s 
Requirement already satisfied: boto3 in /usr/local/lib/python3.6/dist-packages (from transformers) (1.12.40)
Collecting sentencepiece
  Downloading https://files.pythonhosted.org/packages/74/f4/2d5214cbf13d06e7cb2c20d84115ca25b53ea76fa1f0ade0e3c9749de214/sentencepiece-0.1.85-cp36-cp36m-manylinux1_x86_64.whl (1.0MB)
     |████████████████████████████████| 1.0MB 27.1MB/s 
Requirement already satisfied: dataclasses; python_version < "3.7" in /usr/local/lib/python3.6/dist-packages (from transformers) (0.7)
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.6/dist-packages (from transformers) (2019.12.20)
Collecting sacremoses
  Downloading https://files.pythonhosted.org/packages/99/50/93509f906a40bffd7d175f97fd75ea328ad9bd91f48f59c4bd084c94a25e/sacremoses-0.0.41.tar.gz (883kB)
     |████████████████████████████████| 890kB 55.6MB/s 
Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from transformers) (1.18.2)
Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from transformers) (2.21.0)
Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.6/dist-packages (from transformers) (4.38.0)
Requirement already satisfied: filelock in /usr/local/lib/python3.6/dist-packages (from transformers) (3.0.12)
Requirement already satisfied: jmespath<1.0.0,>=0.7.1 in /usr/local/lib/python3.6/dist-packages (from boto3->transformers) (0.9.5)
Requirement already satisfied: s3transfer<0.4.0,>=0.3.0 in /usr/local/lib/python3.6/dist-packages (from boto3->transformers) (0.3.3)
Requirement already satisfied: botocore<1.16.0,>=1.15.40 in /usr/local/lib/python3.6/dist-packages (from boto3->transformers) (1.15.40)
Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (1.12.0)
Requirement already satisfied: click in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (7.1.1)
Requirement already satisfied: joblib in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (0.14.1)
Requirement already satisfied: idna<2.9,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (2.8)
Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (3.0.4)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (2020.4.5.1)
Requirement already satisfied: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (1.24.3)
Requirement already satisfied: python-dateutil<3.0.0,>=2.1 in /usr/local/lib/python3.6/dist-packages (from botocore<1.16.0,>=1.15.40->boto3->transformers) (2.8.1)
Requirement already satisfied: docutils<0.16,>=0.10 in /usr/local/lib/python3.6/dist-packages (from botocore<1.16.0,>=1.15.40->boto3->transformers) (0.15.2)
Building wheels for collected packages: sacremoses
  Building wheel for sacremoses (setup.py) ... done
  Created wheel for sacremoses: filename=sacremoses-0.0.41-cp36-none-any.whl size=893334 sha256=49159e7f355e0b67097229550cdd83f0d2fdb44567a06534b35c19bfedadfcc3
  Stored in directory: /root/.cache/pip/wheels/22/5a/d4/b020a81249de7dc63758a34222feaa668dbe8ebfe9170cc9b1
Successfully built sacremoses
Installing collected packages: tokenizers, sentencepiece, sacremoses, transformers
Successfully installed sacremoses-0.0.41 sentencepiece-0.1.85 tokenizers-0.5.2 transformers-2.8.0

Now, to ensure our model demonstrates an understanding of chemical syntax and molecular structure, we'll be testing it on predicting a masked token/character within the SMILES molecule for Remdesivir.


In [3]:
from transformers import AutoModelWithLMHead, AutoTokenizer, pipeline

model = AutoModelWithLMHead.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")

fill_mask = pipeline('fill-mask', model=model, tokenizer=tokenizer)









In [4]:
remdesivir_mask = "CCC(CC)COC(=O)[C@H](C)N[P@](=O)(OC[C@H]1O[C@](C#N)([C@H](O)[C@@H]1O)C1=CC=C2N1N=CN=C2N)OC1=CC=CC=<mask>1"
remdesivir = "CCC(CC)COC(=O)[C@H](C)N[P@](=O)(OC[C@H]1O[C@](C#N)([C@H](O)[C@@H]1O)C1=CC=C2N1N=CN=C2N)OC1=CC=CC=C1"

"CCC(CC)COC(=O)[C@H](C)N[P@](=O)(OC[C@H]1O[C@](C#N)([C@H](O)[C@@H]1O)C1=CC=C2N1N=CN=C2N)OC1=CC=CC=O1"

masked_smi = fill_mask(remdesivir_mask)

for smi in masked_smi:
  print(smi)


{'sequence': '<s> CCC(CC)COC(=O)[C@H](C)N[P@](=O)(OC[C@H]1O[C@](C#N)([C@H](O)[C@@H]1O)C1=CC=C2N1N=CN=C2N)OC1=CC=CC=C1</s>', 'score': 0.5986586809158325, 'token': 39}
{'sequence': '<s> CCC(CC)COC(=O)[C@H](C)N[P@](=O)(OC[C@H]1O[C@](C#N)([C@H](O)[C@@H]1O)C1=CC=C2N1N=CN=C2N)OC1=CC=CC=O1</s>', 'score': 0.09766950458288193, 'token': 51}
{'sequence': '<s> CCC(CC)COC(=O)[C@H](C)N[P@](=O)(OC[C@H]1O[C@](C#N)([C@H](O)[C@@H]1O)C1=CC=C2N1N=CN=C2N)OC1=CC=CC=N1</s>', 'score': 0.07694468647241592, 'token': 50}
{'sequence': '<s> CCC(CC)COC(=O)[C@H](C)N[P@](=O)(OC[C@H]1O[C@](C#N)([C@H](O)[C@@H]1O)C1=CC=C2N1N=CN=C2N)OC1=CC=CC=21</s>', 'score': 0.0241263248026371, 'token': 22}
{'sequence': '<s> CCC(CC)COC(=O)[C@H](C)N[P@](=O)(OC[C@H]1O[C@](C#N)([C@H](O)[C@@H]1O)C1=CC=C2N1N=CN=C2N)OC1=CC=CC=H1</s>', 'score': 0.0188530795276165, 'token': 44}

Here, we get some interesting results. The final branch, C1=CC=CC=C1, is a benzene ring. Since its a pretty common molecule, the model is easily able to predict the final double carbon bond with a score of 0.60. Let's get a list of the top 5 predictions (including the target, Remdesivir), and visualize them (with a highlighted focus on the beginning of the final benzene-like pattern). Lets import some various RDKit packages to do so.


In [5]:
!wget -c https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
!chmod +x Miniconda3-latest-Linux-x86_64.sh
!bash ./Miniconda3-latest-Linux-x86_64.sh -b -f -p /usr/local
!time conda install -q -y -c conda-forge rdkit
import sys
sys.path.append('/usr/local/lib/python3.7/site-packages/')


--2020-04-24 04:23:25--  https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
Resolving repo.anaconda.com (repo.anaconda.com)... 104.16.130.3, 104.16.131.3, 2606:4700::6810:8203, ...
Connecting to repo.anaconda.com (repo.anaconda.com)|104.16.130.3|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 85055499 (81M) [application/x-sh]
Saving to: ‘Miniconda3-latest-Linux-x86_64.sh’

Miniconda3-latest-L 100%[===================>]  81.12M  55.6MB/s    in 1.5s    

2020-04-24 04:23:27 (55.6 MB/s) - ‘Miniconda3-latest-Linux-x86_64.sh’ saved [85055499/85055499]

PREFIX=/usr/local
Unpacking payload ...
Collecting package metadata (current_repodata.json): - \ done
Solving environment: / - done

## Package Plan ##

  environment location: /usr/local

  added / updated specs:
    - _libgcc_mutex==0.1=main
    - asn1crypto==1.3.0=py37_0
    - ca-certificates==2020.1.1=0
    - certifi==2019.11.28=py37_0
    - cffi==1.14.0=py37h2e261b9_0
    - chardet==3.0.4=py37_1003
    - conda-package-handling==1.6.0=py37h7b6447c_0
    - conda==4.8.2=py37_0
    - cryptography==2.8=py37h1ba5d50_0
    - idna==2.8=py37_0
    - ld_impl_linux-64==2.33.1=h53a641e_7
    - libedit==3.1.20181209=hc058e9b_0
    - libffi==3.2.1=hd88cf55_4
    - libgcc-ng==9.1.0=hdf63c60_0
    - libstdcxx-ng==9.1.0=hdf63c60_0
    - ncurses==6.2=he6710b0_0
    - openssl==1.1.1d=h7b6447c_4
    - pip==20.0.2=py37_1
    - pycosat==0.6.3=py37h7b6447c_0
    - pycparser==2.19=py37_0
    - pyopenssl==19.1.0=py37_0
    - pysocks==1.7.1=py37_0
    - python==3.7.6=h0371630_2
    - readline==7.0=h7b6447c_5
    - requests==2.22.0=py37_1
    - ruamel_yaml==0.15.87=py37h7b6447c_0
    - setuptools==45.2.0=py37_0
    - six==1.14.0=py37_0
    - sqlite==3.31.1=h7b6447c_0
    - tk==8.6.8=hbc83047_0
    - tqdm==4.42.1=py_0
    - urllib3==1.25.8=py37_0
    - wheel==0.34.2=py37_0
    - xz==5.2.4=h14c3975_4
    - yaml==0.1.7=had09818_2
    - zlib==1.2.11=h7b6447c_3


The following NEW packages will be INSTALLED:

  _libgcc_mutex      pkgs/main/linux-64::_libgcc_mutex-0.1-main
  asn1crypto         pkgs/main/linux-64::asn1crypto-1.3.0-py37_0
  ca-certificates    pkgs/main/linux-64::ca-certificates-2020.1.1-0
  certifi            pkgs/main/linux-64::certifi-2019.11.28-py37_0
  cffi               pkgs/main/linux-64::cffi-1.14.0-py37h2e261b9_0
  chardet            pkgs/main/linux-64::chardet-3.0.4-py37_1003
  conda              pkgs/main/linux-64::conda-4.8.2-py37_0
  conda-package-han~ pkgs/main/linux-64::conda-package-handling-1.6.0-py37h7b6447c_0
  cryptography       pkgs/main/linux-64::cryptography-2.8-py37h1ba5d50_0
  idna               pkgs/main/linux-64::idna-2.8-py37_0
  ld_impl_linux-64   pkgs/main/linux-64::ld_impl_linux-64-2.33.1-h53a641e_7
  libedit            pkgs/main/linux-64::libedit-3.1.20181209-hc058e9b_0
  libffi             pkgs/main/linux-64::libffi-3.2.1-hd88cf55_4
  libgcc-ng          pkgs/main/linux-64::libgcc-ng-9.1.0-hdf63c60_0
  libstdcxx-ng       pkgs/main/linux-64::libstdcxx-ng-9.1.0-hdf63c60_0
  ncurses            pkgs/main/linux-64::ncurses-6.2-he6710b0_0
  openssl            pkgs/main/linux-64::openssl-1.1.1d-h7b6447c_4
  pip                pkgs/main/linux-64::pip-20.0.2-py37_1
  pycosat            pkgs/main/linux-64::pycosat-0.6.3-py37h7b6447c_0
  pycparser          pkgs/main/linux-64::pycparser-2.19-py37_0
  pyopenssl          pkgs/main/linux-64::pyopenssl-19.1.0-py37_0
  pysocks            pkgs/main/linux-64::pysocks-1.7.1-py37_0
  python             pkgs/main/linux-64::python-3.7.6-h0371630_2
  readline           pkgs/main/linux-64::readline-7.0-h7b6447c_5
  requests           pkgs/main/linux-64::requests-2.22.0-py37_1
  ruamel_yaml        pkgs/main/linux-64::ruamel_yaml-0.15.87-py37h7b6447c_0
  setuptools         pkgs/main/linux-64::setuptools-45.2.0-py37_0
  six                pkgs/main/linux-64::six-1.14.0-py37_0
  sqlite             pkgs/main/linux-64::sqlite-3.31.1-h7b6447c_0
  tk                 pkgs/main/linux-64::tk-8.6.8-hbc83047_0
  tqdm               pkgs/main/noarch::tqdm-4.42.1-py_0
  urllib3            pkgs/main/linux-64::urllib3-1.25.8-py37_0
  wheel              pkgs/main/linux-64::wheel-0.34.2-py37_0
  xz                 pkgs/main/linux-64::xz-5.2.4-h14c3975_4
  yaml               pkgs/main/linux-64::yaml-0.1.7-had09818_2
  zlib               pkgs/main/linux-64::zlib-1.2.11-h7b6447c_3


Preparing transaction: | / - done
Executing transaction: | / - \ | / - \ | / - \ | / - \ | / done
installation finished.
WARNING:
    You currently have a PYTHONPATH environment variable set. This may cause
    unexpected behavior when running the Python interpreter in Miniconda3.
    For best results, please verify that your PYTHONPATH only points to
    directories of packages that are compatible with the Python interpreter
    in Miniconda3: /usr/local
Collecting package metadata (current_repodata.json): ...working... done
Solving environment: ...working... done

## Package Plan ##

  environment location: /usr/local

  added / updated specs:
    - rdkit


The following packages will be downloaded:

    package                    |            build
    ---------------------------|-----------------
    boost-1.72.0               |   py37h9de70de_0         316 KB  conda-forge
    boost-cpp-1.72.0           |       h8e57a91_0        21.8 MB  conda-forge
    bzip2-1.0.8                |       h516909a_2         396 KB  conda-forge
    ca-certificates-2020.4.5.1 |       hecc5488_0         146 KB  conda-forge
    cairo-1.16.0               |    hcf35c78_1003         1.5 MB  conda-forge
    certifi-2020.4.5.1         |   py37hc8dfbb8_0         151 KB  conda-forge
    conda-4.8.3                |   py37hc8dfbb8_1         3.0 MB  conda-forge
    fontconfig-2.13.1          |    h86ecdb6_1001         340 KB  conda-forge
    freetype-2.10.1            |       he06d7ca_0         877 KB  conda-forge
    gettext-0.19.8.1           |    hc5be6a0_1002         3.6 MB  conda-forge
    glib-2.64.2                |       h6f030ca_0         3.4 MB  conda-forge
    icu-64.2                   |       he1b5a44_1        12.6 MB  conda-forge
    jpeg-9c                    |    h14c3975_1001         251 KB  conda-forge
    libblas-3.8.0              |      14_openblas          10 KB  conda-forge
    libcblas-3.8.0             |      14_openblas          10 KB  conda-forge
    libgfortran-ng-7.3.0       |       hdf63c60_5         1.7 MB  conda-forge
    libiconv-1.15              |    h516909a_1006         2.0 MB  conda-forge
    liblapack-3.8.0            |      14_openblas          10 KB  conda-forge
    libopenblas-0.3.7          |       h5ec1e0e_6         7.6 MB  conda-forge
    libpng-1.6.37              |       hed695b0_1         308 KB  conda-forge
    libtiff-4.1.0              |       hc7e4089_6         668 KB  conda-forge
    libuuid-2.32.1             |    h14c3975_1000          26 KB  conda-forge
    libwebp-base-1.1.0         |       h516909a_3         845 KB  conda-forge
    libxcb-1.13                |    h14c3975_1002         396 KB  conda-forge
    libxml2-2.9.10             |       hee79883_0         1.3 MB  conda-forge
    lz4-c-1.8.3                |    he1b5a44_1001         187 KB  conda-forge
    numpy-1.18.1               |   py37h8960a57_1         5.2 MB  conda-forge
    olefile-0.46               |             py_0          31 KB  conda-forge
    openssl-1.1.1g             |       h516909a_0         2.1 MB  conda-forge
    pandas-1.0.3               |   py37h0da4684_1        11.1 MB  conda-forge
    pcre-8.44                  |       he1b5a44_0         261 KB  conda-forge
    pillow-7.0.0               |   py37hb39fc2d_0         598 KB
    pixman-0.38.0              |    h516909a_1003         594 KB  conda-forge
    pthread-stubs-0.4          |    h14c3975_1001           5 KB  conda-forge
    pycairo-1.19.1             |   py37h01af8b0_3          77 KB  conda-forge
    python-dateutil-2.8.1      |             py_0         220 KB  conda-forge
    python_abi-3.7             |          1_cp37m           4 KB  conda-forge
    pytz-2019.3                |             py_0         237 KB  conda-forge
    rdkit-2020.03.1            |   py37hdd87690_1        24.7 MB  conda-forge
    xorg-kbproto-1.0.7         |    h14c3975_1002          26 KB  conda-forge
    xorg-libice-1.0.10         |       h516909a_0          57 KB  conda-forge
    xorg-libsm-1.2.3           |    h84519dc_1000          25 KB  conda-forge
    xorg-libx11-1.6.9          |       h516909a_0         918 KB  conda-forge
    xorg-libxau-1.0.9          |       h14c3975_0          13 KB  conda-forge
    xorg-libxdmcp-1.1.3        |       h516909a_0          18 KB  conda-forge
    xorg-libxext-1.3.4         |       h516909a_0          51 KB  conda-forge
    xorg-libxrender-0.9.10     |    h516909a_1002          31 KB  conda-forge
    xorg-renderproto-0.11.1    |    h14c3975_1002           8 KB  conda-forge
    xorg-xextproto-7.3.0       |    h14c3975_1002          27 KB  conda-forge
    xorg-xproto-7.0.31         |    h14c3975_1007          72 KB  conda-forge
    zstd-1.4.4                 |       h3b9ef0a_2         982 KB  conda-forge
    ------------------------------------------------------------
                                           Total:       110.7 MB

The following NEW packages will be INSTALLED:

  boost              conda-forge/linux-64::boost-1.72.0-py37h9de70de_0
  boost-cpp          conda-forge/linux-64::boost-cpp-1.72.0-h8e57a91_0
  bzip2              conda-forge/linux-64::bzip2-1.0.8-h516909a_2
  cairo              conda-forge/linux-64::cairo-1.16.0-hcf35c78_1003
  fontconfig         conda-forge/linux-64::fontconfig-2.13.1-h86ecdb6_1001
  freetype           conda-forge/linux-64::freetype-2.10.1-he06d7ca_0
  gettext            conda-forge/linux-64::gettext-0.19.8.1-hc5be6a0_1002
  glib               conda-forge/linux-64::glib-2.64.2-h6f030ca_0
  icu                conda-forge/linux-64::icu-64.2-he1b5a44_1
  jpeg               conda-forge/linux-64::jpeg-9c-h14c3975_1001
  libblas            conda-forge/linux-64::libblas-3.8.0-14_openblas
  libcblas           conda-forge/linux-64::libcblas-3.8.0-14_openblas
  libgfortran-ng     conda-forge/linux-64::libgfortran-ng-7.3.0-hdf63c60_5
  libiconv           conda-forge/linux-64::libiconv-1.15-h516909a_1006
  liblapack          conda-forge/linux-64::liblapack-3.8.0-14_openblas
  libopenblas        conda-forge/linux-64::libopenblas-0.3.7-h5ec1e0e_6
  libpng             conda-forge/linux-64::libpng-1.6.37-hed695b0_1
  libtiff            conda-forge/linux-64::libtiff-4.1.0-hc7e4089_6
  libuuid            conda-forge/linux-64::libuuid-2.32.1-h14c3975_1000
  libwebp-base       conda-forge/linux-64::libwebp-base-1.1.0-h516909a_3
  libxcb             conda-forge/linux-64::libxcb-1.13-h14c3975_1002
  libxml2            conda-forge/linux-64::libxml2-2.9.10-hee79883_0
  lz4-c              conda-forge/linux-64::lz4-c-1.8.3-he1b5a44_1001
  numpy              conda-forge/linux-64::numpy-1.18.1-py37h8960a57_1
  olefile            conda-forge/noarch::olefile-0.46-py_0
  pandas             conda-forge/linux-64::pandas-1.0.3-py37h0da4684_1
  pcre               conda-forge/linux-64::pcre-8.44-he1b5a44_0
  pillow             pkgs/main/linux-64::pillow-7.0.0-py37hb39fc2d_0
  pixman             conda-forge/linux-64::pixman-0.38.0-h516909a_1003
  pthread-stubs      conda-forge/linux-64::pthread-stubs-0.4-h14c3975_1001
  pycairo            conda-forge/linux-64::pycairo-1.19.1-py37h01af8b0_3
  python-dateutil    conda-forge/noarch::python-dateutil-2.8.1-py_0
  python_abi         conda-forge/linux-64::python_abi-3.7-1_cp37m
  pytz               conda-forge/noarch::pytz-2019.3-py_0
  rdkit              conda-forge/linux-64::rdkit-2020.03.1-py37hdd87690_1
  xorg-kbproto       conda-forge/linux-64::xorg-kbproto-1.0.7-h14c3975_1002
  xorg-libice        conda-forge/linux-64::xorg-libice-1.0.10-h516909a_0
  xorg-libsm         conda-forge/linux-64::xorg-libsm-1.2.3-h84519dc_1000
  xorg-libx11        conda-forge/linux-64::xorg-libx11-1.6.9-h516909a_0
  xorg-libxau        conda-forge/linux-64::xorg-libxau-1.0.9-h14c3975_0
  xorg-libxdmcp      conda-forge/linux-64::xorg-libxdmcp-1.1.3-h516909a_0
  xorg-libxext       conda-forge/linux-64::xorg-libxext-1.3.4-h516909a_0
  xorg-libxrender    conda-forge/linux-64::xorg-libxrender-0.9.10-h516909a_1002
  xorg-renderproto   conda-forge/linux-64::xorg-renderproto-0.11.1-h14c3975_1002
  xorg-xextproto     conda-forge/linux-64::xorg-xextproto-7.3.0-h14c3975_1002
  xorg-xproto        conda-forge/linux-64::xorg-xproto-7.0.31-h14c3975_1007
  zstd               conda-forge/linux-64::zstd-1.4.4-h3b9ef0a_2

The following packages will be UPDATED:

  ca-certificates     pkgs/main::ca-certificates-2020.1.1-0 --> conda-forge::ca-certificates-2020.4.5.1-hecc5488_0
  certifi              pkgs/main::certifi-2019.11.28-py37_0 --> conda-forge::certifi-2020.4.5.1-py37hc8dfbb8_0
  conda                       pkgs/main::conda-4.8.2-py37_0 --> conda-forge::conda-4.8.3-py37hc8dfbb8_1
  openssl              pkgs/main::openssl-1.1.1d-h7b6447c_4 --> conda-forge::openssl-1.1.1g-h516909a_0


Preparing transaction: ...working... done
Verifying transaction: ...working... done
Executing transaction: ...working... done

real	0m37.898s
user	0m31.771s
sys	0m3.614s

In [0]:
import torch
import rdkit
import rdkit.Chem as Chem
from rdkit.Chem import rdFMCS
from matplotlib import colors
from rdkit.Chem import Draw
from rdkit.Chem.Draw import MolToImage
from PIL import Image


def get_mol(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None
    Chem.Kekulize(mol)
    return mol


def find_matches_one(mol,submol):
    #find all matching atoms for each submol in submol_list in mol.
    match_dict = {}
    mols = [mol,submol] #pairwise search
    res=rdFMCS.FindMCS(mols) #,ringMatchesRingOnly=True)
    mcsp = Chem.MolFromSmarts(res.smartsString)
    matches = mol.GetSubstructMatches(mcsp)
    return matches

#Draw the molecule
def get_image(mol,atomset):    
    hcolor = colors.to_rgb('green')
    if atomset is not None:
        #highlight the atoms set while drawing the whole molecule.
        img = MolToImage(mol, size=(600, 600),fitImage=True, highlightAtoms=atomset,highlightColor=hcolor)
    else:
        img = MolToImage(mol, size=(400, 400),fitImage=True)
    return img

In [7]:
sequence = f"CCC(CC)COC(=O)[C@H](C)N[P@](=O)(OC[C@H]1O[C@](C#N)([C@H](O)[C@@H]1O)C1=CC=C2N1N=CN=C2N)OC1=CC=CC={tokenizer.mask_token}1"
substructure = "CC=CC"
image_list = []

input = tokenizer.encode(sequence, return_tensors="pt")
mask_token_index = torch.where(input == tokenizer.mask_token_id)[1]

token_logits = model(input)[0]
mask_token_logits = token_logits[0, mask_token_index, :]

top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()

for token in top_5_tokens:
  smi = (sequence.replace(tokenizer.mask_token, tokenizer.decode([token])))
  print (smi)
  smi_mol = get_mol(smi)
  substructure_mol = get_mol(substructure)
  if smi_mol is None: # if the model's token prediction isn't chemically feasible
    continue
  Draw.MolToFile(smi_mol, smi+".png")
  matches = find_matches_one(smi_mol, substructure_mol)
  atomset = list(matches[0])
  img = get_image(smi_mol, atomset)
  img.format="PNG" 
  image_list.append(img)


CCC(CC)COC(=O)[C@H](C)N[P@](=O)(OC[C@H]1O[C@](C#N)([C@H](O)[C@@H]1O)C1=CC=C2N1N=CN=C2N)OC1=CC=CC=C1
CCC(CC)COC(=O)[C@H](C)N[P@](=O)(OC[C@H]1O[C@](C#N)([C@H](O)[C@@H]1O)C1=CC=C2N1N=CN=C2N)OC1=CC=CC=O1
CCC(CC)COC(=O)[C@H](C)N[P@](=O)(OC[C@H]1O[C@](C#N)([C@H](O)[C@@H]1O)C1=CC=C2N1N=CN=C2N)OC1=CC=CC=N1
CCC(CC)COC(=O)[C@H](C)N[P@](=O)(OC[C@H]1O[C@](C#N)([C@H](O)[C@@H]1O)C1=CC=C2N1N=CN=C2N)OC1=CC=CC=21
CCC(CC)COC(=O)[C@H](C)N[P@](=O)(OC[C@H]1O[C@](C#N)([C@H](O)[C@@H]1O)C1=CC=C2N1N=CN=C2N)OC1=CC=CC=H1

In [8]:
from IPython.display import Image 

for img in image_list:
  display(img)


As we can see above, 2 of 4 of the model's MLM predictions are chemically valid. The one the model would've chosen (with a score of 0.6), is the first image, in which the top left molecular structure resembles the benzene found in the therapy Remdesivir. Overall, the model seems to understand syntax with a pretty decent degree of certainity.

However, further training on a more specific dataset (say leads for a specific target) may generate a stronger MLM model. Let's now fine-tune our model on a dataset of our choice, Tox21.

Fine-tuning ChemBERTa on a Small Mollecular Dataset

Tumor suppressor protein (SR.p53), typically the p53 pathway is “off” and is activated when cells are under stress or damaged, hence being a good indicator of DNA damage and other cellular stresses. Tumor suppressor protein p53 is activated by inducing DNA repair, cell cycle arrest and apoptosis.

The Tox21 challenge was introduced in 2014 in an attempt to build models that are successful in predicting compounds' interference in biochemical pathways using only chemical structure data. The computational models produced from the challenge could become decision-making tools for government agencies in determining which environmental chemicals and drugs are of the greatest potential concern to human health. Additionally, these models can act as drug screening tools in the drug discovery pipelines for toxicity.

Lets start by loading the dataset from s3, before importing apex and transformers, the tool which will allow us to import the pre-trained masked-language modelling architecture trained on ZINC15.


In [1]:
!wget https://t.co/zrC7F8DcRs?amp=1


--2020-04-24 04:31:11--  https://t.co/zrC7F8DcRs?amp=1
Resolving t.co (t.co)... 104.244.42.5, 104.244.42.133, 104.244.42.69, ...
Connecting to t.co (t.co)|104.244.42.5|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/tox21_balanced_revised_no_id.csv [following]
--2020-04-24 04:31:12--  https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/tox21_balanced_revised_no_id.csv
Resolving deepchemdata.s3-us-west-1.amazonaws.com (deepchemdata.s3-us-west-1.amazonaws.com)... 52.219.112.193
Connecting to deepchemdata.s3-us-west-1.amazonaws.com (deepchemdata.s3-us-west-1.amazonaws.com)|52.219.112.193|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 85962 (84K) [text/csv]
Saving to: ‘zrC7F8DcRs?amp=1’

zrC7F8DcRs?amp=1    100%[===================>]  83.95K   231KB/s    in 0.4s    

2020-04-24 04:31:13 (231 KB/s) - ‘zrC7F8DcRs?amp=1’ saved [85962/85962]

We want to install NVIDIA's Apex tool, for the training pipeline used by simple-transformer and Weights and Biases.


In [2]:
!git clone https://github.com/NVIDIA/apex
!cd /content/apex
!pip install -v --no-cache-dir /content/apex
!cd ..


Cloning into 'apex'...
remote: Enumerating objects: 4, done.
remote: Counting objects: 100% (4/4), done.
remote: Compressing objects: 100% (4/4), done.
remote: Total 6593 (delta 0), reused 0 (delta 0), pack-reused 6589
Receiving objects: 100% (6593/6593), 13.70 MiB | 1.52 MiB/s, done.
Resolving deltas: 100% (4383/4383), done.
Created temporary directory: /tmp/pip-ephem-wheel-cache-q5nbg4uh
Created temporary directory: /tmp/pip-req-tracker-ixo7f527
Created requirements tracker '/tmp/pip-req-tracker-ixo7f527'
Created temporary directory: /tmp/pip-install-4bvkod3b
Processing ./apex
  Created temporary directory: /tmp/pip-req-build-nltce4xy
  Added file:///content/apex to build tracker '/tmp/pip-req-tracker-ixo7f527'
    Running setup.py (path:/tmp/pip-req-build-nltce4xy/setup.py) egg_info for package from file:///content/apex
    Running command python setup.py egg_info
    torch.__version__  =  1.4.0
    running egg_info
    creating /tmp/pip-req-build-nltce4xy/pip-egg-info/apex.egg-info
    writing /tmp/pip-req-build-nltce4xy/pip-egg-info/apex.egg-info/PKG-INFO
    writing dependency_links to /tmp/pip-req-build-nltce4xy/pip-egg-info/apex.egg-info/dependency_links.txt
    writing top-level names to /tmp/pip-req-build-nltce4xy/pip-egg-info/apex.egg-info/top_level.txt
    writing manifest file '/tmp/pip-req-build-nltce4xy/pip-egg-info/apex.egg-info/SOURCES.txt'
    writing manifest file '/tmp/pip-req-build-nltce4xy/pip-egg-info/apex.egg-info/SOURCES.txt'
    /tmp/pip-req-build-nltce4xy/setup.py:46: UserWarning: Option --pyprof not specified. Not installing PyProf dependencies!
      warnings.warn("Option --pyprof not specified. Not installing PyProf dependencies!")
  Source in /tmp/pip-req-build-nltce4xy has version 0.1, which satisfies requirement apex==0.1 from file:///content/apex
  Removed apex==0.1 from file:///content/apex from build tracker '/tmp/pip-req-tracker-ixo7f527'
Building wheels for collected packages: apex
  Created temporary directory: /tmp/pip-wheel-rym8vea2
  Building wheel for apex (setup.py) ...   Destination directory: /tmp/pip-wheel-rym8vea2
  Running command /usr/bin/python3 -u -c 'import sys, setuptools, tokenize; sys.argv[0] = '"'"'/tmp/pip-req-build-nltce4xy/setup.py'"'"'; __file__='"'"'/tmp/pip-req-build-nltce4xy/setup.py'"'"';f=getattr(tokenize, '"'"'open'"'"', open)(__file__);code=f.read().replace('"'"'\r\n'"'"', '"'"'\n'"'"');f.close();exec(compile(code, __file__, '"'"'exec'"'"'))' bdist_wheel -d /tmp/pip-wheel-rym8vea2 --python-tag cp36
  torch.__version__  =  1.4.0
  /tmp/pip-req-build-nltce4xy/setup.py:46: UserWarning: Option --pyprof not specified. Not installing PyProf dependencies!
    warnings.warn("Option --pyprof not specified. Not installing PyProf dependencies!")
  running bdist_wheel
  running build
  running build_py
  creating build
  creating build/lib
  creating build/lib/apex
  copying apex/__init__.py -> build/lib/apex
  creating build/lib/apex/optimizers
  copying apex/optimizers/__init__.py -> build/lib/apex/optimizers
  copying apex/optimizers/fused_lamb.py -> build/lib/apex/optimizers
  copying apex/optimizers/fused_sgd.py -> build/lib/apex/optimizers
  copying apex/optimizers/fused_novograd.py -> build/lib/apex/optimizers
  copying apex/optimizers/fused_adam.py -> build/lib/apex/optimizers
  creating build/lib/apex/pyprof
  copying apex/pyprof/__init__.py -> build/lib/apex/pyprof
  creating build/lib/apex/normalization
  copying apex/normalization/__init__.py -> build/lib/apex/normalization
  copying apex/normalization/fused_layer_norm.py -> build/lib/apex/normalization
  creating build/lib/apex/multi_tensor_apply
  copying apex/multi_tensor_apply/__init__.py -> build/lib/apex/multi_tensor_apply
  copying apex/multi_tensor_apply/multi_tensor_apply.py -> build/lib/apex/multi_tensor_apply
  creating build/lib/apex/parallel
  copying apex/parallel/optimized_sync_batchnorm_kernel.py -> build/lib/apex/parallel
  copying apex/parallel/optimized_sync_batchnorm.py -> build/lib/apex/parallel
  copying apex/parallel/__init__.py -> build/lib/apex/parallel
  copying apex/parallel/LARC.py -> build/lib/apex/parallel
  copying apex/parallel/sync_batchnorm_kernel.py -> build/lib/apex/parallel
  copying apex/parallel/distributed.py -> build/lib/apex/parallel
  copying apex/parallel/sync_batchnorm.py -> build/lib/apex/parallel
  copying apex/parallel/multiproc.py -> build/lib/apex/parallel
  creating build/lib/apex/fp16_utils
  copying apex/fp16_utils/__init__.py -> build/lib/apex/fp16_utils
  copying apex/fp16_utils/fp16util.py -> build/lib/apex/fp16_utils
  copying apex/fp16_utils/fp16_optimizer.py -> build/lib/apex/fp16_utils
  copying apex/fp16_utils/loss_scaler.py -> build/lib/apex/fp16_utils
  creating build/lib/apex/reparameterization
  copying apex/reparameterization/__init__.py -> build/lib/apex/reparameterization
  copying apex/reparameterization/reparameterization.py -> build/lib/apex/reparameterization
  copying apex/reparameterization/weight_norm.py -> build/lib/apex/reparameterization
  creating build/lib/apex/contrib
  copying apex/contrib/__init__.py -> build/lib/apex/contrib
  creating build/lib/apex/mlp
  copying apex/mlp/__init__.py -> build/lib/apex/mlp
  copying apex/mlp/mlp.py -> build/lib/apex/mlp
  creating build/lib/apex/RNN
  copying apex/RNN/cells.py -> build/lib/apex/RNN
  copying apex/RNN/__init__.py -> build/lib/apex/RNN
  copying apex/RNN/RNNBackend.py -> build/lib/apex/RNN
  copying apex/RNN/models.py -> build/lib/apex/RNN
  creating build/lib/apex/amp
  copying apex/amp/opt.py -> build/lib/apex/amp
  copying apex/amp/__init__.py -> build/lib/apex/amp
  copying apex/amp/scaler.py -> build/lib/apex/amp
  copying apex/amp/__version__.py -> build/lib/apex/amp
  copying apex/amp/_amp_state.py -> build/lib/apex/amp
  copying apex/amp/rnn_compat.py -> build/lib/apex/amp
  copying apex/amp/utils.py -> build/lib/apex/amp
  copying apex/amp/wrap.py -> build/lib/apex/amp
  copying apex/amp/_initialize.py -> build/lib/apex/amp
  copying apex/amp/frontend.py -> build/lib/apex/amp
  copying apex/amp/_process_optimizer.py -> build/lib/apex/amp
  copying apex/amp/compat.py -> build/lib/apex/amp
  copying apex/amp/amp.py -> build/lib/apex/amp
  copying apex/amp/handle.py -> build/lib/apex/amp
  creating build/lib/apex/pyprof/nvtx
  copying apex/pyprof/nvtx/__init__.py -> build/lib/apex/pyprof/nvtx
  copying apex/pyprof/nvtx/nvmarker.py -> build/lib/apex/pyprof/nvtx
  creating build/lib/apex/pyprof/parse
  copying apex/pyprof/parse/__main__.py -> build/lib/apex/pyprof/parse
  copying apex/pyprof/parse/__init__.py -> build/lib/apex/pyprof/parse
  copying apex/pyprof/parse/kernel.py -> build/lib/apex/pyprof/parse
  copying apex/pyprof/parse/db.py -> build/lib/apex/pyprof/parse
  copying apex/pyprof/parse/nvvp.py -> build/lib/apex/pyprof/parse
  copying apex/pyprof/parse/parse.py -> build/lib/apex/pyprof/parse
  creating build/lib/apex/pyprof/prof
  copying apex/pyprof/prof/utility.py -> build/lib/apex/pyprof/prof
  copying apex/pyprof/prof/prof.py -> build/lib/apex/pyprof/prof
  copying apex/pyprof/prof/__main__.py -> build/lib/apex/pyprof/prof
  copying apex/pyprof/prof/__init__.py -> build/lib/apex/pyprof/prof
  copying apex/pyprof/prof/misc.py -> build/lib/apex/pyprof/prof
  copying apex/pyprof/prof/conv.py -> build/lib/apex/pyprof/prof
  copying apex/pyprof/prof/embedding.py -> build/lib/apex/pyprof/prof
  copying apex/pyprof/prof/output.py -> build/lib/apex/pyprof/prof
  copying apex/pyprof/prof/usage.py -> build/lib/apex/pyprof/prof
  copying apex/pyprof/prof/index_slice_join_mutate.py -> build/lib/apex/pyprof/prof
  copying apex/pyprof/prof/convert.py -> build/lib/apex/pyprof/prof
  copying apex/pyprof/prof/pooling.py -> build/lib/apex/pyprof/prof
  copying apex/pyprof/prof/linear.py -> build/lib/apex/pyprof/prof
  copying apex/pyprof/prof/pointwise.py -> build/lib/apex/pyprof/prof
  copying apex/pyprof/prof/activation.py -> build/lib/apex/pyprof/prof
  copying apex/pyprof/prof/base.py -> build/lib/apex/pyprof/prof
  copying apex/pyprof/prof/dropout.py -> build/lib/apex/pyprof/prof
  copying apex/pyprof/prof/optim.py -> build/lib/apex/pyprof/prof
  copying apex/pyprof/prof/data.py -> build/lib/apex/pyprof/prof
  copying apex/pyprof/prof/reduction.py -> build/lib/apex/pyprof/prof
  copying apex/pyprof/prof/blas.py -> build/lib/apex/pyprof/prof
  copying apex/pyprof/prof/loss.py -> build/lib/apex/pyprof/prof
  copying apex/pyprof/prof/softmax.py -> build/lib/apex/pyprof/prof
  copying apex/pyprof/prof/randomSample.py -> build/lib/apex/pyprof/prof
  copying apex/pyprof/prof/normalization.py -> build/lib/apex/pyprof/prof
  copying apex/pyprof/prof/recurrentCell.py -> build/lib/apex/pyprof/prof
  creating build/lib/apex/contrib/optimizers
  copying apex/contrib/optimizers/__init__.py -> build/lib/apex/contrib/optimizers
  copying apex/contrib/optimizers/fused_lamb.py -> build/lib/apex/contrib/optimizers
  copying apex/contrib/optimizers/fused_sgd.py -> build/lib/apex/contrib/optimizers
  copying apex/contrib/optimizers/fp16_optimizer.py -> build/lib/apex/contrib/optimizers
  copying apex/contrib/optimizers/fused_adam.py -> build/lib/apex/contrib/optimizers
  creating build/lib/apex/contrib/groupbn
  copying apex/contrib/groupbn/__init__.py -> build/lib/apex/contrib/groupbn
  copying apex/contrib/groupbn/batch_norm.py -> build/lib/apex/contrib/groupbn
  creating build/lib/apex/contrib/xentropy
  copying apex/contrib/xentropy/__init__.py -> build/lib/apex/contrib/xentropy
  copying apex/contrib/xentropy/softmax_xentropy.py -> build/lib/apex/contrib/xentropy
  creating build/lib/apex/contrib/multihead_attn
  copying apex/contrib/multihead_attn/__init__.py -> build/lib/apex/contrib/multihead_attn
  copying apex/contrib/multihead_attn/encdec_multihead_attn_func.py -> build/lib/apex/contrib/multihead_attn
  copying apex/contrib/multihead_attn/encdec_multihead_attn.py -> build/lib/apex/contrib/multihead_attn
  copying apex/contrib/multihead_attn/fast_self_multihead_attn_func.py -> build/lib/apex/contrib/multihead_attn
  copying apex/contrib/multihead_attn/fast_encdec_multihead_attn_func.py -> build/lib/apex/contrib/multihead_attn
  copying apex/contrib/multihead_attn/self_multihead_attn_func.py -> build/lib/apex/contrib/multihead_attn
  copying apex/contrib/multihead_attn/fast_self_multihead_attn_norm_add_func.py -> build/lib/apex/contrib/multihead_attn
  copying apex/contrib/multihead_attn/self_multihead_attn.py -> build/lib/apex/contrib/multihead_attn
  copying apex/contrib/multihead_attn/fast_encdec_multihead_attn_norm_add_func.py -> build/lib/apex/contrib/multihead_attn
  creating build/lib/apex/amp/lists
  copying apex/amp/lists/__init__.py -> build/lib/apex/amp/lists
  copying apex/amp/lists/functional_overrides.py -> build/lib/apex/amp/lists
  copying apex/amp/lists/tensor_overrides.py -> build/lib/apex/amp/lists
  copying apex/amp/lists/torch_overrides.py -> build/lib/apex/amp/lists
  installing to build/bdist.linux-x86_64/wheel
  running install
  running install_lib
  creating build/bdist.linux-x86_64
  creating build/bdist.linux-x86_64/wheel
  creating build/bdist.linux-x86_64/wheel/apex
  creating build/bdist.linux-x86_64/wheel/apex/optimizers
  copying build/lib/apex/optimizers/__init__.py -> build/bdist.linux-x86_64/wheel/apex/optimizers
  copying build/lib/apex/optimizers/fused_lamb.py -> build/bdist.linux-x86_64/wheel/apex/optimizers
  copying build/lib/apex/optimizers/fused_sgd.py -> build/bdist.linux-x86_64/wheel/apex/optimizers
  copying build/lib/apex/optimizers/fused_novograd.py -> build/bdist.linux-x86_64/wheel/apex/optimizers
  copying build/lib/apex/optimizers/fused_adam.py -> build/bdist.linux-x86_64/wheel/apex/optimizers
  copying build/lib/apex/__init__.py -> build/bdist.linux-x86_64/wheel/apex
  creating build/bdist.linux-x86_64/wheel/apex/pyprof
  creating build/bdist.linux-x86_64/wheel/apex/pyprof/nvtx
  copying build/lib/apex/pyprof/nvtx/__init__.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/nvtx
  copying build/lib/apex/pyprof/nvtx/nvmarker.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/nvtx
  copying build/lib/apex/pyprof/__init__.py -> build/bdist.linux-x86_64/wheel/apex/pyprof
  creating build/bdist.linux-x86_64/wheel/apex/pyprof/parse
  copying build/lib/apex/pyprof/parse/__main__.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/parse
  copying build/lib/apex/pyprof/parse/__init__.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/parse
  copying build/lib/apex/pyprof/parse/kernel.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/parse
  copying build/lib/apex/pyprof/parse/db.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/parse
  copying build/lib/apex/pyprof/parse/nvvp.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/parse
  copying build/lib/apex/pyprof/parse/parse.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/parse
  creating build/bdist.linux-x86_64/wheel/apex/pyprof/prof
  copying build/lib/apex/pyprof/prof/utility.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/prof
  copying build/lib/apex/pyprof/prof/prof.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/prof
  copying build/lib/apex/pyprof/prof/__main__.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/prof
  copying build/lib/apex/pyprof/prof/__init__.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/prof
  copying build/lib/apex/pyprof/prof/misc.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/prof
  copying build/lib/apex/pyprof/prof/conv.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/prof
  copying build/lib/apex/pyprof/prof/embedding.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/prof
  copying build/lib/apex/pyprof/prof/output.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/prof
  copying build/lib/apex/pyprof/prof/usage.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/prof
  copying build/lib/apex/pyprof/prof/index_slice_join_mutate.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/prof
  copying build/lib/apex/pyprof/prof/convert.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/prof
  copying build/lib/apex/pyprof/prof/pooling.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/prof
  copying build/lib/apex/pyprof/prof/linear.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/prof
  copying build/lib/apex/pyprof/prof/pointwise.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/prof
  copying build/lib/apex/pyprof/prof/activation.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/prof
  copying build/lib/apex/pyprof/prof/base.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/prof
  copying build/lib/apex/pyprof/prof/dropout.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/prof
  copying build/lib/apex/pyprof/prof/optim.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/prof
  copying build/lib/apex/pyprof/prof/data.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/prof
  copying build/lib/apex/pyprof/prof/reduction.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/prof
  copying build/lib/apex/pyprof/prof/blas.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/prof
  copying build/lib/apex/pyprof/prof/loss.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/prof
  copying build/lib/apex/pyprof/prof/softmax.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/prof
  copying build/lib/apex/pyprof/prof/randomSample.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/prof
  copying build/lib/apex/pyprof/prof/normalization.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/prof
  copying build/lib/apex/pyprof/prof/recurrentCell.py -> build/bdist.linux-x86_64/wheel/apex/pyprof/prof
  creating build/bdist.linux-x86_64/wheel/apex/normalization
  copying build/lib/apex/normalization/__init__.py -> build/bdist.linux-x86_64/wheel/apex/normalization
  copying build/lib/apex/normalization/fused_layer_norm.py -> build/bdist.linux-x86_64/wheel/apex/normalization
  creating build/bdist.linux-x86_64/wheel/apex/multi_tensor_apply
  copying build/lib/apex/multi_tensor_apply/__init__.py -> build/bdist.linux-x86_64/wheel/apex/multi_tensor_apply
  copying build/lib/apex/multi_tensor_apply/multi_tensor_apply.py -> build/bdist.linux-x86_64/wheel/apex/multi_tensor_apply
  creating build/bdist.linux-x86_64/wheel/apex/parallel
  copying build/lib/apex/parallel/optimized_sync_batchnorm_kernel.py -> build/bdist.linux-x86_64/wheel/apex/parallel
  copying build/lib/apex/parallel/optimized_sync_batchnorm.py -> build/bdist.linux-x86_64/wheel/apex/parallel
  copying build/lib/apex/parallel/__init__.py -> build/bdist.linux-x86_64/wheel/apex/parallel
  copying build/lib/apex/parallel/LARC.py -> build/bdist.linux-x86_64/wheel/apex/parallel
  copying build/lib/apex/parallel/sync_batchnorm_kernel.py -> build/bdist.linux-x86_64/wheel/apex/parallel
  copying build/lib/apex/parallel/distributed.py -> build/bdist.linux-x86_64/wheel/apex/parallel
  copying build/lib/apex/parallel/sync_batchnorm.py -> build/bdist.linux-x86_64/wheel/apex/parallel
  copying build/lib/apex/parallel/multiproc.py -> build/bdist.linux-x86_64/wheel/apex/parallel
  creating build/bdist.linux-x86_64/wheel/apex/fp16_utils
  copying build/lib/apex/fp16_utils/__init__.py -> build/bdist.linux-x86_64/wheel/apex/fp16_utils
  copying build/lib/apex/fp16_utils/fp16util.py -> build/bdist.linux-x86_64/wheel/apex/fp16_utils
  copying build/lib/apex/fp16_utils/fp16_optimizer.py -> build/bdist.linux-x86_64/wheel/apex/fp16_utils
  copying build/lib/apex/fp16_utils/loss_scaler.py -> build/bdist.linux-x86_64/wheel/apex/fp16_utils
  creating build/bdist.linux-x86_64/wheel/apex/reparameterization
  copying build/lib/apex/reparameterization/__init__.py -> build/bdist.linux-x86_64/wheel/apex/reparameterization
  copying build/lib/apex/reparameterization/reparameterization.py -> build/bdist.linux-x86_64/wheel/apex/reparameterization
  copying build/lib/apex/reparameterization/weight_norm.py -> build/bdist.linux-x86_64/wheel/apex/reparameterization
  creating build/bdist.linux-x86_64/wheel/apex/contrib
  creating build/bdist.linux-x86_64/wheel/apex/contrib/optimizers
  copying build/lib/apex/contrib/optimizers/__init__.py -> build/bdist.linux-x86_64/wheel/apex/contrib/optimizers
  copying build/lib/apex/contrib/optimizers/fused_lamb.py -> build/bdist.linux-x86_64/wheel/apex/contrib/optimizers
  copying build/lib/apex/contrib/optimizers/fused_sgd.py -> build/bdist.linux-x86_64/wheel/apex/contrib/optimizers
  copying build/lib/apex/contrib/optimizers/fp16_optimizer.py -> build/bdist.linux-x86_64/wheel/apex/contrib/optimizers
  copying build/lib/apex/contrib/optimizers/fused_adam.py -> build/bdist.linux-x86_64/wheel/apex/contrib/optimizers
  copying build/lib/apex/contrib/__init__.py -> build/bdist.linux-x86_64/wheel/apex/contrib
  creating build/bdist.linux-x86_64/wheel/apex/contrib/groupbn
  copying build/lib/apex/contrib/groupbn/__init__.py -> build/bdist.linux-x86_64/wheel/apex/contrib/groupbn
  copying build/lib/apex/contrib/groupbn/batch_norm.py -> build/bdist.linux-x86_64/wheel/apex/contrib/groupbn
  creating build/bdist.linux-x86_64/wheel/apex/contrib/xentropy
  copying build/lib/apex/contrib/xentropy/__init__.py -> build/bdist.linux-x86_64/wheel/apex/contrib/xentropy
  copying build/lib/apex/contrib/xentropy/softmax_xentropy.py -> build/bdist.linux-x86_64/wheel/apex/contrib/xentropy
  creating build/bdist.linux-x86_64/wheel/apex/contrib/multihead_attn
  copying build/lib/apex/contrib/multihead_attn/__init__.py -> build/bdist.linux-x86_64/wheel/apex/contrib/multihead_attn
  copying build/lib/apex/contrib/multihead_attn/encdec_multihead_attn_func.py -> build/bdist.linux-x86_64/wheel/apex/contrib/multihead_attn
  copying build/lib/apex/contrib/multihead_attn/encdec_multihead_attn.py -> build/bdist.linux-x86_64/wheel/apex/contrib/multihead_attn
  copying build/lib/apex/contrib/multihead_attn/fast_self_multihead_attn_func.py -> build/bdist.linux-x86_64/wheel/apex/contrib/multihead_attn
  copying build/lib/apex/contrib/multihead_attn/fast_encdec_multihead_attn_func.py -> build/bdist.linux-x86_64/wheel/apex/contrib/multihead_attn
  copying build/lib/apex/contrib/multihead_attn/self_multihead_attn_func.py -> build/bdist.linux-x86_64/wheel/apex/contrib/multihead_attn
  copying build/lib/apex/contrib/multihead_attn/fast_self_multihead_attn_norm_add_func.py -> build/bdist.linux-x86_64/wheel/apex/contrib/multihead_attn
  copying build/lib/apex/contrib/multihead_attn/self_multihead_attn.py -> build/bdist.linux-x86_64/wheel/apex/contrib/multihead_attn
  copying build/lib/apex/contrib/multihead_attn/fast_encdec_multihead_attn_norm_add_func.py -> build/bdist.linux-x86_64/wheel/apex/contrib/multihead_attn
  creating build/bdist.linux-x86_64/wheel/apex/mlp
  copying build/lib/apex/mlp/__init__.py -> build/bdist.linux-x86_64/wheel/apex/mlp
  copying build/lib/apex/mlp/mlp.py -> build/bdist.linux-x86_64/wheel/apex/mlp
  creating build/bdist.linux-x86_64/wheel/apex/RNN
  copying build/lib/apex/RNN/cells.py -> build/bdist.linux-x86_64/wheel/apex/RNN
  copying build/lib/apex/RNN/__init__.py -> build/bdist.linux-x86_64/wheel/apex/RNN
  copying build/lib/apex/RNN/RNNBackend.py -> build/bdist.linux-x86_64/wheel/apex/RNN
  copying build/lib/apex/RNN/models.py -> build/bdist.linux-x86_64/wheel/apex/RNN
  creating build/bdist.linux-x86_64/wheel/apex/amp
  copying build/lib/apex/amp/opt.py -> build/bdist.linux-x86_64/wheel/apex/amp
  copying build/lib/apex/amp/__init__.py -> build/bdist.linux-x86_64/wheel/apex/amp
  copying build/lib/apex/amp/scaler.py -> build/bdist.linux-x86_64/wheel/apex/amp
  copying build/lib/apex/amp/__version__.py -> build/bdist.linux-x86_64/wheel/apex/amp
  copying build/lib/apex/amp/_amp_state.py -> build/bdist.linux-x86_64/wheel/apex/amp
  copying build/lib/apex/amp/rnn_compat.py -> build/bdist.linux-x86_64/wheel/apex/amp
  copying build/lib/apex/amp/utils.py -> build/bdist.linux-x86_64/wheel/apex/amp
  copying build/lib/apex/amp/wrap.py -> build/bdist.linux-x86_64/wheel/apex/amp
  copying build/lib/apex/amp/_initialize.py -> build/bdist.linux-x86_64/wheel/apex/amp
  copying build/lib/apex/amp/frontend.py -> build/bdist.linux-x86_64/wheel/apex/amp
  copying build/lib/apex/amp/_process_optimizer.py -> build/bdist.linux-x86_64/wheel/apex/amp
  copying build/lib/apex/amp/compat.py -> build/bdist.linux-x86_64/wheel/apex/amp
  creating build/bdist.linux-x86_64/wheel/apex/amp/lists
  copying build/lib/apex/amp/lists/__init__.py -> build/bdist.linux-x86_64/wheel/apex/amp/lists
  copying build/lib/apex/amp/lists/functional_overrides.py -> build/bdist.linux-x86_64/wheel/apex/amp/lists
  copying build/lib/apex/amp/lists/tensor_overrides.py -> build/bdist.linux-x86_64/wheel/apex/amp/lists
  copying build/lib/apex/amp/lists/torch_overrides.py -> build/bdist.linux-x86_64/wheel/apex/amp/lists
  copying build/lib/apex/amp/amp.py -> build/bdist.linux-x86_64/wheel/apex/amp
  copying build/lib/apex/amp/handle.py -> build/bdist.linux-x86_64/wheel/apex/amp
  running install_egg_info
  running egg_info
  creating apex.egg-info
  writing apex.egg-info/PKG-INFO
  writing dependency_links to apex.egg-info/dependency_links.txt
  writing top-level names to apex.egg-info/top_level.txt
  writing manifest file 'apex.egg-info/SOURCES.txt'
  writing manifest file 'apex.egg-info/SOURCES.txt'
  Copying apex.egg-info to build/bdist.linux-x86_64/wheel/apex-0.1-py3.6.egg-info
  running install_scripts
  adding license file "LICENSE" (matched pattern "LICEN[CS]E*")
  creating build/bdist.linux-x86_64/wheel/apex-0.1.dist-info/WHEEL
  creating '/tmp/pip-wheel-rym8vea2/apex-0.1-cp36-none-any.whl' and adding 'build/bdist.linux-x86_64/wheel' to it
  adding 'apex/__init__.py'
  adding 'apex/RNN/RNNBackend.py'
  adding 'apex/RNN/__init__.py'
  adding 'apex/RNN/cells.py'
  adding 'apex/RNN/models.py'
  adding 'apex/amp/__init__.py'
  adding 'apex/amp/__version__.py'
  adding 'apex/amp/_amp_state.py'
  adding 'apex/amp/_initialize.py'
  adding 'apex/amp/_process_optimizer.py'
  adding 'apex/amp/amp.py'
  adding 'apex/amp/compat.py'
  adding 'apex/amp/frontend.py'
  adding 'apex/amp/handle.py'
  adding 'apex/amp/opt.py'
  adding 'apex/amp/rnn_compat.py'
  adding 'apex/amp/scaler.py'
  adding 'apex/amp/utils.py'
  adding 'apex/amp/wrap.py'
  adding 'apex/amp/lists/__init__.py'
  adding 'apex/amp/lists/functional_overrides.py'
  adding 'apex/amp/lists/tensor_overrides.py'
  adding 'apex/amp/lists/torch_overrides.py'
  adding 'apex/contrib/__init__.py'
  adding 'apex/contrib/groupbn/__init__.py'
  adding 'apex/contrib/groupbn/batch_norm.py'
  adding 'apex/contrib/multihead_attn/__init__.py'
  adding 'apex/contrib/multihead_attn/encdec_multihead_attn.py'
  adding 'apex/contrib/multihead_attn/encdec_multihead_attn_func.py'
  adding 'apex/contrib/multihead_attn/fast_encdec_multihead_attn_func.py'
  adding 'apex/contrib/multihead_attn/fast_encdec_multihead_attn_norm_add_func.py'
  adding 'apex/contrib/multihead_attn/fast_self_multihead_attn_func.py'
  adding 'apex/contrib/multihead_attn/fast_self_multihead_attn_norm_add_func.py'
  adding 'apex/contrib/multihead_attn/self_multihead_attn.py'
  adding 'apex/contrib/multihead_attn/self_multihead_attn_func.py'
  adding 'apex/contrib/optimizers/__init__.py'
  adding 'apex/contrib/optimizers/fp16_optimizer.py'
  adding 'apex/contrib/optimizers/fused_adam.py'
  adding 'apex/contrib/optimizers/fused_lamb.py'
  adding 'apex/contrib/optimizers/fused_sgd.py'
  adding 'apex/contrib/xentropy/__init__.py'
  adding 'apex/contrib/xentropy/softmax_xentropy.py'
  adding 'apex/fp16_utils/__init__.py'
  adding 'apex/fp16_utils/fp16_optimizer.py'
  adding 'apex/fp16_utils/fp16util.py'
  adding 'apex/fp16_utils/loss_scaler.py'
  adding 'apex/mlp/__init__.py'
  adding 'apex/mlp/mlp.py'
  adding 'apex/multi_tensor_apply/__init__.py'
  adding 'apex/multi_tensor_apply/multi_tensor_apply.py'
  adding 'apex/normalization/__init__.py'
  adding 'apex/normalization/fused_layer_norm.py'
  adding 'apex/optimizers/__init__.py'
  adding 'apex/optimizers/fused_adam.py'
  adding 'apex/optimizers/fused_lamb.py'
  adding 'apex/optimizers/fused_novograd.py'
  adding 'apex/optimizers/fused_sgd.py'
  adding 'apex/parallel/LARC.py'
  adding 'apex/parallel/__init__.py'
  adding 'apex/parallel/distributed.py'
  adding 'apex/parallel/multiproc.py'
  adding 'apex/parallel/optimized_sync_batchnorm.py'
  adding 'apex/parallel/optimized_sync_batchnorm_kernel.py'
  adding 'apex/parallel/sync_batchnorm.py'
  adding 'apex/parallel/sync_batchnorm_kernel.py'
  adding 'apex/pyprof/__init__.py'
  adding 'apex/pyprof/nvtx/__init__.py'
  adding 'apex/pyprof/nvtx/nvmarker.py'
  adding 'apex/pyprof/parse/__init__.py'
  adding 'apex/pyprof/parse/__main__.py'
  adding 'apex/pyprof/parse/db.py'
  adding 'apex/pyprof/parse/kernel.py'
  adding 'apex/pyprof/parse/nvvp.py'
  adding 'apex/pyprof/parse/parse.py'
  adding 'apex/pyprof/prof/__init__.py'
  adding 'apex/pyprof/prof/__main__.py'
  adding 'apex/pyprof/prof/activation.py'
  adding 'apex/pyprof/prof/base.py'
  adding 'apex/pyprof/prof/blas.py'
  adding 'apex/pyprof/prof/conv.py'
  adding 'apex/pyprof/prof/convert.py'
  adding 'apex/pyprof/prof/data.py'
  adding 'apex/pyprof/prof/dropout.py'
  adding 'apex/pyprof/prof/embedding.py'
  adding 'apex/pyprof/prof/index_slice_join_mutate.py'
  adding 'apex/pyprof/prof/linear.py'
  adding 'apex/pyprof/prof/loss.py'
  adding 'apex/pyprof/prof/misc.py'
  adding 'apex/pyprof/prof/normalization.py'
  adding 'apex/pyprof/prof/optim.py'
  adding 'apex/pyprof/prof/output.py'
  adding 'apex/pyprof/prof/pointwise.py'
  adding 'apex/pyprof/prof/pooling.py'
  adding 'apex/pyprof/prof/prof.py'
  adding 'apex/pyprof/prof/randomSample.py'
  adding 'apex/pyprof/prof/recurrentCell.py'
  adding 'apex/pyprof/prof/reduction.py'
  adding 'apex/pyprof/prof/softmax.py'
  adding 'apex/pyprof/prof/usage.py'
  adding 'apex/pyprof/prof/utility.py'
  adding 'apex/reparameterization/__init__.py'
  adding 'apex/reparameterization/reparameterization.py'
  adding 'apex/reparameterization/weight_norm.py'
  adding 'apex-0.1.dist-info/LICENSE'
  adding 'apex-0.1.dist-info/METADATA'
  adding 'apex-0.1.dist-info/WHEEL'
  adding 'apex-0.1.dist-info/top_level.txt'
  adding 'apex-0.1.dist-info/RECORD'
  removing build/bdist.linux-x86_64/wheel
done
  Created wheel for apex: filename=apex-0.1-cp36-none-any.whl size=157194 sha256=9d3ee1058b54c45a2be01bb279de4ec903855ec8cf6d6f3d5559dffda5dfaf89
  Stored in directory: /tmp/pip-ephem-wheel-cache-q5nbg4uh/wheels/b1/3a/aa/d84906eaab780ae580c7a5686a33bf2820d8590ac3b60d5967
  Removing source in /tmp/pip-req-build-nltce4xy
Successfully built apex
Installing collected packages: apex

Successfully installed apex-0.1
Cleaning up...
Removed build tracker '/tmp/pip-req-tracker-ixo7f527'

In [0]:
# Test if NVIDIA apex training tool works
from apex import amp

If you're only running the toxicity prediction portion of this tutorial, make sure you install transformers here. If you've ran all the cells before, you can ignore this install as we've already done pip install transformers before.


In [4]:
!pip install transformers


Collecting transformers
  Downloading https://files.pythonhosted.org/packages/a3/78/92cedda05552398352ed9784908b834ee32a0bd071a9b32de287327370b7/transformers-2.8.0-py3-none-any.whl (563kB)
     |████████████████████████████████| 573kB 33.4MB/s 
Requirement already satisfied: dataclasses; python_version < "3.7" in /usr/local/lib/python3.6/dist-packages (from transformers) (0.7)
Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.6/dist-packages (from transformers) (4.38.0)
Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from transformers) (2.21.0)
Collecting tokenizers==0.5.2
  Downloading https://files.pythonhosted.org/packages/d1/3f/73c881ea4723e43c1e9acf317cf407fab3a278daab3a69c98dcac511c04f/tokenizers-0.5.2-cp36-cp36m-manylinux1_x86_64.whl (3.7MB)
     |████████████████████████████████| 3.7MB 59.5MB/s 
Requirement already satisfied: boto3 in /usr/local/lib/python3.6/dist-packages (from transformers) (1.12.40)
Collecting sentencepiece
  Downloading https://files.pythonhosted.org/packages/74/f4/2d5214cbf13d06e7cb2c20d84115ca25b53ea76fa1f0ade0e3c9749de214/sentencepiece-0.1.85-cp36-cp36m-manylinux1_x86_64.whl (1.0MB)
     |████████████████████████████████| 1.0MB 54.3MB/s 
Collecting sacremoses
  Downloading https://files.pythonhosted.org/packages/99/50/93509f906a40bffd7d175f97fd75ea328ad9bd91f48f59c4bd084c94a25e/sacremoses-0.0.41.tar.gz (883kB)
     |████████████████████████████████| 890kB 52.9MB/s 
Requirement already satisfied: filelock in /usr/local/lib/python3.6/dist-packages (from transformers) (3.0.12)
Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from transformers) (1.18.2)
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.6/dist-packages (from transformers) (2019.12.20)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (2020.4.5.1)
Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (3.0.4)
Requirement already satisfied: idna<2.9,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (2.8)
Requirement already satisfied: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (1.24.3)
Requirement already satisfied: s3transfer<0.4.0,>=0.3.0 in /usr/local/lib/python3.6/dist-packages (from boto3->transformers) (0.3.3)
Requirement already satisfied: jmespath<1.0.0,>=0.7.1 in /usr/local/lib/python3.6/dist-packages (from boto3->transformers) (0.9.5)
Requirement already satisfied: botocore<1.16.0,>=1.15.40 in /usr/local/lib/python3.6/dist-packages (from boto3->transformers) (1.15.40)
Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (1.12.0)
Requirement already satisfied: click in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (7.1.1)
Requirement already satisfied: joblib in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (0.14.1)
Requirement already satisfied: python-dateutil<3.0.0,>=2.1 in /usr/local/lib/python3.6/dist-packages (from botocore<1.16.0,>=1.15.40->boto3->transformers) (2.8.1)
Requirement already satisfied: docutils<0.16,>=0.10 in /usr/local/lib/python3.6/dist-packages (from botocore<1.16.0,>=1.15.40->boto3->transformers) (0.15.2)
Building wheels for collected packages: sacremoses
  Building wheel for sacremoses (setup.py) ... done
  Created wheel for sacremoses: filename=sacremoses-0.0.41-cp36-none-any.whl size=893334 sha256=5debc70bf2760c36e513997d6cfe94649592d95dd3b9654ec526cd2a32cad0f2
  Stored in directory: /root/.cache/pip/wheels/22/5a/d4/b020a81249de7dc63758a34222feaa668dbe8ebfe9170cc9b1
Successfully built sacremoses
Installing collected packages: tokenizers, sentencepiece, sacremoses, transformers
Successfully installed sacremoses-0.0.41 sentencepiece-0.1.85 tokenizers-0.5.2 transformers-2.8.0

In [5]:
!pip install simpletransformers
!pip install wandb


Collecting simpletransformers
  Downloading https://files.pythonhosted.org/packages/ce/cc/4b42c1c362c7c3b939ebf5b628145abf69aeb8e1ac3f79770577466319c1/simpletransformers-0.25.0-py3-none-any.whl (157kB)
     |████████████████████████████████| 163kB 15.8MB/s 
Requirement already satisfied: tokenizers in /usr/local/lib/python3.6/dist-packages (from simpletransformers) (0.5.2)
Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from simpletransformers) (2.21.0)
Collecting seqeval
  Downloading https://files.pythonhosted.org/packages/34/91/068aca8d60ce56dd9ba4506850e876aba5e66a6f2f29aa223224b50df0de/seqeval-0.0.12.tar.gz
Requirement already satisfied: scikit-learn in /usr/local/lib/python3.6/dist-packages (from simpletransformers) (0.22.2.post1)
Requirement already satisfied: pandas in /usr/local/lib/python3.6/dist-packages (from simpletransformers) (1.0.3)
Collecting tensorboardx
  Downloading https://files.pythonhosted.org/packages/35/f1/5843425495765c8c2dd0784a851a93ef204d314fc87bcc2bbb9f662a3ad1/tensorboardX-2.0-py2.py3-none-any.whl (195kB)
     |████████████████████████████████| 204kB 26.6MB/s 
Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from simpletransformers) (4.38.0)
Requirement already satisfied: scipy in /usr/local/lib/python3.6/dist-packages (from simpletransformers) (1.4.1)
Requirement already satisfied: regex in /usr/local/lib/python3.6/dist-packages (from simpletransformers) (2019.12.20)
Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from simpletransformers) (1.18.2)
Requirement already satisfied: transformers in /usr/local/lib/python3.6/dist-packages (from simpletransformers) (2.8.0)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->simpletransformers) (2020.4.5.1)
Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->simpletransformers) (3.0.4)
Requirement already satisfied: idna<2.9,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->simpletransformers) (2.8)
Requirement already satisfied: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->simpletransformers) (1.24.3)
Requirement already satisfied: Keras>=2.2.4 in /usr/local/lib/python3.6/dist-packages (from seqeval->simpletransformers) (2.3.1)
Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.6/dist-packages (from scikit-learn->simpletransformers) (0.14.1)
Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.6/dist-packages (from pandas->simpletransformers) (2018.9)
Requirement already satisfied: python-dateutil>=2.6.1 in /usr/local/lib/python3.6/dist-packages (from pandas->simpletransformers) (2.8.1)
Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from tensorboardx->simpletransformers) (1.12.0)
Requirement already satisfied: protobuf>=3.8.0 in /usr/local/lib/python3.6/dist-packages (from tensorboardx->simpletransformers) (3.10.0)
Requirement already satisfied: dataclasses; python_version < "3.7" in /usr/local/lib/python3.6/dist-packages (from transformers->simpletransformers) (0.7)
Requirement already satisfied: boto3 in /usr/local/lib/python3.6/dist-packages (from transformers->simpletransformers) (1.12.40)
Requirement already satisfied: sentencepiece in /usr/local/lib/python3.6/dist-packages (from transformers->simpletransformers) (0.1.85)
Requirement already satisfied: sacremoses in /usr/local/lib/python3.6/dist-packages (from transformers->simpletransformers) (0.0.41)
Requirement already satisfied: filelock in /usr/local/lib/python3.6/dist-packages (from transformers->simpletransformers) (3.0.12)
Requirement already satisfied: h5py in /usr/local/lib/python3.6/dist-packages (from Keras>=2.2.4->seqeval->simpletransformers) (2.10.0)
Requirement already satisfied: pyyaml in /usr/local/lib/python3.6/dist-packages (from Keras>=2.2.4->seqeval->simpletransformers) (3.13)
Requirement already satisfied: keras-preprocessing>=1.0.5 in /usr/local/lib/python3.6/dist-packages (from Keras>=2.2.4->seqeval->simpletransformers) (1.1.0)
Requirement already satisfied: keras-applications>=1.0.6 in /usr/local/lib/python3.6/dist-packages (from Keras>=2.2.4->seqeval->simpletransformers) (1.0.8)
Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from protobuf>=3.8.0->tensorboardx->simpletransformers) (46.1.3)
Requirement already satisfied: botocore<1.16.0,>=1.15.40 in /usr/local/lib/python3.6/dist-packages (from boto3->transformers->simpletransformers) (1.15.40)
Requirement already satisfied: s3transfer<0.4.0,>=0.3.0 in /usr/local/lib/python3.6/dist-packages (from boto3->transformers->simpletransformers) (0.3.3)
Requirement already satisfied: jmespath<1.0.0,>=0.7.1 in /usr/local/lib/python3.6/dist-packages (from boto3->transformers->simpletransformers) (0.9.5)
Requirement already satisfied: click in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers->simpletransformers) (7.1.1)
Requirement already satisfied: docutils<0.16,>=0.10 in /usr/local/lib/python3.6/dist-packages (from botocore<1.16.0,>=1.15.40->boto3->transformers->simpletransformers) (0.15.2)
Building wheels for collected packages: seqeval
  Building wheel for seqeval (setup.py) ... done
  Created wheel for seqeval: filename=seqeval-0.0.12-cp36-none-any.whl size=7424 sha256=9e0590e9e861f1347cbf412dcdb58e84cd910ed9897342f135f0caf8b102ccab
  Stored in directory: /root/.cache/pip/wheels/4f/32/0a/df3b340a82583566975377d65e724895b3fad101a3fb729f68
Successfully built seqeval
Installing collected packages: seqeval, tensorboardx, simpletransformers
Successfully installed seqeval-0.0.12 simpletransformers-0.25.0 tensorboardx-2.0
Collecting wandb
  Downloading https://files.pythonhosted.org/packages/68/dd/ce719d36c4172b56c7579a79fcfd2f731c386b39f258bb186ef17b73fd7d/wandb-0.8.32-py2.py3-none-any.whl (1.4MB)
     |████████████████████████████████| 1.4MB 22.0MB/s 
Requirement already satisfied: requests>=2.0.0 in /usr/local/lib/python3.6/dist-packages (from wandb) (2.21.0)
Collecting gql==0.2.0
  Downloading https://files.pythonhosted.org/packages/c4/6f/cf9a3056045518f06184e804bae89390eb706168349daa9dff8ac609962a/gql-0.2.0.tar.gz
Collecting docker-pycreds>=0.4.0
  Downloading https://files.pythonhosted.org/packages/f5/e8/f6bd1eee09314e7e6dee49cbe2c5e22314ccdb38db16c9fc72d2fa80d054/docker_pycreds-0.4.0-py2.py3-none-any.whl
Collecting sentry-sdk>=0.4.0
  Downloading https://files.pythonhosted.org/packages/20/7e/19545324e83db4522b885808cd913c3b93ecc0c88b03e037b78c6a417fa8/sentry_sdk-0.14.3-py2.py3-none-any.whl (103kB)
     |████████████████████████████████| 112kB 49.4MB/s 
Collecting GitPython>=1.0.0
  Downloading https://files.pythonhosted.org/packages/19/1a/0df85d2bddbca33665d2148173d3281b290ac054b5f50163ea735740ac7b/GitPython-3.1.1-py3-none-any.whl (450kB)
     |████████████████████████████████| 460kB 55.5MB/s 
Collecting shortuuid>=0.5.0
  Downloading https://files.pythonhosted.org/packages/25/a6/2ecc1daa6a304e7f1b216f0896b26156b78e7c38e1211e9b798b4716c53d/shortuuid-1.0.1-py3-none-any.whl
Collecting watchdog>=0.8.3
  Downloading https://files.pythonhosted.org/packages/73/c3/ed6d992006837e011baca89476a4bbffb0a91602432f73bd4473816c76e2/watchdog-0.10.2.tar.gz (95kB)
     |████████████████████████████████| 102kB 13.4MB/s 
Requirement already satisfied: Click>=7.0 in /usr/local/lib/python3.6/dist-packages (from wandb) (7.1.1)
Requirement already satisfied: six>=1.10.0 in /usr/local/lib/python3.6/dist-packages (from wandb) (1.12.0)
Collecting subprocess32>=3.5.3
  Downloading https://files.pythonhosted.org/packages/32/c8/564be4d12629b912ea431f1a50eb8b3b9d00f1a0b1ceff17f266be190007/subprocess32-3.5.4.tar.gz (97kB)
     |████████████████████████████████| 102kB 13.1MB/s 
Collecting configparser>=3.8.1
  Downloading https://files.pythonhosted.org/packages/4b/6b/01baa293090240cf0562cc5eccb69c6f5006282127f2b846fad011305c79/configparser-5.0.0-py3-none-any.whl
Requirement already satisfied: python-dateutil>=2.6.1 in /usr/local/lib/python3.6/dist-packages (from wandb) (2.8.1)
Requirement already satisfied: PyYAML>=3.10 in /usr/local/lib/python3.6/dist-packages (from wandb) (3.13)
Requirement already satisfied: nvidia-ml-py3>=7.352.0 in /usr/local/lib/python3.6/dist-packages (from wandb) (7.352.0)
Requirement already satisfied: psutil>=5.0.0 in /usr/local/lib/python3.6/dist-packages (from wandb) (5.4.8)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests>=2.0.0->wandb) (2020.4.5.1)
Requirement already satisfied: idna<2.9,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests>=2.0.0->wandb) (2.8)
Requirement already satisfied: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests>=2.0.0->wandb) (1.24.3)
Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests>=2.0.0->wandb) (3.0.4)
Collecting graphql-core<2,>=0.5.0
  Downloading https://files.pythonhosted.org/packages/b0/89/00ad5e07524d8c523b14d70c685e0299a8b0de6d0727e368c41b89b7ed0b/graphql-core-1.1.tar.gz (70kB)
     |████████████████████████████████| 71kB 9.5MB/s 
Requirement already satisfied: promise<3,>=2.0 in /usr/local/lib/python3.6/dist-packages (from gql==0.2.0->wandb) (2.3)
Collecting gitdb<5,>=4.0.1
  Downloading https://files.pythonhosted.org/packages/74/52/ca35448b56c53a079d3ffe18b1978c6e424f6d4df02404877094c89f5bfb/gitdb-4.0.4-py3-none-any.whl (63kB)
     |████████████████████████████████| 71kB 11.1MB/s 
Collecting pathtools>=0.1.1
  Downloading https://files.pythonhosted.org/packages/e7/7f/470d6fcdf23f9f3518f6b0b76be9df16dcc8630ad409947f8be2eb0ed13a/pathtools-0.1.2.tar.gz
Collecting smmap<4,>=3.0.1
  Downloading https://files.pythonhosted.org/packages/27/b1/e379cfb7c07bbf8faee29c4a1a2469dbea525f047c2b454c4afdefa20a30/smmap-3.0.2-py2.py3-none-any.whl
Building wheels for collected packages: gql, watchdog, subprocess32, graphql-core, pathtools
  Building wheel for gql (setup.py) ... done
  Created wheel for gql: filename=gql-0.2.0-cp36-none-any.whl size=7630 sha256=88d395f05da00e481a02baafbdb85210b0ac9c1ebd46282674b22ba931f49b49
  Stored in directory: /root/.cache/pip/wheels/ce/0e/7b/58a8a5268655b3ad74feef5aa97946f0addafb3cbb6bd2da23
  Building wheel for watchdog (setup.py) ... done
  Created wheel for watchdog: filename=watchdog-0.10.2-cp36-none-any.whl size=73605 sha256=fc4714af40db86d8a9cc75d2531cb62ef7118b522f522607c9175386229ed6da
  Stored in directory: /root/.cache/pip/wheels/bc/ed/6c/028dea90d31b359cd2a7c8b0da4db80e41d24a59614154072e
  Building wheel for subprocess32 (setup.py) ... done
  Created wheel for subprocess32: filename=subprocess32-3.5.4-cp36-none-any.whl size=6489 sha256=48d3d5ee6e337e6149a50fc26fba262fbee92d8780e397d39ed5ef26ee59e6b4
  Stored in directory: /root/.cache/pip/wheels/68/39/1a/5e402bdfdf004af1786c8b853fd92f8c4a04f22aad179654d1
  Building wheel for graphql-core (setup.py) ... done
  Created wheel for graphql-core: filename=graphql_core-1.1-cp36-none-any.whl size=104650 sha256=fad902d3f06c9fc44b5fe47e103a279b23b3c3b71b3c75b8b1ae66e430e3ac61
  Stored in directory: /root/.cache/pip/wheels/45/99/d7/c424029bb0fe910c63b68dbf2aa20d3283d023042521bcd7d5
  Building wheel for pathtools (setup.py) ... done
  Created wheel for pathtools: filename=pathtools-0.1.2-cp36-none-any.whl size=8784 sha256=75048560edcc7c400832dc7e006d620847ca8c44625c452ac9a650464f549bd1
  Stored in directory: /root/.cache/pip/wheels/0b/04/79/c3b0c3a0266a3cb4376da31e5bfe8bba0c489246968a68e843
Successfully built gql watchdog subprocess32 graphql-core pathtools
Installing collected packages: graphql-core, gql, docker-pycreds, sentry-sdk, smmap, gitdb, GitPython, shortuuid, pathtools, watchdog, subprocess32, configparser, wandb
Successfully installed GitPython-3.1.1 configparser-5.0.0 docker-pycreds-0.4.0 gitdb-4.0.4 gql-0.2.0 graphql-core-1.1 pathtools-0.1.2 sentry-sdk-0.14.3 shortuuid-1.0.1 smmap-3.0.2 subprocess32-3.5.4 wandb-0.8.32 watchdog-0.10.2

From here, we want to load the dataset from tox21 for training the model. We're going to use a filtered dataset of 2100 compounds, as there are only 400 positive leads and we want to avoid having a large data imbalance. We'll also use simple-transformer's auto_weights argument in defining our ChemBERTa model to do automatic weight balancing later on, to counteract this problem.


In [7]:
!cd ..
dataset_path = "/content/zrC7F8DcRs?amp=1"
df = pd.read_csv(dataset_path, sep = ',', warn_bad_lines=True, header=None)


df.rename(columns={0:'smiles',1:'labels'}, inplace=True)
df.head()


Out[7]:
smiles labels
0 CCCCCCCC/C=C\CCCCCCCC(N)=O 0
1 CCCCCCOC(=O)c1ccccc1 0
2 O=C(c1ccc(Cl)cc1)c1ccc(Cl)cc1 0
3 COc1cc(Cl)c(OC)cc1N 0
4 N[C@H](Cc1c[nH]c2ccccc12)C(=O)O 0

From here, lets set up a logger to record if any issues occur, and notify us if there are any problems with the arguments we've set for the model.


In [0]:
from simpletransformers.classification import ClassificationModel
import logging

logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.WARNING)

Now, using simple-transformer, let's load the pre-trained model from HuggingFace's useful model-hub. We'll set the number of epochs to 3 in the arguments, but you can train for longer. Also make sure that auto_weights is set to True as we are dealing with imbalanced toxicity datasets.


In [0]:
model = ClassificationModel('roberta', 'seyonec/ChemBERTa-zinc-base-v1', args={'num_train_epochs': 3, 'auto_weights': True}) # You can set class weights by using the optional weight argument

In [0]:
# Split the train and test dataset 80-20

train_size = 0.8
train_dataset=df.sample(frac=train_size,random_state=200).reset_index(drop=True)
test_dataset=df.drop(train_dataset.index).reset_index(drop=True)

In [11]:
# check if our train and evaluation dataframes are setup properly. There should only be two columns for the SMILES string and its corresponding label.

print("FULL Dataset: {}".format(df.shape))
print("TRAIN Dataset: {}".format(train_dataset.shape))
print("TEST Dataset: {}".format(test_dataset.shape))


FULL Dataset: (2142, 2)
TRAIN Dataset: (1714, 2)
TEST Dataset: (428, 2)

Now that we've set everything up, lets get to the fun part: training the model! We use Weights and Biases, which is optional (simply remove wandb_project from the list of args). Its a really useful tool for monitering the model's training results (such as accuracy, learning rate and loss), alongside with custom visualizations you can create as well as the gradients.

When you run this cell, Weights and Biases will ask for an account, which you can setup when you get a key through a Github account. Again, this is completely optional and it can be removed from the list of arguments.


In [14]:
# Create directory to store model weights (change path accordingly to where you want!)
!cd /content
!mkdir chemberta_tox21

# Train the model
model.train_model(train_dataset, output_dir='/content/chemberta_tox21', num_labels=2, use_cuda=True, args={'wandb_project': 'project-name'})


mkdir: cannot create directory ‘chemberta_tox21’: File exists
/usr/local/lib/python3.6/dist-packages/simpletransformers/classification/classification_model.py:243: UserWarning: Dataframe headers not specified. Falling back to using column 0 as text and column 1 as labels.
  "Dataframe headers not specified. Falling back to using column 0 as text and column 1 as labels."
INFO:simpletransformers.classification.classification_model: Converting to features started. Cache is not used.
Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
wandb: ERROR Not authenticated.  Copy a key from https://app.wandb.ai/authorize
API Key: ··········
wandb: Appending key for api.wandb.ai to your netrc file: /root/.netrc
INFO:wandb.run_manager:system metrics and metadata threads started
INFO:wandb.run_manager:checking resume status, waiting at most 10 seconds
INFO:wandb.run_manager:resuming run from id: UnVuOnYxOjEyMGVtb2htOnByb2plY3QtbmFtZTpzZXlvbmVj
INFO:wandb.run_manager:upserting run before process can begin, waiting at most 10 seconds
INFO:wandb.run_manager:saving patches
INFO:wandb.run_manager:saving pip packages
INFO:wandb.run_manager:initializing streaming files api
INFO:wandb.run_manager:unblocking file change observer, beginning sync with W&B servers
/usr/local/lib/python3.6/dist-packages/torch/optim/lr_scheduler.py:113: UserWarning: Seems like `optimizer.step()` has been overridden after learning rate scheduler initialization. Please, make sure to call `optimizer.step()` before `lr_scheduler.step()`. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
  "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/config.yaml
Running loss: 0.788242
INFO:wandb.run_manager:file/dir created: /content/wandb/run-20200424_043823-120emohm/wandb-summary.json
INFO:wandb.run_manager:file/dir created: /content/wandb/run-20200424_043823-120emohm/wandb-metadata.json
INFO:wandb.run_manager:file/dir created: /content/wandb/run-20200424_043823-120emohm/media/graph/graph_0_summary_0fce41b2.graph.json
INFO:wandb.run_manager:file/dir created: /content/wandb/run-20200424_043823-120emohm/wandb-events.jsonl
INFO:wandb.run_manager:file/dir created: /content/wandb/run-20200424_043823-120emohm/requirements.txt
INFO:wandb.run_manager:file/dir created: /content/wandb/run-20200424_043823-120emohm/wandb-history.jsonl
INFO:wandb.run_manager:file/dir created: /content/wandb/run-20200424_043823-120emohm/media/graph
INFO:wandb.run_manager:file/dir created: /content/wandb/run-20200424_043823-120emohm/media
Running loss: 0.511236
/usr/local/lib/python3.6/dist-packages/torch/optim/lr_scheduler.py:224: UserWarning: To get the last learning rate computed by the scheduler, please use `get_last_lr()`.
  warnings.warn("To get the last learning rate computed by the scheduler, "
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-history.jsonl
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-summary.json
Running loss: 0.602639
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-metadata.json
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-history.jsonl
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-summary.json
Running loss: 0.232389
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-history.jsonl
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-summary.json
Running loss: 0.853039
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-events.jsonl
Running loss: 0.501909
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-metadata.json
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-history.jsonl
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-summary.json
Running loss: 0.066326
Running loss: 0.170885
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-history.jsonl
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-summary.json
Running loss: 0.221705
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-metadata.json
Running loss: 0.207991
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-history.jsonl
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-summary.json
Running loss: 0.173742
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-history.jsonl
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-summary.json
Running loss: 0.456498
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-events.jsonl
Running loss: 1.234981
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-metadata.json
Running loss: 0.397285
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-history.jsonl
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-summary.json
Running loss: 0.094101
Running loss: 0.043023
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-history.jsonl
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-summary.json
Running loss: 0.053245
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-metadata.json
Running loss: 0.175583
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-history.jsonl
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-summary.json
Running loss: 0.182486
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-events.jsonl
Running loss: 0.071419
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-history.jsonl
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-summary.json
Running loss: 0.565325
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-metadata.json
Running loss: 0.601075
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-history.jsonl
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-summary.json
Running loss: 0.970592

INFO:simpletransformers.classification.classification_model: Training of roberta model complete. Saved to /content/chemberta_tox21.
INFO:wandb.run_manager:shutting down system stats and metadata service
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-events.jsonl
INFO:wandb.run_manager:stopping streaming files and file change observer
INFO:wandb.run_manager:file/dir modified: /content/wandb/run-20200424_043823-120emohm/wandb-metadata.json

Let's install scikit-learn now, to evaluate the model we've trained.


In [15]:
!pip install -U scikit-learn


Requirement already up-to-date: scikit-learn in /usr/local/lib/python3.6/dist-packages (0.22.2.post1)
Requirement already satisfied, skipping upgrade: numpy>=1.11.0 in /usr/local/lib/python3.6/dist-packages (from scikit-learn) (1.18.2)
Requirement already satisfied, skipping upgrade: joblib>=0.11 in /usr/local/lib/python3.6/dist-packages (from scikit-learn) (0.14.1)
Requirement already satisfied, skipping upgrade: scipy>=0.17.0 in /usr/local/lib/python3.6/dist-packages (from scikit-learn) (1.4.1)

The following cell can be ignored unless you are starting a new run-time and just want to load the model from your local directory.


In [0]:
# Loading a saved model for evaluation
model = ClassificationModel('roberta', '/content/chemberta_tox21', num_labels=2, use_cuda=True, args={'wandb_project': 'project-name','num_train_epochs': 3})

In [16]:
import sklearn
result, model_outputs, wrong_predictions = model.eval_model(test_dataset, acc=sklearn.metrics.accuracy_score)


/usr/local/lib/python3.6/dist-packages/simpletransformers/classification/classification_model.py:660: UserWarning: Dataframe headers not specified. Falling back to using column 0 as text and column 1 as labels.
  "Dataframe headers not specified. Falling back to using column 0 as text and column 1 as labels."
INFO:simpletransformers.classification.classification_model: Converting to features started. Cache is not used.

INFO:simpletransformers.classification.classification_model:{'mcc': 0.7136017700095658, 'tp': 55, 'tn': 335, 'fp': 4, 'fn': 34, 'acc': 0.9112149532710281, 'eval_loss': 0.2323251810890657}

The model performs pretty well, averaging above 91% after training on only ~2000 data samples and 400 positive leads! We can clearly see the predictive power of transfer learning, and approaches like these are becoming increasing popular in the pharmaceutical industry where larger datasets are scarce. By training on more epochs and tasks, we can probably boost the accuracy as well!

Lets train the model on one last string outside of the filtered dataset for toxicity. The model should predict 0, meaning no interference in biochemical pathways for p53.


In [17]:
# Lets input a molecule with a SR-p53 value of 0
predictions, raw_outputs = model.predict(['CCCCOc1cc(C(=O)OCCN(CC)CC)ccc1N'])


INFO:simpletransformers.classification.classification_model: Converting to features started. Cache is not used.



In [18]:
print(predictions)
print(raw_outputs)


[0]
[[ 2.9023438 -2.859375 ]]

The model predicts the sample correctly! Some future tasks may include using the same model on multiple tasks (Tox21 provides multiple for toxicity), through multi-task classification, as well as training on a wider dataset. This will be expanded on in a future tutorial!

Congratulations! Time to join the Community!

Congratulations on completing this tutorial notebook! If you enjoyed working through the tutorial, and want to continue working with DeepChem, we encourage you to finish the rest of the tutorials in this series. You can also help the DeepChem community in the following ways:

Star DeepChem on Github

This helps build awareness of the DeepChem project and the tools for open source drug discovery that we're trying to build.

Join the DeepChem Gitter

The DeepChem Gitter hosts a number of scientists, developers, and enthusiasts interested in deep learning for the life sciences. Join the conversation!