In [117]:
import pandas as pd
import numpy as np
import nolearn
import matplotlib.pyplot as plt
import seaborn
import sklearn.linear_model as lm
import scipy.stats as sps
import math

from Bio import SeqIO
from collections import Counter
from decimal import Decimal
from lasagne import layers, nonlinearities
from lasagne.updates import nesterov_momentum
from lasagne import layers
from nolearn.lasagne import NeuralNet
from sklearn.ensemble import RandomForestRegressor, AdaBoostRegressor, GradientBoostingRegressor
from sklearn.cross_validation import train_test_split
from sklearn.metrics import r2_score, mean_squared_error
from sklearn.svm import SVR

seaborn.set_style('white')
seaborn.set_context('poster')

%matplotlib inline

In [91]:
# Read in the protease inhibitor data
widths = [8]
widths.extend([4]*8)
widths.extend([4]*99)
data = pd.read_csv('hiv-nnrt-data.csv', index_col='SeqID')
drug_cols = data.columns[0:4]
feat_cols = data.columns[4:]

In [92]:
# Read in the consensus data
consensus = SeqIO.read('hiv-rt-consensus.fasta', 'fasta')

consensus_map = {i+1:letter for i, letter in enumerate(str(consensus.seq))}

In [93]:
# Because there are '-' characters in the dataset, representing consensus sequence at each of the positions, 
# they need to be replaced with the actual consensus letter.

for i, col in enumerate(feat_cols):
    # Replace '-' with the consensus letter.
    data[col] = data[col].replace({'-':consensus_map[i+1]})
    
    # Replace '.' with np.nan
    data[col] = data[col].replace({'.':np.nan})
    
    # Replace 'X' with np.nan
    data[col] = data[col].replace({'X':np.nan})

In [94]:
# Drop any feat_cols that have np.nan inside them. We don't want low quality sequences.
data.dropna(inplace=True, subset=feat_cols)

In [95]:
data


Out[95]:
EFV NVP ETR RPV P1 P2 P3 P4 P5 P6 ... P231 P232 P233 P234 P235 P236 P237 P238 P239 P240
SeqID
4427 0.4 0.4 NaN NaN P I S P I E ... G Y E L H P D K W T
4433 0.7 0.7 NaN NaN P I S P I ED ... G Y E L H P D K W T
4483 0.3 0.4 NaN NaN P I S P I E ... G Y E L H P D K W T
4487 7.0 200.0 NaN NaN P I S P I E ... G Y E L H P D K W T
4689 0.5 0.3 NaN NaN P I S P I E ... G Y E L H P D K W T
4697 0.5 0.5 NaN NaN P I S P I E ... G Y E L H P D K W T
5071 200.0 200.0 9.6 6.0 P I S P I E ... G Y E L H P D K W T
5222 0.3 0.3 NaN NaN P I S P I E ... G Y E L H P D K W T
5280 0.5 1.0 NaN NaN P I S P I E ... G Y E L H P D K W T
5445 200.0 200.0 NaN NaN P I S P I E ... G Y E L H P D K W T
5463 1.0 0.7 NaN NaN P I S P I E ... G Y E L H P D K W T
5465 17.4 42.0 NaN NaN P I S P I E ... G Y E L H P D K W T
5641 0.4 0.5 NaN NaN P I S P I E ... G Y E L H P D K W T
5682 154.0 200.0 NaN NaN P I S P I E ... G Y E L H P D K W T
5708 200.0 200.0 NaN NaN P I S P I E ... G Y E L H P D K W T
6029 200.0 79.0 NaN NaN P I S P I E ... G Y E L H P D K W T
6485 2.8 200.0 NaN NaN P I S P I E ... G Y E L H P D K W T
6519 18.0 25.0 NaN NaN P I S P I E ... G Y E LH H P D K W T
6540 1.0 53.0 NaN NaN P I S P I E ... G Y E L H P D K W T
6569 45.6 200.0 NaN NaN P I S P I E ... G Y E L H P D K W T
6709 0.4 0.6 NaN NaN P I S P I E ... G Y E L H P D K W T
6796 0.2 0.3 NaN NaN P I S P I E ... G Y E L H P D K W T
6820 4.3 200.0 NaN NaN P I S P I E ... G Y E L H P E K W T
6859 0.4 0.6 NaN NaN P I S P I E ... G Y E L H P D K W T
6876 0.3 0.5 NaN NaN P I S P I E ... G Y E L H P D K W T
7328 24.0 200.0 NaN NaN P I S P I E ... G Y E L H P D K W T
7347 0.5 0.6 NaN NaN P I S P I E ... G Y E L H P D K W T
7348 0.4 0.3 NaN NaN P I S P I E ... G Y E L H P D K W T
7350 1.0 1.1 NaN NaN P I S P I E ... G Y E L H P D K W T
7360 0.6 0.9 NaN NaN P I S P I E ... G Y E L H P D K W T
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
259222 11.7 144.0 44.0 19.2 P I S P I E ... G Y E L H P D K W T
259224 1.4 1.3 1.3 1.5 P I S P I ED ... G Y E L H P D K W T
259228 0.6 0.4 0.7 0.7 P I S P I E ... G Y E L H P D K W T
259230 0.7 153.0 53.2 27.7 P I S P I E ... G Y E L H P D K W T
259232 0.7 180.0 67.7 32.1 P I S P I E ... G Y E L H P D K W T
259234 1.2 0.9 1.1 1.2 P I S P I E ... G Y E L H P D K W T
259236 18.2 22.2 11.1 10.4 P I S P I E ... G Y E LI H P D K W T
259238 0.9 0.7 0.8 0.7 P I S P I E ... G Y E L H P D K W T
259242 0.8 0.8 0.8 0.9 P I S P I E ... G Y E L H P D K W T
259244 119.0 91.8 11.6 15.6 P I S P I E ... G Y E L H P D K W T
259246 0.8 0.9 0.9 0.9 P I S P I E ... G Y E L H P D K W T
259250 0.6 0.2 0.6 0.6 P I S P I E ... G Y E L H P D K W T
259254 1.0 0.7 1.0 1.1 P I S P I E ... G Y E L H P D K W T
259256 0.8 153.0 17.7 25.9 P I S P I E ... G Y E L H P D K W T
259258 1.1 1.4 1.3 1.3 P I S P I E ... G Y E L H P D K W T
259266 0.8 0.4 0.7 0.8 P I S P I E ... G Y E L H P D K W T
259268 1.1 194.0 10.0 22.8 P I S P I E ... G Y E L H P D K W T
268425 1.2 1.0 1.0 1.1 P I S P I E ... G Y E L H P D K W T
268426 1500.0 20.7 6.5 17.4 P I S P I E ... G Y E L H P D K W T
268427 1500.0 51.3 8.2 13.8 P I S P I E ... G Y E L H P D K W T
268428 6.8 2.1 1.3 0.7 P I S P I E ... G Y E L H P D K W T
268429 6.8 34.8 1.1 1.2 P I S P I E ... G Y E L H P D K W T
268430 14.2 16.5 3.1 2.8 P I S P I E ... G Y E L H P D K W T
268431 9.3 1.3 1.0 1.0 P I S P I E ... G Y E L H P D K W T
268432 1500.0 26.4 20.3 6.8 P I S P I E ... G Y E L H P D K W T
270334 2.1 3.3 3.1 3.2 P I S P I E ... G Y E L H P D K W T
270335 1.6 0.8 3.3 3.2 P I S P I E ... G Y E L H P D K W T
270336 1.9 2.0 3.2 3.3 P I S P I E ... G Y E L H P D K W T
270337 1.3 0.9 2.5 2.3 P I S P I E ... G Y E L H P D K W T
270338 2.1 170.0 4.9 3.0 P I S P I E ... G Y E L H P D K W T

1495 rows × 244 columns


In [96]:
# Drop any feat_cols that are completely conserved.

# The nonconserved_cols list will serve as a convenient selector for the X- data from the 
# original dataframe.
nonconserved_cols = []
for col in feat_cols:
    if len(pd.unique(data[col])) == 1:
        data.drop(col, axis=1, inplace=True)
        
    else:
        nonconserved_cols.append(col)

In [97]:
drug_cols


Out[97]:
Index(['EFV', 'NVP', 'ETR', 'RPV'], dtype='object')

In [98]:
def x_equals_y(y_test):
    """
    A function that returns a range from minimum to maximum of y_test.
    
    Used below in the plotting below.
    """
    floor = math.floor(min(y_test))
    ceil = math.ceil(max(y_test))
    x_eq_y = range(floor, ceil)
    return x_eq_y

TWOPLACES = Decimal(10) ** -2

In [170]:
colnum = 3

drug_df = pd.DataFrame()
drug_df[drug_cols[colnum]] = data[drug_cols[colnum]]
drug_df[nonconserved_cols] = data[nonconserved_cols]
for col in nonconserved_cols:
    drug_df[col] = drug_df[col].apply(lambda x: np.nan if len(x) > 1 else x)
drug_df.dropna(inplace=True)

drug_X = drug_df[nonconserved_cols]
drug_Y = drug_df[drug_cols[colnum]].apply(lambda x:np.log(x))
# drug_Y.values

In [171]:
from isoelectric_point import isoelectric_points
from molecular_weight import molecular_weights

# Standardize pI matrix. 7 is neutral
drug_X_pi = drug_X.replace(isoelectric_points)

# Standardize MW matrix.
drug_X_mw = drug_X.replace(molecular_weights)

# Binarize drug_X matrix.
from sklearn.preprocessing import LabelBinarizer
drug_X_bi = pd.DataFrame()
binarizers = dict()

for col in drug_X.columns:
    lb = LabelBinarizer()
    binarized_cols = lb.fit_transform(drug_X[col])
    # print(binarized_cols)
    if len(lb.classes_) == 2:
        # print(binarized_cols)
        drug_X_bi[col] = pd.Series(binarized_cols[:,0])
    else:
        for i, c in enumerate(lb.classes_):
            # print(col + c)
            # print(binarized_cols[:,i])

            drug_X_bi[col + '_' + c] = binarized_cols[:,i]
        
drug_X_bi


Out[171]:
P1_P P2_I P3_S P4 P5_I P6 P7 P8_V P9_P P10_V ... P230 P232_Y P233_E P234_L P235_H P236_P P237 P238_K P239_W P240
0 0 0 0 0 0 1 1 0 0 0 ... 1 0 0 0 0 0 0 0 0 1
1 0 0 0 0 0 1 1 0 0 0 ... 1 0 0 0 0 0 0 0 0 1
2 0 0 0 0 0 1 1 0 0 0 ... 1 0 0 0 0 0 0 0 0 1
3 0 0 0 0 0 1 0 0 0 0 ... 1 0 0 0 0 0 0 0 0 1
4 0 0 0 0 0 1 1 0 0 0 ... 1 0 0 0 0 0 0 0 0 1
5 0 0 0 0 0 1 1 0 0 0 ... 1 0 0 0 0 0 0 0 0 1
6 0 0 0 0 0 1 1 0 0 0 ... 1 0 0 0 0 0 0 0 0 1
7 0 0 0 0 0 1 1 0 0 0 ... 1 0 0 0 0 0 1 0 0 1
8 0 0 0 0 0 0 1 0 0 0 ... 1 0 0 0 0 0 0 0 0 1
9 0 0 0 0 0 1 1 0 0 0 ... 1 0 0 0 0 0 0 0 0 1
10 0 0 0 0 0 1 1 0 0 0 ... 1 0 0 0 0 0 0 0 0 1
11 0 0 0 0 0 1 1 0 0 0 ... 1 0 0 0 0 0 0 0 0 1
12 0 0 0 0 0 1 1 0 0 0 ... 1 0 0 0 0 0 0 0 0 1
13 0 0 0 0 0 1 1 0 0 0 ... 0 0 0 0 0 0 0 0 0 1
14 0 0 0 0 0 1 1 0 0 0 ... 1 0 0 0 0 0 0 0 0 1
15 0 0 0 0 0 1 1 0 0 0 ... 1 0 0 0 0 0 0 0 0 1
16 0 0 0 0 0 1 1 0 0 0 ... 0 0 0 0 0 0 0 0 0 0
17 0 0 0 0 0 1 1 0 0 0 ... 1 0 0 0 0 0 0 0 0 1
18 0 0 0 0 0 1 1 0 0 0 ... 1 0 0 0 0 0 0 0 0 1
19 0 0 0 0 0 1 1 0 0 0 ... 1 0 0 0 0 0 0 0 0 1
20 0 0 0 0 0 1 1 0 0 0 ... 1 0 0 0 0 0 0 0 0 1
21 0 0 0 0 0 1 1 0 0 0 ... 0 0 0 0 0 0 0 0 0 1
22 0 0 0 0 0 1 1 0 0 0 ... 1 0 0 0 0 0 0 0 0 1
23 0 0 0 0 0 1 1 0 0 0 ... 1 0 0 0 0 0 0 0 0 1
24 0 0 0 0 0 1 1 0 0 0 ... 1 0 0 0 0 0 0 0 0 1
25 0 0 0 0 0 1 1 0 0 0 ... 1 0 0 0 0 0 0 0 0 1
26 0 0 0 0 0 1 1 0 0 0 ... 1 0 0 0 0 0 0 0 0 1
27 0 0 0 0 0 1 1 0 0 0 ... 1 0 0 0 0 0 0 0 0 1
28 0 0 0 0 0 1 1 0 0 0 ... 1 0 0 0 0 0 0 0 0 1
29 0 0 0 0 0 1 1 0 0 0 ... 1 0 0 0 0 0 0 0 0 1
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
51 0 0 0 0 0 1 1 0 0 0 ... 1 0 0 0 0 0 0 0 0 1
52 0 0 0 0 0 1 1 0 0 0 ... 1 0 0 0 0 0 0 0 0 1
53 0 0 0 0 0 1 1 0 0 0 ... 1 0 0 0 0 0 0 0 0 1
54 0 0 0 0 0 1 1 0 0 0 ... 1 0 0 0 0 0 0 0 0 1
55 0 0 0 0 0 1 1 0 0 0 ... 1 0 0 0 0 0 0 0 0 1
56 0 0 0 0 0 1 1 0 0 0 ... 1 0 0 0 0 0 0 0 0 1
57 0 0 0 0 0 1 1 0 0 0 ... 1 0 0 0 0 0 0 0 0 1
58 0 0 0 1 0 0 1 0 0 0 ... 1 0 0 0 0 0 0 0 0 1
59 0 0 0 0 0 1 1 0 0 0 ... 1 0 0 0 0 0 0 0 0 1
60 0 0 0 0 0 1 1 0 0 0 ... 1 0 0 0 0 0 0 0 0 1
61 0 0 0 0 0 0 1 0 0 0 ... 1 0 0 0 0 0 0 0 0 1
62 0 0 0 0 0 1 1 0 0 0 ... 0 0 0 0 0 0 0 0 0 1
63 0 0 0 0 0 1 1 0 0 0 ... 1 0 0 0 0 0 0 0 0 1
64 0 0 0 0 0 1 1 0 0 0 ... 1 0 0 0 0 0 0 0 0 1
65 0 0 0 0 0 1 1 0 0 0 ... 1 0 0 0 0 0 0 0 0 1
66 0 0 0 0 0 1 1 0 0 0 ... 1 0 0 0 0 0 0 0 0 1
67 0 0 0 0 0 1 1 0 0 0 ... 1 0 0 0 0 0 0 0 0 1
68 0 0 0 0 0 1 1 0 0 0 ... 1 0 0 0 0 0 0 0 0 1
69 0 0 0 0 0 1 1 0 0 0 ... 1 0 0 0 0 0 0 0 0 1
70 0 0 0 0 0 1 1 0 0 0 ... 1 0 0 0 0 0 0 0 0 1
71 0 0 0 0 0 1 1 0 0 0 ... 1 0 0 0 0 0 0 0 0 1
72 0 0 0 0 0 1 1 0 0 0 ... 1 0 0 0 0 0 0 0 0 1
73 0 0 0 0 0 1 1 0 0 0 ... 1 0 0 0 0 0 0 0 0 1
74 0 0 0 0 0 1 1 0 0 0 ... 1 0 0 0 0 0 0 0 0 1
75 0 0 0 0 0 1 1 0 0 0 ... 1 0 0 0 0 0 0 0 0 1
76 0 0 0 0 0 1 1 0 0 0 ... 1 0 0 0 0 0 0 0 0 1
77 0 0 0 0 0 1 1 0 0 0 ... 1 0 0 0 0 0 0 0 0 1
78 0 0 0 0 0 1 1 0 0 0 ... 1 0 0 0 0 0 0 0 0 1
79 0 0 0 0 0 1 1 0 0 0 ... 1 0 0 0 0 0 0 0 0 1
80 0 0 0 0 0 1 1 0 0 0 ... 1 0 0 0 0 0 0 0 0 1

81 rows × 301 columns


In [172]:
fig = plt.figure(figsize=(3,3))
drug_Y.hist(grid=False)
plt.xlabel('Value')
plt.ylabel('Count')
plt.title('{0} Distribution'.format(drug_cols[colnum]))


Out[172]:
<matplotlib.text.Text at 0x7f4b4a6a2eb8>

In [173]:
# Here, let's try the Random Forest Regressor. This will be the baseline.
x_train, x_test, y_train, y_test = train_test_split(drug_X_bi, drug_Y)

rfr = RandomForestRegressor(n_estimators=500, n_jobs=-1, oob_score=True)

rfr.fit(x_train, y_train)
rfr_preds = rfr.predict(x_test)
print(rfr.score(x_test, y_test), mean_squared_error(rfr_preds, y_test))
rfr_mse = mean_squared_error(rfr_preds, y_test)
# print(rfr.oob_score_)
sps.pearsonr(rfr_preds, y_test)

plt.figure(figsize=(3,3))
plt.scatter(y_test, rfr_preds,)
plt.title('{0} Random Forest'.format(drug_cols[colnum]))
plt.xlabel('Actual')
plt.ylabel('Predicted')
plt.gca().set_aspect('equal', 'datalim')
plt.annotate(s='mse: {0}'.format(str(Decimal(rfr_mse).quantize(TWOPLACES))), xy=(1,0), xycoords='axes fraction', ha='right', va='bottom')
plt.plot(x_equals_y(y_test), x_equals_y(y_test), color='red')
plt.show()


-0.0326395088037 0.997215655263

In [174]:
plt.bar(range(len(rfr.feature_importances_)), rfr.feature_importances_)
plt.xlabel('Position')
plt.ylabel('Relative Importance')
plt.title('{0} Random Forest'.format(drug_cols[colnum]))


Out[174]:
<matplotlib.text.Text at 0x7f4b481912e8>

In [175]:
# Get back the importance of each feature.
feat_impt = [(p, i) for p, i in zip(drug_X_bi.columns, rfr.feature_importances_)]
sorted(feat_impt, key=lambda x:x[1], reverse=True)


Out[175]:
[('P181_I', 0.12515626224281362),
 ('P67_G', 0.10834144812999441),
 ('P230', 0.10333738102254802),
 ('P188', 0.043665958219112858),
 ('P100', 0.034264079953305784),
 ('P41', 0.033258661397242165),
 ('P103_N', 0.032187303067942478),
 ('P215_Y', 0.028652404593970164),
 ('P35_I', 0.02808556779219902),
 ('P138_E', 0.02198352996815621),
 ('P181_Y', 0.01983358547595742),
 ('P184_V', 0.019073040109787217),
 ('P179_V', 0.017454897104972032),
 ('P179_D', 0.012491730797141909),
 ('P98_A', 0.012049857441204017),
 ('P207_E', 0.011957507913412723),
 ('P195_T', 0.011288385576928458),
 ('P67_D', 0.010728785431529882),
 ('P122_E', 0.010684235575673948),
 ('P103_K', 0.010345663396107365),
 ('P135_I', 0.009952372670047744),
 ('P179_I', 0.0097610503276108571),
 ('P103_R', 0.0097095949360301755),
 ('P101_K', 0.0087772262674080028),
 ('P43_E', 0.0084231987482012837),
 ('P214', 0.0076189049074741003),
 ('P74_V', 0.0073734665064819358),
 ('P135_L', 0.0072737617119804391),
 ('P101_E', 0.0070969933961763081),
 ('P211_R', 0.0070795751946143249),
 ('P211_K', 0.0070746055832158431),
 ('P219_E', 0.0066887974714789434),
 ('P102', 0.0064385755967013378),
 ('P123_N', 0.0062315572471197895),
 ('P122_K', 0.0054058040361537035),
 ('P138_K', 0.0053633114651937082),
 ('P123_E', 0.0053012136176515323),
 ('P177_E', 0.0047038751739025171),
 ('P210_L', 0.0046705057085067395),
 ('P123_G', 0.0046144170257456296),
 ('P221', 0.004593564787154252),
 ('P184_M', 0.0045195353383546575),
 ('P228_H', 0.0044447035262805802),
 ('P123_D', 0.0043774748819879715),
 ('P142_I', 0.0042898180130194231),
 ('P35_V', 0.0042087759292818664),
 ('P74_L', 0.0040257365668179973),
 ('P135_V', 0.0038112149898448331),
 ('P20', 0.0037140094426395638),
 ('P60_A', 0.0035510594080978186),
 ('P60_V', 0.00333818793996337),
 ('P70', 0.0032934178016638723),
 ('P215_T', 0.0032101000614233803),
 ('P190_G', 0.0031984322276173158),
 ('P202', 0.0031684465671659974),
 ('P39', 0.0030841521810167564),
 ('P219_Q', 0.0030639040076665394),
 ('P181_C', 0.003029332174913241),
 ('P215_D', 0.0029472944930296783),
 ('P142_V', 0.0029202352775381623),
 ('P207_A', 0.0029079177980571688),
 ('P68', 0.0028332231667971661),
 ('P207_Q', 0.002749845108264684),
 ('P162_S', 0.0026702275244181732),
 ('P64_N', 0.0026610037579367645),
 ('P196', 0.0026581024324151808),
 ('P43_K', 0.0026350249639297278),
 ('P200_T', 0.002553049308552499),
 ('P210_W', 0.0025026765700731108),
 ('P118', 0.0024518223525403101),
 ('P90', 0.0024267467978901446),
 ('P219_K', 0.002376178278273943),
 ('P177_N', 0.0022034152933523956),
 ('P101_Q', 0.0021924345075517234),
 ('P177_D', 0.0021504738197370759),
 ('P228_L', 0.0021070868528241973),
 ('P178_L', 0.0020852415987475258),
 ('P190_A', 0.0020608851822308755),
 ('P225', 0.0019019387712490539),
 ('P138_G', 0.0018663749550823908),
 ('P211_N', 0.0015566439722936778),
 ('P74_I', 0.0015458467058144579),
 ('P184_I', 0.0015036854209839621),
 ('P240', 0.0014876415036924435),
 ('P200_A', 0.0014785415452070847),
 ('P215_F', 0.0014445595229011004),
 ('P106_V', 0.0014352149395090079),
 ('P162_C', 0.001398856715483846),
 ('P98_S', 0.0013860935618076927),
 ('P60_I', 0.0013462407429566342),
 ('P98_G', 0.0013006878390810427),
 ('P35_T', 0.0012423853049942),
 ('P227', 0.0011903834925342729),
 ('P35_L', 0.0011741673069497346),
 ('P192', 0.001143039849028499),
 ('P211_T', 0.0011145572925410395),
 ('P165_T', 0.0010870443187673414),
 ('P162_Y', 0.001078912772970125),
 ('P195_I', 0.0010306803853387413),
 ('P69_T', 0.00098654634064042107),
 ('P101_P', 0.00096750264906824074),
 ('P67_N', 0.00095948776640328185),
 ('P218', 0.00086484327814864573),
 ('P69_N', 0.00083354981503933534),
 ('P166', 0.00082871944869750244),
 ('P138_R', 0.00079582185516189079),
 ('P64_K', 0.00078298463458083343),
 ('P64_H', 0.00077564921540242127),
 ('P173', 0.00076267208371457698),
 ('P165_I', 0.00068431400076820125),
 ('P211_A', 0.00064004344129215145),
 ('P103_S', 0.00056655774539188404),
 ('P210_F', 0.00051988662161328049),
 ('P121', 0.00050682881116857054),
 ('P138_Q', 0.000503198973034824),
 ('P228_R', 0.00050191556235062193),
 ('P106_A', 0.00044672767714506057),
 ('P106_I', 0.00044079303705608453),
 ('P174', 0.00041855325130677428),
 ('P198', 0.00034485629962465836),
 ('P207_K', 0.00033302736928496324),
 ('P49', 0.00031699528154714463),
 ('P179_F', 0.00031339196315206716),
 ('P135_T', 0.00031219171039836315),
 ('P122_Q', 0.00030252464215748109),
 ('P82', 0.00028992038832000331),
 ('P178_I', 0.00027207811854838506),
 ('P138_A', 0.00025705873583400934),
 ('P176', 0.00023257023097364313),
 ('P207_D', 0.00023081744467714582),
 ('P215_E', 0.00022970855791083663),
 ('P135_M', 0.00022698925765269386),
 ('P237', 0.00020830930328917592),
 ('P204_E', 0.00018006393045651517),
 ('P80', 0.00017982258429690857),
 ('P108', 0.00017804435285747935),
 ('P6', 0.00016862850720856306),
 ('P43_R', 0.00014590104489037572),
 ('P197', 0.00013771159217070548),
 ('P122_P', 0.00013225902414368461),
 ('P204_Q', 0.00012963434340721594),
 ('P200_I', 0.00012875722156110474),
 ('P195_L', 0.0001225817979781304),
 ('P64_Y', 0.00011497155581516753),
 ('P178_M', 0.0001041206484110235),
 ('P35_M', 9.4574602756123733e-05),
 ('P142_T', 7.8949302585763607e-05),
 ('P28', 6.759517417808197e-05),
 ('P83', 6.6591835798483035e-05),
 ('P211_S', 4.6261360503589671e-05),
 ('P169', 3.6468480649915733e-05),
 ('P189', 3.2871185323285835e-05),
 ('P31', 3.0910196972604575e-05),
 ('P35_R', 1.0797965754040618e-05),
 ('P165_L', 1.067082457928276e-05),
 ('P162_D', 9.0406616431302377e-06),
 ('P7', 8.3956447528148519e-06),
 ('P1_P', 0.0),
 ('P2_I', 0.0),
 ('P3_S', 0.0),
 ('P4', 0.0),
 ('P5_I', 0.0),
 ('P8_V', 0.0),
 ('P9_P', 0.0),
 ('P10_V', 0.0),
 ('P11_K', 0.0),
 ('P13_K', 0.0),
 ('P14_P', 0.0),
 ('P15_G', 0.0),
 ('P16_M', 0.0),
 ('P18_G', 0.0),
 ('P19_P', 0.0),
 ('P21_V', 0.0),
 ('P22', 0.0),
 ('P23_Q', 0.0),
 ('P24_W', 0.0),
 ('P27_T', 0.0),
 ('P29_E', 0.0),
 ('P30_K', 0.0),
 ('P32_K', 0.0),
 ('P33_A', 0.0),
 ('P34_L', 0.0),
 ('P35_E', 0.0),
 ('P35_K', 0.0),
 ('P36_E', 0.0),
 ('P37_I', 0.0),
 ('P38_C', 0.0),
 ('P40', 0.0),
 ('P44_E', 0.0),
 ('P45_G', 0.0),
 ('P46_K', 0.0),
 ('P47_I', 0.0),
 ('P48', 0.0),
 ('P50_I', 0.0),
 ('P52_P', 0.0),
 ('P53_E', 0.0),
 ('P54_N', 0.0),
 ('P56_Y', 0.0),
 ('P57_N', 0.0),
 ('P58_T', 0.0),
 ('P59_P', 0.0),
 ('P62', 0.0),
 ('P63_I', 0.0),
 ('P65_K', 0.0),
 ('P66_K', 0.0),
 ('P69_D', 0.0),
 ('P72_R', 0.0),
 ('P73_K', 0.0),
 ('P75', 0.0),
 ('P77_F', 0.0),
 ('P79_E', 0.0),
 ('P84_T', 0.0),
 ('P85_Q', 0.0),
 ('P86_D', 0.0),
 ('P87_F', 0.0),
 ('P88_W', 0.0),
 ('P89_E', 0.0),
 ('P91_Q', 0.0),
 ('P93_G', 0.0),
 ('P94_I', 0.0),
 ('P95_P', 0.0),
 ('P96_H', 0.0),
 ('P101_H', 0.0),
 ('P101_R', 0.0),
 ('P104', 0.0),
 ('P105_S', 0.0),
 ('P107_T', 0.0),
 ('P109_L', 0.0),
 ('P110_D', 0.0),
 ('P111_V', 0.0),
 ('P113_D', 0.0),
 ('P115_Y', 0.0),
 ('P116', 0.0),
 ('P119_P', 0.0),
 ('P124_F', 0.0),
 ('P125_R', 0.0),
 ('P126_K', 0.0),
 ('P127_Y', 0.0),
 ('P128_T', 0.0),
 ('P130_F', 0.0),
 ('P131_T', 0.0),
 ('P132_I', 0.0),
 ('P134_S', 0.0),
 ('P136_N', 0.0),
 ('P137_N', 0.0),
 ('P139_T', 0.0),
 ('P140_P', 0.0),
 ('P141_G', 0.0),
 ('P143_R', 0.0),
 ('P144_Y', 0.0),
 ('P145_Q', 0.0),
 ('P146_Y', 0.0),
 ('P147_N', 0.0),
 ('P151', 0.0),
 ('P154_K', 0.0),
 ('P155_G', 0.0),
 ('P157_P', 0.0),
 ('P158_A', 0.0),
 ('P159_I', 0.0),
 ('P161_Q', 0.0),
 ('P163_S', 0.0),
 ('P164_M', 0.0),
 ('P167_I', 0.0),
 ('P168_L', 0.0),
 ('P170_P', 0.0),
 ('P171_F', 0.0),
 ('P172_R', 0.0),
 ('P175_N', 0.0),
 ('P180_I', 0.0),
 ('P182_Q', 0.0),
 ('P187_L', 0.0),
 ('P190_E', 0.0),
 ('P190_S', 0.0),
 ('P191_S', 0.0),
 ('P193_L', 0.0),
 ('P194_E', 0.0),
 ('P199_R', 0.0),
 ('P200_E', 0.0),
 ('P201_K', 0.0),
 ('P203_E', 0.0),
 ('P204_N', 0.0),
 ('P205_L', 0.0),
 ('P206_R', 0.0),
 ('P208', 0.0),
 ('P209_L', 0.0),
 ('P212_W', 0.0),
 ('P216_T', 0.0),
 ('P217_P', 0.0),
 ('P219_N', 0.0),
 ('P222_Q', 0.0),
 ('P223_K', 0.0),
 ('P224_E', 0.0),
 ('P226_P', 0.0),
 ('P229_W', 0.0),
 ('P232_Y', 0.0),
 ('P233_E', 0.0),
 ('P234_L', 0.0),
 ('P235_H', 0.0),
 ('P236_P', 0.0),
 ('P238_K', 0.0),
 ('P239_W', 0.0)]

In [176]:
# # Here, let's try a parameter grid search, to figure out what would be the best 
# from sklearn.grid_search import GridSearchCV
# import numpy as np

# param_grid = [{'n_estimators':[100, 500, 1000],
#                #'max_features':['auto', 'sqrt', 'log2'],
#                #'min_samples_leaf':np.arange(1,20,1),
#               }]

# x_train, x_test, y_train, y_test = train_test_split(fpv_X_bi, fpv_Y)


# rfr_gs = GridSearchCV(RandomForestRegressor(), param_grid=param_grid, n_jobs=-1)
# rfr_gs.fit(x_train, y_train)
# print(rfr_gs.best_estimator_)
# print(rfr_gs.best_params_)

In [177]:
# Try Bayesian Ridge Regression
# x_train, x_test, y_train, y_test = train_test_split(drug_X_bi, drug_Y)

brr = lm.BayesianRidge()
brr.fit(x_train, y_train)
brr_preds = brr.predict(x_test)
print(brr.score(x_test, y_test), mean_squared_error(brr_preds, y_test))
print(sps.pearsonr(brr_preds, y_test))
brr_mse = mean_squared_error(brr_preds, y_test)

plt.figure(figsize=(3,3))
plt.scatter(y_test, brr_preds)
plt.xlabel('Actual')
plt.ylabel('Predicted')
plt.title('{0} Bayesian Ridge'.format(drug_cols[colnum]))
plt.gca().set_aspect('equal', 'datalim')
plt.annotate(s='mse: {0}'.format(str(Decimal(brr_mse).quantize(TWOPLACES))), xy=(1,0), xycoords='axes fraction', ha='right', va='bottom')
plt.plot(x_equals_y(y_test), x_equals_y(y_test), color='red')
plt.show()


0.127821841659 0.842258800154
(0.43881339541178765, 0.046588693002101465)

In [178]:
# Try ARD regression
ardr = lm.ARDRegression()
ardr.fit(x_train, y_train)
ardr_preds = ardr.predict(x_test)
ardr_mse = mean_squared_error(ardr_preds, y_test)

plt.figure(figsize=(3,3))
plt.scatter(y_test, ardr_preds)
plt.xlabel('Actual')
plt.ylabel('Predicted')
plt.title('{0} ARD Regression'.format(drug_cols[colnum]))
plt.gca().set_aspect('equal', 'datalim')
plt.annotate(s='mse: {0}'.format(str(Decimal(ardr_mse).quantize(TWOPLACES))), xy=(1,0), xycoords='axes fraction', ha='right', va='bottom')
plt.plot(x_equals_y(y_test), x_equals_y(y_test), color='red')
plt.show()



In [179]:
# Try Gradient Boost
# x_train, x_test, y_train, y_test = train_test_split(drug_X_bi, drug_Y)

gbr = GradientBoostingRegressor()
gbr.fit(x_train, y_train)
gbr_preds = gbr.predict(x_test)
print(gbr.score(x_test, y_test), mean_squared_error(gbr_preds, y_test))
print(sps.pearsonr(gbr_preds, y_test))
gbr_mse = mean_squared_error(gbr_preds, y_test)

plt.figure(figsize=(3,3))
plt.scatter(y_test, gbr_preds)
plt.xlabel('Actual')
plt.ylabel('Predicted')
plt.title('{0} Grad. Boost'.format(drug_cols[colnum]))
plt.gca().set_aspect('equal', 'datalim')
plt.annotate(s='mse: {0}'.format(str(Decimal(gbr_mse).quantize(TWOPLACES))), xy=(1,0), xycoords='axes fraction', ha='right', va='bottom')
plt.plot(x_equals_y(y_test), x_equals_y(y_test), color='red')
plt.show()


0.0103796304678 0.95567225237
(0.39743565271055858, 0.074410522130373941)

In [180]:
plt.bar(range(len(gbr.feature_importances_)), gbr.feature_importances_)


Out[180]:
<Container object of 301 artists>

In [181]:
# Try AdaBoost
# x_train, x_test, y_train, y_test = train_test_split(drug_X_bi, drug_Y)

abr = AdaBoostRegressor()
abr.fit(x_train, y_train)
abr_preds = abr.predict(x_test)
print(abr.score(x_test, y_test), mean_squared_error(abr_preds, y_test))
print(sps.pearsonr(abr_preds, y_test))
abr_mse = mean_squared_error(abr_preds, y_test)

plt.figure(figsize=(3,3))
plt.scatter(x=y_test, y=abr_preds)
plt.xlabel('Actual')
plt.ylabel('Predicted')
plt.title('{0} AdaBoost'.format(drug_cols[colnum]))
plt.gca().set_aspect('equal', 'datalim')
plt.annotate(s='mse: {0}'.format(str(Decimal(abr_mse).quantize(TWOPLACES))), xy=(1,0), xycoords='axes fraction', ha='right', va='bottom')
plt.plot(x_equals_y(y_test), x_equals_y(y_test), color='red')
plt.show()


-0.141282096696 1.10213134806
(0.025866145510122118, 0.91138362307506238)

In [182]:
plt.bar(range(len(abr.feature_importances_)), abr.feature_importances_)


Out[182]:
<Container object of 301 artists>

In [183]:
# Try support vector regression
svr = SVR()
svr.fit(x_train, y_train)
svr_preds = svr.predict(x_test)

svr_mse = mean_squared_error(svr_preds, y_test)

plt.figure(figsize=(3,3))
plt.scatter(y_test, svr_preds, )
plt.xlabel('Actual')
plt.ylabel('Predicted')
plt.title('{0} SVR'.format(drug_cols[colnum]))
plt.gca().set_aspect('equal', 'datalim')
plt.annotate(s='mse: {0}'.format(str(Decimal(svr_mse).quantize(TWOPLACES))), xy=(1,0), xycoords='axes fraction', ha='right', va='bottom')
plt.plot(x_equals_y(y_test), x_equals_y(y_test), color='red')
plt.show()



In [184]:
# Neural Network 1 Specification: Feed Forward ANN with 1 hidden layer.
# x_train, x_test, y_train, y_test = train_test_split(drug_X_bi, drug_Y)

x_train = x_train.astype(np.float32)
y_train = y_train.astype(np.float32)
x_test = x_test.astype(np.float32)
y_test = y_test.astype(np.float32)

net1 = NeuralNet(
    layers=[  # three layers: one hidden layer
        ('input', layers.InputLayer),
        ('hidden1', layers.DenseLayer),
        ('dropout1', layers.DropoutLayer),
        #('hidden2', layers.DenseLayer),
        #('dropout2', layers.DropoutLayer),
        ('nonlinear', layers.NonlinearityLayer),
        ('output', layers.DenseLayer),
        ],
    # layer parameters:
    input_shape=(None, x_train.shape[1]),  # 
    hidden1_num_units=math.ceil(x_train.shape[1] / 2),  # number of units in hidden layer
    hidden1_nonlinearity=nonlinearities.tanh,
    dropout1_p = 0.5,
    #hidden2_num_units=math.ceil(x_train.shape[1] / 2),
    #dropout2_p = 0.5,
    output_nonlinearity=None,  # output layer uses identity function
    output_num_units=1,  # 30 target values
    
    # optimization method:
    update=nesterov_momentum,
    update_learning_rate=0.01,
    update_momentum=0.95,

    regression=True,  # flag to indicate we're dealing with regression problem
    max_epochs=500,  # we want to train this many epochs
    verbose=1,
    )
net1.fit(x_train.values, y_train.values)


# Neural Network with 45754 learnable parameters

## Layer information

  #  name         size
---  ---------  ------
  0  input         301
  1  hidden1       151
  2  dropout1      151
  3  nonlinear     151
  4  output          1

  epoch    train loss    valid loss    train/val  dur
-------  ------------  ------------  -----------  -----
      1       3.09138       1.87169      1.65165  0.00s
      2       2.23954       2.15974      1.03695  0.00s
      3       2.21964       1.49758      1.48215  0.00s
      4       2.05813       1.13670      1.81062  0.00s
      5       2.11053       1.05192      2.00636  0.00s
      6       1.89762       0.93186      2.03637  0.00s
      7       1.51066       1.14422      1.32025  0.00s
      8       1.66632       1.12656      1.47912  0.00s
      9       1.38603       0.90850      1.52563  0.00s
     10       1.59926       0.80087      1.99690  0.00s
     11       1.42572       0.74973      1.90164  0.00s
     12       1.45986       0.68320      2.13679  0.00s
     13       1.15931       0.84708      1.36858  0.00s
     14       1.28390       0.74891      1.71436  0.00s
     15       1.16442       0.60978      1.90957  0.00s
     16       0.91826       0.53520      1.71576  0.00s
     17       1.03375       0.52452      1.97086  0.00s
     18       0.88357       0.55902      1.58055  0.00s
     19       1.03875       0.51289      2.02530  0.00s
     20       0.55124       0.44621      1.23538  0.00s
     21       0.81158       0.40679      1.99511  0.00s
     22       0.74313       0.52161      1.42470  0.00s
     23       0.52532       0.47521      1.10546  0.00s
     24       0.57771       0.49559      1.16570  0.00s
     25       0.60192       0.55973      1.07538  0.00s
     26       0.65065       0.60645      1.07288  0.00s
     27       0.51811       0.55793      0.92862  0.00s
     28       0.69553       0.62857      1.10653  0.00s
     29       0.60919       0.56841      1.07174  0.00s
     30       0.51942       0.64703      0.80278  0.00s
     31       0.46033       0.56315      0.81741  0.00s
     32       0.53810       0.50626      1.06289  0.00s
     33       0.59144       0.49816      1.18724  0.00s
     34       0.42397       0.53463      0.79302  0.00s
     35       0.47340       0.58048      0.81553  0.00s
     36       0.44709       0.62009      0.72101  0.00s
     37       0.44645       0.52953      0.84312  0.00s
     38       0.40098       0.60618      0.66149  0.00s
     39       0.46663       0.46865      0.99569  0.00s
     40       0.51329       0.46833      1.09599  0.00s
     41       0.41557       0.50491      0.82306  0.00s
     42       0.32113       0.56337      0.57001  0.00s
     43       0.49427       0.49019      1.00833  0.00s
     44       0.50834       0.52046      0.97672  0.00s
     45       0.46611       0.47508      0.98111  0.00s
     46       0.23285       0.49051      0.47471  0.00s
     47       0.39066       0.55530      0.70351  0.00s
     48       0.34065       0.61602      0.55299  0.00s
     49       0.35575       0.68071      0.52262  0.00s
     50       0.46853       0.54531      0.85920  0.00s
     51       0.32417       0.53988      0.60045  0.00s
     52       0.35954       0.54593      0.65858  0.00s
     53       0.29628       0.59466      0.49823  0.00s
     54       0.39856       0.53345      0.74714  0.00s
     55       0.30403       0.56129      0.54166  0.00s
     56       0.33263       0.50355      0.66058  0.00s
     57       0.48097       0.57542      0.83587  0.00s
     58       0.23384       0.54406      0.42981  0.00s
     59       0.36108       0.51444      0.70188  0.00s
     60       0.20788       0.46661      0.44552  0.00s
     61       0.28687       0.49280      0.58212  0.00s
     62       0.24756       0.51272      0.48285  0.00s
     63       0.34467       0.47157      0.73088  0.00s
     64       0.25737       0.46671      0.55146  0.00s
     65       0.22398       0.47088      0.47566  0.00s
     66       0.27373       0.52378      0.52261  0.00s
     67       0.30621       0.50902      0.60156  0.00s
     68       0.32212       0.54688      0.58900  0.00s
     69       0.22235       0.57572      0.38621  0.00s
     70       0.30402       0.54859      0.55418  0.00s
     71       0.32921       0.57469      0.57285  0.00s
     72       0.43078       0.59750      0.72097  0.00s
     73       0.30278       0.41999      0.72094  0.00s
     74       0.22134       0.44486      0.49754  0.00s
     75       0.21523       0.47938      0.44897  0.00s
     76       0.24382       0.54551      0.44696  0.00s
     77       0.24298       0.46924      0.51782  0.00s
     78       0.25483       0.42821      0.59510  0.00s
     79       0.23100       0.44146      0.52325  0.00s
     80       0.27043       0.48716      0.55512  0.00s
     81       0.26689       0.40340      0.66161  0.00s
     82       0.23394       0.38443      0.60855  0.00s
     83       0.25832       0.41631      0.62050  0.00s
     84       0.27356       0.42058      0.65043  0.00s
     85       0.28602       0.43938      0.65097  0.00s
     86       0.42921       0.49612      0.86514  0.00s
     87       0.22773       0.43381      0.52496  0.00s
     88       0.33308       0.44382      0.75048  0.00s
     89       0.20340       0.44419      0.45790  0.00s
     90       0.33828       0.47155      0.71737  0.00s
     91       0.21919       0.52556      0.41706  0.00s
     92       0.25578       0.60145      0.42527  0.00s
     93       0.15603       0.59955      0.26024  0.00s
     94       0.29586       0.55079      0.53715  0.00s
     95       0.16876       0.51785      0.32588  0.00s
     96       0.19347       0.50411      0.38377  0.00s
     97       0.19466       0.51190      0.38027  0.00s
     98       0.20521       0.56051      0.36612  0.00s
     99       0.22197       0.55149      0.40249  0.00s
    100       0.25016       0.53098      0.47113  0.00s
    101       0.24881       0.52797      0.47126  0.00s
    102       0.10956       0.54175      0.20224  0.00s
    103       0.23224       0.49617      0.46805  0.00s
    104       0.19941       0.50594      0.39413  0.00s
    105       0.18224       0.48829      0.37323  0.00s
    106       0.15698       0.51548      0.30452  0.00s
    107       0.29843       0.51177      0.58313  0.00s
    108       0.27643       0.52835      0.52320  0.00s
    109       0.19681       0.49970      0.39385  0.00s
    110       0.21056       0.42282      0.49798  0.00s
    111       0.21823       0.38774      0.56284  0.00s
    112       0.26796       0.39494      0.67848  0.00s
    113       0.19231       0.41157      0.46726  0.00s
    114       0.27135       0.41391      0.65556  0.00s
    115       0.29406       0.34087      0.86267  0.00s
    116       0.16451       0.35832      0.45911  0.00s
    117       0.23582       0.37801      0.62385  0.00s
    118       0.10371       0.44939      0.23078  0.00s
    119       0.19807       0.45647      0.43393  0.00s
    120       0.22335       0.45223      0.49389  0.00s
    121       0.12276       0.46000      0.26687  0.00s
    122       0.23238       0.41623      0.55830  0.00s
    123       0.24280       0.41475      0.58541  0.00s
    124       0.17967       0.45277      0.39682  0.00s
    125       0.24324       0.43866      0.55451  0.00s
    126       0.22363       0.42584      0.52515  0.00s
    127       0.32865       0.42247      0.77792  0.00s
    128       0.15613       0.44600      0.35007  0.00s
    129       0.14728       0.48491      0.30373  0.00s
    130       0.16325       0.49518      0.32968  0.00s
    131       0.24730       0.52127      0.47443  0.00s
    132       0.18786       0.44274      0.42432  0.00s
    133       0.22262       0.39334      0.56598  0.00s
    134       0.18142       0.42768      0.42420  0.00s
    135       0.23724       0.47547      0.49896  0.00s
    136       0.24842       0.43624      0.56946  0.00s
    137       0.18097       0.40796      0.44361  0.00s
    138       0.16755       0.37406      0.44793  0.00s
    139       0.18657       0.34842      0.53549  0.00s
    140       0.13599       0.36185      0.37584  0.00s
    141       0.13461       0.40427      0.33298  0.00s
    142       0.18162       0.34902      0.52037  0.00s
    143       0.31479       0.37467      0.84018  0.00s
    144       0.18973       0.35463      0.53499  0.00s
    145       0.25297       0.37509      0.67444  0.00s
    146       0.21293       0.40665      0.52361  0.00s
    147       0.26757       0.42913      0.62351  0.00s
    148       0.28121       0.41158      0.68323  0.00s
    149       0.14542       0.41373      0.35148  0.00s
    150       0.17847       0.42901      0.41600  0.00s
    151       0.20040       0.36735      0.54553  0.00s
    152       0.20503       0.36798      0.55719  0.00s
    153       0.22552       0.38445      0.58661  0.00s
    154       0.24421       0.45644      0.53504  0.00s
    155       0.18467       0.48839      0.37813  0.00s
    156       0.22159       0.43567      0.50862  0.00s
    157       0.14652       0.41941      0.34935  0.00s
    158       0.19862       0.39194      0.50676  0.00s
    159       0.14240       0.40723      0.34969  0.00s
    160       0.18289       0.41172      0.44422  0.00s
    161       0.17957       0.44499      0.40353  0.00s
    162       0.18945       0.44876      0.42216  0.00s
    163       0.18226       0.45907      0.39702  0.00s
    164       0.15764       0.48007      0.32836  0.00s
    165       0.22186       0.44761      0.49566  0.00s
    166       0.17152       0.45007      0.38109  0.00s
    167       0.20805       0.44765      0.46475  0.00s
    168       0.19255       0.45398      0.42414  0.00s
    169       0.25317       0.48876      0.51798  0.00s
    170       0.17812       0.44153      0.40343  0.00s
    171       0.12683       0.42089      0.30133  0.00s
    172       0.25779       0.39733      0.64881  0.00s
    173       0.19952       0.43371      0.46003  0.00s
    174       0.18193       0.46627      0.39019  0.00s
    175       0.13520       0.48848      0.27678  0.00s
    176       0.18287       0.50763      0.36024  0.00s
    177       0.22439       0.48094      0.46657  0.00s
    178       0.18779       0.45807      0.40997  0.00s
    179       0.24129       0.42817      0.56354  0.00s
    180       0.16144       0.41753      0.38665  0.00s
    181       0.25348       0.39531      0.64123  0.00s
    182       0.15385       0.38145      0.40333  0.00s
    183       0.16376       0.35135      0.46609  0.00s
    184       0.16853       0.41728      0.40387  0.00s
    185       0.22920       0.44394      0.51630  0.00s
    186       0.23342       0.39058      0.59762  0.00s
    187       0.16090       0.34413      0.46756  0.00s
    188       0.17448       0.36420      0.47908  0.00s
    189       0.18982       0.36391      0.52161  0.00s
    190       0.20931       0.38052      0.55007  0.00s
    191       0.18204       0.43381      0.41964  0.00s
    192       0.11557       0.43488      0.26575  0.00s
    193       0.12561       0.45786      0.27435  0.00s
    194       0.16443       0.42465      0.38721  0.00s
    195       0.15605       0.42623      0.36613  0.00s
    196       0.18383       0.46564      0.39479  0.00s
    197       0.21353       0.45860      0.46562  0.00s
    198       0.13469       0.41361      0.32564  0.00s
    199       0.17229       0.41449      0.41567  0.00s
    200       0.15719       0.46143      0.34065  0.00s
    201       0.24497       0.43210      0.56693  0.00s
    202       0.15500       0.42007      0.36899  0.00s
    203       0.16956       0.44163      0.38395  0.00s
    204       0.19798       0.43762      0.45241  0.00s
    205       0.18616       0.44630      0.41712  0.00s
    206       0.23306       0.47927      0.48628  0.00s
    207       0.22098       0.48083      0.45958  0.00s
    208       0.26807       0.46174      0.58056  0.00s
    209       0.12048       0.47331      0.25455  0.00s
    210       0.16628       0.44177      0.37640  0.00s
    211       0.14991       0.41311      0.36289  0.00s
    212       0.16667       0.41545      0.40117  0.00s
    213       0.20032       0.45987      0.43559  0.00s
    214       0.24823       0.55946      0.44370  0.00s
    215       0.15568       0.56430      0.27589  0.00s
    216       0.13624       0.52889      0.25760  0.00s
    217       0.16734       0.50172      0.33353  0.00s
    218       0.13921       0.49355      0.28206  0.00s
    219       0.13079       0.42782      0.30571  0.00s
    220       0.25422       0.38193      0.66563  0.00s
    221       0.22030       0.37029      0.59494  0.00s
    222       0.17243       0.38554      0.44725  0.00s
    223       0.11812       0.36707      0.32179  0.00s
    224       0.17460       0.35524      0.49149  0.00s
    225       0.19774       0.37418      0.52847  0.00s
    226       0.19729       0.41038      0.48074  0.00s
    227       0.21207       0.47035      0.45088  0.00s
    228       0.14493       0.46481      0.31181  0.00s
    229       0.22095       0.44409      0.49754  0.00s
    230       0.21775       0.41623      0.52316  0.00s
    231       0.14489       0.41214      0.35155  0.00s
    232       0.18268       0.42003      0.43491  0.00s
    233       0.14634       0.40093      0.36500  0.00s
    234       0.12449       0.39561      0.31467  0.00s
    235       0.22431       0.43649      0.51391  0.00s
    236       0.15988       0.50132      0.31892  0.00s
    237       0.14909       0.49459      0.30145  0.00s
    238       0.13146       0.43407      0.30286  0.00s
    239       0.11940       0.44099      0.27076  0.00s
    240       0.18423       0.42787      0.43059  0.00s
    241       0.12592       0.45630      0.27596  0.00s
    242       0.27872       0.38570      0.72263  0.00s
    243       0.09457       0.39544      0.23915  0.00s
    244       0.11120       0.33088      0.33606  0.00s
    245       0.20293       0.32361      0.62707  0.00s
    246       0.17574       0.38565      0.45569  0.00s
    247       0.24619       0.42528      0.57889  0.00s
    248       0.17196       0.42075      0.40871  0.00s
    249       0.15630       0.36443      0.42890  0.00s
    250       0.18169       0.33758      0.53821  0.00s
    251       0.20032       0.37805      0.52989  0.00s
    252       0.16481       0.38545      0.42758  0.00s
    253       0.11965       0.44973      0.26604  0.00s
    254       0.24357       0.44526      0.54704  0.00s
    255       0.25769       0.44115      0.58413  0.00s
    256       0.15654       0.43137      0.36289  0.00s
    257       0.14450       0.41287      0.34999  0.00s
    258       0.21421       0.42114      0.50864  0.00s
    259       0.21049       0.44605      0.47189  0.00s
    260       0.15106       0.45804      0.32979  0.00s
    261       0.12586       0.44596      0.28222  0.00s
    262       0.15501       0.46483      0.33348  0.00s
    263       0.10804       0.48741      0.22167  0.00s
    264       0.11454       0.46444      0.24662  0.00s
    265       0.30716       0.43669      0.70338  0.00s
    266       0.22119       0.40813      0.54197  0.00s
    267       0.14890       0.41495      0.35885  0.00s
    268       0.16302       0.40583      0.40170  0.00s
    269       0.12400       0.39051      0.31754  0.00s
    270       0.19779       0.34990      0.56527  0.00s
    271       0.18624       0.35843      0.51960  0.00s
    272       0.13797       0.37777      0.36521  0.00s
    273       0.18197       0.38203      0.47632  0.00s
    274       0.11763       0.41305      0.28479  0.00s
    275       0.20775       0.41196      0.50430  0.00s
    276       0.20212       0.37823      0.53438  0.00s
    277       0.19213       0.39521      0.48614  0.00s
    278       0.13605       0.47247      0.28794  0.00s
    279       0.21504       0.44695      0.48113  0.00s
    280       0.14952       0.42259      0.35381  0.00s
    281       0.18918       0.37812      0.50030  0.00s
    282       0.18663       0.38395      0.48608  0.00s
    283       0.15899       0.40995      0.38783  0.00s
    284       0.15726       0.39936      0.39378  0.00s
    285       0.10907       0.40247      0.27101  0.00s
    286       0.17084       0.40635      0.42042  0.00s
    287       0.11953       0.40893      0.29230  0.00s
    288       0.17144       0.42953      0.39914  0.00s
    289       0.15096       0.41467      0.36406  0.00s
    290       0.13888       0.36702      0.37840  0.00s
    291       0.11618       0.38770      0.29966  0.00s
    292       0.14910       0.41726      0.35733  0.00s
    293       0.15269       0.45942      0.33235  0.00s
    294       0.12731       0.49315      0.25815  0.00s
    295       0.13822       0.49839      0.27735  0.00s
    296       0.15629       0.50204      0.31131  0.00s
    297       0.14306       0.49288      0.29025  0.00s
    298       0.12700       0.47170      0.26925  0.00s
    299       0.16062       0.46547      0.34507  0.00s
    300       0.19869       0.47675      0.41675  0.00s
    301       0.13293       0.51149      0.25989  0.00s
    302       0.15508       0.49533      0.31309  0.00s
    303       0.13728       0.46501      0.29521  0.00s
    304       0.21131       0.41828      0.50519  0.00s
    305       0.16777       0.43430      0.38630  0.00s
    306       0.12766       0.39521      0.32303  0.00s
    307       0.15115       0.37284      0.40541  0.00s
    308       0.11492       0.36825      0.31207  0.00s
    309       0.23844       0.40569      0.58775  0.00s
    310       0.08760       0.40415      0.21674  0.00s
    311       0.09866       0.39403      0.25039  0.00s
    312       0.14412       0.41539      0.34695  0.00s
    313       0.11898       0.42987      0.27677  0.00s
    314       0.13703       0.40691      0.33675  0.00s
    315       0.11172       0.41193      0.27121  0.00s
    316       0.19763       0.43815      0.45105  0.00s
    317       0.18425       0.44145      0.41737  0.00s
    318       0.18067       0.43614      0.41424  0.00s
    319       0.17902       0.41748      0.42880  0.00s
    320       0.13434       0.41341      0.32495  0.00s
    321       0.22994       0.40996      0.56089  0.00s
    322       0.15464       0.42608      0.36293  0.00s
    323       0.11553       0.46292      0.24957  0.00s
    324       0.14368       0.41745      0.34419  0.00s
    325       0.19690       0.42360      0.46482  0.00s
    326       0.12352       0.42472      0.29083  0.00s
    327       0.16101       0.42050      0.38292  0.00s
    328       0.07361       0.44160      0.16669  0.00s
    329       0.12520       0.43122      0.29035  0.00s
    330       0.17651       0.43139      0.40917  0.00s
    331       0.21932       0.38731      0.56628  0.00s
    332       0.16294       0.33167      0.49128  0.00s
    333       0.18145       0.39247      0.46232  0.00s
    334       0.10546       0.40430      0.26085  0.00s
    335       0.17985       0.42179      0.42640  0.00s
    336       0.18062       0.43645      0.41384  0.00s
    337       0.16260       0.41565      0.39119  0.00s
    338       0.13172       0.41287      0.31904  0.00s
    339       0.10307       0.42702      0.24137  0.00s
    340       0.08668       0.42864      0.20222  0.00s
    341       0.08041       0.43068      0.18670  0.00s
    342       0.18792       0.42528      0.44188  0.00s
    343       0.14743       0.44011      0.33499  0.00s
    344       0.11225       0.41332      0.27159  0.00s
    345       0.15434       0.40542      0.38068  0.00s
    346       0.17948       0.46262      0.38796  0.00s
    347       0.18471       0.50175      0.36813  0.00s
    348       0.10156       0.47952      0.21179  0.00s
    349       0.13969       0.49037      0.28487  0.00s
    350       0.14797       0.47688      0.31029  0.00s
    351       0.17977       0.48958      0.36719  0.00s
    352       0.14430       0.43662      0.33050  0.00s
    353       0.17794       0.43068      0.41317  0.00s
    354       0.12091       0.42317      0.28573  0.00s
    355       0.11057       0.40002      0.27642  0.00s
    356       0.15169       0.37661      0.40279  0.00s
    357       0.15511       0.37817      0.41017  0.00s
    358       0.13248       0.35061      0.37784  0.00s
    359       0.15361       0.35208      0.43631  0.00s
    360       0.21937       0.43448      0.50491  0.00s
    361       0.14466       0.52781      0.27407  0.00s
    362       0.14350       0.54853      0.26161  0.00s
    363       0.11797       0.51132      0.23072  0.00s
    364       0.12622       0.48274      0.26146  0.00s
    365       0.14322       0.42058      0.34054  0.00s
    366       0.18183       0.43924      0.41396  0.00s
    367       0.11924       0.44949      0.26529  0.00s
    368       0.20143       0.44426      0.45340  0.00s
    369       0.12251       0.43300      0.28293  0.00s
    370       0.17353       0.40403      0.42951  0.00s
    371       0.16430       0.40331      0.40737  0.00s
    372       0.11625       0.36916      0.31490  0.00s
    373       0.13061       0.38853      0.33617  0.00s
    374       0.16136       0.36536      0.44165  0.00s
    375       0.11228       0.38959      0.28820  0.00s
    376       0.14427       0.43218      0.33383  0.00s
    377       0.12645       0.43524      0.29054  0.00s
    378       0.24974       0.43096      0.57950  0.00s
    379       0.14375       0.39932      0.35999  0.00s
    380       0.14485       0.43276      0.33471  0.00s
    381       0.16448       0.43862      0.37499  0.00s
    382       0.13428       0.41424      0.32416  0.00s
    383       0.16880       0.48333      0.34924  0.00s
    384       0.09015       0.50689      0.17786  0.00s
    385       0.17057       0.50407      0.33839  0.00s
    386       0.21560       0.50450      0.42735  0.00s
    387       0.18626       0.47157      0.39497  0.00s
    388       0.14427       0.45289      0.31856  0.00s
    389       0.09454       0.41243      0.22923  0.00s
    390       0.10682       0.35913      0.29745  0.00s
    391       0.17205       0.34496      0.49877  0.00s
    392       0.13486       0.38705      0.34844  0.00s
    393       0.14818       0.42877      0.34561  0.00s
    394       0.13179       0.40983      0.32157  0.00s
    395       0.09048       0.40433      0.22377  0.00s
    396       0.12786       0.38695      0.33043  0.00s
    397       0.14500       0.37406      0.38763  0.00s
    398       0.18746       0.42545      0.44061  0.00s
    399       0.13653       0.46100      0.29617  0.00s
    400       0.07629       0.45392      0.16807  0.00s
    401       0.13193       0.41358      0.31900  0.00s
    402       0.16322       0.41918      0.38938  0.00s
    403       0.10212       0.40413      0.25270  0.00s
    404       0.07205       0.40300      0.17879  0.00s
    405       0.17966       0.42586      0.42187  0.00s
    406       0.11136       0.42729      0.26062  0.00s
    407       0.09570       0.43143      0.22183  0.00s
    408       0.12855       0.42813      0.30027  0.00s
    409       0.10288       0.42617      0.24141  0.00s
    410       0.20182       0.41628      0.48481  0.00s
    411       0.12772       0.40389      0.31622  0.00s
    412       0.22371       0.38572      0.57999  0.00s
    413       0.10472       0.36037      0.29058  0.00s
    414       0.14928       0.33752      0.44227  0.00s
    415       0.13208       0.35018      0.37717  0.00s
    416       0.12945       0.33669      0.38448  0.00s
    417       0.22385       0.36231      0.61783  0.00s
    418       0.16967       0.41171      0.41211  0.00s
    419       0.14182       0.37876      0.37442  0.00s
    420       0.13669       0.43200      0.31641  0.00s
    421       0.21925       0.50310      0.43580  0.00s
    422       0.16268       0.52194      0.31169  0.00s
    423       0.10304       0.46837      0.21999  0.00s
    424       0.19968       0.44780      0.44592  0.00s
    425       0.13362       0.43315      0.30848  0.00s
    426       0.14067       0.44298      0.31755  0.00s
    427       0.12770       0.39726      0.32145  0.00s
    428       0.30254       0.43220      0.70001  0.00s
    429       0.13089       0.41773      0.31333  0.00s
    430       0.10269       0.41009      0.25041  0.00s
    431       0.17775       0.40585      0.43797  0.00s
    432       0.16703       0.42900      0.38935  0.00s
    433       0.21834       0.47288      0.46172  0.00s
    434       0.15158       0.42967      0.35278  0.00s
    435       0.08360       0.39088      0.21389  0.00s
    436       0.10412       0.38909      0.26760  0.00s
    437       0.10187       0.43261      0.23547  0.00s
    438       0.08783       0.40766      0.21545  0.00s
    439       0.22014       0.41474      0.53077  0.00s
    440       0.08805       0.38833      0.22675  0.00s
    441       0.18845       0.39237      0.48029  0.00s
    442       0.12879       0.40229      0.32014  0.00s
    443       0.16524       0.42054      0.39293  0.00s
    444       0.14565       0.38698      0.37638  0.00s
    445       0.12258       0.41202      0.29750  0.00s
    446       0.09884       0.41337      0.23911  0.00s
    447       0.12202       0.41453      0.29435  0.00s
    448       0.15602       0.39282      0.39719  0.00s
    449       0.18891       0.41240      0.45809  0.00s
    450       0.13180       0.43664      0.30184  0.00s
    451       0.15972       0.39464      0.40471  0.00s
    452       0.15344       0.38024      0.40354  0.00s
    453       0.13572       0.36991      0.36691  0.00s
    454       0.14937       0.39960      0.37378  0.00s
    455       0.22317       0.41395      0.53912  0.00s
    456       0.15609       0.35947      0.43421  0.00s
    457       0.20053       0.40510      0.49503  0.00s
    458       0.14295       0.39691      0.36015  0.00s
    459       0.10816       0.37199      0.29077  0.00s
    460       0.14829       0.37979      0.39046  0.00s
    461       0.23266       0.38832      0.59914  0.00s
    462       0.09924       0.39591      0.25066  0.00s
    463       0.22525       0.37517      0.60040  0.00s
    464       0.15090       0.39636      0.38072  0.00s
    465       0.12370       0.36297      0.34079  0.00s
    466       0.12268       0.34934      0.35119  0.00s
    467       0.09910       0.37316      0.26557  0.00s
    468       0.09775       0.34401      0.28414  0.00s
    469       0.17952       0.32719      0.54866  0.00s
    470       0.14320       0.36042      0.39732  0.00s
    471       0.20605       0.45740      0.45049  0.00s
    472       0.12161       0.41507      0.29298  0.00s
    473       0.17274       0.37693      0.45827  0.00s
    474       0.12220       0.40690      0.30033  0.00s
    475       0.13222       0.41404      0.31933  0.00s
    476       0.19081       0.48656      0.39216  0.00s
    477       0.13357       0.47658      0.28027  0.00s
    478       0.08921       0.44526      0.20035  0.00s
    479       0.12012       0.40888      0.29378  0.00s
    480       0.13888       0.37450      0.37085  0.00s
    481       0.15324       0.37782      0.40560  0.00s
    482       0.25653       0.35523      0.72214  0.00s
    483       0.12568       0.35882      0.35027  0.00s
    484       0.15253       0.35128      0.43421  0.00s
    485       0.20506       0.42219      0.48570  0.00s
    486       0.13871       0.46994      0.29516  0.00s
    487       0.19386       0.42608      0.45499  0.00s
    488       0.18317       0.36988      0.49522  0.00s
    489       0.11097       0.34740      0.31945  0.00s
    490       0.18076       0.34716      0.52068  0.00s
    491       0.21016       0.42074      0.49951  0.00s
    492       0.07795       0.41984      0.18566  0.00s
    493       0.14279       0.41075      0.34762  0.00s
    494       0.15156       0.46324      0.32718  0.00s
    495       0.12480       0.46665      0.26745  0.00s
    496       0.18079       0.41748      0.43306  0.00s
    497       0.10472       0.42136      0.24853  0.00s
    498       0.17465       0.42295      0.41294  0.00s
    499       0.12050       0.40118      0.30037  0.00s
    500       0.14382       0.40962      0.35109  0.00s
Out[184]:
NeuralNet(X_tensor_type=None,
     batch_iterator_test=<nolearn.lasagne.base.BatchIterator object at 0x7f4bf326e0b8>,
     batch_iterator_train=<nolearn.lasagne.base.BatchIterator object at 0x7f4bf326e908>,
     custom_score=None, dropout1_p=0.5,
     hidden1_nonlinearity=<function tanh at 0x7f4ba9069f28>,
     hidden1_num_units=151, input_shape=(None, 301),
     layers=[('input', <class 'lasagne.layers.input.InputLayer'>), ('hidden1', <class 'lasagne.layers.dense.DenseLayer'>), ('dropout1', <class 'lasagne.layers.noise.DropoutLayer'>), ('nonlinear', <class 'lasagne.layers.dense.NonlinearityLayer'>), ('output', <class 'lasagne.layers.dense.DenseLayer'>)],
     loss=None, max_epochs=500, more_params={},
     objective=<function objective at 0x7f4bf3270510>,
     objective_loss_function=<function squared_error at 0x7f4ba8bcd840>,
     on_epoch_finished=[<nolearn.lasagne.handlers.PrintLog object at 0x7f4b481a0400>],
     on_training_finished=[],
     on_training_started=[<nolearn.lasagne.handlers.PrintLayerInfo object at 0x7f4b481a04a8>],
     output_nonlinearity=None, output_num_units=1, regression=True,
     train_split=<nolearn.lasagne.base.TrainSplit object at 0x7f4bf326e6a0>,
     update=<function nesterov_momentum at 0x7f4ba8bd40d0>,
     update_learning_rate=0.01, update_momentum=0.95,
     use_label_encoder=False, verbose=1,
     y_tensor_type=TensorType(float32, matrix))

In [185]:
nn1_preds = net1.predict(x_test)
nn1_mse = float(mean_squared_error(nn1_preds, y_test))

plt.figure(figsize=(3,3))
plt.scatter(y_test, nn1_preds)
plt.xlabel('Actual')
plt.ylabel('Predicted')
plt.title('{0} Neural Network'.format(drug_cols[colnum]))
plt.gca().set_aspect('equal', 'datalim')
plt.annotate(s='mse: {0}'.format(str(Decimal(nn1_mse).quantize(TWOPLACES))), xy=(1,0), xycoords='axes fraction', ha='right', va='bottom')
plt.plot(x_equals_y(y_test), x_equals_y(y_test), color='red')
plt.show()



In [186]:
sps.pearsonr(nn1_preds, y_test.reshape(y_test.shape[0],1))


Out[186]:
(array([ 0.43451592], dtype=float32), array([ 0.04903143], dtype=float32))

In [ ]:


In [ ]: