In [1]:
import pandas as pd
import numpy as np
import nolearn
import matplotlib.pyplot as plt
import seaborn
import sklearn.linear_model as lm
import scipy.stats as sps
import math

from Bio import SeqIO
from collections import Counter
from decimal import Decimal
from lasagne import layers, nonlinearities
from lasagne.updates import nesterov_momentum
from lasagne import layers
from nolearn.lasagne import NeuralNet
from sklearn.ensemble import RandomForestRegressor, AdaBoostRegressor, GradientBoostingRegressor
from sklearn.cross_validation import train_test_split
from sklearn.metrics import r2_score, mean_squared_error
from sklearn.svm import SVR

seaborn.set_style('white')
seaborn.set_context('poster')

%matplotlib inline


Using gpu device 0: Quadro 2000

In [7]:
# Read in the protease inhibitor data
widths = [8]
widths.extend([4]*8)
widths.extend([4]*99)
data = pd.read_csv('hiv-nrt-data.csv', index_col='SeqID')
data


Out[7]:
3TC ABC AZT D4T DDI TDF P1 P2 P3 P4 ... P231 P232 P233 P234 P235 P236 P237 P238 P239 P240
SeqID
2997 200.00 4.3 3.9 1.4 1.2 NaN - - - - ... - - - - - - - - - -
4388 200.00 8.1 4.7 1.8 1.7 NaN - - - - ... - - - - - - - - - .
4427 1.40 1.1 28.0 1.0 0.8 1.9 - - - - ... - - - - - - - - - -
4487 1.80 1.5 7.1 1.2 1.1 1.1 - - - - ... - - - - - - - - - -
4539 200.00 3.7 4.5 2.0 1.3 NaN - - - - ... . . . . . . . . . .
4663 200.00 7.1 0.8 1.3 1.9 NaN - - - - ... - - - - - - - - - -
4689 2.55 5.2 6.7 5.4 5.3 0.9 - - - - ... - - - - - - - - - -
4697 200.00 5.9 1.9 1.3 1.7 NaN - - - - ... - - - - - - - - - -
5071 200.00 6.3 5.2 1.4 1.9 0.9 - - - - ... - - - - - - - - - -
5222 200.00 5.7 0.3 1.2 2.0 NaN - - - - ... - - - - - - - - - -
5280 6.20 4.5 1000.0 2.4 1.6 5.0 - - - - ... - - - - - - - - - -
5445 3.30 4.9 177.0 2.9 1.6 2.2 - - - - ... - - - - - - - - - -
5463 3.30 6.5 40.0 9.5 9.2 1.2 - - - - ... - - - - - - - - - -
5465 200.00 7.0 77.0 2.7 2.1 NaN - - - - ... - - - - - - - - - -
5641 200.00 4.2 1.5 1.2 1.3 0.9 - - - - ... - - - - - - - - - -
5708 200.00 12.0 1000.0 7.2 2.4 3.7 - - - - ... - - - - - - - - - -
6029 7.30 18.0 2400.0 16.0 4.8 NaN - - - - ... - - - - - - - - - -
6485 1.00 0.6 0.2 0.4 0.7 0.4 - - - - ... - - - - - - - - - -
6519 200.00 2.7 0.4 0.8 1.3 NaN - - - - ... - - - LH - - - - - -
6540 1.10 1.0 0.3 0.8 1.0 0.7 - - - - ... - - - - - - - - - -
6569 3.90 5.3 216.5 4.2 1.8 NaN - - - - ... - - - - - - - - - -
6709 0.90 0.8 0.5 0.8 1.0 0.7 - - - - ... - - - - - - - - - -
6796 7.10 7.0 259.0 4.4 2.4 NaN - - - - ... - - - - - - - - - -
6820 4.80 1.6 24.0 1.8 1.4 NaN - - - - ... - - - - - - E - - -
6859 1.40 2.3 42.0 2.0 1.0 NaN - - - - ... - - - - - - - - - -
6876 200.00 9.6 35.4 3.3 2.1 NaN - - - - ... - - - - - - - - - -
7328 3.60 1.7 7.2 1.5 1.1 NaN - - - - ... - - - - - - - - - -
7347 200.00 9.9 100.5 3.9 2.1 NaN - - - - ... - - - - - - - - - -
7348 200.00 7.7 7.9 2.3 1.8 NaN - - - - ... - - - - - - - - - -
7350 200.00 4.6 0.6 1.0 1.6 NaN - - - - ... - - - - - - - - - -
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
259208 1.25 1.0 0.6 0.9 1.0 0.9 - - - - ... - - - - - - - - - -
259210 81.70 1.6 0.2 0.6 1.6 0.3 - - - - ... - - - - - - - - - -
259212 1.00 0.7 0.4 0.8 0.8 0.7 - - - - ... - - - - - - - - - -
259214 74.00 1.5 0.2 0.7 1.4 0.4 - - - - ... - - - - - - - - - -
259216 1.05 0.9 0.8 1.2 1.3 0.9 - - - - ... - - - - - - - - - -
259222 29.50 3.6 0.4 1.2 1.7 2.0 - - - - ... - - - - - - - - - -
259224 1.40 1.0 0.6 0.9 1.0 0.9 - - - - ... - - - - - - - - - -
259226 113.00 1.8 0.2 0.8 1.6 0.5 - - - - ... - - - - - - - N - -
259228 1.30 0.9 0.5 0.7 0.7 0.8 - - - - ... - - - - - - - - - -
259230 91.50 3.0 0.2 0.5 1.2 0.2 - - - - ... - - - - - - - - - -
259234 1.15 1.0 0.8 1.1 1.3 1.0 - - - - ... - - - - - - - - - -
259238 1.30 1.0 0.6 0.9 1.2 0.9 - - - - ... - - - - - - - - - -
259242 0.90 0.8 0.7 0.9 0.9 0.9 - - - - ... - - - - - - - - - -
259244 88.50 5.2 0.2 0.6 2.0 0.3 - - - - ... - - - - - - - - - -
259246 1.05 0.8 0.8 1.1 1.3 0.9 - - - - ... - - - - - - - - - -
259248 82.10 1.9 0.3 0.8 1.6 0.5 - - - - ... - - - - - - - - - -
259250 1.15 0.9 0.6 0.9 1.2 0.8 - - - - ... - - - - - - - - - -
259252 87.00 3.8 0.4 0.9 1.9 0.7 - - - - ... - - - - - - - - - -
259254 1.40 0.9 0.6 0.9 0.8 0.8 - - - - ... - - - - - - - - - -
259256 91.50 3.0 0.2 0.6 1.1 0.2 - - - - ... - - - - - - - - - -
259258 1.20 0.9 0.8 1.2 1.3 1.0 - - - - ... - - - - - - - - - -
259260 104.50 4.4 0.6 0.7 1.9 0.8 - - - - ... - - - - - - - - - -
259262 1.15 1.0 0.5 0.8 0.8 0.7 - - - - ... - - - - - - - - - -
259266 1.05 0.7 0.3 0.8 0.8 0.6 - - - - ... - - - - - - - - - -
259268 125.50 2.9 0.2 0.7 1.2 0.4 - - - - ... - - - - - - - - - -
270334 1.15 1.4 1.2 1.4 1.2 1.1 - - - - ... - - - - - - - - - -
270335 89.40 3.2 0.8 1.6 2.4 0.8 - - - - ... - - - - - - - - - -
270336 89.40 2.8 0.5 1.4 2.2 0.6 - - - - ... - - - - - - - - - -
270337 1.20 1.4 1.8 1.4 1.2 1.2 - - - - ... - - - - - - - - - -
270338 0.70 0.8 0.4 1.0 0.8 0.7 - - - - ... - - - - - - - - - -

1498 rows × 246 columns


In [9]:
# Set the drug data columns and the amino acid data columns
drug_cols = data.columns[0:6]
feat_cols = data.columns[6:]

In [10]:
# Read in the consensus data
consensus = SeqIO.read('hiv-rt-consensus.fasta', 'fasta')

consensus_map = {i+1:letter for i, letter in enumerate(str(consensus.seq))}

In [11]:
# Because there are '-' characters in the dataset, representing consensus sequence at each of the positions, 
# they need to be replaced with the actual consensus letter.

for i, col in enumerate(feat_cols):
    # Replace '-' with the consensus letter.
    data[col] = data[col].replace({'-':consensus_map[i+1]})
    
    # Replace '.' with np.nan
    data[col] = data[col].replace({'.':np.nan})
    
    # Replace 'X' with np.nan
    data[col] = data[col].replace({'X':np.nan})

In [12]:
# Drop any feat_cols that have np.nan inside them. We don't want low quality sequences.
data.dropna(inplace=True, subset=feat_cols)

In [13]:
data


Out[13]:
3TC ABC AZT D4T DDI TDF P1 P2 P3 P4 ... P231 P232 P233 P234 P235 P236 P237 P238 P239 P240
SeqID
4427 1.40 1.1 28.0 1.0 0.8 1.9 P I S P ... G Y E L H P D K W T
4487 1.80 1.5 7.1 1.2 1.1 1.1 P I S P ... G Y E L H P D K W T
4663 200.00 7.1 0.8 1.3 1.9 NaN P I S P ... G Y E L H P D K W T
4689 2.55 5.2 6.7 5.4 5.3 0.9 P I S P ... G Y E L H P D K W T
4697 200.00 5.9 1.9 1.3 1.7 NaN P I S P ... G Y E L H P D K W T
5071 200.00 6.3 5.2 1.4 1.9 0.9 P I S P ... G Y E L H P D K W T
5222 200.00 5.7 0.3 1.2 2.0 NaN P I S P ... G Y E L H P D K W T
5280 6.20 4.5 1000.0 2.4 1.6 5.0 P I S P ... G Y E L H P D K W T
5445 3.30 4.9 177.0 2.9 1.6 2.2 P I S P ... G Y E L H P D K W T
5463 3.30 6.5 40.0 9.5 9.2 1.2 P I S P ... G Y E L H P D K W T
5465 200.00 7.0 77.0 2.7 2.1 NaN P I S P ... G Y E L H P D K W T
5641 200.00 4.2 1.5 1.2 1.3 0.9 P I S P ... G Y E L H P D K W T
5708 200.00 12.0 1000.0 7.2 2.4 3.7 P I S P ... G Y E L H P D K W T
6029 7.30 18.0 2400.0 16.0 4.8 NaN P I S P ... G Y E L H P D K W T
6485 1.00 0.6 0.2 0.4 0.7 0.4 P I S P ... G Y E L H P D K W T
6519 200.00 2.7 0.4 0.8 1.3 NaN P I S P ... G Y E LH H P D K W T
6540 1.10 1.0 0.3 0.8 1.0 0.7 P I S P ... G Y E L H P D K W T
6569 3.90 5.3 216.5 4.2 1.8 NaN P I S P ... G Y E L H P D K W T
6709 0.90 0.8 0.5 0.8 1.0 0.7 P I S P ... G Y E L H P D K W T
6796 7.10 7.0 259.0 4.4 2.4 NaN P I S P ... G Y E L H P D K W T
6820 4.80 1.6 24.0 1.8 1.4 NaN P I S P ... G Y E L H P E K W T
6859 1.40 2.3 42.0 2.0 1.0 NaN P I S P ... G Y E L H P D K W T
6876 200.00 9.6 35.4 3.3 2.1 NaN P I S P ... G Y E L H P D K W T
7328 3.60 1.7 7.2 1.5 1.1 NaN P I S P ... G Y E L H P D K W T
7347 200.00 9.9 100.5 3.9 2.1 NaN P I S P ... G Y E L H P D K W T
7348 200.00 7.7 7.9 2.3 1.8 NaN P I S P ... G Y E L H P D K W T
7350 200.00 4.6 0.6 1.0 1.6 NaN P I S P ... G Y E L H P D K W T
7360 200.00 6.7 6.5 1.6 1.4 NaN P I S P ... G Y E L H P D K W T
7364 200.00 7.8 7.9 2.3 1.7 NaN P I S P ... G Y E L H P D K W T
7377 200.00 29.3 726.0 10.1 4.1 NaN P I S P ... G Y E L H P D K W T
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
259204 1.50 1.0 0.5 0.8 0.9 0.8 P I S P ... G Y E L H P D K W T
259206 121.00 1.6 0.5 0.7 1.3 0.5 P I S P ... G Y E L H P D K W T
259208 1.25 1.0 0.6 0.9 1.0 0.9 P I S P ... G Y E L H P D K W T
259210 81.70 1.6 0.2 0.6 1.6 0.3 P I S P ... G Y E L H P D K W T
259214 74.00 1.5 0.2 0.7 1.4 0.4 P I S P ... G Y E L H P D K W T
259216 1.05 0.9 0.8 1.2 1.3 0.9 P I S P ... G Y E L H P D K W T
259222 29.50 3.6 0.4 1.2 1.7 2.0 P I S P ... G Y E L H P D K W T
259224 1.40 1.0 0.6 0.9 1.0 0.9 P I S P ... G Y E L H P D K W T
259226 113.00 1.8 0.2 0.8 1.6 0.5 P I S P ... G Y E L H P D N W T
259228 1.30 0.9 0.5 0.7 0.7 0.8 P I S P ... G Y E L H P D K W T
259230 91.50 3.0 0.2 0.5 1.2 0.2 P I S P ... G Y E L H P D K W T
259234 1.15 1.0 0.8 1.1 1.3 1.0 P I S P ... G Y E L H P D K W T
259238 1.30 1.0 0.6 0.9 1.2 0.9 P I S P ... G Y E L H P D K W T
259242 0.90 0.8 0.7 0.9 0.9 0.9 P I S P ... G Y E L H P D K W T
259244 88.50 5.2 0.2 0.6 2.0 0.3 P I S P ... G Y E L H P D K W T
259246 1.05 0.8 0.8 1.1 1.3 0.9 P I S P ... G Y E L H P D K W T
259248 82.10 1.9 0.3 0.8 1.6 0.5 P I S P ... G Y E L H P D K W T
259250 1.15 0.9 0.6 0.9 1.2 0.8 P I S P ... G Y E L H P D K W T
259252 87.00 3.8 0.4 0.9 1.9 0.7 P I S P ... G Y E L H P D K W T
259254 1.40 0.9 0.6 0.9 0.8 0.8 P I S P ... G Y E L H P D K W T
259256 91.50 3.0 0.2 0.6 1.1 0.2 P I S P ... G Y E L H P D K W T
259258 1.20 0.9 0.8 1.2 1.3 1.0 P I S P ... G Y E L H P D K W T
259260 104.50 4.4 0.6 0.7 1.9 0.8 P I S P ... G Y E L H P D K W T
259266 1.05 0.7 0.3 0.8 0.8 0.6 P I S P ... G Y E L H P D K W T
259268 125.50 2.9 0.2 0.7 1.2 0.4 P I S P ... G Y E L H P D K W T
270334 1.15 1.4 1.2 1.4 1.2 1.1 P I S P ... G Y E L H P D K W T
270335 89.40 3.2 0.8 1.6 2.4 0.8 P I S P ... G Y E L H P D K W T
270336 89.40 2.8 0.5 1.4 2.2 0.6 P I S P ... G Y E L H P D K W T
270337 1.20 1.4 1.8 1.4 1.2 1.2 P I S P ... G Y E L H P D K W T
270338 0.70 0.8 0.4 1.0 0.8 0.7 P I S P ... G Y E L H P D K W T

1263 rows × 246 columns


In [14]:
# Drop any feat_cols that are completely conserved.

# The nonconserved_cols list will serve as a convenient selector for the X- data from the 
# original dataframe.
nonconserved_cols = []
for col in feat_cols:
    if len(pd.unique(data[col])) == 1:
        data.drop(col, axis=1, inplace=True)
        
    else:
        nonconserved_cols.append(col)

In [15]:
drug_cols


Out[15]:
Index(['3TC', 'ABC', 'AZT', 'D4T', 'DDI', 'TDF'], dtype='object')

In [136]:
def x_equals_y(y_test):
    """
    A function that returns a range from minimum to maximum of y_test.
    
    Used below in the plotting below.
    """
    floor = math.floor(np.min(y_test))
    ceil = math.ceil(np.max(y_test))
    x_eq_y = range(floor, ceil)
    return x_eq_y

TWOPLACES = Decimal(10) ** -2

In [152]:
colnum = 3

drug_df = pd.DataFrame()
drug_df[drug_cols[colnum]] = data[drug_cols[colnum]]
drug_df[nonconserved_cols] = data[nonconserved_cols]
for col in nonconserved_cols:
    drug_df[col] = drug_df[col].apply(lambda x: np.nan if len(x) > 1 else x)
drug_df.dropna(inplace=True)

drug_X = drug_df[nonconserved_cols]
drug_Y = drug_df[drug_cols[colnum]].apply(lambda x:np.log(x))
# drug_Y.values

In [153]:
from isoelectric_point import isoelectric_points
from molecular_weight import molecular_weights

# Standardize pI matrix. 7 is neutral
drug_X_pi = drug_X.replace(isoelectric_points)

# Standardize MW matrix.
drug_X_mw = drug_X.replace(molecular_weights)

# Binarize drug_X matrix.
from sklearn.preprocessing import LabelBinarizer
drug_X_bi = pd.DataFrame()
binarizers = dict()

for col in drug_X.columns:
    lb = LabelBinarizer()
    binarized_cols = lb.fit_transform(drug_X[col])
    # print(binarized_cols)
    if len(lb.classes_) == 2:
        # print(binarized_cols)
        drug_X_bi[col] = pd.Series(binarized_cols[:,0])
    else:
        for i, c in enumerate(lb.classes_):
            # print(col + c)
            # print(binarized_cols[:,i])

            drug_X_bi[col + '_' + c] = binarized_cols[:,i]
        
drug_X_bi


Out[153]:
P1 P2 P3_C P3_N P3_S P4_P P4_S P4_T P5_I P5_S ... P237_E P237_N P238_K P238_R P238_T P239 P240_A P240_E P240_K P240_T
0 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
1 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
2 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
3 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
4 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
5 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
6 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
7 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
8 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
9 1 0 0 0 1 1 0 0 1 0 ... 1 0 1 0 0 1 0 0 0 1
10 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
11 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
12 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
13 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
14 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
15 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
16 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
17 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
18 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
19 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
20 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
21 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
22 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
23 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
24 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
25 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
26 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
27 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
28 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
29 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
428 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
429 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
430 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
431 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
432 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
433 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
434 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
435 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
436 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
437 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
438 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
439 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
440 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
441 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
442 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
443 1 0 0 0 1 0 1 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
444 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
445 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
446 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
447 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
448 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
449 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
450 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
451 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
452 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
453 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
454 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
455 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
456 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
457 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1

458 rows × 482 columns


In [154]:
fig = plt.figure(figsize=(3,3))
drug_Y.hist(grid=False)
plt.xlabel('Value')
plt.ylabel('Count')
plt.title('{0} Distribution'.format(drug_cols[colnum]))


Out[154]:
<matplotlib.text.Text at 0x7fa5c215d4e0>

In [155]:
# Here, let's try the Random Forest Regressor. This will be the baseline.
x_train, x_test, y_train, y_test = train_test_split(drug_X_bi, drug_Y)

rfr = RandomForestRegressor(n_estimators=500, n_jobs=-1, oob_score=True)

rfr.fit(x_train, y_train)
rfr_preds = rfr.predict(x_test)
print(rfr.score(x_test, y_test), mean_squared_error(rfr_preds, y_test))
rfr_mse = mean_squared_error(rfr_preds, y_test)
# print(rfr.oob_score_)
sps.pearsonr(rfr_preds, y_test)

plt.figure(figsize=(3,3))
plt.scatter(y_test, rfr_preds,)
plt.title('{0} Random Forest'.format(drug_cols[colnum]))
plt.xlabel('Actual')
plt.ylabel('Predicted')
plt.gca().set_aspect('equal', 'datalim')
plt.annotate(s='mse: {0}'.format(str(Decimal(rfr_mse).quantize(TWOPLACES))), xy=(1,0), xycoords='axes fraction', ha='right', va='bottom')
plt.plot(x_equals_y(y_test), x_equals_y(y_test), color='red')
plt.show()


0.786368305999 0.132429943561

In [156]:
plt.bar(range(len(rfr.feature_importances_)), rfr.feature_importances_)
plt.xlabel('Position')
plt.ylabel('Relative Importance')
plt.title('{0} Random Forest'.format(drug_cols[colnum]))


Out[156]:
<matplotlib.text.Text at 0x7fa5c20d3b38>

