Fine-tuning a pre-trained model

This tutorial describes how to fine-tune a pre-trained model from the DeepCpG model zoo. Fine-tuning a model that has been pre-trained on a cells which are similar to the cells of interest can considerably decrease training time.


We first initialize some variables that will be used throughout the tutorial. test_mode=1 should be used for testing purposes, which speeds up computations by only using a subset of the data. For real applications, test_mode=0 should be used.

In [1]:
function run {
  local cmd=$@
  echo "#################################"
  echo $cmd
  echo "#################################"
  eval $cmd

test_mode=1 # Set to 1 for testing and 0 otherwise
example_dir="../../data/" # Directory with example data.
cpg_dir="$example_dir/cpg" # Directory with CpG profiles.
dna_dir="$example_dir/dna/mm10" # Directory with DNA sequences.

Creating DeepCpG data files

First, we create DeepCpG data files using Since we will fine-tune a CpG model, we do not extract sequence windows. Otherwise, --dna_files and --dna_wlen must to be specified.

In [4]:
    --cpg_profiles $cpg_dir/*.tsv
    --out_dir $data_dir
    --cpg_wlen 50
if [[ $test_mode -eq 1 ]]; then
        --nb_sample 10000
run $cmd

################################# --cpg_profiles ../../data//cpg/BS27_1_SER.tsv ../../data//cpg/BS27_3_SER.tsv ../../data//cpg/BS27_5_SER.tsv ../../data//cpg/BS27_6_SER.tsv ../../data//cpg/BS27_8_SER.tsv --out_dir ./data --cpg_wlen 50 --nb_sample 10000
INFO (2017-03-05 19:19:42,901): Reading single-cell profiles ...
INFO (2017-03-05 19:19:43,339): 10000 samples
INFO (2017-03-05 19:19:43,340): --------------------------------------------------------------------------------
INFO (2017-03-05 19:19:43,340): Chromosome 1 ...
INFO (2017-03-05 19:19:43,368): 10000 / 10000 (100.0%) sites matched minimum coverage filter
INFO (2017-03-05 19:19:43,369): Chunk 	1 / 1
INFO (2017-03-05 19:19:43,379): Extracting CpG neighbors ...
INFO (2017-03-05 19:19:44,498): Done!

Downloading a pre-trained model downloads a pre-trained model from the DeepCpG model zoo. Available models and their corresponding description can be found on the model zoo website, or retrieved with --show:

In [5]: --show

Available models:

A model name consist of three parts, which are separated by '_'. The first part corresponds to the publication, the second to the cell type, and the third to the modle type(CpG, DNA, or Joint model). Cells from 'Hou2016' were profiled using scRRBS-seq, cells from 'Smallwood2014' using scBS-seq. 'HCC' and 'HepG2' are human cancer cells, and the rest mouse cells. You should use the cell-type that is most similar to the cell-type you are working with. More information about the available models can be found here.

Since we are dealing with 2i cells and want to train a CpG model, we will fine-tune 'Smallwood2014_2i_cpg':

In [6]:
  $(basename $pretrained_model)
  -o $pretrained_model
run $cmd

################################# Smallwood2014_2i_cpg -o ./models/Smallwood2014_2i_cpg
INFO (2017-03-05 19:19:51,601): Downloading model ...
INFO (2017-03-05 19:19:51,601): Model URL:
--2017-03-05 19:19:51--
Resolving (
Connecting to (||:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 31068468 (30M) [text/plain]
Saving to: ‘./models/Smallwood2014_2i_cpg/’

./models/Smallwood2 100%[===================>]  29.63M  10.1MB/s    in 2.9s    

2017-03-05 19:19:54 (10.1 MB/s) - ‘./models/Smallwood2014_2i_cpg/’ saved [31068468/31068468]

Archive:  ./models/Smallwood2014_2i_cpg/
  inflating: ./models/Smallwood2014_2i_cpg/model.h5  
  inflating: ./models/Smallwood2014_2i_cpg/model.json  
  inflating: ./models/Smallwood2014_2i_cpg/model_weights.h5  
  inflating: ./models/Smallwood2014_2i_cpg/model_weights_train.h5  
  inflating: ./models/Smallwood2014_2i_cpg/model_weights_val.h5  
INFO (2017-03-05 19:19:55,376): Done!

The command downloads and stores model files in the output directory, including the weights and JSON file with the model specification:

In [7]:
ls $pretrained_model

model.h5               model_weights.h5       model_weights_val.h5
model.json             model_weights_train.h5

model.json stores the model specification, and model_weights_train.h5 and model_weights_val.h5 the weights that yielded the highest performance on the training and validation set, respectively. model.h5 combines model.json and model_weights_val.h5.

Fine-tuning the model

To fine-tune the downloaded model, we use --cpg_model followed by the model directory, and --fine_tune to only train the output layers.

--cpg_model $pretrained_model is equivalent to --cpg_model $pretrained_model/model.json $pretrained_model/model_weights_val.h5. To fine-tune the weights with the highest performance on the training set, you have to use model_weights_train.h5 as input instead of model_weights_val.h5.

Without --fine_tune, will train all weights, not only the output layers. This is recommended if the cells that were used for the pre-trained model are only distantly related to the cells of interests, e.g. if cell-types do not match. Training all weights can lead to a higher prediction performance, but also increase training time.

In [8]:
    --cpg_model $pretrained_model
    --out_dir ./models/cpg
if [[ $test_mode -eq 1 ]]; then
    --nb_epoch 2
    --nb_train_sample 1000
    --nb_val_sample 1000
    --nb_epoch 25
    --early_stopping 5
run $cmd

################################# ./data/c1_000000-010000.h5 --cpg_model ./models/Smallwood2014_2i_cpg --out_dir ./models/cpg --fine_tune --nb_epoch 2 --nb_train_sample 1000 --nb_val_sample 1000
Using TensorFlow backend.
INFO (2017-03-05 19:20:27,727): Building model ...
Replicate names:
BS27_1_SER, BS27_3_SER, BS27_5_SER, BS27_6_SER, BS27_8_SER

INFO (2017-03-05 19:20:27,735): Loading existing CpG model ...
INFO (2017-03-05 19:20:27,736): Using model files ./models/Smallwood2014_2i_cpg/model.json ./models/Smallwood2014_2i_cpg/model_weights.h5
INFO (2017-03-05 19:20:28,772): Replicate names differ: Copying weights to new model ...
Layer (type)                     Output Shape          Param #     Connected to                     
cpg/state (InputLayer)           (None, 5, 50)         0                                            
cpg/dist (InputLayer)            (None, 5, 50)         0                                            
cpg/merge_1 (Merge)              (None, 5, 100)        0           cpg/state[0][0]                  
cpg/timedistributed_1 (TimeDistr (None, 5, 256)        25856       cpg/merge_1[0][0]                
cpg/bidirectional_1 (Bidirection (None, 512)           787968      cpg/timedistributed_1[0][0]      
cpg/dropout_1 (Dropout)          (None, 512)           0           cpg/bidirectional_1[0][0]        
cpg/BS27_1_SER (Dense)           (None, 1)             513         cpg/dropout_1[0][0]              
cpg/BS27_3_SER (Dense)           (None, 1)             513         cpg/dropout_1[0][0]              
cpg/BS27_5_SER (Dense)           (None, 1)             513         cpg/dropout_1[0][0]              
cpg/BS27_6_SER (Dense)           (None, 1)             513         cpg/dropout_1[0][0]              
cpg/BS27_8_SER (Dense)           (None, 1)             513         cpg/dropout_1[0][0]              
Total params: 816389
Layer trainability:
                layer | trainable
cpg/timedistributed_1 |     False
  cpg/bidirectional_1 |     False
        cpg/dropout_1 |     False

INFO (2017-03-05 19:20:29,581): Computing output statistics ...
Output statistics:
          name | nb_tot | nb_obs | frac_obs | mean |  var
cpg/BS27_1_SER |   1000 |    351 |     0.35 | 0.90 | 0.09
cpg/BS27_3_SER |   1000 |    146 |     0.15 | 0.79 | 0.16
cpg/BS27_5_SER |   1000 |    220 |     0.22 | 0.85 | 0.12
cpg/BS27_6_SER |   1000 |    336 |     0.34 | 0.67 | 0.22
cpg/BS27_8_SER |   1000 |    276 |     0.28 | 0.92 | 0.07

Class weights:
cpg/BS27_1_SER | cpg/BS27_3_SER | cpg/BS27_5_SER | cpg/BS27_6_SER | cpg/BS27_8_SER
        0=0.90 |         0=0.79 |         0=0.85 |         0=0.67 |         0=0.92
        1=0.10 |         1=0.21 |         1=0.15 |         1=0.33 |         1=0.08

INFO (2017-03-05 19:20:29,862): Loading data ...
INFO (2017-03-05 19:20:29,864): Initializing callbacks ...
INFO (2017-03-05 19:20:29,865): Training model ...

Training samples: 1000
Epochs: 2
Learning rate: 0.0001
Epoch 1/2
done (%) | time |   loss |    acc | cpg/BS27_3_SER_loss | cpg/BS27_1_SER_loss | cpg/BS27_6_SER_loss | cpg/BS27_8_SER_loss | cpg/BS27_5_SER_loss | cpg/BS27_1_SER_acc | cpg/BS27_3_SER_acc | cpg/BS27_8_SER_acc | cpg/BS27_6_SER_acc | cpg/BS27_5_SER_acc
    12.8 |  0.0 | 0.4260 | 0.2876 |              0.0935 |              0.0511 |              0.0711 |              0.0703 |              0.1185 |             0.0889 |             0.6000 |             0.0882 |             0.1379 |             0.5227
    25.6 |  0.0 | 0.3779 | 0.3107 |              0.0771 |              0.0522 |              0.0734 |              0.0597 |              0.0923 |             0.0955 |             0.6269 |             0.0569 |             0.1599 |             0.6143
    38.4 |  0.0 | 0.3637 | 0.3138 |              0.0712 |              0.0471 |              0.0775 |              0.0551 |              0.0894 |             0.0875 |             0.5897 |             0.0588 |             0.1773 |             0.6559
    51.2 |  0.0 | 0.3494 | 0.3247 |              0.0704 |              0.0472 |              0.0768 |              0.0509 |              0.0810 |             0.0976 |             0.5804 |             0.0644 |             0.1746 |             0.7062
    64.0 |  0.0 | 0.3473 | 0.3295 |              0.0714 |              0.0467 |              0.0810 |              0.0454 |              0.0802 |             0.1019 |             0.6063 |             0.0515 |             0.1859 |             0.7018
    76.8 |  0.0 | 0.3472 | 0.3322 |              0.0736 |              0.0444 |              0.0838 |              0.0426 |              0.0806 |             0.0988 |             0.6256 |             0.0524 |             0.1909 |             0.6932
    89.6 |  0.0 | 0.3368 | 0.3379 |              0.0710 |              0.0434 |              0.0839 |              0.0392 |              0.0773 |             0.0903 |             0.6358 |             0.0568 |             0.1954 |             0.7114
   100.0 |  0.0 | 0.3433 | 0.3480 |              0.0708 |              0.0464 |              0.0857 |              0.0401 |              0.0785 |             0.1055 |             0.6445 |             0.0695 |             0.2076 |             0.7130
Epoch 00000: loss improved from inf to 0.34332, saving model to ./models/cpg/model_weights_val.h5

 split |   loss |    acc | cpg/BS27_3_SER_loss | cpg/BS27_8_SER_loss | cpg/BS27_1_SER_loss | cpg/BS27_6_SER_loss | cpg/BS27_5_SER_loss | cpg/BS27_3_SER_acc | cpg/BS27_6_SER_acc | cpg/BS27_8_SER_acc | cpg/BS27_5_SER_acc | cpg/BS27_1_SER_acc
 train | 0.3433 | 0.3480 |              0.0708 |              0.0401 |              0.0464 |              0.0857 |              0.0785 |             0.6445 |             0.2076 |             0.0695 |             0.7130 |             0.1055
Learning rate: 9.75e-05
Epoch 2/2
done (%) | time |   loss |    acc | cpg/BS27_3_SER_loss | cpg/BS27_1_SER_loss | cpg/BS27_6_SER_loss | cpg/BS27_8_SER_loss | cpg/BS27_5_SER_loss | cpg/BS27_1_SER_acc | cpg/BS27_3_SER_acc | cpg/BS27_8_SER_acc | cpg/BS27_6_SER_acc | cpg/BS27_5_SER_acc
    12.8 |  0.0 | 0.3435 | 0.4013 |              0.0811 |              0.0311 |              0.0910 |              0.0470 |              0.0736 |             0.0750 |             0.7838 |             0.1892 |             0.2821 |             0.6765
    25.6 |  0.1 | 0.3261 | 0.3615 |              0.0656 |              0.0358 |              0.0917 |              0.0364 |              0.0771 |             0.0970 |             0.6597 |             0.1074 |             0.2331 |             0.7103
    38.4 |  0.1 | 0.3414 | 0.3604 |              0.0674 |              0.0385 |              0.0948 |              0.0360 |              0.0853 |             0.1053 |             0.6398 |             0.1272 |             0.2286 |             0.7012
    51.2 |  0.1 | 0.3441 | 0.3664 |              0.0680 |              0.0459 |              0.0859 |              0.0377 |              0.0872 |             0.1225 |             0.6443 |             0.1224 |             0.2232 |             0.7196
    64.0 |  0.1 | 0.3436 | 0.3729 |              0.0629 |              0.0472 |              0.0885 |              0.0392 |              0.0864 |             0.1370 |             0.6207 |             0.1403 |             0.2157 |             0.7507
    76.8 |  0.1 | 0.3285 | 0.3756 |              0.0590 |              0.0470 |              0.0837 |              0.0384 |              0.0811 |             0.1367 |             0.6222 |             0.1345 |             0.2200 |             0.7645
    89.6 |  0.1 | 0.3297 | 0.3798 |              0.0583 |              0.0481 |              0.0877 |              0.0384 |              0.0779 |             0.1426 |             0.6180 |             0.1269 |             0.2308 |             0.7807
   100.0 |  0.1 | 0.3315 | 0.3844 |              0.0610 |              0.0492 |              0.0879 |              0.0358 |              0.0784 |             0.1563 |             0.6242 |             0.1137 |             0.2470 |             0.7808
Epoch 00001: loss improved from 0.34332 to 0.33151, saving model to ./models/cpg/model_weights_val.h5

 split |   loss |    acc | cpg/BS27_3_SER_loss | cpg/BS27_8_SER_loss | cpg/BS27_1_SER_loss | cpg/BS27_6_SER_loss | cpg/BS27_5_SER_loss | cpg/BS27_3_SER_acc | cpg/BS27_6_SER_acc | cpg/BS27_8_SER_acc | cpg/BS27_5_SER_acc | cpg/BS27_1_SER_acc
 train | 0.3315 | 0.3844 |              0.0610 |              0.0358 |              0.0492 |              0.0879 |              0.0784 |             0.6242 |             0.2470 |             0.1137 |             0.7808 |             0.1563

Training set performance:
  loss |    acc | cpg/BS27_3_SER_loss | cpg/BS27_8_SER_loss | cpg/BS27_1_SER_loss | cpg/BS27_6_SER_loss | cpg/BS27_5_SER_loss | cpg/BS27_3_SER_acc | cpg/BS27_6_SER_acc | cpg/BS27_8_SER_acc | cpg/BS27_5_SER_acc | cpg/BS27_1_SER_acc
0.3433 | 0.3480 |              0.0708 |              0.0401 |              0.0464 |              0.0857 |              0.0785 |             0.6445 |             0.2076 |             0.0695 |             0.7130 |             0.1055
0.3315 | 0.3844 |              0.0610 |              0.0358 |              0.0492 |              0.0879 |              0.0784 |             0.6242 |             0.2470 |             0.1137 |             0.7808 |             0.1563
INFO (2017-03-05 19:20:41,962): Done!
Exception ignored in: <bound method BaseSession.__del__ of <tensorflow.python.client.session.Session object at 0x1147d1cc0>>
Traceback (most recent call last):
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/tensorflow/python/client/", line 522, in __del__
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/tensorflow/python/client/", line 518, in close
AttributeError: 'NoneType' object has no attribute 'raise_exception_on_not_ok_status'

Imputing methylation profiles

Finally, we impute methylation profiles and evaluate our fine-tuned model using

In [9]:
mkdir -p $eval_dir

    --model_files ./models/cpg
    --out_data $eval_dir/data.h5
    --out_report $eval_dir/report.tsv
if [[ $test_mode -eq 1 ]]; then
        --nb_sample 1000
run $cmd

################################# ./data/c1_000000-010000.h5 --model_files ./models/cpg --out_data ./eval/data.h5 --out_report ./eval/report.tsv --nb_sample 1000
Using TensorFlow backend.
INFO (2017-03-05 19:20:46,772): Loading model ...
INFO (2017-03-05 19:20:47,542): Loading data ...
INFO (2017-03-05 19:20:47,571): Predicting ...
INFO (2017-03-05 19:20:47,587):  128/1000 (12.8%)
INFO (2017-03-05 19:20:47,686):  256/1000 (25.6%)
INFO (2017-03-05 19:20:47,759):  384/1000 (38.4%)
INFO (2017-03-05 19:20:47,833):  512/1000 (51.2%)
INFO (2017-03-05 19:20:47,914):  640/1000 (64.0%)
INFO (2017-03-05 19:20:47,991):  768/1000 (76.8%)
INFO (2017-03-05 19:20:48,078):  896/1000 (89.6%)
INFO (2017-03-05 19:20:48,158): 1000/1000 (100.0%)
/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/sklearn/metrics/ UndefinedMetricWarning: F-score is ill-defined and being set to 0.0 due to no predicted samples.
  'precision', 'predicted', average, warn_for)
           output       auc       acc       tpr       tnr        f1       mcc      n
2  cpg/BS27_5_SER  0.614279  0.850000  0.989362  0.031250  0.918519  0.062658  220.0
1  cpg/BS27_3_SER  0.574425  0.630137  0.663793  0.500000  0.740385  0.137086  146.0
0  cpg/BS27_1_SER  0.520723  0.139601  0.044444  0.972222  0.084848  0.025000  351.0
3  cpg/BS27_6_SER  0.437062  0.500000  0.580357  0.339286  0.607477 -0.077563  336.0
4  cpg/BS27_8_SER  0.424276  0.076087  0.000000  1.000000  0.000000  0.000000  276.0
INFO (2017-03-05 19:20:48,362): Done!