In [1]:
import os
import json
os.environ["CUDA_VISIBLE_DEVICES"] = '-1' ### run on CPU
import tensorflow as tf
print(tf.__version__)
if tf.__version__[0] == '1':
    tf.compat.v1.enable_eager_execution()

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from cooltools.lib.numutils import set_diag
from basenji import dataset, dna_io, seqnn


1.15.0

Load trained model


In [2]:
### load params, specify model ###

model_dir = './'
params_file = model_dir+'params.json'
model_file  = model_dir+'model_best.h5'
with open(params_file) as params_open:
    params = json.load(params_open)
    params_model = params['model']
    params_train = params['train']

seqnn_model = seqnn.SeqNN(params_model)


Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
sequence (InputLayer)           [(None, 1048576, 4)] 0                                            
__________________________________________________________________________________________________
stochastic_reverse_complement ( ((None, 1048576, 4), 0           sequence[0][0]                   
__________________________________________________________________________________________________
stochastic_shift (StochasticShi (None, 1048576, 4)   0           stochastic_reverse_complement[0][
__________________________________________________________________________________________________
re_lu (ReLU)                    (None, 1048576, 4)   0           stochastic_shift[0][0]           
__________________________________________________________________________________________________
conv1d (Conv1D)                 (None, 1048576, 96)  4224        re_lu[0][0]                      
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 1048576, 96)  384         conv1d[0][0]                     
__________________________________________________________________________________________________
max_pooling1d (MaxPooling1D)    (None, 524288, 96)   0           batch_normalization[0][0]        
__________________________________________________________________________________________________
re_lu_1 (ReLU)                  (None, 524288, 96)   0           max_pooling1d[0][0]              
__________________________________________________________________________________________________
conv1d_1 (Conv1D)               (None, 524288, 96)   46080       re_lu_1[0][0]                    
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 524288, 96)   384         conv1d_1[0][0]                   
__________________________________________________________________________________________________
max_pooling1d_1 (MaxPooling1D)  (None, 262144, 96)   0           batch_normalization_1[0][0]      
__________________________________________________________________________________________________
re_lu_2 (ReLU)                  (None, 262144, 96)   0           max_pooling1d_1[0][0]            
__________________________________________________________________________________________________
conv1d_2 (Conv1D)               (None, 262144, 96)   46080       re_lu_2[0][0]                    
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, 262144, 96)   384         conv1d_2[0][0]                   
__________________________________________________________________________________________________
max_pooling1d_2 (MaxPooling1D)  (None, 131072, 96)   0           batch_normalization_2[0][0]      
__________________________________________________________________________________________________
re_lu_3 (ReLU)                  (None, 131072, 96)   0           max_pooling1d_2[0][0]            
__________________________________________________________________________________________________
conv1d_3 (Conv1D)               (None, 131072, 96)   46080       re_lu_3[0][0]                    
__________________________________________________________________________________________________
batch_normalization_3 (BatchNor (None, 131072, 96)   384         conv1d_3[0][0]                   
__________________________________________________________________________________________________
max_pooling1d_3 (MaxPooling1D)  (None, 65536, 96)    0           batch_normalization_3[0][0]      
__________________________________________________________________________________________________
re_lu_4 (ReLU)                  (None, 65536, 96)    0           max_pooling1d_3[0][0]            
__________________________________________________________________________________________________
conv1d_4 (Conv1D)               (None, 65536, 96)    46080       re_lu_4[0][0]                    
__________________________________________________________________________________________________
batch_normalization_4 (BatchNor (None, 65536, 96)    384         conv1d_4[0][0]                   
__________________________________________________________________________________________________
max_pooling1d_4 (MaxPooling1D)  (None, 32768, 96)    0           batch_normalization_4[0][0]      
__________________________________________________________________________________________________
re_lu_5 (ReLU)                  (None, 32768, 96)    0           max_pooling1d_4[0][0]            
__________________________________________________________________________________________________
conv1d_5 (Conv1D)               (None, 32768, 96)    46080       re_lu_5[0][0]                    
__________________________________________________________________________________________________
batch_normalization_5 (BatchNor (None, 32768, 96)    384         conv1d_5[0][0]                   
__________________________________________________________________________________________________
max_pooling1d_5 (MaxPooling1D)  (None, 16384, 96)    0           batch_normalization_5[0][0]      
__________________________________________________________________________________________________
re_lu_6 (ReLU)                  (None, 16384, 96)    0           max_pooling1d_5[0][0]            
__________________________________________________________________________________________________
conv1d_6 (Conv1D)               (None, 16384, 96)    46080       re_lu_6[0][0]                    
__________________________________________________________________________________________________
batch_normalization_6 (BatchNor (None, 16384, 96)    384         conv1d_6[0][0]                   
__________________________________________________________________________________________________
max_pooling1d_6 (MaxPooling1D)  (None, 8192, 96)     0           batch_normalization_6[0][0]      
__________________________________________________________________________________________________
re_lu_7 (ReLU)                  (None, 8192, 96)     0           max_pooling1d_6[0][0]            
__________________________________________________________________________________________________
conv1d_7 (Conv1D)               (None, 8192, 96)     46080       re_lu_7[0][0]                    
__________________________________________________________________________________________________
batch_normalization_7 (BatchNor (None, 8192, 96)     384         conv1d_7[0][0]                   
__________________________________________________________________________________________________
max_pooling1d_7 (MaxPooling1D)  (None, 4096, 96)     0           batch_normalization_7[0][0]      
__________________________________________________________________________________________________
re_lu_8 (ReLU)                  (None, 4096, 96)     0           max_pooling1d_7[0][0]            
__________________________________________________________________________________________________
conv1d_8 (Conv1D)               (None, 4096, 96)     46080       re_lu_8[0][0]                    
__________________________________________________________________________________________________
batch_normalization_8 (BatchNor (None, 4096, 96)     384         conv1d_8[0][0]                   
__________________________________________________________________________________________________
max_pooling1d_8 (MaxPooling1D)  (None, 2048, 96)     0           batch_normalization_8[0][0]      
__________________________________________________________________________________________________
re_lu_9 (ReLU)                  (None, 2048, 96)     0           max_pooling1d_8[0][0]            
__________________________________________________________________________________________________
conv1d_9 (Conv1D)               (None, 2048, 96)     46080       re_lu_9[0][0]                    
__________________________________________________________________________________________________
batch_normalization_9 (BatchNor (None, 2048, 96)     384         conv1d_9[0][0]                   
__________________________________________________________________________________________________
max_pooling1d_9 (MaxPooling1D)  (None, 1024, 96)     0           batch_normalization_9[0][0]      
__________________________________________________________________________________________________
re_lu_10 (ReLU)                 (None, 1024, 96)     0           max_pooling1d_9[0][0]            
__________________________________________________________________________________________________
conv1d_10 (Conv1D)              (None, 1024, 96)     46080       re_lu_10[0][0]                   
__________________________________________________________________________________________________
batch_normalization_10 (BatchNo (None, 1024, 96)     384         conv1d_10[0][0]                  
__________________________________________________________________________________________________
max_pooling1d_10 (MaxPooling1D) (None, 512, 96)      0           batch_normalization_10[0][0]     
__________________________________________________________________________________________________
re_lu_11 (ReLU)                 (None, 512, 96)      0           max_pooling1d_10[0][0]           
__________________________________________________________________________________________________
conv1d_11 (Conv1D)              (None, 512, 48)      13824       re_lu_11[0][0]                   
__________________________________________________________________________________________________
batch_normalization_11 (BatchNo (None, 512, 48)      192         conv1d_11[0][0]                  
__________________________________________________________________________________________________
re_lu_12 (ReLU)                 (None, 512, 48)      0           batch_normalization_11[0][0]     
__________________________________________________________________________________________________
conv1d_12 (Conv1D)              (None, 512, 96)      4608        re_lu_12[0][0]                   
__________________________________________________________________________________________________
batch_normalization_12 (BatchNo (None, 512, 96)      384         conv1d_12[0][0]                  
__________________________________________________________________________________________________
dropout (Dropout)               (None, 512, 96)      0           batch_normalization_12[0][0]     
__________________________________________________________________________________________________
add (Add)                       (None, 512, 96)      0           max_pooling1d_10[0][0]           
                                                                 dropout[0][0]                    
__________________________________________________________________________________________________
re_lu_13 (ReLU)                 (None, 512, 96)      0           add[0][0]                        
__________________________________________________________________________________________________
conv1d_13 (Conv1D)              (None, 512, 48)      13824       re_lu_13[0][0]                   
__________________________________________________________________________________________________
batch_normalization_13 (BatchNo (None, 512, 48)      192         conv1d_13[0][0]                  
__________________________________________________________________________________________________
re_lu_14 (ReLU)                 (None, 512, 48)      0           batch_normalization_13[0][0]     
__________________________________________________________________________________________________
conv1d_14 (Conv1D)              (None, 512, 96)      4608        re_lu_14[0][0]                   
__________________________________________________________________________________________________
batch_normalization_14 (BatchNo (None, 512, 96)      384         conv1d_14[0][0]                  
__________________________________________________________________________________________________
dropout_1 (Dropout)             (None, 512, 96)      0           batch_normalization_14[0][0]     
__________________________________________________________________________________________________
add_1 (Add)                     (None, 512, 96)      0           add[0][0]                        
                                                                 dropout_1[0][0]                  
__________________________________________________________________________________________________
re_lu_15 (ReLU)                 (None, 512, 96)      0           add_1[0][0]                      
__________________________________________________________________________________________________
conv1d_15 (Conv1D)              (None, 512, 48)      13824       re_lu_15[0][0]                   
__________________________________________________________________________________________________
batch_normalization_15 (BatchNo (None, 512, 48)      192         conv1d_15[0][0]                  
__________________________________________________________________________________________________
re_lu_16 (ReLU)                 (None, 512, 48)      0           batch_normalization_15[0][0]     
__________________________________________________________________________________________________
conv1d_16 (Conv1D)              (None, 512, 96)      4608        re_lu_16[0][0]                   
__________________________________________________________________________________________________
batch_normalization_16 (BatchNo (None, 512, 96)      384         conv1d_16[0][0]                  
__________________________________________________________________________________________________
dropout_2 (Dropout)             (None, 512, 96)      0           batch_normalization_16[0][0]     
__________________________________________________________________________________________________
add_2 (Add)                     (None, 512, 96)      0           add_1[0][0]                      
                                                                 dropout_2[0][0]                  
__________________________________________________________________________________________________
re_lu_17 (ReLU)                 (None, 512, 96)      0           add_2[0][0]                      
__________________________________________________________________________________________________
conv1d_17 (Conv1D)              (None, 512, 48)      13824       re_lu_17[0][0]                   
__________________________________________________________________________________________________
batch_normalization_17 (BatchNo (None, 512, 48)      192         conv1d_17[0][0]                  
__________________________________________________________________________________________________
re_lu_18 (ReLU)                 (None, 512, 48)      0           batch_normalization_17[0][0]     
__________________________________________________________________________________________________
conv1d_18 (Conv1D)              (None, 512, 96)      4608        re_lu_18[0][0]                   
__________________________________________________________________________________________________
batch_normalization_18 (BatchNo (None, 512, 96)      384         conv1d_18[0][0]                  
__________________________________________________________________________________________________
dropout_3 (Dropout)             (None, 512, 96)      0           batch_normalization_18[0][0]     
__________________________________________________________________________________________________
add_3 (Add)                     (None, 512, 96)      0           add_2[0][0]                      
                                                                 dropout_3[0][0]                  
__________________________________________________________________________________________________
re_lu_19 (ReLU)                 (None, 512, 96)      0           add_3[0][0]                      
__________________________________________________________________________________________________
conv1d_19 (Conv1D)              (None, 512, 48)      13824       re_lu_19[0][0]                   
__________________________________________________________________________________________________
batch_normalization_19 (BatchNo (None, 512, 48)      192         conv1d_19[0][0]                  
__________________________________________________________________________________________________
re_lu_20 (ReLU)                 (None, 512, 48)      0           batch_normalization_19[0][0]     
__________________________________________________________________________________________________
conv1d_20 (Conv1D)              (None, 512, 96)      4608        re_lu_20[0][0]                   
__________________________________________________________________________________________________
batch_normalization_20 (BatchNo (None, 512, 96)      384         conv1d_20[0][0]                  
__________________________________________________________________________________________________
dropout_4 (Dropout)             (None, 512, 96)      0           batch_normalization_20[0][0]     
__________________________________________________________________________________________________
add_4 (Add)                     (None, 512, 96)      0           add_3[0][0]                      
                                                                 dropout_4[0][0]                  
__________________________________________________________________________________________________
re_lu_21 (ReLU)                 (None, 512, 96)      0           add_4[0][0]                      
__________________________________________________________________________________________________
conv1d_21 (Conv1D)              (None, 512, 48)      13824       re_lu_21[0][0]                   
__________________________________________________________________________________________________
batch_normalization_21 (BatchNo (None, 512, 48)      192         conv1d_21[0][0]                  
__________________________________________________________________________________________________
re_lu_22 (ReLU)                 (None, 512, 48)      0           batch_normalization_21[0][0]     
__________________________________________________________________________________________________
conv1d_22 (Conv1D)              (None, 512, 96)      4608        re_lu_22[0][0]                   
__________________________________________________________________________________________________
batch_normalization_22 (BatchNo (None, 512, 96)      384         conv1d_22[0][0]                  
__________________________________________________________________________________________________
dropout_5 (Dropout)             (None, 512, 96)      0           batch_normalization_22[0][0]     
__________________________________________________________________________________________________
add_5 (Add)                     (None, 512, 96)      0           add_4[0][0]                      
                                                                 dropout_5[0][0]                  
__________________________________________________________________________________________________
re_lu_23 (ReLU)                 (None, 512, 96)      0           add_5[0][0]                      
__________________________________________________________________________________________________
conv1d_23 (Conv1D)              (None, 512, 48)      13824       re_lu_23[0][0]                   
__________________________________________________________________________________________________
batch_normalization_23 (BatchNo (None, 512, 48)      192         conv1d_23[0][0]                  
__________________________________________________________________________________________________
re_lu_24 (ReLU)                 (None, 512, 48)      0           batch_normalization_23[0][0]     
__________________________________________________________________________________________________
conv1d_24 (Conv1D)              (None, 512, 96)      4608        re_lu_24[0][0]                   
__________________________________________________________________________________________________
batch_normalization_24 (BatchNo (None, 512, 96)      384         conv1d_24[0][0]                  
__________________________________________________________________________________________________
dropout_6 (Dropout)             (None, 512, 96)      0           batch_normalization_24[0][0]     
__________________________________________________________________________________________________
add_6 (Add)                     (None, 512, 96)      0           add_5[0][0]                      
                                                                 dropout_6[0][0]                  
__________________________________________________________________________________________________
re_lu_25 (ReLU)                 (None, 512, 96)      0           add_6[0][0]                      
__________________________________________________________________________________________________
conv1d_25 (Conv1D)              (None, 512, 48)      13824       re_lu_25[0][0]                   
__________________________________________________________________________________________________
batch_normalization_25 (BatchNo (None, 512, 48)      192         conv1d_25[0][0]                  
__________________________________________________________________________________________________
re_lu_26 (ReLU)                 (None, 512, 48)      0           batch_normalization_25[0][0]     
__________________________________________________________________________________________________
conv1d_26 (Conv1D)              (None, 512, 96)      4608        re_lu_26[0][0]                   
__________________________________________________________________________________________________
batch_normalization_26 (BatchNo (None, 512, 96)      384         conv1d_26[0][0]                  
__________________________________________________________________________________________________
dropout_7 (Dropout)             (None, 512, 96)      0           batch_normalization_26[0][0]     
__________________________________________________________________________________________________
add_7 (Add)                     (None, 512, 96)      0           add_6[0][0]                      
                                                                 dropout_7[0][0]                  
__________________________________________________________________________________________________
re_lu_27 (ReLU)                 (None, 512, 96)      0           add_7[0][0]                      
__________________________________________________________________________________________________
conv1d_27 (Conv1D)              (None, 512, 64)      30720       re_lu_27[0][0]                   
__________________________________________________________________________________________________
batch_normalization_27 (BatchNo (None, 512, 64)      256         conv1d_27[0][0]                  
__________________________________________________________________________________________________
re_lu_28 (ReLU)                 (None, 512, 64)      0           batch_normalization_27[0][0]     
__________________________________________________________________________________________________
one_to_two (OneToTwo)           (None, 512, 512, 64) 0           re_lu_28[0][0]                   
__________________________________________________________________________________________________
concat_dist2d (ConcatDist2D)    (None, 512, 512, 65) 0           one_to_two[0][0]                 
__________________________________________________________________________________________________
re_lu_29 (ReLU)                 (None, 512, 512, 65) 0           concat_dist2d[0][0]              
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 512, 512, 48) 28080       re_lu_29[0][0]                   
__________________________________________________________________________________________________
batch_normalization_28 (BatchNo (None, 512, 512, 48) 192         conv2d[0][0]                     
__________________________________________________________________________________________________
symmetrize2d (Symmetrize2D)     (None, 512, 512, 48) 0           batch_normalization_28[0][0]     
__________________________________________________________________________________________________
re_lu_30 (ReLU)                 (None, 512, 512, 48) 0           symmetrize2d[0][0]               
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 512, 512, 24) 10368       re_lu_30[0][0]                   
__________________________________________________________________________________________________
batch_normalization_29 (BatchNo (None, 512, 512, 24) 96          conv2d_1[0][0]                   
__________________________________________________________________________________________________
re_lu_31 (ReLU)                 (None, 512, 512, 24) 0           batch_normalization_29[0][0]     
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 512, 512, 48) 1152        re_lu_31[0][0]                   
__________________________________________________________________________________________________
batch_normalization_30 (BatchNo (None, 512, 512, 48) 192         conv2d_2[0][0]                   
__________________________________________________________________________________________________
dropout_8 (Dropout)             (None, 512, 512, 48) 0           batch_normalization_30[0][0]     
__________________________________________________________________________________________________
add_8 (Add)                     (None, 512, 512, 48) 0           symmetrize2d[0][0]               
                                                                 dropout_8[0][0]                  
__________________________________________________________________________________________________
symmetrize2d_1 (Symmetrize2D)   (None, 512, 512, 48) 0           add_8[0][0]                      
__________________________________________________________________________________________________
re_lu_32 (ReLU)                 (None, 512, 512, 48) 0           symmetrize2d_1[0][0]             
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 512, 512, 24) 10368       re_lu_32[0][0]                   
__________________________________________________________________________________________________
batch_normalization_31 (BatchNo (None, 512, 512, 24) 96          conv2d_3[0][0]                   
__________________________________________________________________________________________________
re_lu_33 (ReLU)                 (None, 512, 512, 24) 0           batch_normalization_31[0][0]     
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 512, 512, 48) 1152        re_lu_33[0][0]                   
__________________________________________________________________________________________________
batch_normalization_32 (BatchNo (None, 512, 512, 48) 192         conv2d_4[0][0]                   
__________________________________________________________________________________________________
dropout_9 (Dropout)             (None, 512, 512, 48) 0           batch_normalization_32[0][0]     
__________________________________________________________________________________________________
add_9 (Add)                     (None, 512, 512, 48) 0           symmetrize2d_1[0][0]             
                                                                 dropout_9[0][0]                  
__________________________________________________________________________________________________
symmetrize2d_2 (Symmetrize2D)   (None, 512, 512, 48) 0           add_9[0][0]                      
__________________________________________________________________________________________________
re_lu_34 (ReLU)                 (None, 512, 512, 48) 0           symmetrize2d_2[0][0]             
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, 512, 512, 24) 10368       re_lu_34[0][0]                   
__________________________________________________________________________________________________
batch_normalization_33 (BatchNo (None, 512, 512, 24) 96          conv2d_5[0][0]                   
__________________________________________________________________________________________________
re_lu_35 (ReLU)                 (None, 512, 512, 24) 0           batch_normalization_33[0][0]     
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 512, 512, 48) 1152        re_lu_35[0][0]                   
__________________________________________________________________________________________________
batch_normalization_34 (BatchNo (None, 512, 512, 48) 192         conv2d_6[0][0]                   
__________________________________________________________________________________________________
dropout_10 (Dropout)            (None, 512, 512, 48) 0           batch_normalization_34[0][0]     
__________________________________________________________________________________________________
add_10 (Add)                    (None, 512, 512, 48) 0           symmetrize2d_2[0][0]             
                                                                 dropout_10[0][0]                 
__________________________________________________________________________________________________
symmetrize2d_3 (Symmetrize2D)   (None, 512, 512, 48) 0           add_10[0][0]                     
__________________________________________________________________________________________________
re_lu_36 (ReLU)                 (None, 512, 512, 48) 0           symmetrize2d_3[0][0]             
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, 512, 512, 24) 10368       re_lu_36[0][0]                   
__________________________________________________________________________________________________
batch_normalization_35 (BatchNo (None, 512, 512, 24) 96          conv2d_7[0][0]                   
__________________________________________________________________________________________________
re_lu_37 (ReLU)                 (None, 512, 512, 24) 0           batch_normalization_35[0][0]     
__________________________________________________________________________________________________
conv2d_8 (Conv2D)               (None, 512, 512, 48) 1152        re_lu_37[0][0]                   
__________________________________________________________________________________________________
batch_normalization_36 (BatchNo (None, 512, 512, 48) 192         conv2d_8[0][0]                   
__________________________________________________________________________________________________
dropout_11 (Dropout)            (None, 512, 512, 48) 0           batch_normalization_36[0][0]     
__________________________________________________________________________________________________
add_11 (Add)                    (None, 512, 512, 48) 0           symmetrize2d_3[0][0]             
                                                                 dropout_11[0][0]                 
__________________________________________________________________________________________________
symmetrize2d_4 (Symmetrize2D)   (None, 512, 512, 48) 0           add_11[0][0]                     
__________________________________________________________________________________________________
re_lu_38 (ReLU)                 (None, 512, 512, 48) 0           symmetrize2d_4[0][0]             
__________________________________________________________________________________________________
conv2d_9 (Conv2D)               (None, 512, 512, 24) 10368       re_lu_38[0][0]                   
__________________________________________________________________________________________________
batch_normalization_37 (BatchNo (None, 512, 512, 24) 96          conv2d_9[0][0]                   
__________________________________________________________________________________________________
re_lu_39 (ReLU)                 (None, 512, 512, 24) 0           batch_normalization_37[0][0]     
__________________________________________________________________________________________________
conv2d_10 (Conv2D)              (None, 512, 512, 48) 1152        re_lu_39[0][0]                   
__________________________________________________________________________________________________
batch_normalization_38 (BatchNo (None, 512, 512, 48) 192         conv2d_10[0][0]                  
__________________________________________________________________________________________________
dropout_12 (Dropout)            (None, 512, 512, 48) 0           batch_normalization_38[0][0]     
__________________________________________________________________________________________________
add_12 (Add)                    (None, 512, 512, 48) 0           symmetrize2d_4[0][0]             
                                                                 dropout_12[0][0]                 
__________________________________________________________________________________________________
symmetrize2d_5 (Symmetrize2D)   (None, 512, 512, 48) 0           add_12[0][0]                     
__________________________________________________________________________________________________
re_lu_40 (ReLU)                 (None, 512, 512, 48) 0           symmetrize2d_5[0][0]             
__________________________________________________________________________________________________
conv2d_11 (Conv2D)              (None, 512, 512, 24) 10368       re_lu_40[0][0]                   
__________________________________________________________________________________________________
batch_normalization_39 (BatchNo (None, 512, 512, 24) 96          conv2d_11[0][0]                  
__________________________________________________________________________________________________
re_lu_41 (ReLU)                 (None, 512, 512, 24) 0           batch_normalization_39[0][0]     
__________________________________________________________________________________________________
conv2d_12 (Conv2D)              (None, 512, 512, 48) 1152        re_lu_41[0][0]                   
__________________________________________________________________________________________________
batch_normalization_40 (BatchNo (None, 512, 512, 48) 192         conv2d_12[0][0]                  
__________________________________________________________________________________________________
dropout_13 (Dropout)            (None, 512, 512, 48) 0           batch_normalization_40[0][0]     
__________________________________________________________________________________________________
add_13 (Add)                    (None, 512, 512, 48) 0           symmetrize2d_5[0][0]             
                                                                 dropout_13[0][0]                 
__________________________________________________________________________________________________
symmetrize2d_6 (Symmetrize2D)   (None, 512, 512, 48) 0           add_13[0][0]                     
__________________________________________________________________________________________________
cropping2d (Cropping2D)         (None, 448, 448, 48) 0           symmetrize2d_6[0][0]             
__________________________________________________________________________________________________
upper_tri (UpperTri)            (None, 99681, 48)    0           cropping2d[0][0]                 
__________________________________________________________________________________________________
dense (Dense)                   (None, 99681, 5)     245         upper_tri[0][0]                  
__________________________________________________________________________________________________
switch_reverse_triu (SwitchReve (None, 99681, 5)     0           dense[0][0]                      
                                                                 stochastic_reverse_complement[0][
==================================================================================================
Total params: 751,653
Trainable params: 746,149
Non-trainable params: 5,504
__________________________________________________________________________________________________
None
model_strides [2048]
target_lengths [99681]
target_crops [-49585]

In [3]:
### restore model ###
# note: run %%bash get_model.sh 
# if you have not already downloaded the model
seqnn_model.restore(model_file)
print('successfully loaded')


successfully loaded

In [4]:
### names of targets ###
data_dir =   './data/'

hic_targets = pd.read_csv(data_dir+'/targets.txt',sep='\t')
hic_file_dict_num = dict(zip(hic_targets['index'].values, hic_targets['file'].values) )
hic_file_dict     = dict(zip(hic_targets['identifier'].values, hic_targets['file'].values) )
hic_num_to_name_dict = dict(zip(hic_targets['index'].values, hic_targets['identifier'].values) )

# read data parameters
data_stats_file = '%s/statistics.json' % data_dir
with open(data_stats_file) as data_stats_open:
    data_stats = json.load(data_stats_open)
seq_length = data_stats['seq_length']
target_length = data_stats['target_length']
hic_diags =  data_stats['diagonal_offset']
target_crop = data_stats['crop_bp'] // data_stats['pool_width']
target_length1 = data_stats['seq_length'] // data_stats['pool_width']

Make predictions for saved tfrecords


In [5]:
### load data ###

# note: run %%bash get_data.sh 
# if you have not already downloaded the data

sequences = pd.read_csv(data_dir+'sequences.bed',sep='\t',  names=['chr','start','stop','type'])
seqs_per_tf_default = 256
test_tr_num = 0
sequences_test = sequences.iloc[  sequences['type'].values=='test']
sequences_test.reset_index(inplace=True, drop=True)

tfr_pattern_path = (data_dir +'tfrecords/%s' % ('test-*.tfr' ) )
test_data = dataset.SeqDataset(tfr_pattern_path,
                               seq_length=seq_length,
                               target_length=target_length,
                               batch_size=8)

# test_targets is a float array with shape 
# [#regions, #pixels, target #target datasets]
# representing log(obs/exp)data, where #pixels 
# corresponds to the number of entries in the flattened
# upper-triangular representation of the matrix
test_targets = test_data.numpy(return_inputs=False, return_outputs=True)

# test_inputs are 1-hot encoded arrays with shape
# [#regions, 2^20 bp, 4 nucleotides datasets]
test_inputs = test_data.numpy(return_inputs=True, return_outputs=False)


./data/tfrecords/test-*.tfr has 413 sequences with 5/5 targets

In [6]:
### for converting from flattened upper-triangluar vector to symmetric matrix  ###

def from_upper_triu(vector_repr, matrix_len, num_diags):
    z = np.zeros((matrix_len,matrix_len))
    triu_tup = np.triu_indices(matrix_len,num_diags)
    z[triu_tup] = vector_repr
    for i in range(-num_diags+1,num_diags):
        set_diag(z, np.nan, i)
    return z + z.T

target_length1_cropped = target_length1 - 2*target_crop
print('flattened representation length:', target_length) 
print('symmetrix matrix size:', '('+str(target_length1_cropped)+','+str(target_length1_cropped)+')')


flattened representation length: 99681
symmetrix matrix size: (448,448)

In [7]:
fig2_examples = [   'chr12:115163136-116211712',
                    'chr11:75429888-76478464',
                    'chr15:63281152-64329728' ]
fig2_inds = []
for seq in fig2_examples:
    print(seq)
    chrm,start,stop = seq.split(':')[0], seq.split(':')[1].split('-')[0], seq.split(':')[1].split('-')[1]
    test_ind = np.where( (sequences_test['chr'].values== chrm) *
                         (sequences_test['start'].values== int(start))*
                         (sequences_test['stop'].values==  int(stop ))  )[0][0]
    fig2_inds.append(test_ind)
fig2_inds


chr12:115163136-116211712
chr11:75429888-76478464
chr15:63281152-64329728
Out[7]:
[85, 402, 393]

In [8]:
### make predictions and plot the three examples above ###

target_index = 0 # HFF 

for test_index in fig2_inds:
    chrm, seq_start, seq_end = sequences_test.iloc[test_index][0:3]
    myseq_str = chrm+':'+str(seq_start)+'-'+str(seq_end)
    print(' ')
    print(myseq_str)
    
    test_target = test_targets[test_index:test_index+1,:,:]
    test_pred = seqnn_model.model.predict(test_inputs[test_index:test_index+1,:,:])

    plt.figure(figsize=(8,4))
    target_index = 0
    vmin=-2; vmax=2

    # plot pred
    plt.subplot(121) 
    mat = from_upper_triu(test_pred[:,:,target_index], target_length1_cropped, hic_diags)
    im = plt.matshow(mat, fignum=False, cmap= 'RdBu_r', vmax=vmax, vmin=vmin)
    plt.colorbar(im, fraction=.04, pad = 0.05, ticks=[-2,-1, 0, 1,2]);
    plt.title('pred-'+str(hic_num_to_name_dict[target_index]),y=1.15 )
    plt.ylabel(myseq_str)

    # plot target 
    plt.subplot(122) 
    mat = from_upper_triu(test_target[:,:,target_index], target_length1_cropped, hic_diags)
    im = plt.matshow(mat, fignum=False, cmap= 'RdBu_r', vmax=vmax, vmin=vmin)
    plt.colorbar(im, fraction=.04, pad = 0.05, ticks=[-2,-1, 0, 1,2]);
    plt.title( 'target-'+str(hic_num_to_name_dict[target_index]),y=1.15)

    plt.tight_layout()
    plt.show()


 
chr12:115163136-116211712
 
chr11:75429888-76478464
 
chr15:63281152-64329728

Make a prediction from sequence


In [9]:
### make a prediction from sequence ###

import subprocess
if not os.path.isfile('./data/hg38.ml.fa'):
    print('downloading hg38.ml.fa')
    subprocess.call('curl -o ./data/hg38.ml.fa.gz https://storage.googleapis.com/basenji_barnyard/hg38.ml.fa.gz', shell=True)
    subprocess.call('gunzip ./data/hg38.ml.fa.gz', shell=True)

import pysam
fasta_open = pysam.Fastafile('./data/hg38.ml.fa')

In [10]:
# this example uses the sequence for the test set region
# with the corresponding test_index, but
# predictions can be made for any DNA sequence of length = seq_length = 2^20

chrm, seq_start, seq_end = sequences_test.iloc[test_index][0:3] 
seq = fasta_open.fetch( chrm, seq_start, seq_end ).upper()
if len(seq) != seq_length: raise ValueError('len(seq) != seq_length')

# seq_1hot is a np.array with shape [2^20 bp, 4 nucleotides]
# representing 1-hot encoded DNA sequence
seq_1hot = dna_io.dna_1hot(seq)

In [11]:
# expand input dimensions, as model accepts arrays of size [#regions,2^20bp, 4]
test_pred_from_seq = seqnn_model.model.predict(np.expand_dims(seq_1hot,0))

In [12]:
# plot pred

plt.figure(figsize=(8,4))
target_index = 0
vmin=-2; vmax=2

#transform from flattened representation to symmetric matrix representation
mat = from_upper_triu(test_pred_from_seq[:,:,target_index], target_length1_cropped, hic_diags)

plt.subplot(121) 
im = plt.matshow(mat, fignum=False, cmap= 'RdBu_r', vmax=vmax, vmin=vmin)
plt.colorbar(im, fraction=.04, pad = 0.05, ticks=[-2,-1, 0, 1,2]);
plt.title('pred-'+str(hic_num_to_name_dict[target_index]),y=1.15 )
plt.ylabel(myseq_str)

# plot target 
plt.subplot(122) 
mat = from_upper_triu(test_target[:,:,target_index], target_length1_cropped, hic_diags)
im = plt.matshow(mat, fignum=False, cmap= 'RdBu_r', vmax=vmax, vmin=vmin)
plt.colorbar(im, fraction=.04, pad = 0.05, ticks=[-2,-1, 0, 1,2]);
plt.title( 'target-'+str(hic_num_to_name_dict[target_index]),y=1.15)

plt.tight_layout()
plt.show()