In [157]:
# Get back the importance of each feature.
feat_impt = [(p, i) for p, i in zip(drug_X_bi.columns, rfr.feature_importances_)]
sorted(feat_impt, key=lambda x:x[1], reverse=True)


Out[157]:
[('P151', 0.16251462938669017),
 ('P215_T', 0.16006592453195229),
 ('P62', 0.089598402308289846),
 ('P210_W', 0.04805732919349797),
 ('P69_#', 0.043176594447178349),
 ('P210_L', 0.03820719628644869),
 ('P116', 0.036923518809471458),
 ('P41_L', 0.027512755954099771),
 ('P77', 0.018855167959727394),
 ('P67_D', 0.018461063073661806),
 ('P70_K', 0.016167279408335643),
 ('P40_E', 0.010868828402363542),
 ('P215_D', 0.0097352909437584643),
 ('P40_F', 0.0095284336496718729),
 ('P118', 0.0093069566910762534),
 ('P69_T', 0.0091408832193107317),
 ('P184_V', 0.0089808668340186793),
 ('P65', 0.0078659771260242446),
 ('P67_N', 0.0078486143018055513),
 ('P70_N', 0.0076543349919679102),
 ('P70_d', 0.0072499899020392116),
 ('P215_Y', 0.0071962504233967191),
 ('P223_N', 0.0062562829205002248),
 ('P43_E', 0.0060866113487613007),
 ('P73', 0.0060846705519703387),
 ('P70_R', 0.0060199337114942904),
 ('P139_S', 0.0053925554834850821),
 ('P41_M', 0.0052268298420678544),
 ('P184_M', 0.0047365291124894888),
 ('P135_T', 0.0047257702295656111),
 ('P208', 0.0045941531295171948),
 ('P219_K', 0.0042685057849632061),
 ('P75_V', 0.0040200811033691917),
 ('P218', 0.0038741506470258355),
 ('P174_R', 0.0038138763719247435),
 ('P58_N', 0.0036469640236670266),
 ('P190_A', 0.0034880680895961361),
 ('P181_C', 0.0034809909691663351),
 ('P203_K', 0.0034169068828247013),
 ('P181_Y', 0.0034023562884997189),
 ('P74_V', 0.0033380187605316792),
 ('P138_K', 0.0032196758142861827),
 ('P214', 0.0030899144516060234),
 ('P74_L', 0.0030685216674645865),
 ('P207_E', 0.0027232206439434365),
 ('P39_T', 0.0026421469117063161),
 ('P75_T', 0.0025827401321966976),
 ('P135_V', 0.0025695161872389847),
 ('P200_T', 0.0025636478473387938),
 ('P139_T', 0.0024212663266235321),
 ('P207_Q', 0.002402825396636144),
 ('P64_H', 0.002383787045730263),
 ('P123_E', 0.0023502644301236194),
 ('P69_D', 0.0023465037879284579),
 ('P190_G', 0.0022155694228006011),
 ('P122_E', 0.002213964890623938),
 ('P211_R', 0.0021859602876278584),
 ('P102_Q', 0.0021678505224306203),
 ('P44_E', 0.0020674696042807359),
 ('P215_F', 0.0019449153773165285),
 ('P219_E', 0.0019272493402581862),
 ('P203_E', 0.0018550549534186885),
 ('P196_E', 0.0018383325223303162),
 ('P122_K', 0.0018347234940733837),
 ('P178_I', 0.001806222233384819),
 ('P123_D', 0.0017749513060090457),
 ('P69_N', 0.0017700664947200423),
 ('P135_I', 0.0017449729362553255),
 ('P75_I', 0.0017440799295997303),
 ('P103_K', 0.0017227293956739301),
 ('P219_R', 0.0016741067913017835),
 ('P43_K', 0.0016735102186183708),
 ('P102_K', 0.0016675568335350174),
 ('P228_H', 0.0016663261354250669),
 ('P196_G', 0.0016541932141252525),
 ('P200_A', 0.0016447160208774292),
 ('P162_S', 0.00159023188765284),
 ('P138_E', 0.001478400658828117),
 ('P177_E', 0.0014627745557714018),
 ('P219_Q', 0.0014482811153136784),
 ('P35_T', 0.0014381497441807549),
 ('P188_L', 0.0013285168302978377),
 ('P211_K', 0.0013093825552678343),
 ('P103_N', 0.001248617556734215),
 ('P179_I', 0.0012328551602509344),
 ('P162_C', 0.0011982045240013577),
 ('P228_L', 0.0011635428973818041),
 ('P178_L', 0.0011481025443574502),
 ('P238_K', 0.0011445712967038549),
 ('P35_V', 0.0011024788869958006),
 ('P146', 0.0010982585955854973),
 ('P64_K', 0.0010920025818014561),
 ('P177_D', 0.0010905129160233638),
 ('P178_F', 0.0010693452316058952),
 ('P39_A', 0.0010484497888503823),
 ('P20_K', 0.0010475814864030529),
 ('P101_E', 0.0010354309956481218),
 ('P101_K', 0.0010101778830088773),
 ('P188_Y', 0.000998000612165823),
 ('P228_R', 0.00097952916267022867),
 ('P68_G', 0.00096397487205400543),
 ('P108', 0.00095849356763606428),
 ('P115', 0.00094286952675548899),
 ('P162_A', 0.00093144483875862323),
 ('P238_T', 0.00090917464128133089),
 ('P60_I', 0.00090070518000247117),
 ('P67_E', 0.00089224907758878632),
 ('P68_S', 0.00086652663411294559),
 ('P179_V', 0.00086618235778399889),
 ('P83', 0.00085505255029670385),
 ('P103_S', 0.00085144885179626469),
 ('P121_Y', 0.00085126753806842744),
 ('P20_R', 0.00079379237648687486),
 ('P203_D', 0.00077172511141351081),
 ('P202', 0.00075929783491084622),
 ('P180', 0.00073909838686148088),
 ('P174_Q', 0.00073874043164542686),
 ('P200_I', 0.00073162255681505047),
 ('P123_N', 0.00070153937678211304),
 ('P178_M', 0.00069111021143699323),
 ('P60_V', 0.00068919471690468239),
 ('P207_N', 0.00067385270252314779),
 ('P142_I', 0.00065640298445170553),
 ('P227_L', 0.00065024723775872889),
 ('P35_I', 0.00064649900469493577),
 ('P44_D', 0.00064010540730884949),
 ('P90', 0.00063989615343746926),
 ('P106_V', 0.00063507207005211491),
 ('P98_G', 0.00062887490465580531),
 ('P172', 0.00061024702564071141),
 ('P28_E', 0.00059981013085939265),
 ('P98_A', 0.00057534786487980669),
 ('P35_M', 0.00056911894252802164),
 ('P142_V', 0.00055137585045650437),
 ('P166_R', 0.00054977243910318146),
 ('P237_D', 0.00054775854075780049),
 ('P179_F', 0.00052963131097299315),
 ('P49', 0.00052512110059578228),
 ('P16', 0.00049606463759768607),
 ('P138_G', 0.000486103946357874),
 ('P5_V', 0.00045940617868891722),
 ('P219_H', 0.00045381112961015966),
 ('P197_Q', 0.00044784706876806028),
 ('P227_F', 0.0004473313209620604),
 ('P5_I', 0.00042971435742582422),
 ('P211_N', 0.00042786537379451052),
 ('P215_S', 0.00041865725477986172),
 ('P89', 0.00041764304987800856),
 ('P223_K', 0.00040976773738635435),
 ('P68_T', 0.00039561065486796288),
 ('P43_Q', 0.00037804569690291247),
 ('P44_A', 0.00037484028070881067),
 ('P139_A', 0.00035590141937412467),
 ('P219_N', 0.00035154342032472273),
 ('P123_I', 0.00034953177682979503),
 ('P39_E', 0.00034730857148287779),
 ('P135_L', 0.00032111616800641442),
 ('P184_I', 0.00031939754644173097),
 ('P121_D', 0.00031732579512637157),
 ('P173_K', 0.00031615518762627149),
 ('P221', 0.00031455011193028858),
 ('P6_K', 0.00030830883055275345),
 ('P181_I', 0.00030233783360660249),
 ('P98_S', 0.00028876819567226215),
 ('P32_K', 0.00028124417542713646),
 ('P166_K', 0.0002808748111850798),
 ('P11_R', 0.00028085164193501294),
 ('P67_G', 0.00028084776657503395),
 ('P142_T', 0.00028032331721691682),
 ('P237_E', 0.00027972256099027565),
 ('P162_D', 0.00027099659813596829),
 ('P197_E', 0.00026770517690369783),
 ('P207_D', 0.00026727523712224109),
 ('P135_M', 0.00026651589804190277),
 ('P33_V', 0.00026081957296777335),
 ('P11_K', 0.0002544333961544424),
 ('P106_A', 0.00025252978521046244),
 ('P169_D', 0.00025055970776281003),
 ('P138_A', 0.00024359386375885614),
 ('P211_S', 0.00024350310232222971),
 ('P173_R', 0.00024258349397140316),
 ('P1', 0.00023332647291195526),
 ('P27', 0.00023226793237666573),
 ('P179_D', 0.00023137107928803926),
 ('P58_T', 0.00022655002519174334),
 ('P100', 0.00022014608755782165),
 ('P190_C', 0.00020714297387020066),
 ('P207_G', 0.00020647995815651133),
 ('P211_T', 0.00020404825050321954),
 ('P230', 0.00020287080602791622),
 ('P101_Q', 0.00020120962784956936),
 ('P157', 0.00020000058336736433),
 ('P190_S', 0.00019419345405297073),
 ('P74_I', 0.00018296967171673874),
 ('P48_S', 0.0001775163749034273),
 ('P104_R', 0.00017244319760841673),
 ('P111', 0.000167731611715598),
 ('P174_K', 0.0001671768480088083),
 ('P31_V', 0.00016695766584317567),
 ('P33_A', 0.00016554905591164114),
 ('P109', 0.00016410756193115599),
 ('P64_R', 0.00016334690794083678),
 ('P101_H', 0.00016073242777569233),
 ('P4_S', 0.000159582501610235),
 ('P173_Q', 0.00015518873743907723),
 ('P28_K', 0.00015296418244316532),
 ('P169_E', 0.00015235216899203938),
 ('P32_E', 0.00015230190426268028),
 ('P48_T', 0.00014931600475534391),
 ('P101_R', 0.00014790045863572254),
 ('P173_N', 0.00014770509449362994),
 ('P4_P', 0.0001469270609201563),
 ('P239', 0.00014672393892225505),
 ('P106_I', 0.00014256246536649794),
 ('P13', 0.0001413177624044851),
 ('P31_I', 0.00013858828937927102),
 ('P121_H', 0.00013747673450159318),
 ('P215_V', 0.000136070845889713),
 ('P37', 0.00013531876061213566),
 ('P207_K', 0.00013337564255884132),
 ('P200_R', 0.00013174523839097946),
 ('P223_Q', 0.00013096240494140797),
 ('P28_A', 0.00012548340756437777),
 ('P138_Q', 0.00011532785441760907),
 ('P207_A', 0.00011039518865872977),
 ('P197_K', 0.00010950928712528276),
 ('P6_E', 0.00010824725552461508),
 ('P43_N', 0.00010378416513680041),
 ('P139_R', 9.9636100065719714e-05),
 ('P3_S', 9.8314835546991982e-05),
 ('P47_V', 9.8145633287732452e-05),
 ('P104_K', 9.2512522314128646e-05),
 ('P207_H', 9.1557056511385934e-05),
 ('P35_R', 8.5729676593072237e-05),
 ('P224_D', 8.5043931640428925e-05),
 ('P224_E', 8.277277434848758e-05),
 ('P163', 8.2429008848087891e-05),
 ('P63', 8.0182907627834556e-05),
 ('P179_E', 7.90508869786849e-05),
 ('P39_K', 7.8996196163410209e-05),
 ('P5_S', 7.8763581406593054e-05),
 ('P123_S', 7.8306273276253831e-05),
 ('P122_P', 7.7440964987469778e-05),
 ('P6_D', 7.6734583529687543e-05),
 ('P165_T', 7.5866324719554409e-05),
 ('P225', 7.5306144414764273e-05),
 ('P85', 7.4260000014479368e-05),
 ('P200_E', 7.4010491294990887e-05),
 ('P195_L', 7.3414017088698944e-05),
 ('P32_T', 7.2079978204739242e-05),
 ('P195_I', 7.1311497389698543e-05),
 ('P197_L', 7.0086870127877187e-05),
 ('P211_Q', 6.8967116529355059e-05),
 ('P224_K', 6.8384307989135959e-05),
 ('P166_T', 6.793157761611871e-05),
 ('P69_S', 6.5746514933650235e-05),
 ('P101_P', 6.5001407055795941e-05),
 ('P132', 6.4888442329120034e-05),
 ('P198', 6.1937193164967217e-05),
 ('P35_K', 6.1207994804328268e-05),
 ('P238_R', 6.0920435106250203e-05),
 ('P223_E', 6.089911496557083e-05),
 ('P215_I', 6.0291401747136104e-05),
 ('P102_R', 6.0072874363770748e-05),
 ('P174_H', 5.8923496579779993e-05),
 ('P122_Q', 5.8823392940486941e-05),
 ('P82', 5.7365951036787005e-05),
 ('P174_L', 5.4329223948507411e-05),
 ('P88_C', 5.3129844946612595e-05),
 ('P43_A', 5.1403624745991304e-05),
 ('P237_N', 5.057886566871371e-05),
 ('P215_L', 5.0287989972268171e-05),
 ('P204_E', 4.895528866502371e-05),
 ('P50', 4.6616079559863745e-05),
 ('P170', 4.5277446250194471e-05),
 ('P64_Y', 4.4686356333231143e-05),
 ('P181_V', 4.4139566915812721e-05),
 ('P39_N', 4.3575327951873241e-05),
 ('P35_L', 4.3280729834555589e-05),
 ('P177_G', 4.2939158480663363e-05),
 ('P69_G', 4.2712037213690829e-05),
 ('P7', 4.2577869169305751e-05),
 ('P3_N', 4.253990404912848e-05),
 ('P3_C', 4.0793962533628899e-05),
 ('P80', 4.0738013528253687e-05),
 ('P195_M', 4.0549062820673448e-05),
 ('P103_R', 4.0535151027103161e-05),
 ('P162_Y', 4.0502921610321552e-05),
 ('P46', 3.938733629927549e-05),
 ('P240_T', 3.922657631975248e-05),
 ('P104_N', 3.8375956322611281e-05),
 ('P174_E', 3.7667347299986625e-05),
 ('P43_R', 3.6967727822854313e-05),
 ('P106_M', 3.6706555739788651e-05),
 ('P75_M', 3.620893936993806e-05),
 ('P240_E', 3.496513401654536e-05),
 ('P135_K', 3.4296684429088164e-05),
 ('P21_V', 3.4283499227464581e-05),
 ('P28_G', 3.1606039987253433e-05),
 ('P211_A', 3.1339152264936688e-05),
 ('P165_I', 3.1264606669730081e-05),
 ('P211_G', 3.0904875189993984e-05),
 ('P88_W', 3.0696233897704535e-05),
 ('P39_I', 3.0225356134305558e-05),
 ('P200_K', 2.96926806919556e-05),
 ('P101_I', 2.860147433881264e-05),
 ('P22', 2.773397225078223e-05),
 ('P158', 2.6961852352201984e-05),
 ('P21_I', 2.6890565857257042e-05),
 ('P210_S', 2.6380376338860684e-05),
 ('P21_A', 2.5903507060762843e-05),
 ('P188_H', 2.5870275556699503e-05),
 ('P34', 2.5354038673889142e-05),
 ('P47_I', 2.5133841977865804e-05),
 ('P69_K', 2.5063630726307757e-05),
 ('P67_~', 2.4946710951120842e-05),
 ('P212', 2.4892262810345553e-05),
 ('P204_D', 2.4279655242071229e-05),
 ('P173_T', 2.4235630046209914e-05),
 ('P39_L', 2.3979867291791524e-05),
 ('P173_S', 2.3888613063677351e-05),
 ('P35_E', 2.3840376582614944e-05),
 ('P177_N', 2.3821915553087907e-05),
 ('P165_A', 2.325313210508823e-05),
 ('P39_S', 2.3180720272361477e-05),
 ('P31_L', 2.258466535340127e-05),
 ('P159_L', 2.2250669326182636e-05),
 ('P228_Q', 2.0646478625955823e-05),
 ('P189', 2.0242772947484979e-05),
 ('P200_S', 1.9167348467636126e-05),
 ('P107', 1.8460370565011147e-05),
 ('P58_G', 1.841829731588697e-05),
 ('P197_P', 1.7876763272087301e-05),
 ('P35_Q', 1.6521125261421695e-05),
 ('P102_M', 1.6337589824524213e-05),
 ('P190_E', 1.6089060311500773e-05),
 ('P122_A', 1.5362032752082303e-05),
 ('P204_N', 1.5026234928358834e-05),
 ('P6_G', 1.499251790325744e-05),
 ('P14', 1.4951797644508774e-05),
 ('P195_T', 1.3342347447179275e-05),
 ('P240_K', 1.3024231131735054e-05),
 ('P96', 1.2531349596052805e-05),
 ('P176_P', 1.2260048801154113e-05),
 ('P75_L', 1.2059981516383563e-05),
 ('P145', 1.1937069946213749e-05),
 ('P20_T', 1.1243103452339689e-05),
 ('P88_G', 1.1148211982685336e-05),
 ('P204_Q', 1.0980630032656169e-05),
 ('P113', 1.0272818990464138e-05),
 ('P66', 1.0040123196221622e-05),
 ('P48_V', 9.889363680492432e-06),
 ('P215_C', 9.6387834023666912e-06),
 ('P9', 9.6176332577349102e-06),
 ('P215_N', 9.2117924818115869e-06),
 ('P40_D', 8.2212612156804012e-06),
 ('P159_I', 8.1435250139559716e-06),
 ('P159_V', 8.0165716589835687e-06),
 ('P228_F', 7.9687525674008383e-06),
 ('P210_F', 7.8842792091922328e-06),
 ('P36', 7.3820368195924407e-06),
 ('P128', 6.6159503856832758e-06),
 ('P127', 6.4648237411951987e-06),
 ('P161', 6.170217236359231e-06),
 ('P200_V', 5.9395016760378583e-06),
 ('P173_E', 5.6872943675120443e-06),
 ('P11_T', 5.5186160983299495e-06),
 ('P70_G', 5.2656507077404917e-06),
 ('P86', 4.9187649220681258e-06),
 ('P101_A', 4.794689076618831e-06),
 ('P176_T', 4.5354871983633148e-06),
 ('P176_S', 4.4820962149087409e-06),
 ('P69_A', 4.4233211338960185e-06),
 ('P33_G', 4.1612184324786338e-06),
 ('P19', 4.0936989626473446e-06),
 ('P227_Y', 4.0702759885375919e-06),
 ('P223_R', 3.9869755081948724e-06),
 ('P200_L', 3.8689378684052706e-06),
 ('P135_R', 3.8263114032032092e-06),
 ('P69_I', 3.8148496017361344e-06),
 ('P53', 3.5117395389113594e-06),
 ('P173_I', 3.3270858084319144e-06),
 ('P123_G', 3.3176536042403781e-06),
 ('P182', 3.1849964653397612e-06),
 ('P195_K', 3.1153300879020901e-06),
 ('P164', 3.1051135760467751e-06),
 ('P8', 3.0457801694138396e-06),
 ('P40_K', 2.9025830231359228e-06),
 ('P45', 2.6725090075384944e-06),
 ('P167', 2.3928120733466434e-06),
 ('P165_L', 2.1256630019541142e-06),
 ('P173_A', 1.9348268249657992e-06),
 ('P60_A', 1.8128465735366147e-06),
 ('P199', 1.7584412182626765e-06),
 ('P144', 1.5109722083712416e-06),
 ('P196_K', 1.4619230053096696e-06),
 ('P240_A', 1.4017245065337739e-06),
 ('P2', 1.3437155057732335e-06),
 ('P35_A', 1.2998876912630738e-06),
 ('P79', 7.0993410033283175e-07),
 ('P32_R', 6.0690878075020209e-07),
 ('P102_E', 5.6410767082531577e-07),
 ('P169_K', 1.7175941527700947e-07),
 ('P236', 1.073097201017753e-07),
 ('P226', 6.0113693215943616e-09),
 ('P4_T', 0.0),
 ('P10_V', 0.0),
 ('P15_G', 0.0),
 ('P18_G', 0.0),
 ('P23_Q', 0.0),
 ('P24_W', 0.0),
 ('P25_P', 0.0),
 ('P29_E', 0.0),
 ('P30_K', 0.0),
 ('P32_N', 0.0),
 ('P38_C', 0.0),
 ('P41_I', 0.0),
 ('P42_E', 0.0),
 ('P47_F', 0.0),
 ('P52_P', 0.0),
 ('P54', 0.0),
 ('P56_Y', 0.0),
 ('P59_P', 0.0),
 ('P64_N', 0.0),
 ('P68_K', 0.0),
 ('P70_E', 0.0),
 ('P72_R', 0.0),
 ('P75_A', 0.0),
 ('P84_T', 0.0),
 ('P87_F', 0.0),
 ('P91_Q', 0.0),
 ('P93_G', 0.0),
 ('P94_I', 0.0),
 ('P95_P', 0.0),
 ('P101_S', 0.0),
 ('P105_S', 0.0),
 ('P110_D', 0.0),
 ('P117_S', 0.0),
 ('P119_P', 0.0),
 ('P122_R', 0.0),
 ('P124_F', 0.0),
 ('P125_R', 0.0),
 ('P126_K', 0.0),
 ('P130_F', 0.0),
 ('P131_T', 0.0),
 ('P134_S', 0.0),
 ('P136_N', 0.0),
 ('P137_N', 0.0),
 ('P139_I', 0.0),
 ('P140_P', 0.0),
 ('P141_G', 0.0),
 ('P143_R', 0.0),
 ('P147_N', 0.0),
 ('P149_L', 0.0),
 ('P154_K', 0.0),
 ('P162_H', 0.0),
 ('P162_N', 0.0),
 ('P171', 0.0),
 ('P175_N', 0.0),
 ('P179_L', 0.0),
 ('P187_L', 0.0),
 ('P188_C', 0.0),
 ('P188_F', 0.0),
 ('P191', 0.0),
 ('P192', 0.0),
 ('P193_L', 0.0),
 ('P194_E', 0.0),
 ('P197_I', 0.0),
 ('P200_F', 0.0),
 ('P201_K', 0.0),
 ('P205_L', 0.0),
 ('P206_R', 0.0),
 ('P213_G', 0.0),
 ('P215_E', 0.0),
 ('P216_T', 0.0),
 ('P217_P', 0.0),
 ('P219_D', 0.0),
 ('P219_T', 0.0),
 ('P222_Q', 0.0),
 ('P232_Y', 0.0),
 ('P233_E', 0.0),
 ('P234', 0.0)]

In [158]:
# # Here, let's try a parameter grid search, to figure out what would be the best 
# from sklearn.grid_search import GridSearchCV
# import numpy as np

# param_grid = [{'n_estimators':[100, 500, 1000],
#                #'max_features':['auto', 'sqrt', 'log2'],
#                #'min_samples_leaf':np.arange(1,20,1),
#               }]

# x_train, x_test, y_train, y_test = train_test_split(fpv_X_bi, fpv_Y)


# rfr_gs = GridSearchCV(RandomForestRegressor(), param_grid=param_grid, n_jobs=-1)
# rfr_gs.fit(x_train, y_train)
# print(rfr_gs.best_estimator_)
# print(rfr_gs.best_params_)

In [159]:
# Try Bayesian Ridge Regression
# x_train, x_test, y_train, y_test = train_test_split(drug_X_bi, drug_Y)

brr = lm.BayesianRidge()
brr.fit(x_train, y_train)
brr_preds = brr.predict(x_test)
print(brr.score(x_test, y_test), mean_squared_error(brr_preds, y_test))
print(sps.pearsonr(brr_preds, y_test))
brr_mse = mean_squared_error(brr_preds, y_test)

plt.figure(figsize=(3,3))
plt.scatter(y_test, brr_preds)
plt.xlabel('Actual')
plt.ylabel('Predicted')
plt.title('{0} Bayesian Ridge'.format(drug_cols[colnum]))
plt.gca().set_aspect('equal', 'datalim')
plt.annotate(s='mse: {0}'.format(str(Decimal(brr_mse).quantize(TWOPLACES))), xy=(1,0), xycoords='axes fraction', ha='right', va='bottom')
plt.plot(x_equals_y(y_test), x_equals_y(y_test), color='red')
plt.show()


0.77269618709 0.140905268087
(0.88401111341649008, 4.0613783272294827e-39)

In [160]:
# Try ARD regression
ardr = lm.ARDRegression()
ardr.fit(x_train, y_train)
ardr_preds = ardr.predict(x_test)
ardr_mse = mean_squared_error(ardr_preds, y_test)

plt.figure(figsize=(3,3))
plt.scatter(y_test, ardr_preds)
plt.xlabel('Actual')
plt.ylabel('Predicted')
plt.title('{0} ARD Regression'.format(drug_cols[colnum]))
plt.gca().set_aspect('equal', 'datalim')
plt.annotate(s='mse: {0}'.format(str(Decimal(ardr_mse).quantize(TWOPLACES))), xy=(1,0), xycoords='axes fraction', ha='right', va='bottom')
plt.plot(x_equals_y(y_test), x_equals_y(y_test), color='red')
plt.show()



In [161]:
# Try Gradient Boost
# x_train, x_test, y_train, y_test = train_test_split(drug_X_bi, drug_Y)

gbr = GradientBoostingRegressor()
gbr.fit(x_train, y_train)
gbr_preds = gbr.predict(x_test)
print(gbr.score(x_test, y_test), mean_squared_error(gbr_preds, y_test))
print(sps.pearsonr(gbr_preds, y_test))
gbr_mse = mean_squared_error(gbr_preds, y_test)

plt.figure(figsize=(3,3))
plt.scatter(y_test, gbr_preds)
plt.xlabel('Actual')
plt.ylabel('Predicted')
plt.title('{0} Grad. Boost'.format(drug_cols[colnum]))
plt.gca().set_aspect('equal', 'datalim')
plt.annotate(s='mse: {0}'.format(str(Decimal(gbr_mse).quantize(TWOPLACES))), xy=(1,0), xycoords='axes fraction', ha='right', va='bottom')
plt.plot(x_equals_y(y_test), x_equals_y(y_test), color='red')
plt.show()


0.800344680916 0.123766011227
(0.8991383071394452, 2.3372440847266272e-42)

In [162]:
plt.bar(range(len(gbr.feature_importances_)), gbr.feature_importances_)


Out[162]:
<Container object of 482 artists>

In [163]:
# Try AdaBoost
# x_train, x_test, y_train, y_test = train_test_split(drug_X_bi, drug_Y)

abr = AdaBoostRegressor()
abr.fit(x_train, y_train)
abr_preds = abr.predict(x_test)
print(abr.score(x_test, y_test), mean_squared_error(abr_preds, y_test))
print(sps.pearsonr(abr_preds, y_test))
abr_mse = mean_squared_error(abr_preds, y_test)

plt.figure(figsize=(3,3))
plt.scatter(x=y_test, y=abr_preds)
plt.xlabel('Actual')
plt.ylabel('Predicted')
plt.title('{0} AdaBoost'.format(drug_cols[colnum]))
plt.gca().set_aspect('equal', 'datalim')
plt.annotate(s='mse: {0}'.format(str(Decimal(abr_mse).quantize(TWOPLACES))), xy=(1,0), xycoords='axes fraction', ha='right', va='bottom')
plt.plot(x_equals_y(y_test), x_equals_y(y_test), color='red')
plt.show()


0.701945703317 0.184763379201
(0.87256175470337738, 5.9504271811245195e-37)

In [164]:
plt.bar(range(len(abr.feature_importances_)), abr.feature_importances_)


Out[164]:
<Container object of 482 artists>

In [165]:
# Try support vector regression
svr = SVR()
svr.fit(x_train, y_train)
svr_preds = svr.predict(x_test)

svr_mse = mean_squared_error(svr_preds, y_test)

plt.figure(figsize=(3,3))
plt.scatter(y_test, svr_preds, )
plt.xlabel('Actual')
plt.ylabel('Predicted')
plt.title('{0} SVR'.format(drug_cols[colnum]))
plt.gca().set_aspect('equal', 'datalim')
plt.annotate(s='mse: {0}'.format(str(Decimal(svr_mse).quantize(TWOPLACES))), xy=(1,0), xycoords='axes fraction', ha='right', va='bottom')
plt.plot(x_equals_y(y_test), x_equals_y(y_test), color='red')
plt.show()



In [166]:
# Neural Network 1 Specification: Feed Forward ANN with 1 hidden layer.
# x_train, x_test, y_train, y_test = train_test_split(drug_X_bi, drug_Y)

x_train = x_train.astype(np.float32)
y_train = y_train.astype(np.float32)
x_test = x_test.astype(np.float32)
y_test = y_test.astype(np.float32)

net1 = NeuralNet(
    layers=[  # three layers: one hidden layer
        ('input', layers.InputLayer),
        ('hidden1', layers.DenseLayer),
        ('dropout1', layers.DropoutLayer),
        #('hidden2', layers.DenseLayer),
        #('dropout2', layers.DropoutLayer),
        ('nonlinear', layers.NonlinearityLayer),
        ('output', layers.DenseLayer),
        ],
    # layer parameters:
    input_shape=(None, x_train.shape[1]),  # 
    hidden1_num_units=math.ceil(x_train.shape[1] / 2),  # number of units in hidden layer
    hidden1_nonlinearity=nonlinearities.tanh,
    dropout1_p = 0.5,
    #hidden2_num_units=math.ceil(x_train.shape[1] / 2),
    #dropout2_p = 0.5,
    output_nonlinearity=None,  # output layer uses identity function
    output_num_units=1,  # 30 target values
    
    # optimization method:
    update=nesterov_momentum,
    update_learning_rate=0.01,
    update_momentum=0.95,

    regression=True,  # flag to indicate we're dealing with regression problem
    max_epochs=500,  # we want to train this many epochs
    verbose=1,
    )
net1.fit(x_train.values, y_train.values)


# Neural Network with 116645 learnable parameters

## Layer information

  #  name         size
---  ---------  ------
  0  input         482
  1  hidden1       241
  2  dropout1      241
  3  nonlinear     241
  4  output          1

  epoch    train loss    valid loss    train/val  dur
-------  ------------  ------------  -----------  -----
      1       1.19743       1.68213      0.71185  0.01s
      2       1.95433       0.72601      2.69189  0.01s
      3       0.94574       0.30878      3.06285  0.01s
      4       0.66858       0.31704      2.10884  0.01s
      5       0.53894       0.26701      2.01844  0.01s
      6       0.50842       0.25242      2.01423  0.01s
      7       0.40975       0.24942      1.64282  0.01s
      8       0.50461       0.23500      2.14729  0.01s
      9       0.31233       0.32754      0.95357  0.01s
     10       0.38249       0.29446      1.29897  0.01s
     11       0.29550       0.25208      1.17227  0.01s
     12       0.24624       0.22847      1.07778  0.01s
     13       0.23960       0.21291      1.12538  0.01s
     14       0.20809       0.21113      0.98563  0.01s
     15       0.26575       0.24637      1.07865  0.01s
     16       0.24770       0.25892      0.95665  0.01s
     17       0.24705       0.21733      1.13673  0.01s
     18       0.21937       0.19826      1.10649  0.01s
     19       0.16858       0.20001      0.84284  0.01s
     20       0.22139       0.19664      1.12585  0.01s
     21       0.20173       0.19444      1.03745  0.01s
     22       0.30861       0.19095      1.61618  0.01s
     23       0.17062       0.19223      0.88759  0.01s
     24       0.22889       0.18154      1.26077  0.01s
     25       0.15896       0.20431      0.77802  0.01s
     26       0.16305       0.18917      0.86192  0.01s
     27       0.17398       0.19096      0.91108  0.01s
     28       0.15817       0.17914      0.88294  0.01s
     29       0.15098       0.18068      0.83563  0.01s
     30       0.12210       0.20242      0.60322  0.01s
     31       0.15813       0.20358      0.77673  0.01s
     32       0.15006       0.19258      0.77922  0.01s
     33       0.13211       0.20734      0.63715  0.01s
     34       0.17773       0.18803      0.94521  0.01s
     35       0.12933       0.17848      0.72459  0.01s
     36       0.14602       0.18083      0.80747  0.01s
     37       0.13874       0.18274      0.75920  0.01s
     38       0.11998       0.17328      0.69239  0.01s
     39       0.12378       0.16272      0.76070  0.01s
     40       0.14595       0.16904      0.86339  0.01s
     41       0.15521       0.16610      0.93445  0.01s
     42       0.13233       0.17003      0.77824  0.01s
     43       0.16456       0.18901      0.87061  0.01s
     44       0.32841       0.23504      1.39725  0.01s
     45       0.20293       0.19713      1.02946  0.01s
     46       0.15822       0.17463      0.90599  0.01s
     47       0.19039       0.15962      1.19283  0.01s
     48       0.14678       0.14900      0.98507  0.01s
     49       0.24509       0.15656      1.56549  0.01s
     50       0.15942       0.17256      0.92382  0.01s
     51       0.11750       0.20041      0.58629  0.01s
     52       0.11213       0.17061      0.65726  0.01s
     53       0.12890       0.14601      0.88285  0.01s
     54       0.11421       0.14886      0.76725  0.01s
     55       0.15005       0.16531      0.90768  0.01s
     56       0.19297       0.16276      1.18561  0.01s
     57       0.11098       0.16149      0.68726  0.01s
     58       0.11118       0.15150      0.73384  0.01s
     59       0.19944       0.18011      1.10728  0.01s
     60       0.11480       0.20267      0.56644  0.01s
     61       0.13112       0.18593      0.70524  0.01s
     62       0.10427       0.16075      0.64865  0.01s
     63       0.11867       0.15690      0.75631  0.01s
     64       0.08913       0.16179      0.55090  0.01s
     65       0.12996       0.15242      0.85270  0.01s
     66       0.09034       0.14482      0.62380  0.01s
     67       0.09174       0.14283      0.64229  0.01s
     68       0.11042       0.16092      0.68613  0.01s
     69       0.07618       0.16674      0.45686  0.01s
     70       0.07536       0.15406      0.48917  0.01s
     71       0.09851       0.14473      0.68067  0.01s
     72       0.10116       0.14914      0.67834  0.01s
     73       0.08253       0.14524      0.56819  0.01s
     74       0.11012       0.14710      0.74863  0.01s
     75       0.09852       0.13707      0.71875  0.01s
     76       0.09305       0.13483      0.69012  0.01s
     77       0.10088       0.16733      0.60284  0.01s
     78       0.09275       0.18086      0.51285  0.01s
     79       0.08869       0.17111      0.51831  0.01s
     80       0.16047       0.15100      1.06270  0.01s
     81       0.11934       0.14352      0.83149  0.01s
     82       0.19195       0.16509      1.16270  0.01s
     83       0.10747       0.14800      0.72615  0.01s
     84       0.06586       0.13695      0.48092  0.01s
     85       0.11371       0.14434      0.78779  0.01s
     86       0.08070       0.14007      0.57610  0.01s
     87       0.10536       0.12969      0.81239  0.01s
     88       0.07028       0.14674      0.47893  0.01s
     89       0.18234       0.14641      1.24542  0.01s
     90       0.08129       0.14214      0.57190  0.01s
     91       0.11248       0.15048      0.74750  0.01s
     92       0.08253       0.14805      0.55746  0.01s
     93       0.09572       0.14202      0.67397  0.01s
     94       0.19924       0.13006      1.53194  0.01s
     95       0.11625       0.12150      0.95679  0.01s
     96       0.06974       0.12586      0.55412  0.01s
     97       0.09992       0.14316      0.69797  0.01s
     98       0.11114       0.13600      0.81719  0.01s
     99       0.08876       0.13129      0.67604  0.01s
    100       0.13738       0.12904      1.06461  0.01s
    101       0.10018       0.12606      0.79473  0.01s
    102       0.12242       0.13659      0.89629  0.01s
    103       0.10089       0.13986      0.72137  0.01s
    104       0.09595       0.13154      0.72948  0.01s
    105       0.12665       0.12320      1.02796  0.01s
    106       0.12632       0.11770      1.07331  0.01s
    107       0.09772       0.12210      0.80030  0.01s
    108       0.10360       0.12519      0.82754  0.01s
    109       0.13190       0.12772      1.03268  0.01s
    110       0.08357       0.12930      0.64634  0.01s
    111       0.08338       0.14676      0.56812  0.01s
    112       0.18342       0.13090      1.40130  0.01s
    113       0.08833       0.12613      0.70033  0.01s
    114       0.10523       0.16371      0.64278  0.01s
    115       0.28506       0.12587      2.26470  0.01s
    116       0.08084       0.12635      0.63982  0.01s
    117       0.12764       0.12970      0.98409  0.01s
    118       0.07287       0.13624      0.53482  0.01s
    119       0.17367       0.13381      1.29789  0.01s
    120       0.12418       0.12860      0.96563  0.01s
    121       0.07062       0.10958      0.64443  0.01s
    122       0.14090       0.11285      1.24861  0.01s
    123       0.15817       0.14505      1.09049  0.01s
    124       0.09398       0.12590      0.74649  0.01s
    125       0.11679       0.11381      1.02621  0.01s
    126       0.12874       0.13595      0.94694  0.01s
    127       0.06756       0.13883      0.48666  0.01s
    128       0.15036       0.14205      1.05849  0.01s
    129       0.09346       0.12188      0.76682  0.01s
    130       0.11129       0.11641      0.95598  0.01s
    131       0.10106       0.11543      0.87547  0.01s
    132       0.08539       0.12200      0.69989  0.01s
    133       0.05179       0.12845      0.40314  0.01s
    134       0.11935       0.13800      0.86488  0.01s
    135       0.09000       0.14137      0.63662  0.01s
    136       0.11975       0.15208      0.78741  0.01s
    137       0.08219       0.13515      0.60816  0.01s
    138       0.06151       0.13946      0.44104  0.01s
    139       0.08918       0.14164      0.62963  0.01s
    140       0.08815       0.13352      0.66018  0.01s
    141       0.10486       0.12188      0.86042  0.01s
    142       0.08883       0.13003      0.68310  0.01s
    143       0.06719       0.13224      0.50808  0.01s
    144       0.08657       0.12857      0.67328  0.01s
    145       0.08456       0.11816      0.71562  0.01s
    146       0.12163       0.11974      1.01582  0.01s
    147       0.05964       0.12404      0.48080  0.01s
    148       0.12442       0.10907      1.14071  0.01s
    149       0.07411       0.11275      0.65728  0.01s
    150       0.07070       0.11459      0.61698  0.01s
    151       0.08252       0.13039      0.63285  0.01s
    152       0.22600       0.10907      2.07195  0.01s
    153       0.08475       0.10037      0.84439  0.01s
    154       0.13983       0.10329      1.35380  0.01s
    155       0.21503       0.11749      1.83023  0.01s
    156       0.09434       0.14672      0.64297  0.01s
    157       0.08061       0.13848      0.58210  0.01s
    158       0.11020       0.14070      0.78323  0.01s
    159       0.07640       0.15349      0.49777  0.01s
    160       0.10059       0.14362      0.70033  0.01s
    161       0.17161       0.13071      1.31297  0.01s
    162       0.09409       0.12029      0.78218  0.01s
    163       0.15377       0.11169      1.37676  0.01s
    164       0.13956       0.10527      1.32576  0.01s
    165       0.08953       0.11600      0.77181  0.01s
    166       0.14288       0.12357      1.15634  0.01s
    167       0.12022       0.11994      1.00238  0.01s
    168       0.07633       0.12713      0.60041  0.01s
    169       0.09462       0.12695      0.74531  0.01s
    170       0.20254       0.12086      1.67580  0.01s
    171       0.15304       0.11058      1.38389  0.01s
    172       0.11511       0.11520      0.99919  0.01s
    173       0.07404       0.12196      0.60705  0.01s
    174       0.16150       0.15179      1.06398  0.01s
    175       0.08393       0.14864      0.56463  0.01s
    176       0.11697       0.13070      0.89489  0.01s
    177       0.16888       0.11628      1.45238  0.01s
    178       0.08649       0.11871      0.72856  0.01s
    179       0.10527       0.13650      0.77119  0.01s
    180       0.13215       0.15267      0.86561  0.01s
    181       0.11524       0.14075      0.81879  0.01s
    182       0.12995       0.12186      1.06637  0.01s
    183       0.13974       0.11200      1.24773  0.01s
    184       0.11034       0.10854      1.01658  0.01s
    185       0.08954       0.12519      0.71520  0.01s
    186       0.09614       0.13197      0.72849  0.01s
    187       0.11542       0.13394      0.86171  0.01s
    188       0.18394       0.12831      1.43355  0.01s
    189       0.09980       0.13187      0.75678  0.01s
    190       0.08412       0.12378      0.67955  0.01s
    191       0.10366       0.11898      0.87120  0.01s
    192       0.11733       0.11964      0.98062  0.01s
    193       0.21819       0.11653      1.87243  0.01s
    194       0.08570       0.11871      0.72198  0.01s
    195       0.27932       0.11683      2.39073  0.01s
    196       0.12238       0.11941      1.02483  0.01s
    197       0.10302       0.12086      0.85238  0.01s
    198       0.23725       0.12547      1.89097  0.01s
    199       0.11024       0.11744      0.93872  0.01s
    200       0.10038       0.11068      0.90688  0.01s
    201       0.09320       0.11110      0.83889  0.01s
    202       0.16984       0.12371      1.37288  0.01s
    203       0.13055       0.11477      1.13744  0.01s
    204       0.11040       0.11653      0.94740  0.01s
    205       0.14614       0.10772      1.35668  0.01s
    206       0.13186       0.10909      1.20868  0.01s
    207       0.09123       0.11366      0.80265  0.01s
    208       0.08160       0.13464      0.60607  0.01s
    209       0.08654       0.13987      0.61867  0.01s
    210       0.08025       0.13004      0.61715  0.01s
    211       0.09097       0.12227      0.74403  0.01s
    212       0.12589       0.13893      0.90616  0.01s
    213       0.14856       0.14380      1.03314  0.01s
    214       0.12276       0.12380      0.99156  0.01s
    215       0.11822       0.11806      1.00135  0.01s
    216       0.10319       0.11564      0.89234  0.01s
    217       0.11402       0.11674      0.97668  0.01s
    218       0.27109       0.11497      2.35795  0.01s
    219       0.09970       0.12409      0.80348  0.01s
    220       0.14090       0.11524      1.22273  0.01s
    221       0.12857       0.12957      0.99222  0.01s
    222       0.07628       0.12505      0.61001  0.01s
    223       0.17852       0.11901      1.49997  0.01s
    224       0.11503       0.10621      1.08311  0.01s
    225       0.19264       0.11913      1.61707  0.01s
    226       0.07438       0.13171      0.56471  0.01s
    227       0.16558       0.13311      1.24386  0.01s
    228       0.11467       0.13764      0.83310  0.01s
    229       0.09415       0.12420      0.75808  0.01s
    230       0.22680       0.12710      1.78445  0.01s
    231       0.34571       0.11241      3.07541  0.01s
    232       0.10193       0.11659      0.87422  0.01s
    233       0.09103       0.11723      0.77646  0.01s
    234       0.10405       0.12078      0.86146  0.01s
    235       0.10594       0.11446      0.92553  0.01s
    236       0.18869       0.11008      1.71401  0.01s
    237       0.11835       0.12989      0.91113  0.01s
    238       0.13303       0.12998      1.02346  0.01s
    239       0.10228       0.11656      0.87746  0.01s
    240       0.12744       0.11222      1.13569  0.01s
    241       0.12617       0.10443      1.20820  0.01s
    242       0.11751       0.11067      1.06178  0.01s
    243       0.11329       0.11561      0.97993  0.01s
    244       0.07801       0.11557      0.67497  0.01s
    245       0.07290       0.11985      0.60827  0.01s
    246       0.09177       0.12275      0.74764  0.01s
    247       0.16516       0.11076      1.49116  0.01s
    248       0.14838       0.10435      1.42195  0.01s
    249       0.14063       0.12676      1.10937  0.01s
    250       0.08579       0.12798      0.67035  0.01s
    251       0.10262       0.12816      0.80075  0.01s
    252       0.22056       0.12011      1.83630  0.01s
    253       0.11093       0.11132      0.99655  0.01s
    254       0.08826       0.11225      0.78626  0.01s
    255       0.12252       0.11137      1.10011  0.01s
    256       0.46803       0.10848      4.31445  0.01s
    257       0.14076       0.10959      1.28442  0.01s
    258       0.14253       0.10520      1.35483  0.01s
    259       0.11187       0.11667      0.95887  0.01s
    260       0.20481       0.10591      1.93380  0.01s
    261       0.10244       0.10637      0.96308  0.01s
    262       0.16964       0.10803      1.57036  0.01s
    263       0.29071       0.11017      2.63867  0.01s
    264       0.07704       0.11463      0.67203  0.01s
    265       0.06943       0.12622      0.55009  0.01s
    266       0.07959       0.12558      0.63378  0.01s
    267       0.20754       0.12320      1.68458  0.01s
    268       0.12011       0.11356      1.05770  0.01s
    269       0.12476       0.11536      1.08154  0.01s
    270       0.17503       0.11010      1.58978  0.01s
    271       0.10125       0.10330      0.98007  0.01s
    272       0.14162       0.11844      1.19563  0.01s
    273       0.16627       0.11906      1.39654  0.01s
    274       0.12680       0.11905      1.06513  0.01s
    275       0.12465       0.11530      1.08104  0.01s
    276       0.11628       0.10806      1.07604  0.01s
    277       0.07748       0.10359      0.74797  0.01s
    278       0.15013       0.10246      1.46517  0.01s
    279       0.10133       0.10362      0.97792  0.01s
    280       0.11667       0.11448      1.01914  0.01s
    281       0.09241       0.11514      0.80261  0.01s
    282       0.24398       0.11741      2.07809  0.01s
    283       0.11680       0.11378      1.02655  0.01s
    284       0.08042       0.11638      0.69101  0.01s
    285       0.14535       0.10724      1.35532  0.01s
    286       0.24521       0.10983      2.23260  0.01s
    287       0.09396       0.11252      0.83508  0.01s
    288       0.08388       0.12504      0.67082  0.01s
    289       0.09342       0.13650      0.68435  0.01s
    290       0.08754       0.12773      0.68532  0.01s
    291       0.18065       0.12100      1.49297  0.01s
    292       0.09112       0.11719      0.77754  0.01s
    293       0.12370       0.11364      1.08859  0.01s
    294       0.09218       0.10652      0.86536  0.01s
    295       0.10706       0.10734      0.99744  0.01s
    296       0.07818       0.11892      0.65741  0.01s
    297       0.12493       0.12213      1.02295  0.01s
    298       0.24867       0.12240      2.03163  0.01s
    299       0.24085       0.12847      1.87484  0.01s
    300       0.22015       0.12969      1.69757  0.01s
    301       0.27346       0.12789      2.13818  0.01s
    302       0.12892       0.12044      1.07040  0.01s
    303       0.10835       0.11448      0.94647  0.01s
    304       0.12714       0.11780      1.07929  0.01s
    305       0.17920       0.11530      1.55419  0.01s
    306       0.11329       0.12292      0.92161  0.01s
    307       0.07553       0.11891      0.63522  0.01s
    308       0.13481       0.11160      1.20799  0.01s
    309       0.10359       0.10694      0.96872  0.01s
    310       0.07985       0.10992      0.72646  0.01s
    311       0.11029       0.10577      1.04277  0.01s
    312       0.16066       0.10568      1.52025  0.01s
    313       0.11485       0.10818      1.06166  0.01s
    314       0.07152       0.11596      0.61681  0.01s
    315       0.18642       0.12340      1.51066  0.01s
    316       0.25431       0.11513      2.20886  0.01s
    317       0.07030       0.11237      0.62558  0.01s
    318       0.08992       0.12206      0.73670  0.01s
    319       0.10449       0.12094      0.86394  0.01s
    320       0.11660       0.10877      1.07195  0.01s
    321       0.08406       0.10463      0.80346  0.01s
    322       0.08647       0.11436      0.75617  0.01s
    323       0.12884       0.11228      1.14749  0.01s
    324       0.08573       0.10950      0.78296  0.01s
    325       0.08238       0.11099      0.74220  0.01s
    326       0.07104       0.10458      0.67925  0.01s
    327       0.13142       0.10535      1.24749  0.01s
    328       0.12778       0.11225      1.13832  0.01s
    329       0.16600       0.11924      1.39215  0.01s
    330       0.11120       0.12677      0.87719  0.01s
    331       0.10203       0.12839      0.79470  0.01s
    332       0.22979       0.11863      1.93708  0.01s
    333       0.19929       0.10867      1.83399  0.01s
    334       0.06854       0.10215      0.67095  0.01s
    335       0.10265       0.10456      0.98169  0.01s
    336       0.14792       0.10731      1.37846  0.01s
    337       0.20062       0.12097      1.65844  0.01s
    338       0.09548       0.12428      0.76824  0.01s
    339       0.11293       0.11852      0.95284  0.01s
    340       0.13741       0.11049      1.24371  0.01s
    341       0.10877       0.11369      0.95672  0.01s
    342       0.10017       0.11329      0.88421  0.01s
    343       0.12561       0.11486      1.09352  0.01s
    344       0.16905       0.11363      1.48779  0.01s
    345       0.15755       0.10594      1.48713  0.01s
    346       0.09806       0.10779      0.90980  0.01s
    347       0.16437       0.11106      1.47998  0.01s
    348       0.11702       0.11615      1.00744  0.01s
    349       0.31535       0.12388      2.54559  0.01s
    350       0.08488       0.12407      0.68414  0.01s
    351       0.10687       0.11902      0.89794  0.01s
    352       0.09052       0.11557      0.78328  0.01s
    353       0.09330       0.11175      0.83487  0.01s
    354       0.27415       0.10993      2.49389  0.01s
    355       0.12707       0.12385      1.02607  0.01s
    356       0.19793       0.12569      1.57478  0.01s
    357       0.34713       0.12064      2.87739  0.01s
    358       0.10644       0.12088      0.88053  0.01s
    359       0.09510       0.10983      0.86587  0.01s
    360       0.16441       0.10088      1.62978  0.01s
    361       0.11593       0.11305      1.02554  0.01s
    362       0.11494       0.12057      0.95329  0.01s
    363       0.12216       0.11234      1.08741  0.01s
    364       0.13576       0.12015      1.12992  0.01s
    365       0.14869       0.11554      1.28688  0.01s
    366       0.10738       0.12143      0.88432  0.01s
    367       0.16793       0.10595      1.58495  0.01s
    368       0.10210       0.10116      1.00929  0.01s
    369       0.08089       0.10337      0.78250  0.01s
    370       0.11187       0.10552      1.06021  0.01s
    371       0.14882       0.10535      1.41259  0.01s
    372       0.09565       0.10372      0.92219  0.01s
    373       0.11587       0.10403      1.11386  0.01s
    374       0.12957       0.10684      1.21273  0.01s
    375       0.12155       0.11439      1.06260  0.01s
    376       0.13024       0.11501      1.13242  0.01s
    377       0.10404       0.10944      0.95068  0.01s
    378       0.09425       0.10831      0.87016  0.01s
    379       0.12229       0.11896      1.02794  0.01s
    380       0.27788       0.11709      2.37327  0.01s
    381       0.13730       0.10770      1.27481  0.01s
    382       0.06769       0.10408      0.65034  0.01s
    383       0.09990       0.10043      0.99476  0.01s
    384       0.12530       0.10382      1.20684  0.01s
    385       0.21650       0.11164      1.93926  0.01s
    386       0.10334       0.12068      0.85628  0.01s
    387       0.15810       0.11898      1.32879  0.01s
    388       0.18800       0.10597      1.77414  0.01s
    389       0.11202       0.11016      1.01683  0.01s
    390       0.10055       0.11460      0.87738  0.01s
    391       0.09213       0.11409      0.80750  0.01s
    392       0.08892       0.11820      0.75232  0.01s
    393       0.10874       0.11569      0.93992  0.01s
    394       0.16282       0.11167      1.45812  0.01s
    395       0.12754       0.11284      1.13032  0.01s
    396       0.13915       0.12221      1.13862  0.01s
    397       0.24500       0.12010      2.03998  0.01s
    398       0.20603       0.11635      1.77081  0.01s
    399       0.16430       0.11840      1.38771  0.01s
    400       0.09977       0.11417      0.87387  0.01s
    401       0.12946       0.11064      1.17013  0.01s
    402       0.23283       0.10917      2.13271  0.01s
    403       0.16833       0.11238      1.49794  0.01s
    404       0.11248       0.10762      1.04519  0.01s
    405       0.12317       0.10358      1.18914  0.01s
    406       0.33181       0.10045      3.30333  0.01s
    407       0.22352       0.10111      2.21060  0.01s
    408       0.20221       0.10430      1.93880  0.01s
    409       0.24501       0.10388      2.35846  0.01s
    410       0.11376       0.10717      1.06150  0.01s
    411       0.09118       0.11045      0.82554  0.01s
    412       0.17486       0.12949      1.35032  0.01s
    413       0.22003       0.11184      1.96732  0.01s
    414       0.12543       0.11305      1.10949  0.01s
    415       0.09908       0.11581      0.85553  0.01s
    416       0.07564       0.11651      0.64925  0.01s
    417       0.11281       0.10889      1.03600  0.01s
    418       0.09717       0.11091      0.87610  0.01s
    419       0.14669       0.11229      1.30634  0.01s
    420       0.28603       0.11228      2.54747  0.01s
    421       0.09208       0.11689      0.78778  0.01s
    422       0.11471       0.11832      0.96951  0.01s
    423       0.13615       0.10567      1.28846  0.01s
    424       0.10088       0.10536      0.95740  0.01s
    425       0.12328       0.10732      1.14873  0.01s
    426       0.22350       0.12054      1.85416  0.01s
    427       0.09210       0.11415      0.80685  0.01s
    428       0.15362       0.11322      1.35680  0.01s
    429       0.11331       0.10815      1.04770  0.01s
    430       0.12148       0.10818      1.12301  0.01s
    431       0.09257       0.10700      0.86507  0.01s
    432       0.09705       0.10703      0.90672  0.01s
    433       0.07938       0.10405      0.76288  0.01s
    434       0.22082       0.10631      2.07709  0.01s
    435       0.10758       0.10861      0.99053  0.01s
    436       0.10711       0.10276      1.04230  0.01s
    437       0.14319       0.10121      1.41477  0.01s
    438       0.43683       0.09328      4.68290  0.01s
    439       0.11782       0.10140      1.16196  0.01s
    440       0.21246       0.11188      1.89904  0.01s
    441       0.19481       0.11114      1.75289  0.01s
    442       0.06085       0.10605      0.57379  0.01s
    443       0.12157       0.10399      1.16899  0.01s
    444       0.13614       0.09945      1.36895  0.01s
    445       0.13134       0.10366      1.26697  0.01s
    446       0.15359       0.10963      1.40101  0.01s
    447       0.09094       0.10463      0.86914  0.01s
    448       0.09409       0.12492      0.75320  0.01s
    449       0.12303       0.11162      1.10221  0.01s
    450       0.14937       0.10294      1.45100  0.01s
    451       0.10203       0.09828      1.03812  0.01s
    452       0.13967       0.09738      1.43420  0.01s
    453       0.14812       0.10017      1.47873  0.01s
    454       0.11346       0.11001      1.03138  0.01s
    455       0.14265       0.11476      1.24295  0.01s
    456       0.12859       0.11210      1.14717  0.01s
    457       0.11018       0.11003      1.00134  0.01s
    458       0.13159       0.10225      1.28691  0.01s
    459       0.09561       0.10233      0.93435  0.01s
    460       0.09103       0.10507      0.86632  0.01s
    461       0.11055       0.09492      1.16473  0.01s
    462       0.12157       0.09544      1.27374  0.01s
    463       0.07093       0.10482      0.67671  0.01s
    464       0.18711       0.11349      1.64871  0.01s
    465       0.14038       0.10664      1.31639  0.01s
    466       0.12566       0.10897      1.15320  0.01s
    467       0.19319       0.11205      1.72406  0.01s
    468       0.09147       0.10147      0.90147  0.01s
    469       0.14029       0.11045      1.27019  0.01s
    470       0.22488       0.11149      2.01710  0.01s
    471       0.10785       0.10846      0.99440  0.01s
    472       0.10929       0.10316      1.05941  0.01s
    473       0.17153       0.10184      1.68436  0.01s
    474       0.14045       0.10790      1.30172  0.01s
    475       0.19863       0.12198      1.62838  0.01s
    476       0.13273       0.11623      1.14202  0.01s
    477       0.23120       0.10988      2.10402  0.01s
    478       0.12489       0.10002      1.24865  0.01s
    479       0.17528       0.09908      1.76904  0.01s
    480       0.07944       0.10917      0.72769  0.01s
    481       0.09229       0.11801      0.78204  0.01s
    482       0.08539       0.11923      0.71624  0.01s
    483       0.12025       0.11856      1.01427  0.01s
    484       0.08485       0.11523      0.73631  0.01s
    485       0.29845       0.10880      2.74312  0.01s
    486       0.10727       0.10669      1.00549  0.01s
    487       0.11725       0.11029      1.06315  0.01s
    488       0.13444       0.11273      1.19267  0.01s
    489       0.10725       0.12235      0.87660  0.01s
    490       0.09939       0.11961      0.83092  0.01s
    491       0.06879       0.11396      0.60364  0.01s
    492       0.10803       0.11315      0.95476  0.01s
    493       0.11545       0.11353      1.01689  0.01s
    494       0.12147       0.11532      1.05333  0.01s
    495       0.18089       0.12079      1.49757  0.01s
    496       0.23218       0.11378      2.04061  0.01s
    497       0.09412       0.11941      0.78820  0.01s
    498       0.10273       0.10618      0.96755  0.01s
    499       0.14538       0.11019      1.31939  0.01s
    500       0.13773       0.11806      1.16664  0.01s
Out[166]:
NeuralNet(X_tensor_type=None,
     batch_iterator_test=<nolearn.lasagne.base.BatchIterator object at 0x7fa67253e358>,
     batch_iterator_train=<nolearn.lasagne.base.BatchIterator object at 0x7fa67253e550>,
     custom_score=None, dropout1_p=0.5,
     hidden1_nonlinearity=<function tanh at 0x7fa62838ff28>,
     hidden1_num_units=241, input_shape=(None, 482),
     layers=[('input', <class 'lasagne.layers.input.InputLayer'>), ('hidden1', <class 'lasagne.layers.dense.DenseLayer'>), ('dropout1', <class 'lasagne.layers.noise.DropoutLayer'>), ('nonlinear', <class 'lasagne.layers.dense.NonlinearityLayer'>), ('output', <class 'lasagne.layers.dense.DenseLayer'>)],
     loss=None, max_epochs=500, more_params={},
     objective=<function objective at 0x7fa672540510>,
     objective_loss_function=<function squared_error at 0x7fa6280fc840>,
     on_epoch_finished=[<nolearn.lasagne.handlers.PrintLog object at 0x7fa5c2153b00>],
     on_training_finished=[],
     on_training_started=[<nolearn.lasagne.handlers.PrintLayerInfo object at 0x7fa5c2153710>],
     output_nonlinearity=None, output_num_units=1, regression=True,
     train_split=<nolearn.lasagne.base.TrainSplit object at 0x7fa67253e7b8>,
     update=<function nesterov_momentum at 0x7fa6281020d0>,
     update_learning_rate=0.01, update_momentum=0.95,
     use_label_encoder=False, verbose=1,
     y_tensor_type=TensorType(float32, matrix))

In [167]:
nn1_preds = net1.predict(x_test)
nn1_mse = float(mean_squared_error(nn1_preds, y_test))

plt.figure(figsize=(3,3))
plt.scatter(y_test, nn1_preds)
plt.xlabel('Actual')
plt.ylabel('Predicted')
plt.title('{0} Neural Network'.format(drug_cols[colnum]))
plt.gca().set_aspect('equal', 'datalim')
plt.annotate(s='mse: {0}'.format(str(Decimal(nn1_mse).quantize(TWOPLACES))), xy=(1,0), xycoords='axes fraction', ha='right', va='bottom')
plt.plot(x_equals_y(y_test), x_equals_y(y_test), color='red')
plt.show()



In [168]:
sps.pearsonr(nn1_preds, y_test.reshape(y_test.shape[0],1))


Out[168]:
(array([ 0.87914544], dtype=float32), array([  3.59665475e-38], dtype=float32))

In [ ]:


In [ ]:


In [ ]: