In [1]:
import matplotlib.pyplot as plt
import pickle as pkl
import pandas as pd
import json
try:
    import autograd.numpy as np
except ImportError:
    import numpy as np
# import autograd.numpy as np
import seaborn as sns

sns.set_context('poster')

from pin import pin

%matplotlib inline


/Users/ericmjl/anaconda/lib/python3.5/site-packages/sklearn/preprocessing/data.py:583: DeprecationWarning: Passing 1d arrays as data is deprecated in 0.17 and will raise ValueError in 0.19. Reshape your data either using X.reshape(-1, 1) if your data has a single feature or X.reshape(1, -1) if it contains a single sample.
  warnings.warn(DEPRECATION_MSG_1D, DeprecationWarning)
/Users/ericmjl/anaconda/lib/python3.5/site-packages/sklearn/preprocessing/data.py:583: DeprecationWarning: Passing 1d arrays as data is deprecated in 0.17 and will raise ValueError in 0.19. Reshape your data either using X.reshape(-1, 1) if your data has a single feature or X.reshape(1, -1) if it contains a single sample.
  warnings.warn(DEPRECATION_MSG_1D, DeprecationWarning)

In [2]:
with open('../experiments/all_graphs/outputs/all-graphs_1000-iters_wbs.pkl', 'rb') as f:
    wb = pkl.load(f)

In [3]:
# Visualize importance of each of the features in the linear regression stage

fig = plt.figure(figsize=(6,4))
ylim = (-0.3, 0.6)

layer0 = wb['layer0_GraphConvLayer']['biases'].T
layer1 = wb['layer1_FingerprintLayer']['weights'].T
layer2 = wb['layer2_LinearRegressionLayer']['linweights']

plt.bar(range(len(layer2)), layer2, label='LinReg Weights', color='green')
plt.xlim(0, 36)
plt.legend(loc='best')
plt.ylim(*ylim)
plt.xlabel('Feature ID')
plt.subplots_adjust(bottom=0.2)
plt.savefig('figures/linreg_weights.pdf')



In [4]:
from sklearn.preprocessing import MinMaxScaler

mms = MinMaxScaler(feature_range=(10,100))
opacities = mms.fit_transform(np.abs(layer2))
for i, op in enumerate(opacities.astype(int)):
    print(i, op)


0 [10]
1 [31]
2 [10]
3 [10]
4 [11]
5 [54]
6 [10]
7 [10]
8 [11]
9 [16]
10 [10]
11 [100]
12 [11]
13 [10]
14 [10]
15 [10]
16 [10]
17 [10]
18 [11]
19 [98]
20 [17]
21 [10]
22 [20]
23 [16]
24 [10]
25 [24]
26 [33]
27 [10]
28 [10]
29 [10]
30 [56]
31 [10]
32 [10]
33 [35]
34 [10]
35 [10]

In [5]:
# Layer 0: Convolution Weights
fig = plt.figure(figsize=(6,5))
ax = fig.add_subplot(111)
plt.pcolor(wb['layer0_GraphConvLayer']['weights'], cmap=plt.cm.RdBu)
plt.colorbar()
ax.set_xlim(0,36)
ax.set_ylim(0,36)
ax.set_xlabel('Feature ID')
ax.set_ylabel('Feature ID')


Out[5]:
<matplotlib.text.Text at 0x11a27e0f0>

In [6]:
drug_data = pd.read_csv('../data/hiv_data/hiv-protease-data-expanded.csv', index_col=0)
drug_data['FPV'] = drug_data['FPV'].apply(np.log10)
drug_data.head(3)


Out[6]:
ATV DRV FPV IDV LPV NFV SQV SeqID TPV seqid sequence sequence_object weight
0 NaN NaN 0.397940 16.3 NaN 38.6 16.1 2996 NaN 2996-0 PQITLWQRPIVTIKIGGQLKEALLDTGADDTVLEDVNLPGRWKPKM... ID: 2996-0\nName: <unknown name>\nDescription:... 0.50
1 NaN NaN 0.397940 16.3 NaN 38.6 16.1 2996 NaN 2996-1 PQITLWQRPIVTIKIGGQLKEALLDTGADDTVLEDVNLPGRWKPKM... ID: 2996-1\nName: <unknown name>\nDescription:... 0.50
2 NaN NaN -0.154902 0.8 NaN 0.8 1.1 4387 NaN 4387-0 PQITLWQRPLVTIKVGGQLKEALLDTGADDTVLEDMELPGRWKPKM... ID: 4387-0\nName: <unknown name>\nDescription:... 0.25

