In [42]:
import keras
import keras.preprocessing
import keras.preprocessing.text
import keras.preprocessing.sequence

import mhcflurry
import skbio

In [43]:
mhc_sequences = {}
for name, seq in skbio.parse_fasta("mhc_seqs.fasta"):
    if name[-1].isalpha():
        # skip null/questionable alleles
        continue
    four_digit_name = ":".join(name.split(":")[:2])
    if four_digit_name in mhc_sequences:
        old_seq = mhc_sequences[four_digit_name]
        if len(old_seq) >= len(seq):
            continue
    mhc_sequences[four_digit_name] = seq
print("Loaded sequences of %d MHC alleles" % len(mhc_sequences))


Loaded sequences of 2165 MHC alleles

In [44]:
df = pd.read_csv("combined_human_class1_dataset.csv")

In [45]:
df


Out[45]:
species mhc peptide peptide_length meas
0 cow BoLA-HD6 ALFYKDGKL 9 1.000000
1 cow BoLA-HD6 ALYEKKLAL 9 1.000000
2 cow BoLA-HD6 AMKDRFQPL 9 4.521706
3 cow BoLA-HD6 AQRELFFTL 9 1.000000
4 cow BoLA-HD6 FMKVKFEAL 9 1.576747
5 cow BoLA-HD6 FQHERLGQF 9 1.000000
6 cow BoLA-HD6 FQRAIMNAM 9 1.000000
7 cow BoLA-HD6 GQFLSFASL 9 1.000000
8 cow BoLA-HD6 GQFNRYAAM 9 1.000000
9 cow BoLA-HD6 ILNHKFCNL 9 1.000000
10 cow BoLA-HD6 IMALKQAGL 9 1.000000
11 cow BoLA-HD6 IQYDRRSFF 9 1.000000
12 cow BoLA-HD6 IQYVIRAQL 9 1.000000
13 cow BoLA-HD6 KMNAKAATL 9 1.000000
14 cow BoLA-HD6 KQFCLSILL 9 1.000000
15 cow BoLA-HD6 KQIPIWLPL 9 1.000000
16 cow BoLA-HD6 KQLELFWVI 9 1.000000
17 cow BoLA-HD6 KQRCALPSL 9 1.000000
18 cow BoLA-HD6 LLRRRPYPL 9 4.865762
19 cow BoLA-HD6 LQLARDGMF 9 1.000000
20 cow BoLA-HD6 LQYAIRSVF 9 3.659306
21 cow BoLA-HD6 MLKRRGFHL 9 5.000000
22 cow BoLA-HD6 MSRRRIYVL 9 1.000000
23 cow BoLA-HD6 RARQKGCTL 9 4.797945
24 cow BoLA-HD6 RARRHLAAL 9 1.000000
25 cow BoLA-HD6 RIHDIAVQL 9 1.000000
26 cow BoLA-HD6 RIYRKGNPL 9 1.000000
27 cow BoLA-HD6 RLARRARNI 9 1.000000
28 cow BoLA-HD6 RLKHIFLIF 9 2.796509
29 cow BoLA-HD6 RLRRRRHPL 9 1.000000
... ... ... ... ... ...
182072 human HLA-B*51:01 FPVRPQVPC 9 1300.000000
182073 human HLA-B*51:01 FPVRPQVPD 9 1300.000000
182074 human HLA-B*51:01 FPVRPQVPF 9 1300.000000
182075 human HLA-B*51:01 FPVRPQVPH 9 1300.000000
182076 human HLA-B*51:01 FPVRPQVPK 9 1300.000000
182077 human HLA-B*51:01 FPVRPQVPM 9 1300.000000
182078 human HLA-B*51:01 FPVRPQVPN 9 1300.000000
182079 human HLA-B*51:01 FPVRPQVPQ 9 1300.000000
182080 human HLA-B*51:01 FPVRPQVPT 9 1300.000000
182081 human HLA-B*51:01 FPVRPQVPW 9 1300.000000
182082 human HLA-B*51:01 FPVRPQVPY 9 1300.000000
182083 human HLA-B*51:01 FQVRPQVPL 9 1300.000000
182084 human HLA-B*51:01 FTVRPQVPL 9 1300.000000
182085 human HLA-B*53:01 FPVRPQVPA 9 1300.000000
182086 human HLA-B*53:01 FPVRPQVPC 9 1300.000000
182087 human HLA-B*53:01 FPVRPQVPD 9 1300.000000
182088 human HLA-B*53:01 FPVRPQVPH 9 1300.000000
182089 human HLA-B*53:01 FPVRPQVPK 9 1300.000000
182090 human HLA-B*53:01 FPVRPQVPN 9 1300.000000
182091 human HLA-B*53:01 FPVRPQVPQ 9 1300.000000
182092 human HLA-B*53:01 FPVRPQVPT 9 1300.000000
182093 human HLA-B*53:01 FPVRPQVPV 9 1300.000000
182094 human HLA-B*54:01 FPVRPQVPD 9 79.000000
182095 human HLA-B*54:01 FPVRPQVPF 9 79.000000
182096 human HLA-B*54:01 FPVRPQVPH 9 79.000000
182097 human HLA-B*54:01 FPVRPQVPK 9 79.000000
182098 human HLA-B*54:01 FPVRPQVPM 9 79.000000
182099 human HLA-B*54:01 FPVRPQVPQ 9 79.000000
182100 human HLA-B*54:01 FPVRPQVPW 9 79.000000
182101 human HLA-B*54:01 FPVRPQVPY 9 79.000000

182102 rows × 5 columns


In [47]:
# filter input data down to MHC alleles for which we have sequences
input_mhc_seqs = []
input_peptides = []
target_values = []

skipped = set([])
for _, row in df.iterrows():
    if not row.mhc.startswith("HLA-"):
        continue
    
    allele = row.mhc.replace("HLA-", "")
    if allele in mhc_sequences:
        input_mhc_seqs.append(mhc_sequences[allele])
        input_peptides.append(row.peptide)
        target_values.append(row.meas)
    else:
        skipped.add(allele)
        

for allele in skipped:
    print("Skipped %s" % allele)

print("Kept %d/%d pMHC inputs" % (
    len(input_mhc_seqs),
    len(df)))


Skipped B*40:01
Skipped B*27:03
Skipped B*40:02
Skipped B*44:02
Skipped B*73:01
Skipped B40
Skipped A26
Skipped C*06:02
Skipped B*15:42
Skipped C*07:01
Skipped B*83:01
Skipped B*46:01
Skipped B*08:02
Skipped B*14:02
Skipped B*15:03
Skipped B*18:01
Skipped B*35:08
Skipped B*54:01
Skipped Cw4
Skipped B60
Skipped E*01:01
Skipped B*53:01
Skipped B*08:01
Skipped B*42:01
Skipped B*51:01
Skipped A3
Skipped E*01:03
Skipped B*27:02
Skipped B*27:06
Skipped B*27:04
Skipped B*57:03
Skipped B*58:02
Skipped A2
Skipped B*45:01
Skipped C*12:03
Skipped C*15:02
Skipped B*81:01
Skipped A11
Skipped Cw1
Skipped A24
Skipped C*08:02
Skipped B*39:01
Skipped C*14:02
Skipped C*05:01
Skipped B*27:05
Skipped B62
Skipped B*35:03
Skipped B*45:06
Skipped A3/11
Skipped B*40:13
Skipped B*07:02
Skipped B*57:01
Skipped B44
Skipped B8
Skipped B*15:09
Skipped B7
Skipped C*07:02
Skipped B*58:01
Skipped B*42:02
Skipped B*52:01
Skipped C*03:03
Skipped B*38:01
Skipped B*27:01
Skipped B39
Skipped A1
Skipped B51
Skipped B*27:10
Skipped B*27:20
Skipped B*14:01
Skipped B*35:01
Skipped B*44:03
Skipped B*15:01
Skipped B*48:01
Skipped B*57:02
Skipped B*15:17
Skipped B58
Skipped B*37:01
Skipped B27
Skipped C*04:01
Skipped B*15:02
Skipped B*08:03
Kept 96795/182102 pMHC inputs

In [48]:
def peptides_to_indices(peptides):
    from mhcflurry.data_helpers import amino_acid_letter_indices
    index_sequences = []
    for peptide in peptides:
        if " " not in peptide:
            index_sequences.append([amino_acid_letter_indices[aa] for aa in peptide])
    return index_sequences

In [49]:
def onehot(peptides):
    from mhcflurry.data_helpers import amino_acid_letter_indices
    n = len(peptides)
    maxlen = max(len(peptide) for peptide in peptides)
    result = np.zeros((n, maxlen, 20), dtype=bool)
    for i, peptide in enumerate(peptides):
        if " " in peptide:
            continue
        for j, aa in enumerate(peptide):
            result[i, j, amino_acid_letter_indices[aa]] = 1
    return result

In [50]:
padded_peptides = onehot(input_peptides)
padded_mhc = onehot(input_mhc_seqs)

In [51]:
print(padded_peptides.shape)
print(padded_mhc.shape)


(96795, 30, 20)
(96795, 341, 20)

In [52]:
JZS1??

In [53]:
from keras.models import Graph 
from keras.layers.recurrent import JZS1
from keras.layers.core import Dense 

RNN_OUTPUT_DIM = 32
DENSE_OUTPUT_DIM = 32
N_DISTINCT_AMINO_ACIDS = 20

max_peptide_length = padded_peptides.shape[1]
max_mhc_length = padded_peptides.shape[1]

# graph model with two inputs and one output
graph = Graph()
graph.add_input(name='peptide', ndim=3)
graph.add_input(name='mhc', ndim=3)

# RNN for peptide sequences
graph.add_node(
    JZS1(
        input_dim=N_DISTINCT_AMINO_ACIDS, 
        output_dim=RNN_OUTPUT_DIM), 
    name="peptide_rnn", 
    input="peptide")

# RNN for MHC sequences 
graph.add_node(
    JZS1(
        input_dim=N_DISTINCT_AMINO_ACIDS, 
        output_dim=RNN_OUTPUT_DIM), 
    name="mhc_rnn", 
    input="mhc")

# concatenate last output of both RNNs and transform them into a lower dimensional space
graph.add_node(
    Dense(RNN_OUTPUT_DIM * 2, DENSE_OUTPUT_DIM, activation="relu"), 
    name="hidden", 
    merge_mode="concat", 
    inputs=("peptide_rnn", "mhc_rnn"))

graph.add_node(
    Dense(DENSE_OUTPUT_DIM, 1, activation="sigmoid"),
    name="affinity",
    input="hidden")

graph.add_output(name='affinity_output', input='affinity')

graph.compile('rmsprop', {'affinity_output':'mse'})

print(graph.get_config())


{   'input_config': [],
    'name': 'Graph',
    'node_config': [   {   'input': 'peptide',
                           'inputs': [],
                           'merge_mode': 'concat',
                           'name': 'peptide_rnn'},
                       {   'input': 'mhc',
                           'inputs': [],
                           'merge_mode': 'concat',
                           'name': 'mhc_rnn'},
                       {   'input': None,
                           'inputs': ('peptide_rnn', 'mhc_rnn'),
                           'merge_mode': 'concat',
                           'name': 'hidden'},
                       {   'input': 'hidden',
                           'inputs': [],
                           'merge_mode': 'concat',
                           'name': 'affinity'}],
    'nodes': [   {   'activation': 'tanh',
                     'init': 'glorot_uniform',
                     'inner_activation': 'sigmoid',
                     'inner_init': 'orthogonal',
                     'input_dim': 20,
                     'name': 'JZS1',
                     'output_dim': 32,
                     'return_sequences': False,
                     'truncate_gradient': -1},
                 {   'activation': 'tanh',
                     'init': 'glorot_uniform',
                     'inner_activation': 'sigmoid',
                     'inner_init': 'orthogonal',
                     'input_dim': 20,
                     'name': 'JZS1',
                     'output_dim': 32,
                     'return_sequences': False,
                     'truncate_gradient': -1},
                 {   'activation': 'relu',
                     'init': 'glorot_uniform',
                     'input_dim': 64,
                     'name': 'Dense',
                     'output_dim': 32},
                 {   'activation': 'sigmoid',
                     'init': 'glorot_uniform',
                     'input_dim': 32,
                     'name': 'Dense',
                     'output_dim': 1}],
    'output_config': [   {'dtype': 'float', 'name': 'peptide', 'ndim': 3},
                         {'dtype': 'float', 'name': 'mhc', 'ndim': 3},
                         {   'input': 'affinity',
                             'inputs': [],
                             'merge_mode': 'concat',
                             'name': 'affinity_output'}]}
{'output_config': [{'name': 'peptide', 'ndim': 3, 'dtype': 'float'}, {'name': 'mhc', 'ndim': 3, 'dtype': 'float'}, {'name': 'affinity_output', 'merge_mode': 'concat', 'input': 'affinity', 'inputs': []}], 'name': 'Graph', 'input_config': [], 'node_config': [{'name': 'peptide_rnn', 'merge_mode': 'concat', 'input': 'peptide', 'inputs': []}, {'name': 'mhc_rnn', 'merge_mode': 'concat', 'input': 'mhc', 'inputs': []}, {'name': 'hidden', 'merge_mode': 'concat', 'input': None, 'inputs': ('peptide_rnn', 'mhc_rnn')}, {'name': 'affinity', 'merge_mode': 'concat', 'input': 'hidden', 'inputs': []}], 'nodes': [{'output_dim': 32, 'inner_init': 'orthogonal', 'return_sequences': False, 'activation': 'tanh', 'inner_activation': 'sigmoid', 'truncate_gradient': -1, 'name': 'JZS1', 'input_dim': 20, 'init': 'glorot_uniform'}, {'output_dim': 32, 'inner_init': 'orthogonal', 'return_sequences': False, 'activation': 'tanh', 'inner_activation': 'sigmoid', 'truncate_gradient': -1, 'name': 'JZS1', 'input_dim': 20, 'init': 'glorot_uniform'}, {'output_dim': 32, 'name': 'Dense', 'input_dim': 64, 'init': 'glorot_uniform', 'activation': 'relu'}, {'output_dim': 1, 'name': 'Dense', 'input_dim': 32, 'init': 'glorot_uniform', 'activation': 'sigmoid'}]}

In [54]:
log_target_values = np.maximum(0, 1.0 - np.log(target_values) / np.log(5000))

In [56]:
history = graph.fit({'peptide':padded_peptides, 'mhc':padded_mhc, 'affinity_output':log_target_values}, nb_epoch=10)
predictions = graph.predict({'peptide':padded_peptides, 'mhc':padded_mhc})


Epoch 0
22400/96795 [=====>........................] - ETA: 371s - affinity_output: 0.0982
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-56-9d67aa3c8b0d> in <module>()
----> 1 history = graph.fit({'peptide':padded_peptides, 'mhc':padded_mhc, 'affinity_output':log_target_values}, nb_epoch=10)
      2 predictions = graph.predict({'peptide':padded_peptides, 'mhc':padded_mhc})

/Users/iskander/code/keras/keras/models.py in fit(self, data, batch_size, nb_epoch, verbose, callbacks, validation_split, validation_data, shuffle)
    535         metrics = self.output_order + ['val_' + m for m in self.output_order]
    536         history = self._fit(f, ins, out_labels=out_labels, batch_size=batch_size, nb_epoch=nb_epoch, verbose=verbose, callbacks=callbacks, \
--> 537             validation_split=validation_split, val_f=val_f, val_ins=val_ins, shuffle=shuffle, metrics=metrics)
    538         return history
    539 

/Users/iskander/code/keras/keras/models.py in _fit(self, f, ins, out_labels, batch_size, nb_epoch, verbose, callbacks, validation_split, val_f, val_ins, shuffle, metrics)
    133                 batch_logs['size'] = len(batch_ids)
    134                 callbacks.on_batch_begin(batch_index, batch_logs)
--> 135                 outs = f(*ins_batch)
    136                 if type(outs) != list:
    137                     outs = [outs]

/Library/Frameworks/Python.framework/Versions/3.4/lib/python3.4/site-packages/theano/compile/function_module.py in __call__(self, *args, **kwargs)
    593         t0_fn = time.time()
    594         try:
--> 595             outputs = self.fn()
    596         except Exception:
    597             if hasattr(self.fn, 'position_of_error'):

KeyboardInterrupt: 

In [ ]: