Fine-tuning a pre-trained model

This tutorial describes how to fine-tune a pre-trained model from the DeepCpG model zoo. Fine-tuning a model that has been pre-trained on a cells which are similar to the cells of interest can considerably decrease training time.

Initialization

We first initialize some variables that will be used throughout the tutorial. test_mode=1 should be used for testing purposes, which speeds up computations by only using a subset of the data. For real applications, test_mode=0 should be used.


In [1]:
function run {
  local cmd=$@
  echo
  echo "#################################"
  echo $cmd
  echo "#################################"
  eval $cmd
}

test_mode=1 # Set to 1 for testing and 0 otherwise
example_dir="../../data/" # Directory with example data.
cpg_dir="$example_dir/cpg" # Directory with CpG profiles.
dna_dir="$example_dir/dna/mm10" # Directory with DNA sequences.



Creating DeepCpG data files

First, we create DeepCpG data files using dcpg_data.py. Since we will fine-tune a CpG model, we do not extract sequence windows. Otherwise, --dna_files and --dna_wlen must to be specified.


In [4]:
data_dir="./data"
cmd="dcpg_data.py
    --cpg_profiles $cpg_dir/*.tsv
    --out_dir $data_dir
    --cpg_wlen 50
    "
if [[ $test_mode -eq 1 ]]; then
    cmd="$cmd
        --nb_sample 10000
        "
fi
run $cmd


#################################
dcpg_data.py --cpg_profiles ../../data//cpg/BS27_1_SER.tsv ../../data//cpg/BS27_3_SER.tsv ../../data//cpg/BS27_5_SER.tsv ../../data//cpg/BS27_6_SER.tsv ../../data//cpg/BS27_8_SER.tsv --out_dir ./data --cpg_wlen 50 --nb_sample 10000
#################################
INFO (2017-03-05 19:19:42,901): Reading single-cell profiles ...
INFO (2017-03-05 19:19:43,339): 10000 samples
INFO (2017-03-05 19:19:43,340): --------------------------------------------------------------------------------
INFO (2017-03-05 19:19:43,340): Chromosome 1 ...
INFO (2017-03-05 19:19:43,368): 10000 / 10000 (100.0%) sites matched minimum coverage filter
INFO (2017-03-05 19:19:43,369): Chunk 	1 / 1
INFO (2017-03-05 19:19:43,379): Extracting CpG neighbors ...
INFO (2017-03-05 19:19:44,498): Done!

Downloading a pre-trained model

dcpg_download.py downloads a pre-trained model from the DeepCpG model zoo. Available models and their corresponding description can be found on the model zoo website, or retrieved with dcpg_download.py --show:


In [5]:
dcpg_download.py --show


Available models: https://github.com/cangermueller/deepcpg/blob/master/docs/models.md
Hou2016_HCC_cpg
Hou2016_HCC_dna
Hou2016_HCC_joint
Hou2016_HepG2_cpg
Hou2016_HepG2_dna
Hou2016_HepG2_joint
Hou2016_mESC_cpg
Hou2016_mESC_dna
Hou2016_mESC_joint
Smallwood2014_2i_cpg
Smallwood2014_2i_dna
Smallwood2014_2i_joint
Smallwood2014_serum_cpg
Smallwood2014_serum_dna
Smallwood2014_serum_joint

A model name consist of three parts, which are separated by '_'. The first part corresponds to the publication, the second to the cell type, and the third to the modle type(CpG, DNA, or Joint model). Cells from 'Hou2016' were profiled using scRRBS-seq, cells from 'Smallwood2014' using scBS-seq. 'HCC' and 'HepG2' are human cancer cells, and the rest mouse cells. You should use the cell-type that is most similar to the cell-type you are working with. More information about the available models can be found here.

Since we are dealing with 2i cells and want to train a CpG model, we will fine-tune 'Smallwood2014_2i_cpg':


In [6]:
pretrained_model="./models/Smallwood2014_2i_cpg"
cmd="dcpg_download.py
  $(basename $pretrained_model)
  -o $pretrained_model
  "
run $cmd


#################################
dcpg_download.py Smallwood2014_2i_cpg -o ./models/Smallwood2014_2i_cpg
#################################
INFO (2017-03-05 19:19:51,601): Downloading model ...
INFO (2017-03-05 19:19:51,601): Model URL: http://www.ebi.ac.uk/~angermue/deepcpg/alias/f89b2e8344012d73e95504da06bcf378
--2017-03-05 19:19:51--  http://www.ebi.ac.uk/~angermue/deepcpg/alias/f89b2e8344012d73e95504da06bcf378
Resolving www.ebi.ac.uk (www.ebi.ac.uk)... 193.62.192.80
Connecting to www.ebi.ac.uk (www.ebi.ac.uk)|193.62.192.80|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 31068468 (30M) [text/plain]
Saving to: ‘./models/Smallwood2014_2i_cpg/model.zip’

./models/Smallwood2 100%[===================>]  29.63M  10.1MB/s    in 2.9s    

2017-03-05 19:19:54 (10.1 MB/s) - ‘./models/Smallwood2014_2i_cpg/model.zip’ saved [31068468/31068468]

Archive:  ./models/Smallwood2014_2i_cpg/model.zip
  inflating: ./models/Smallwood2014_2i_cpg/model.h5  
  inflating: ./models/Smallwood2014_2i_cpg/model.json  
  inflating: ./models/Smallwood2014_2i_cpg/model_weights.h5  
  inflating: ./models/Smallwood2014_2i_cpg/model_weights_train.h5  
  inflating: ./models/Smallwood2014_2i_cpg/model_weights_val.h5  
INFO (2017-03-05 19:19:55,376): Done!

The command downloads and stores model files in the output directory, including the weights and JSON file with the model specification:


In [7]:
ls $pretrained_model


model.h5               model_weights.h5       model_weights_val.h5
model.json             model_weights_train.h5

model.json stores the model specification, and model_weights_train.h5 and model_weights_val.h5 the weights that yielded the highest performance on the training and validation set, respectively. model.h5 combines model.json and model_weights_val.h5.

Fine-tuning the model

To fine-tune the downloaded model, we use --cpg_model followed by the model directory, and --fine_tune to only train the output layers.

--cpg_model $pretrained_model is equivalent to --cpg_model $pretrained_model/model.json $pretrained_model/model_weights_val.h5. To fine-tune the weights with the highest performance on the training set, you have to use model_weights_train.h5 as input instead of model_weights_val.h5.

Without --fine_tune, dcpg_train.py will train all weights, not only the output layers. This is recommended if the cells that were used for the pre-trained model are only distantly related to the cells of interests, e.g. if cell-types do not match. Training all weights can lead to a higher prediction performance, but also increase training time.


In [8]:
cmd="dcpg_train.py
    $data_dir/*.h5
    --cpg_model $pretrained_model
    --out_dir ./models/cpg
    --fine_tune
  "
if [[ $test_mode -eq 1 ]]; then
  cmd="$cmd
    --nb_epoch 2
    --nb_train_sample 1000
    --nb_val_sample 1000
    "
else
  cmd="$cmd
    --nb_epoch 25
    --early_stopping 5
    "
fi
run $cmd


#################################
dcpg_train.py ./data/c1_000000-010000.h5 --cpg_model ./models/Smallwood2014_2i_cpg --out_dir ./models/cpg --fine_tune --nb_epoch 2 --nb_train_sample 1000 --nb_val_sample 1000
#################################
Using TensorFlow backend.
INFO (2017-03-05 19:20:27,727): Building model ...
Replicate names:
BS27_1_SER, BS27_3_SER, BS27_5_SER, BS27_6_SER, BS27_8_SER

INFO (2017-03-05 19:20:27,735): Loading existing CpG model ...
INFO (2017-03-05 19:20:27,736): Using model files ./models/Smallwood2014_2i_cpg/model.json ./models/Smallwood2014_2i_cpg/model_weights.h5
INFO (2017-03-05 19:20:28,772): Replicate names differ: Copying weights to new model ...
____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
====================================================================================================
cpg/state (InputLayer)           (None, 5, 50)         0                                            
____________________________________________________________________________________________________
cpg/dist (InputLayer)            (None, 5, 50)         0                                            
____________________________________________________________________________________________________
cpg/merge_1 (Merge)              (None, 5, 100)        0           cpg/state[0][0]                  
                                                                   cpg/dist[0][0]                   
____________________________________________________________________________________________________
cpg/timedistributed_1 (TimeDistr (None, 5, 256)        25856       cpg/merge_1[0][0]                
____________________________________________________________________________________________________
cpg/bidirectional_1 (Bidirection (None, 512)           787968      cpg/timedistributed_1[0][0]      
____________________________________________________________________________________________________
cpg/dropout_1 (Dropout)          (None, 512)           0           cpg/bidirectional_1[0][0]        
____________________________________________________________________________________________________
cpg/BS27_1_SER (Dense)           (None, 1)             513         cpg/dropout_1[0][0]              
____________________________________________________________________________________________________
cpg/BS27_3_SER (Dense)           (None, 1)             513         cpg/dropout_1[0][0]              
____________________________________________________________________________________________________
cpg/BS27_5_SER (Dense)           (None, 1)             513         cpg/dropout_1[0][0]              
____________________________________________________________________________________________________
cpg/BS27_6_SER (Dense)           (None, 1)             513         cpg/dropout_1[0][0]              
____________________________________________________________________________________________________
cpg/BS27_8_SER (Dense)           (None, 1)             513         cpg/dropout_1[0][0]              
====================================================================================================
Total params: 816389
____________________________________________________________________________________________________
Layer trainability:
                layer | trainable
---------------------------------
cpg/timedistributed_1 |     False
  cpg/bidirectional_1 |     False
        cpg/dropout_1 |     False

INFO (2017-03-05 19:20:29,581): Computing output statistics ...
Output statistics:
          name | nb_tot | nb_obs | frac_obs | mean |  var
---------------------------------------------------------
cpg/BS27_1_SER |   1000 |    351 |     0.35 | 0.90 | 0.09
cpg/BS27_3_SER |   1000 |    146 |     0.15 | 0.79 | 0.16
cpg/BS27_5_SER |   1000 |    220 |     0.22 | 0.85 | 0.12
cpg/BS27_6_SER |   1000 |    336 |     0.34 | 0.67 | 0.22
cpg/BS27_8_SER |   1000 |    276 |     0.28 | 0.92 | 0.07

Class weights:
cpg/BS27_1_SER | cpg/BS27_3_SER | cpg/BS27_5_SER | cpg/BS27_6_SER | cpg/BS27_8_SER
----------------------------------------------------------------------------------
        0=0.90 |         0=0.79 |         0=0.85 |         0=0.67 |         0=0.92
        1=0.10 |         1=0.21 |         1=0.15 |         1=0.33 |         1=0.08

INFO (2017-03-05 19:20:29,862): Loading data ...
INFO (2017-03-05 19:20:29,864): Initializing callbacks ...
INFO (2017-03-05 19:20:29,865): Training model ...

Training samples: 1000
Epochs: 2
Learning rate: 0.0001
====================================================================================================
Epoch 1/2
====================================================================================================
done (%) | time |   loss |    acc | cpg/BS27_3_SER_loss | cpg/BS27_1_SER_loss | cpg/BS27_6_SER_loss | cpg/BS27_8_SER_loss | cpg/BS27_5_SER_loss | cpg/BS27_1_SER_acc | cpg/BS27_3_SER_acc | cpg/BS27_8_SER_acc | cpg/BS27_6_SER_acc | cpg/BS27_5_SER_acc
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
    12.8 |  0.0 | 0.4260 | 0.2876 |              0.0935 |              0.0511 |              0.0711 |              0.0703 |              0.1185 |             0.0889 |             0.6000 |             0.0882 |             0.1379 |             0.5227
    25.6 |  0.0 | 0.3779 | 0.3107 |              0.0771 |              0.0522 |              0.0734 |              0.0597 |              0.0923 |             0.0955 |             0.6269 |             0.0569 |             0.1599 |             0.6143
    38.4 |  0.0 | 0.3637 | 0.3138 |              0.0712 |              0.0471 |              0.0775 |              0.0551 |              0.0894 |             0.0875 |             0.5897 |             0.0588 |             0.1773 |             0.6559
    51.2 |  0.0 | 0.3494 | 0.3247 |              0.0704 |              0.0472 |              0.0768 |              0.0509 |              0.0810 |             0.0976 |             0.5804 |             0.0644 |             0.1746 |             0.7062
    64.0 |  0.0 | 0.3473 | 0.3295 |              0.0714 |              0.0467 |              0.0810 |              0.0454 |              0.0802 |             0.1019 |             0.6063 |             0.0515 |             0.1859 |             0.7018
    76.8 |  0.0 | 0.3472 | 0.3322 |              0.0736 |              0.0444 |              0.0838 |              0.0426 |              0.0806 |             0.0988 |             0.6256 |             0.0524 |             0.1909 |             0.6932
    89.6 |  0.0 | 0.3368 | 0.3379 |              0.0710 |              0.0434 |              0.0839 |              0.0392 |              0.0773 |             0.0903 |             0.6358 |             0.0568 |             0.1954 |             0.7114
   100.0 |  0.0 | 0.3433 | 0.3480 |              0.0708 |              0.0464 |              0.0857 |              0.0401 |              0.0785 |             0.1055 |             0.6445 |             0.0695 |             0.2076 |             0.7130
Epoch 00000: loss improved from inf to 0.34332, saving model to ./models/cpg/model_weights_val.h5

 split |   loss |    acc | cpg/BS27_3_SER_loss | cpg/BS27_8_SER_loss | cpg/BS27_1_SER_loss | cpg/BS27_6_SER_loss | cpg/BS27_5_SER_loss | cpg/BS27_3_SER_acc | cpg/BS27_6_SER_acc | cpg/BS27_8_SER_acc | cpg/BS27_5_SER_acc | cpg/BS27_1_SER_acc
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
 train | 0.3433 | 0.3480 |              0.0708 |              0.0401 |              0.0464 |              0.0857 |              0.0785 |             0.6445 |             0.2076 |             0.0695 |             0.7130 |             0.1055
Learning rate: 9.75e-05
====================================================================================================
Epoch 2/2
====================================================================================================
done (%) | time |   loss |    acc | cpg/BS27_3_SER_loss | cpg/BS27_1_SER_loss | cpg/BS27_6_SER_loss | cpg/BS27_8_SER_loss | cpg/BS27_5_SER_loss | cpg/BS27_1_SER_acc | cpg/BS27_3_SER_acc | cpg/BS27_8_SER_acc | cpg/BS27_6_SER_acc | cpg/BS27_5_SER_acc
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
    12.8 |  0.0 | 0.3435 | 0.4013 |              0.0811 |              0.0311 |              0.0910 |              0.0470 |              0.0736 |             0.0750 |             0.7838 |             0.1892 |             0.2821 |             0.6765
    25.6 |  0.1 | 0.3261 | 0.3615 |              0.0656 |              0.0358 |              0.0917 |              0.0364 |              0.0771 |             0.0970 |             0.6597 |             0.1074 |             0.2331 |             0.7103
    38.4 |  0.1 | 0.3414 | 0.3604 |              0.0674 |              0.0385 |              0.0948 |              0.0360 |              0.0853 |             0.1053 |             0.6398 |             0.1272 |             0.2286 |             0.7012
    51.2 |  0.1 | 0.3441 | 0.3664 |              0.0680 |              0.0459 |              0.0859 |              0.0377 |              0.0872 |             0.1225 |             0.6443 |             0.1224 |             0.2232 |             0.7196
    64.0 |  0.1 | 0.3436 | 0.3729 |              0.0629 |              0.0472 |              0.0885 |              0.0392 |              0.0864 |             0.1370 |             0.6207 |             0.1403 |             0.2157 |             0.7507
    76.8 |  0.1 | 0.3285 | 0.3756 |              0.0590 |              0.0470 |              0.0837 |              0.0384 |              0.0811 |             0.1367 |             0.6222 |             0.1345 |             0.2200 |             0.7645
    89.6 |  0.1 | 0.3297 | 0.3798 |              0.0583 |              0.0481 |              0.0877 |              0.0384 |              0.0779 |             0.1426 |             0.6180 |             0.1269 |             0.2308 |             0.7807
   100.0 |  0.1 | 0.3315 | 0.3844 |              0.0610 |              0.0492 |              0.0879 |              0.0358 |              0.0784 |             0.1563 |             0.6242 |             0.1137 |             0.2470 |             0.7808
Epoch 00001: loss improved from 0.34332 to 0.33151, saving model to ./models/cpg/model_weights_val.h5

 split |   loss |    acc | cpg/BS27_3_SER_loss | cpg/BS27_8_SER_loss | cpg/BS27_1_SER_loss | cpg/BS27_6_SER_loss | cpg/BS27_5_SER_loss | cpg/BS27_3_SER_acc | cpg/BS27_6_SER_acc | cpg/BS27_8_SER_acc | cpg/BS27_5_SER_acc | cpg/BS27_1_SER_acc
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
 train | 0.3315 | 0.3844 |              0.0610 |              0.0358 |              0.0492 |              0.0879 |              0.0784 |             0.6242 |             0.2470 |             0.1137 |             0.7808 |             0.1563
====================================================================================================

Training set performance:
  loss |    acc | cpg/BS27_3_SER_loss | cpg/BS27_8_SER_loss | cpg/BS27_1_SER_loss | cpg/BS27_6_SER_loss | cpg/BS27_5_SER_loss | cpg/BS27_3_SER_acc | cpg/BS27_6_SER_acc | cpg/BS27_8_SER_acc | cpg/BS27_5_SER_acc | cpg/BS27_1_SER_acc
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
0.3433 | 0.3480 |              0.0708 |              0.0401 |              0.0464 |              0.0857 |              0.0785 |             0.6445 |             0.2076 |             0.0695 |             0.7130 |             0.1055
0.3315 | 0.3844 |              0.0610 |              0.0358 |              0.0492 |              0.0879 |              0.0784 |             0.6242 |             0.2470 |             0.1137 |             0.7808 |             0.1563
INFO (2017-03-05 19:20:41,962): Done!
Exception ignored in: <bound method BaseSession.__del__ of <tensorflow.python.client.session.Session object at 0x1147d1cc0>>
Traceback (most recent call last):
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 522, in __del__
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 518, in close
AttributeError: 'NoneType' object has no attribute 'raise_exception_on_not_ok_status'

Imputing methylation profiles

Finally, we impute methylation profiles and evaluate our fine-tuned model using dcpg_eval.py:


In [9]:
eval_dir="./eval"
mkdir -p $eval_dir

cmd="dcpg_eval.py
    $data_dir/*.h5
    --model_files ./models/cpg
    --out_data $eval_dir/data.h5
    --out_report $eval_dir/report.tsv
    "
if [[ $test_mode -eq 1 ]]; then
    cmd="$cmd
        --nb_sample 1000
        "
fi
run $cmd


#################################
dcpg_eval.py ./data/c1_000000-010000.h5 --model_files ./models/cpg --out_data ./eval/data.h5 --out_report ./eval/report.tsv --nb_sample 1000
#################################
Using TensorFlow backend.
INFO (2017-03-05 19:20:46,772): Loading model ...
INFO (2017-03-05 19:20:47,542): Loading data ...
INFO (2017-03-05 19:20:47,571): Predicting ...
INFO (2017-03-05 19:20:47,587):  128/1000 (12.8%)
INFO (2017-03-05 19:20:47,686):  256/1000 (25.6%)
INFO (2017-03-05 19:20:47,759):  384/1000 (38.4%)
INFO (2017-03-05 19:20:47,833):  512/1000 (51.2%)
INFO (2017-03-05 19:20:47,914):  640/1000 (64.0%)
INFO (2017-03-05 19:20:47,991):  768/1000 (76.8%)
INFO (2017-03-05 19:20:48,078):  896/1000 (89.6%)
INFO (2017-03-05 19:20:48,158): 1000/1000 (100.0%)
/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/sklearn/metrics/classification.py:1074: UndefinedMetricWarning: F-score is ill-defined and being set to 0.0 due to no predicted samples.
  'precision', 'predicted', average, warn_for)
           output       auc       acc       tpr       tnr        f1       mcc      n
2  cpg/BS27_5_SER  0.614279  0.850000  0.989362  0.031250  0.918519  0.062658  220.0
1  cpg/BS27_3_SER  0.574425  0.630137  0.663793  0.500000  0.740385  0.137086  146.0
0  cpg/BS27_1_SER  0.520723  0.139601  0.044444  0.972222  0.084848  0.025000  351.0
3  cpg/BS27_6_SER  0.437062  0.500000  0.580357  0.339286  0.607477 -0.077563  336.0
4  cpg/BS27_8_SER  0.424276  0.076087  0.000000  1.000000  0.000000  0.000000  276.0
INFO (2017-03-05 19:20:48,362): Done!