In [7]:
drug_data[(drug_data['FPV'] == drug_data['FPV'].max()) & (drug_data['weight'] == 1)].tail(3)


Out[7]:
ATV DRV FPV IDV LPV NFV SQV SeqID TPV seqid sequence sequence_object weight
6007 104.0 120.0 2.60206 30.0 65.0 54.0 45.0 191476 10.0 191476-0 PQITLWQRPFIPVKVGGQPTEALLDTGADDTIFEGINLPGRWKPKM... ID: 191476-0\nName: <unknown name>\nDescriptio... 1.0
6168 700.0 200.0 2.60206 177.0 500.0 600.0 86.0 205644 800.0 205644-0 PQITLWQRPIISVRIGGQPVEALLDTGADDTILDNINLPGRWTPKL... ID: 205644-0\nName: <unknown name>\nDescriptio... 1.0
6518 109.0 239.0 2.60206 21.0 87.0 94.0 63.0 230072 7.4 230072-0 PQITLWQRPVITVRVGGQLTEALLDTGADDTIFEEINLPGKWKPKL... ID: 230072-0\nName: <unknown name>\nDescriptio... 1.0

In [8]:
# We will use the sequence with the SeqID: 230072-0 as the example for "high" FPV resistance.
# The job that has this SeqID is...

seqid = '230072-0'

with open('../data/batch_summary.json', 'r') as f:
    model_data = json.load(f)

def get_project_code_by_title(model_data, title):
    """
    Returns the project which has a title that matches.
    """
    for p in model_data['projects']:
        if p['title'] == title:
            return p
            break
    
code = get_project_code_by_title(model_data, seqid)['code']
p = pin.ProteinInteractionNetwork('../data/batch_models/{0}/model_01.pdb'.format(code))

In [9]:
# Find the node that gives the maximal activation on the convolution, given the weights in the convolutional layer.
activations = dict()
for n, d in p.nodes(data=True):
    activations[n] = np.dot(d['features'], wb['layer2_LinearRegressionLayer']['linweights'])[0]
    
activations


Out[9]:
{'A10VAL': array([-0.53198185]),
 'A11ILE': array([-3.88627819]),
 'A12THR': array([-0.49427965]),
 'A13VAL': array([-3.79999389]),
 'A14ARG': array([-0.61588536]),
 'A15VAL': array([-3.81228992]),
 'A16GLY': array([-0.47855725]),
 'A17GLY': array([-0.47794508]),
 'A18GLN': array([-0.50026661]),
 'A19LEU': array([-0.48830488]),
 'A1PRO': array([-0.84818636]),
 'A20THR': array([-0.99317094]),
 'A21GLU': array([-0.52904048]),
 'A22ALA': array([-2.34855718]),
 'A23LEU': array([-1.40680315]),
 'A24LEU': array([-5.11919298]),
 'A25ASP': array([-0.53442456]),
 'A26THR': array([-0.79100526]),
 'A27GLY': array([-0.47887889]),
 'A28ALA': array([-1.04855136]),
 'A29ASP': array([-0.56043727]),
 'A2GLN': array([-0.6879741]),
 'A30ASP': array([-0.28804816]),
 'A31THR': array([-0.91748931]),
 'A32ILE': array([-3.7359737]),
 'A33PHE': array([-5.14388299]),
 'A34GLU': array([-0.28331906]),
 'A35GLU': array([-0.28598398]),
 'A36ILE': array([-2.2330974]),
 'A37ASN': array([-0.4387759]),
 'A38LEU': array([-3.97185435]),
 'A39PRO': array([-0.48925899]),
 'A3ILE': array([-3.12785581]),
 'A40GLY': array([-0.47363696]),
 'A41LYS': array([-1.01061817]),
 'A42TRP': array([-1.0110934]),
 'A43LYS': array([-0.4469238]),
 'A44PRO': array([-0.95933563]),
 'A45LYS': array([-0.44790785]),
 'A46LEU': array([-0.9421579]),
 'A47ILE': array([-3.13292813]),
 'A48GLY': array([-0.47369903]),
 'A49GLY': array([-0.47392737]),
 'A4THR': array([-0.00466797]),
 'A50ILE': array([-2.03509766]),
 'A51GLY': array([-0.47075938]),
 'A52GLY': array([-0.47156982]),
 'A53PHE': array([-0.94682688]),
 'A54LEU': array([-3.19140391]),
 'A55LYS': array([-0.45035408]),
 'A56VAL': array([-3.07989865]),
 'A57LYS': array([-0.44584834]),
 'A58GLN': array([-0.49671444]),
 'A59TYR': array([-2.40894349]),
 'A5LEU': array([-1.51964508]),
 'A60GLU': array([-1.09745107]),
 'A61GLN': array([-0.50078265]),
 'A62ILE': array([-2.29342674]),
 'A63PRO': array([-0.48987344]),
 'A64ILE': array([-4.76348121]),
 'A65GLU': array([-0.78816505]),
 'A66ILE': array([-4.53321126]),
 'A67CYS': array([-0.66025709]),
 'A68GLU': array([-0.29012373]),
 'A69HIS': array([-0.96388101]),
 'A6TRP': array([ 0.29030681]),
 'A70THR': array([-0.49459114]),
 'A71ILE': array([-2.13716017]),
 'A72MET': array([-0.49955605]),
 'A73SER': array([-0.94010658]),
 'A74THR': array([-0.49501325]),
 'A75VAL': array([-3.32161033]),
 'A76LEU': array([-1.89811445]),
 'A77VAL': array([-2.70243144]),
 'A78GLY': array([-0.47198288]),
 'A79PRO': array([-1.52395411]),
 'A7GLN': array([-0.00910153]),
 'A80THR': array([-0.49410093]),
 'A81PRO': array([-0.94316806]),
 'A82VAL': array([-1.51466961]),
 'A83ASN': array([-0.93656865]),
 'A84VAL': array([-2.63849207]),
 'A85ILE': array([-5.27346963]),
 'A86GLY': array([-0.47478948]),
 'A87ARG': array([-0.63529135]),
 'A88ASN': array([-1.38286085]),
 'A89VAL': array([-2.80862273]),
 'A8ARG': array([ 0.12990423]),
 'A90MET': array([-2.74822449]),
 'A91THR': array([-0.49855878]),
 'A92GLN': array([-0.50157485]),
 'A93LEU': array([-2.03532168]),
 'A94GLY': array([-0.47828203]),
 'A95CYS': array([-0.65747657]),
 'A96THR': array([-0.86701197]),
 'A97LEU': array([-3.67468823]),
 'A98ASN': array([-1.62304095]),
 'A99PHE': array([-4.17985053]),
 'A9PRO': array([-2.53351114]),
 'B10VAL': array([-0.83185173]),
 'B11ILE': array([-3.8857176]),
 'B12THR': array([-0.4942744]),
 'B13VAL': array([-3.80784256]),
 'B14ARG': array([-0.36018518]),
 'B15VAL': array([-3.81889648]),
 'B16GLY': array([-0.47874022]),
 'B17GLY': array([-0.47933536]),
 'B18GLN': array([-1.04313391]),
 'B19LEU': array([-0.48833115]),
 'B1PRO': array([-0.82733071]),
 'B20THR': array([-0.9810832]),
 'B21GLU': array([-0.5291622]),
 'B22ALA': array([-2.3748266]),
 'B23LEU': array([-2.03402353]),
 'B24LEU': array([-5.11295078]),
 'B25ASP': array([-0.53275996]),
 'B26THR': array([-0.79054178]),
 'B27GLY': array([-0.47776205]),
 'B28ALA': array([-1.0482694]),
 'B29ASP': array([-0.55697618]),
 'B2GLN': array([-0.22854545]),
 'B30ASP': array([-0.28753064]),
 'B31THR': array([-0.91376394]),
 'B32ILE': array([-2.9322601]),
 'B33PHE': array([-5.15670857]),
 'B34GLU': array([-0.28562957]),
 'B35GLU': array([-0.28539954]),
 'B36ILE': array([-2.93088476]),
 'B37ASN': array([-0.97908867]),
 'B38LEU': array([-3.94693102]),
 'B39PRO': array([-0.48919239]),
 'B3ILE': array([-3.10800072]),
 'B40GLY': array([-0.47484421]),
 'B41LYS': array([-1.00934914]),
 'B42TRP': array([-1.32201122]),
 'B43LYS': array([-0.44765217]),
 'B44PRO': array([-0.96224071]),
 'B45LYS': array([-0.44787795]),
 'B46LEU': array([-0.94480843]),
 'B47ILE': array([-3.17575588]),
 'B48GLY': array([-0.4749543]),
 'B49GLY': array([-0.47496142]),
 'B4THR': array([-0.00466797]),
 'B50ILE': array([-3.28938655]),
 'B51GLY': array([-0.47683813]),
 'B52GLY': array([-0.47529147]),
 'B53PHE': array([-0.95250603]),
 'B54LEU': array([-1.82448417]),
 'B55LYS': array([-0.44799469]),
 'B56VAL': array([-3.08436653]),
 'B57LYS': array([-1.2192604]),
 'B58GLN': array([-0.4984363]),
 'B59TYR': array([-2.76509964]),
 'B5LEU': array([-1.52759316]),
 'B60GLU': array([-1.09438928]),
 'B61GLN': array([-0.50119651]),
 'B62ILE': array([-2.27226922]),
 'B63PRO': array([-0.48937767]),
 'B64ILE': array([-4.75316441]),
 'B65GLU': array([-0.88050343]),
 'B66ILE': array([-4.53576405]),
 'B67CYS': array([-0.66233286]),
 'B68GLU': array([-0.45872297]),
 'B69HIS': array([-1.48159159]),
 'B6TRP': array([ 0.29030681]),
 'B70THR': array([-0.49346746]),
 'B71ILE': array([-2.14743008]),
 'B72MET': array([-0.5019644]),
 'B73SER': array([-0.94465346]),
 'B74THR': array([-0.49587031]),
 'B75VAL': array([-4.06247477]),
 'B76LEU': array([-1.89229946]),
 'B77VAL': array([-2.68070911]),
 'B78GLY': array([-0.47331416]),
 'B79PRO': array([-1.51768174]),
 'B7GLN': array([-0.00910153]),
 'B80THR': array([-0.49710612]),
 'B81PRO': array([-0.95119968]),
 'B82VAL': array([-1.51231399]),
 'B83ASN': array([-0.92732113]),
 'B84VAL': array([-2.20343188]),
 'B85ILE': array([-5.29138724]),
 'B86GLY': array([-0.47500001]),
 'B87ARG': array([-0.63245266]),
 'B88ASN': array([-1.38169547]),
 'B89VAL': array([-2.81843513]),
 'B8ARG': array([ 0.12990423]),
 'B90MET': array([-2.76765012]),
 'B91THR': array([-0.49931487]),
 'B92GLN': array([-0.50086715]),
 'B93LEU': array([-2.04771597]),
 'B94GLY': array([-0.47919216]),
 'B95CYS': array([-0.65698478]),
 'B96THR': array([-1.31194283]),
 'B97LEU': array([-3.65730294]),
 'B98ASN': array([-1.62911373]),
 'B99PHE': array([-4.2108203]),
 'B9PRO': array([-2.54576489])}

In [10]:
# Some helper functions for the notebook.

def max_key(dictionary):
    vals = [i for i in dictionary.values()]
    max_val = max(vals)
    for k, v in dictionary.items():
        if v == max_val:
            return k

def invert(dictionary):
    from collections import defaultdict
    
    
    inverted = defaultdict(list)
    for k, v in dictionary.items():
        inverted[v].append(k)
    
    return inverted

def get_keys_by_value(dictionary, value):
    """
    More general than get_max_key().
    """
    keys = []
    for k, v in dictionary.items():
        if v == value:
            keys.append(k)
    return keys

In [11]:
layer = layer2
layer_idxd = {i: val for i, val in enumerate(layer)}
layer_values = sorted(layer, reverse=True)
layer_values

maxkey = max_key(layer_idxd)
print(maxkey)
nextmaxkey = get_keys_by_value(layer_idxd, layer_values[1])[0]
print(nextmaxkey)


11
19

In [ ]:


In [ ]:


In [12]:
top_features = []
for n, a in activations.items():
    top_features.append((n, a[0]))

top_feats_sorted = sorted(top_features, key=lambda x:x[1], reverse=True)
top_feats_sorted


Out[12]:
[('B6TRP', 0.29030681496727889),
 ('A6TRP', 0.29030681496727889),
 ('B8ARG', 0.12990422849136798),
 ('A8ARG', 0.12990422849136798),
 ('B4THR', -0.0046679749445809569),
 ('A4THR', -0.0046679749445809569),
 ('A7GLN', -0.0091015297997234686),
 ('B7GLN', -0.0091015297997234686),
 ('B2GLN', -0.22854544767900156),
 ('A34GLU', -0.28331905954142106),
 ('B35GLU', -0.28539953510153043),
 ('B34GLU', -0.28562956762571795),
 ('A35GLU', -0.2859839782233477),
 ('B30ASP', -0.28753064260426131),
 ('A30ASP', -0.28804816400311317),
 ('A68GLU', -0.29012373050084134),
 ('B14ARG', -0.36018517635201514),
 ('A37ASN', -0.4387759042074193),
 ('A57LYS', -0.44584833841803967),
 ('A43LYS', -0.44692380141441107),
 ('B43LYS', -0.44765216980582562),
 ('B45LYS', -0.44787795310613004),
 ('A45LYS', -0.4479078541832277),
 ('B55LYS', -0.44799468978375218),
 ('A55LYS', -0.45035408236889746),
 ('B68GLU', -0.45872297152273855),
 ('A51GLY', -0.47075938434742548),
 ('A52GLY', -0.47156981745258408),
 ('A78GLY', -0.47198288226683638),
 ('B78GLY', -0.47331415779309238),
 ('A40GLY', -0.47363695775873693),
 ('A48GLY', -0.47369902771304928),
 ('A49GLY', -0.47392737130113771),
 ('A86GLY', -0.47478947881318811),
 ('B40GLY', -0.47484420684899864),
 ('B48GLY', -0.47495429891369234),
 ('B49GLY', -0.47496142410141207),
 ('B86GLY', -0.47500001110190437),
 ('B52GLY', -0.47529147448185394),
 ('B51GLY', -0.47683813073740255),
 ('B27GLY', -0.47776205000369587),
 ('A17GLY', -0.47794508217085163),
 ('A94GLY', -0.47828202976913237),
 ('A16GLY', -0.47855724649864717),
 ('B16GLY', -0.47874022043312758),
 ('A27GLY', -0.47887889189781307),
 ('B94GLY', -0.47919215726171704),
 ('B17GLY', -0.47933535644214409),
 ('A19LEU', -0.4883048838105547),
 ('B19LEU', -0.48833115013742701),
 ('B39PRO', -0.48919238790695979),
 ('A39PRO', -0.48925899158327252),
 ('B63PRO', -0.4893776697078806),
 ('A63PRO', -0.48987343677468737),
 ('B70THR', -0.49346745785143681),
 ('A80THR', -0.49410092721794602),
 ('B12THR', -0.49427439524358324),
 ('A12THR', -0.4942796547600074),
 ('A70THR', -0.49459113677611161),
 ('A74THR', -0.49501324995481516),
 ('B74THR', -0.49587031458247549),
 ('A58GLN', -0.49671444327819791),
 ('B80THR', -0.49710611808515387),
 ('B58GLN', -0.49843629737369155),
 ('A91THR', -0.49855877501644535),
 ('B91THR', -0.49931487291541393),
 ('A72MET', -0.49955604899133438),
 ('A18GLN', -0.50026661011857698),
 ('A61GLN', -0.50078265280818335),
 ('B92GLN', -0.5008671483479924),
 ('B61GLN', -0.5011965148950206),
 ('A92GLN', -0.50157484998934221),
 ('B72MET', -0.50196439726461639),
 ('A21GLU', -0.52904048424434746),
 ('B21GLU', -0.5291621951361013),
 ('A10VAL', -0.53198185444181756),
 ('B25ASP', -0.5327599597392616),
 ('A25ASP', -0.53442455942344025),
 ('B29ASP', -0.55697617815697165),
 ('A29ASP', -0.56043727414495115),
 ('A14ARG', -0.61588536255051507),
 ('B87ARG', -0.63245266080642748),
 ('A87ARG', -0.63529134760902417),
 ('B95CYS', -0.65698478493971635),
 ('A95CYS', -0.65747656933049425),
 ('A67CYS', -0.6602570883304395),
 ('B67CYS', -0.6623328643151738),
 ('A2GLN', -0.68797410438935735),
 ('A65GLU', -0.78816504808382903),
 ('B26THR', -0.79054178084674143),
 ('A26THR', -0.79100525785949749),
 ('B1PRO', -0.82733071359502464),
 ('B10VAL', -0.83185173124346723),
 ('A1PRO', -0.84818636285804161),
 ('A96THR', -0.86701196715418516),
 ('B65GLU', -0.88050342844399621),
 ('B31THR', -0.91376393539681877),
 ('A31THR', -0.91748931263751388),
 ('B83ASN', -0.92732113359460377),
 ('A83ASN', -0.9365686477319185),
 ('A73SER', -0.94010658297089145),
 ('A46LEU', -0.94215789632156466),
 ('A81PRO', -0.94316806240305184),
 ('B73SER', -0.94465345540785228),
 ('B46LEU', -0.94480842537246046),
 ('A53PHE', -0.94682688004466664),
 ('B81PRO', -0.95119968091339047),
 ('B53PHE', -0.95250602570533904),
 ('A44PRO', -0.95933563033275826),
 ('B44PRO', -0.96224070919141713),
 ('A69HIS', -0.9638810133850686),
 ('B37ASN', -0.97908867160611834),
 ('B20THR', -0.98108319606923766),
 ('A20THR', -0.99317093757145503),
 ('B41LYS', -1.0093491388291813),
 ('A41LYS', -1.0106181674603687),
 ('A42TRP', -1.0110934014578439),
 ('B18GLN', -1.0431339111358562),
 ('B28ALA', -1.0482693980613031),
 ('A28ALA', -1.0485513597246006),
 ('B60GLU', -1.0943892752164381),
 ('A60GLU', -1.0974510669489572),
 ('B57LYS', -1.219260396487936),
 ('B96THR', -1.3119428334481806),
 ('B42TRP', -1.3220112226697545),
 ('B88ASN', -1.3816954717607008),
 ('A88ASN', -1.3828608543264427),
 ('A23LEU', -1.4068031498246958),
 ('B69HIS', -1.4815915889791804),
 ('B82VAL', -1.5123139937359587),
 ('A82VAL', -1.5146696086164002),
 ('B79PRO', -1.5176817388232446),
 ('A5LEU', -1.5196450773639927),
 ('A79PRO', -1.5239541112527908),
 ('B5LEU', -1.5275931564679628),
 ('A98ASN', -1.6230409497564731),
 ('B98ASN', -1.6291137264449129),
 ('B54LEU', -1.8244841692959302),
 ('B76LEU', -1.8922994578510903),
 ('A76LEU', -1.8981144466284428),
 ('B23LEU', -2.0340235321040709),
 ('A50ILE', -2.0350976560954201),
 ('A93LEU', -2.035321678008283),
 ('B93LEU', -2.0477159741080042),
 ('A71ILE', -2.1371601728760314),
 ('B71ILE', -2.1474300802667083),
 ('B84VAL', -2.2034318785598903),
 ('A36ILE', -2.2330973959621487),
 ('B62ILE', -2.272269215699072),
 ('A62ILE', -2.2934267408074076),
 ('A22ALA', -2.3485571752850012),
 ('B22ALA', -2.3748265978891565),
 ('A59TYR', -2.4089434886270857),
 ('A9PRO', -2.5335111427443002),
 ('B9PRO', -2.5457648905436612),
 ('A84VAL', -2.6384920656326369),
 ('B77VAL', -2.6807091063039552),
 ('A77VAL', -2.7024314366030597),
 ('A90MET', -2.7482244931988387),
 ('B59TYR', -2.7650996355941762),
 ('B90MET', -2.7676501160068021),
 ('A89VAL', -2.8086227281663874),
 ('B89VAL', -2.8184351251040418),
 ('B36ILE', -2.9308847558151747),
 ('B32ILE', -2.9322600994841164),
 ('A56VAL', -3.0798986507834378),
 ('B56VAL', -3.0843665259237101),
 ('B3ILE', -3.1080007216625427),
 ('A3ILE', -3.127855810143501),
 ('A47ILE', -3.1329281250409395),
 ('B47ILE', -3.1757558822122962),
 ('A54LEU', -3.1914039132326875),
 ('B50ILE', -3.2893865544513279),
 ('A75VAL', -3.3216103275193443),
 ('B97LEU', -3.6573029384096225),
 ('A97LEU', -3.6746882299826522),
 ('A32ILE', -3.7359737026849249),
 ('A13VAL', -3.7999938947692242),
 ('B13VAL', -3.8078425566651375),
 ('A15VAL', -3.8122899199809379),
 ('B15VAL', -3.8188964794183442),
 ('B11ILE', -3.8857176007070491),
 ('A11ILE', -3.8862781878680561),
 ('B38LEU', -3.9469310212458204),
 ('A38LEU', -3.9718543471266328),
 ('B75VAL', -4.0624747743562404),
 ('A99PHE', -4.1798505326990716),
 ('B99PHE', -4.2108202994646922),
 ('A66ILE', -4.5332112571166281),
 ('B66ILE', -4.5357640519699185),
 ('B64ILE', -4.7531644064728207),
 ('A64ILE', -4.7634812052823809),
 ('B24LEU', -5.1129507803057272),
 ('A24LEU', -5.1191929751357348),
 ('A33PHE', -5.1438829926885363),
 ('B33PHE', -5.1567085676305746),
 ('A85ILE', -5.27346963453064),
 ('B85ILE', -5.2913872441019301)]

In [13]:
top_feats_df = pd.DataFrame(top_feats_sorted)
top_feats_df[1] = -top_feats_df[1]
top_feats_df.sort_values(1, ascending=False, inplace=True)
top_feats_df.head()


Out[13]:
0 1
197 B85ILE 5.291387
196 A85ILE 5.273470
195 B33PHE 5.156709
194 A33PHE 5.143883
193 A24LEU 5.119193

In [14]:
p.neighbors('B85ILE')


Out[14]:
['B90MET',
 'B33PHE',
 'B66ILE',
 'B89VAL',
 'B86GLY',
 'B11ILE',
 'B13VAL',
 'B64ILE',
 'B24LEU',
 'B84VAL',
 'B22ALA']

In [15]:
from sklearn.preprocessing import MinMaxScaler
ss = MinMaxScaler()
top_feats_transformed = ss.fit_transform(top_feats_df[1].reshape(-1, 1))

In [16]:
import networkx as nx

nois = top_feats_df.sort_values(1, ascending=False)[0:1][0].values
print(nois)
def get_subgraph_nodes_and_neighbors(p, nois):
    g = p.to_directed()
    nois = nois
    sg_nodes = []
    for n in nois:
        sg_nodes.append(n)
        sg_nodes.extend(p.neighbors(n))

    sg = g.subgraph(sg_nodes)

    return sg

def draw_subgraph_activations(p, nois, cmap):
    """
    Draws the subgraph of nodes 
    """
    sg = get_subgraph_nodes_and_neighbors(p, nois)
    fig = plt.figure(figsize=(10,10))
    pos = nx.spring_layout(sg, scale=3, k=0.4, iterations=100)

    cmap = cmap
    nodemap = dict()
    for r, d in top_feats_df.iterrows():
        nodemap[d[0]] = cmap(top_feats_transformed)[r][0]


    nx.draw(sg, pos, node_size=4000, node_color=[nodemap[n] for n in sg.nodes()])
    nx.draw_networkx_labels(sg, pos, font_size=16)

draw_subgraph_activations(p, nois, cmap=plt.get_cmap('Blues_r'))
plt.savefig('figures/top_activating_features.pdf')


['B85ILE']

In [17]:
!ls


baseline.ipynb          figures                 script_prototypes.ipynb
custom_funcs.py         matrix-prototype.ipynb  visualize_weights.ipynb
cython_reindex_data.pyx predictions.ipynb

In [18]:
resinums = ''
for n, d in get_subgraph_nodes_and_neighbors(p, nois).nodes(data=True):
    resinums = resinums + str(d['resi_num']) + '+'
resinums


Out[18]:
'33+66+86+11+85+24+89+64+90+13+84+22+'

In [19]:
p


Out[19]:
<pin.pin.ProteinInteractionNetwork at 0x11bcc70b8>

In [ ]: