Concise

Concise extends keras (https://keras.io) by providing additional layers, intializers, regularizers, metrics and preprocessors, suited for modelling genomic cis-regulatory elements.


In [57]:
## Concise extensions of keras:

import concise
import concise.layers as cl
import concise.initializers as ci
import concise.regularizers as cr
import concise.metrics as cm
from concise.preprocessing import encodeDNA, encodeSplines
from concise.data import attract, encode

## layers:
cl.ConvDNA
cl.ConvDNAQuantitySplines
cr.GAMRegularizer
cl.GAMSmooth
cl.GlobalSumPooling1D
cl.InputDNA
cl.InputSplines

## initializers:
ci.PSSMKernelInitializer
ci.PSSMBiasInitializer

## regularizers
cr.GAMRegularizer

## metrics
cm.var_explained

## Preprocessing
encodeDNA
encodeSplines

## Known motifs
attract
encode


Out[57]:
<module 'concise.data.encode' from '/home/avsec/bin/anaconda3/lib/python3.5/site-packages/concise/data/encode.py'>

In [63]:
attract.get_pwm_list([519])[0].plotPWM()


It also implements a PWM class, used by PWM*Initializers


In [10]:
from concise.utils import PWM

PWM


Out[10]:
concise.utils.pwm.PWM

Simulated data case study

In this notebook, we will replicate the results from Plositional_effect/Simulation/01_fixed_seq_len.html using concise models. Please have a look at that notebook first.


In [1]:
# Used additional packages
%matplotlib inline

import matplotlib.pyplot as plt

import pandas as pd
import numpy as np

## required keras modules
from keras.models import Model, load_model
import keras.layers as kl
import keras.optimizers as ko


Using TensorFlow backend.

Single motif case

Prepare the data


In [59]:
# Load the data
data_dir = "../data/"
dt = pd.read_csv(data_dir + "/01-fixed_seq_len-1.csv")
motifs = ["TTAATGA"]

In [13]:
dt.head


Out[13]:
<bound method NDFrame.head of             y                       seq
0      5.4763  TTCTGGGAGGCGTCCTTACGT...
1      0.2309  AAGTATGATCATACGCACCGT...
2      6.0881  TATGCAGGATCGAAATACTCT...
...       ...                       ...
2997   0.2773  GAGGGGTCAGCATGCGTATGA...
2998  15.6793  CTACCAAATCCAAACTTCAGC...
2999  13.0256  CTGACCATACGTTTCACACCG...

[3000 rows x 2 columns]>

In [30]:
x_seq = encodeDNA(dt["seq"])
y = dt.y.as_matrix()

In [32]:
x_seq.shape # (n_samples, seq_length, n_bases)


Out[32]:
(3000, 500, 4)

Build the model

Concise is a thin wrapper around keras. To know more about keras, read the documentation: https://keras.io/.

In this tutorial, I'll be using the functional API of keras: https://keras.io/getting-started/functional-api-guide/. Feel free to use Concise with the sequential models.


In [33]:
## Parameters
seq_length = x_seq.shape[1]
## Motifs used to initialize the model
pwm_list = [PWM.from_consensus(motif) for motif in motifs]
motif_width = 7
pwm_list


Out[33]:
[PWM(name: None, consensus: TTAATGA)]

In [60]:
pwm_list[0].plotPWM()



In [34]:
np.random.seed(42)

# specify the input shape
input_dna = cl.InputDNA(seq_length)

# convolutional layer with filters initialized on a PWM
x = cl.ConvDNA(filters=1, 
               kernel_size=motif_width, ## motif width
               activation="relu", 
               kernel_initializer=ci.PSSMKernelInitializer(pwm_list),
               bias_initializer=ci.PSSMBiasInitializer(pwm_list,kernel_size=motif_width, mean_max_scale=1)
               ## mean_max_scale of 1 means that only consensus sequence gets score larger than 0
              )(input_dna)

## Smoothing layer - positional-dependent effect
# output = input * (1+ pos_effect)
x = cl.GAMSmooth(n_bases=10, l2_smooth=1e-3, l2=0)(x)
x = cl.GlobalSumPooling1D()(x)
x = kl.Dense(units=1,activation="linear")(x)
model = Model(inputs=input_dna, outputs=x)

# compile the model
model.compile(optimizer="adam", loss="mse", metrics=[cm.var_explained])

Train the model


In [35]:
## TODO - create a callback
from keras.callbacks import EarlyStopping

model.fit(x=x_seq, y=y, epochs=30, verbose=2,
          callbacks=[EarlyStopping(patience=5)],
          validation_split=.2
         )


Train on 2400 samples, validate on 600 samples
Epoch 1/30
0s - loss: 14.0230 - var_explained: 0.0429 - val_loss: 14.0957 - val_var_explained: 0.0963
Epoch 2/30
0s - loss: 11.4912 - var_explained: 0.1887 - val_loss: 10.7865 - val_var_explained: 0.2822
Epoch 3/30
0s - loss: 7.8519 - var_explained: 0.4124 - val_loss: 6.8443 - val_var_explained: 0.5144
Epoch 4/30
0s - loss: 4.5005 - var_explained: 0.6268 - val_loss: 3.9261 - val_var_explained: 0.6938
Epoch 5/30
0s - loss: 2.7617 - var_explained: 0.7420 - val_loss: 2.8493 - val_var_explained: 0.7654
Epoch 6/30
0s - loss: 2.3030 - var_explained: 0.7796 - val_loss: 2.5620 - val_var_explained: 0.7873
Epoch 7/30
0s - loss: 2.1777 - var_explained: 0.7912 - val_loss: 2.4332 - val_var_explained: 0.7977
Epoch 8/30
0s - loss: 2.0947 - var_explained: 0.8030 - val_loss: 2.3272 - val_var_explained: 0.8061
Epoch 9/30
0s - loss: 2.0244 - var_explained: 0.8094 - val_loss: 2.2341 - val_var_explained: 0.8134
Epoch 10/30
0s - loss: 1.9623 - var_explained: 0.8192 - val_loss: 2.1599 - val_var_explained: 0.8189
Epoch 11/30
0s - loss: 1.9079 - var_explained: 0.8197 - val_loss: 2.0903 - val_var_explained: 0.8242
Epoch 12/30
0s - loss: 1.8595 - var_explained: 0.8197 - val_loss: 2.0290 - val_var_explained: 0.8287
Epoch 13/30
0s - loss: 1.8182 - var_explained: 0.8237 - val_loss: 1.9727 - val_var_explained: 0.8329
Epoch 14/30
0s - loss: 1.7825 - var_explained: 0.8286 - val_loss: 1.9308 - val_var_explained: 0.8358
Epoch 15/30
0s - loss: 1.7531 - var_explained: 0.8292 - val_loss: 1.8878 - val_var_explained: 0.8389
Epoch 16/30
0s - loss: 1.7280 - var_explained: 0.8327 - val_loss: 1.8522 - val_var_explained: 0.8414
Epoch 17/30
0s - loss: 1.7052 - var_explained: 0.8346 - val_loss: 1.8149 - val_var_explained: 0.8441
Epoch 18/30
0s - loss: 1.6899 - var_explained: 0.8337 - val_loss: 1.7927 - val_var_explained: 0.8456
Epoch 19/30
0s - loss: 1.6716 - var_explained: 0.8338 - val_loss: 1.7564 - val_var_explained: 0.8482
Epoch 20/30
0s - loss: 1.6569 - var_explained: 0.8358 - val_loss: 1.7357 - val_var_explained: 0.8497
Epoch 21/30
0s - loss: 1.6461 - var_explained: 0.8295 - val_loss: 1.7164 - val_var_explained: 0.8510
Epoch 22/30
0s - loss: 1.6366 - var_explained: 0.8428 - val_loss: 1.7034 - val_var_explained: 0.8518
Epoch 23/30
0s - loss: 1.6292 - var_explained: 0.8396 - val_loss: 1.6882 - val_var_explained: 0.8529
Epoch 24/30
0s - loss: 1.6201 - var_explained: 0.8416 - val_loss: 1.6696 - val_var_explained: 0.8541
Epoch 25/30
0s - loss: 1.6133 - var_explained: 0.8404 - val_loss: 1.6552 - val_var_explained: 0.8551
Epoch 26/30
0s - loss: 1.6084 - var_explained: 0.8381 - val_loss: 1.6505 - val_var_explained: 0.8554
Epoch 27/30
0s - loss: 1.6038 - var_explained: 0.8430 - val_loss: 1.6441 - val_var_explained: 0.8559
Epoch 28/30
0s - loss: 1.5980 - var_explained: 0.8401 - val_loss: 1.6348 - val_var_explained: 0.8565
Epoch 29/30
0s - loss: 1.5924 - var_explained: 0.8400 - val_loss: 1.6291 - val_var_explained: 0.8569
Epoch 30/30
0s - loss: 1.5932 - var_explained: 0.8335 - val_loss: 1.6308 - val_var_explained: 0.8569
Out[35]:
<keras.callbacks.History at 0x7f75013eba58>

Save and load the model

Since concise is fully compatible with keras, we can save and load the entire model to the hdf5 file.


In [36]:
model.save("/tmp/model.h5") ## requires h5py pacakge, pip install h5py

In [37]:
%ls -la  /tmp/model*


-rw-rw-r-- 1 avsec avsec 23992 Mai 10 13:29 /tmp/model.h5

In [38]:
model2 = load_model("/tmp/model.h5")
model2


Out[38]:
<keras.engine.training.Model at 0x7f7501333ac8>

Interpret the model

Predictions


In [39]:
var_expl_history = model.history.history['val_var_explained']
plt.plot(var_expl_history)
plt.ylabel('Variance explained')
plt.xlabel('Epoch')
plt.title("Loss history")


Out[39]:
<matplotlib.text.Text at 0x7f74f97b79b0>

In [41]:
y_pred = model.predict(x_seq)
plt.scatter(y_pred, y)
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("True vs predicted")


Out[41]:
<matplotlib.text.Text at 0x7f74f8687048>

Plot is the same as in the original motifp report.

Weights


In [42]:
# layers in the model
model.layers


Out[42]:
[<keras.engine.topology.InputLayer at 0x7f7501600ba8>,
 <concise.layers.ConvDNA at 0x7f7501600d68>,
 <concise.layers.GAMSmooth at 0x7f7501600b70>,
 <concise.layers.GlobalSumPooling1D at 0x7f7501600e10>,
 <keras.layers.core.Dense at 0x7f75016c57f0>]

In [43]:
## Convenience functions in layers
gam_layer = model.layers[2]
gam_layer.plot()
plt.title("Positional effect")


Out[43]:
<matplotlib.text.Text at 0x7f74f86048d0>

In [44]:
# Compare the curve to the theoretical
positions = gam_layer.positional_effect()["positions"]
pos_effect = gam_layer.positional_effect()["positional_effect"]

from scipy.stats import norm
pef = lambda x: 0.3*norm.pdf(x, 0.2, 0.1) + 0.05*np.sin(15*x) + 0.8
pos_effect_theoretical = pef(positions / positions.max())

# plot
plt.plot(positions, pos_effect, label="infered")
plt.plot(positions, pos_effect_theoretical, label="theoretical")
plt.ylabel('Positional effect')
plt.xlabel('Position')
plt.title("Positional effect")
plt.legend()


Out[44]:
<matplotlib.legend.Legend at 0x7f74f8596390>

Qualitatively, the curves are the same, quantitatively, they differ as the scale is modulated by other parameters in the model. Plot is similar to the original motifp report.


In [48]:
# plot the filters
model.layers[1].plot_weights(plot_type="motif_raw")
model.layers[1].plot_weights(plot_type="motif_pwm")
model.layers[1].plot_weights(plot_type="motif_pwm_info")
model.layers[1].plot_weights(plot_type="heatmap")


Out[48]:

Model with two motifs


In [49]:
dt = pd.read_csv(data_dir + "/01-fixed_seq_len-2.csv")
motifs = ["TTAATGA", "TATTTAT"]

In [51]:
## Parameters
seq_length = x_seq.shape[1]
## Motifs used to initialize the model
pwm_list = [PWM.from_consensus(motif) for motif in motifs]
motif_width = 7
pwm_list


Out[51]:
[PWM(name: None, consensus: TTAATGA), PWM(name: None, consensus: TATTTAT)]

In [52]:
np.random.seed(1)
input_dna = cl.InputDNA(seq_length)
# convolutional layer with filters initialized on a PWM
x = cl.ConvDNA(filters=2, 
               kernel_size=motif_width, ## motif width
               activation="relu", 
               kernel_initializer=ci.PWMKernelInitializer(pwm_list, stddev=0.0),
               bias_initializer=ci.PWMBiasInitializer(pwm_list,kernel_size=motif_width, mean_max_scale=0.999),
               ## mean_max_scale of 1 means that only consensus sequence gets score larger than 0
               trainable=False,
              )(input_dna)
## Smoothing layer - positional-dependent effect
x = cl.GAMSmooth(n_bases=10, l2_smooth=1e-6, l2=0)(x)
x = cl.GlobalSumPooling1D()(x)
x = kl.Dense(units=1,activation="linear")(x)
model = Model(inputs=input_dna, outputs=x)

# compile the model
model.compile(optimizer=ko.Adam(lr=0.01), loss="mse", metrics=[cm.var_explained])

In [53]:
x_seq = encodeDNA(dt["seq"])
y = dt.y.as_matrix()

In [54]:
model.fit(x=x_seq, y=y, epochs=100, verbose = 2,
          callbacks=[EarlyStopping(patience=5)],
          validation_split=.2
         )


Train on 2400 samples, validate on 600 samples
Epoch 1/100
0s - loss: 10.5460 - var_explained: 0.0027 - val_loss: 8.9194 - val_var_explained: 0.0053
Epoch 2/100
0s - loss: 8.7751 - var_explained: 0.0097 - val_loss: 7.7195 - val_var_explained: 0.0151
Epoch 3/100
0s - loss: 7.9169 - var_explained: 0.0221 - val_loss: 7.1681 - val_var_explained: 0.0304
Epoch 4/100
0s - loss: 7.5098 - var_explained: 0.0397 - val_loss: 6.9053 - val_var_explained: 0.0501
Epoch 5/100
0s - loss: 7.2673 - var_explained: 0.0611 - val_loss: 6.7053 - val_var_explained: 0.0740
Epoch 6/100
0s - loss: 7.0555 - var_explained: 0.0872 - val_loss: 6.4965 - val_var_explained: 0.1022
Epoch 7/100
0s - loss: 6.8280 - var_explained: 0.1157 - val_loss: 6.2714 - val_var_explained: 0.1330
Epoch 8/100
0s - loss: 6.5806 - var_explained: 0.1504 - val_loss: 6.0207 - val_var_explained: 0.1666
Epoch 9/100
0s - loss: 6.3105 - var_explained: 0.1834 - val_loss: 5.7512 - val_var_explained: 0.2025
Epoch 10/100
0s - loss: 6.0217 - var_explained: 0.2218 - val_loss: 5.4682 - val_var_explained: 0.2406
Epoch 11/100
0s - loss: 5.7214 - var_explained: 0.2606 - val_loss: 5.1860 - val_var_explained: 0.2781
Epoch 12/100
0s - loss: 5.4198 - var_explained: 0.2976 - val_loss: 4.8931 - val_var_explained: 0.3167
Epoch 13/100
0s - loss: 5.1191 - var_explained: 0.3360 - val_loss: 4.6183 - val_var_explained: 0.3525
Epoch 14/100
0s - loss: 4.8268 - var_explained: 0.3739 - val_loss: 4.3536 - val_var_explained: 0.3881
Epoch 15/100
0s - loss: 4.5485 - var_explained: 0.4124 - val_loss: 4.0955 - val_var_explained: 0.4213
Epoch 16/100
0s - loss: 4.2829 - var_explained: 0.4421 - val_loss: 3.8627 - val_var_explained: 0.4520
Epoch 17/100
0s - loss: 4.0366 - var_explained: 0.4735 - val_loss: 3.6469 - val_var_explained: 0.4797
Epoch 18/100
0s - loss: 3.8103 - var_explained: 0.5024 - val_loss: 3.4548 - val_var_explained: 0.5040
Epoch 19/100
0s - loss: 3.6060 - var_explained: 0.5299 - val_loss: 3.2824 - val_var_explained: 0.5260
Epoch 20/100
0s - loss: 3.4212 - var_explained: 0.5460 - val_loss: 3.1342 - val_var_explained: 0.5448
Epoch 21/100
0s - loss: 3.2589 - var_explained: 0.5694 - val_loss: 3.0065 - val_var_explained: 0.5615
Epoch 22/100
0s - loss: 3.1164 - var_explained: 0.5962 - val_loss: 2.8984 - val_var_explained: 0.5743
Epoch 23/100
0s - loss: 2.9957 - var_explained: 0.6038 - val_loss: 2.8063 - val_var_explained: 0.5854
Epoch 24/100
0s - loss: 2.8915 - var_explained: 0.6196 - val_loss: 2.7302 - val_var_explained: 0.5937
Epoch 25/100
0s - loss: 2.8054 - var_explained: 0.6249 - val_loss: 2.6728 - val_var_explained: 0.6010
Epoch 26/100
0s - loss: 2.7316 - var_explained: 0.6254 - val_loss: 2.6252 - val_var_explained: 0.6066
Epoch 27/100
0s - loss: 2.6709 - var_explained: 0.6372 - val_loss: 2.5864 - val_var_explained: 0.6106
Epoch 28/100
0s - loss: 2.6193 - var_explained: 0.6515 - val_loss: 2.5579 - val_var_explained: 0.6138
Epoch 29/100
0s - loss: 2.5758 - var_explained: 0.6524 - val_loss: 2.5312 - val_var_explained: 0.6163
Epoch 30/100
0s - loss: 2.5427 - var_explained: 0.6600 - val_loss: 2.5141 - val_var_explained: 0.6182
Epoch 31/100
0s - loss: 2.5117 - var_explained: 0.6595 - val_loss: 2.4985 - val_var_explained: 0.6197
Epoch 32/100
0s - loss: 2.4872 - var_explained: 0.6663 - val_loss: 2.4834 - val_var_explained: 0.6209
Epoch 33/100
0s - loss: 2.4662 - var_explained: 0.6632 - val_loss: 2.4793 - val_var_explained: 0.6219
Epoch 34/100
0s - loss: 2.4478 - var_explained: 0.6667 - val_loss: 2.4670 - val_var_explained: 0.6229
Epoch 35/100
0s - loss: 2.4318 - var_explained: 0.6714 - val_loss: 2.4589 - val_var_explained: 0.6235
Epoch 36/100
0s - loss: 2.4183 - var_explained: 0.6687 - val_loss: 2.4481 - val_var_explained: 0.6243
Epoch 37/100
0s - loss: 2.4060 - var_explained: 0.6778 - val_loss: 2.4480 - val_var_explained: 0.6249
Epoch 38/100
0s - loss: 2.3949 - var_explained: 0.6708 - val_loss: 2.4382 - val_var_explained: 0.6255
Epoch 39/100
0s - loss: 2.3862 - var_explained: 0.6752 - val_loss: 2.4300 - val_var_explained: 0.6262
Epoch 40/100
0s - loss: 2.3768 - var_explained: 0.6757 - val_loss: 2.4247 - val_var_explained: 0.6268
Epoch 41/100
0s - loss: 2.3675 - var_explained: 0.6736 - val_loss: 2.4282 - val_var_explained: 0.6272
Epoch 42/100
0s - loss: 2.3613 - var_explained: 0.6678 - val_loss: 2.4196 - val_var_explained: 0.6277
Epoch 43/100
0s - loss: 2.3544 - var_explained: 0.6777 - val_loss: 2.4131 - val_var_explained: 0.6283
Epoch 44/100
0s - loss: 2.3491 - var_explained: 0.6742 - val_loss: 2.4061 - val_var_explained: 0.6288
Epoch 45/100
0s - loss: 2.3420 - var_explained: 0.6765 - val_loss: 2.4072 - val_var_explained: 0.6291
Epoch 46/100
0s - loss: 2.3369 - var_explained: 0.6847 - val_loss: 2.4012 - val_var_explained: 0.6295
Epoch 47/100
0s - loss: 2.3326 - var_explained: 0.6903 - val_loss: 2.4051 - val_var_explained: 0.6297
Epoch 48/100
0s - loss: 2.3291 - var_explained: 0.6826 - val_loss: 2.3962 - val_var_explained: 0.6302
Epoch 49/100
0s - loss: 2.3253 - var_explained: 0.6788 - val_loss: 2.3972 - val_var_explained: 0.6302
Epoch 50/100
0s - loss: 2.3211 - var_explained: 0.6734 - val_loss: 2.3919 - val_var_explained: 0.6311
Epoch 51/100
0s - loss: 2.3182 - var_explained: 0.6860 - val_loss: 2.3850 - val_var_explained: 0.6313
Epoch 52/100
0s - loss: 2.3158 - var_explained: 0.6826 - val_loss: 2.3852 - val_var_explained: 0.6312
Epoch 53/100
0s - loss: 2.3116 - var_explained: 0.6805 - val_loss: 2.3874 - val_var_explained: 0.6317
Epoch 54/100
0s - loss: 2.3097 - var_explained: 0.6857 - val_loss: 2.3808 - val_var_explained: 0.6319
Epoch 55/100
0s - loss: 2.3066 - var_explained: 0.6817 - val_loss: 2.3835 - val_var_explained: 0.6321
Epoch 56/100
0s - loss: 2.3061 - var_explained: 0.6844 - val_loss: 2.3783 - val_var_explained: 0.6322
Epoch 57/100
0s - loss: 2.3043 - var_explained: 0.6794 - val_loss: 2.3756 - val_var_explained: 0.6327
Epoch 58/100
0s - loss: 2.3015 - var_explained: 0.6856 - val_loss: 2.3763 - val_var_explained: 0.6328
Epoch 59/100
0s - loss: 2.2999 - var_explained: 0.6901 - val_loss: 2.3786 - val_var_explained: 0.6330
Epoch 60/100
0s - loss: 2.2991 - var_explained: 0.6759 - val_loss: 2.3799 - val_var_explained: 0.6329
Epoch 61/100
0s - loss: 2.2973 - var_explained: 0.6900 - val_loss: 2.3726 - val_var_explained: 0.6332
Epoch 62/100
0s - loss: 2.2952 - var_explained: 0.6891 - val_loss: 2.3723 - val_var_explained: 0.6333
Epoch 63/100
0s - loss: 2.2949 - var_explained: 0.6794 - val_loss: 2.3706 - val_var_explained: 0.6334
Epoch 64/100
0s - loss: 2.2950 - var_explained: 0.6896 - val_loss: 2.3793 - val_var_explained: 0.6333
Epoch 65/100
0s - loss: 2.2934 - var_explained: 0.6858 - val_loss: 2.3719 - val_var_explained: 0.6335
Epoch 66/100
0s - loss: 2.2915 - var_explained: 0.6866 - val_loss: 2.3746 - val_var_explained: 0.6335
Epoch 67/100
0s - loss: 2.2915 - var_explained: 0.6864 - val_loss: 2.3732 - val_var_explained: 0.6336
Epoch 68/100
0s - loss: 2.2904 - var_explained: 0.6779 - val_loss: 2.3677 - val_var_explained: 0.6338
Epoch 69/100
0s - loss: 2.2903 - var_explained: 0.6927 - val_loss: 2.3742 - val_var_explained: 0.6339
Epoch 70/100
0s - loss: 2.2915 - var_explained: 0.6916 - val_loss: 2.3649 - val_var_explained: 0.6339
Epoch 71/100
0s - loss: 2.2885 - var_explained: 0.6924 - val_loss: 2.3668 - val_var_explained: 0.6339
Epoch 72/100
0s - loss: 2.2871 - var_explained: 0.6879 - val_loss: 2.3653 - val_var_explained: 0.6341
Epoch 73/100
0s - loss: 2.2868 - var_explained: 0.6835 - val_loss: 2.3611 - val_var_explained: 0.6343
Epoch 74/100
0s - loss: 2.2887 - var_explained: 0.6851 - val_loss: 2.3625 - val_var_explained: 0.6341
Epoch 75/100
0s - loss: 2.2860 - var_explained: 0.6916 - val_loss: 2.3700 - val_var_explained: 0.6342
Epoch 76/100
0s - loss: 2.2866 - var_explained: 0.6874 - val_loss: 2.3639 - val_var_explained: 0.6342
Epoch 77/100
0s - loss: 2.2856 - var_explained: 0.6900 - val_loss: 2.3624 - val_var_explained: 0.6343
Epoch 78/100
0s - loss: 2.2849 - var_explained: 0.6834 - val_loss: 2.3644 - val_var_explained: 0.6339
Epoch 79/100
0s - loss: 2.2843 - var_explained: 0.6813 - val_loss: 2.3663 - val_var_explained: 0.6341
Out[54]:
<keras.callbacks.History at 0x7f74f84f7ba8>

In [55]:
y_pred = model.predict(x_seq)
plt.scatter(y_pred, y)
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("True vs predicted")


Out[55]:
<matplotlib.text.Text at 0x7f74f822ceb8>

In [56]:
## TODO - update to the new synthax
gam_layer = model.layers[2]

positions = gam_layer.positional_effect()["positions"]
pos_effect = gam_layer.positional_effect()["positional_effect"]

## Theoretical plot - from the original simulation data
from scipy.stats import norm
# https://docs.scipy.org/doc/scipy-0.16.1/reference/generated/scipy.stats.norm.html#scipy.stats.norm
pef1 = lambda x: 0.3*norm.pdf(x, 0.2, 0.1) + 0.05*np.sin(15*x) + 0.8
pos_effect_theoretical1 = pef1(positions / positions.max())
pef2 = lambda x: 0.3*norm.pdf(x, 0.35, 0.1) + 0.05*np.sin(15*x) + 0.8
pos_effect_theoretical2 = pef2(positions / positions.max())


w_motifs = model.get_weights()[-2]
b = model.get_weights()[-1]

## Create a new plot
pos_effect_calibrated = (pos_effect / np.transpose(w_motifs))/ 0.8
plt.plot(positions, pos_effect_calibrated[:,0], label="infered " + motifs[0])
plt.plot(positions, pos_effect_calibrated[:,1], label="infered " + motifs[1])
plt.plot(positions, pos_effect_theoretical1, label="theoretical " + motifs[0])
plt.plot(positions, pos_effect_theoretical2, label="theoretical " + motifs[1])
plt.ylabel('Positional effect')
plt.xlabel('Position')
plt.title("Positional effect")
plt.legend()


Out[56]:
<matplotlib.legend.Legend at 0x7f74ca854be0>