In [1]:
import pandas as pd
import numpy as np
import nolearn
import matplotlib.pyplot as plt
import seaborn
import sklearn.linear_model as lm
import scipy.stats as sps
import math

from Bio import SeqIO
from collections import Counter
from decimal import Decimal
from lasagne import layers, nonlinearities
from lasagne.updates import nesterov_momentum
from lasagne import layers
from nolearn.lasagne import NeuralNet
from sklearn.ensemble import RandomForestRegressor, AdaBoostRegressor, GradientBoostingRegressor
from sklearn.cross_validation import train_test_split
from sklearn.metrics import r2_score, mean_squared_error
from sklearn.svm import SVR

seaborn.set_style('white')
seaborn.set_context('poster')

%matplotlib inline


Using gpu device 0: Quadro 2000

In [7]:
# Read in the protease inhibitor data
widths = [8]
widths.extend([4]*8)
widths.extend([4]*99)
data = pd.read_csv('hiv-nrt-data.csv', index_col='SeqID')
data


Out[7]:
3TC ABC AZT D4T DDI TDF P1 P2 P3 P4 ... P231 P232 P233 P234 P235 P236 P237 P238 P239 P240
SeqID
2997 200.00 4.3 3.9 1.4 1.2 NaN - - - - ... - - - - - - - - - -
4388 200.00 8.1 4.7 1.8 1.7 NaN - - - - ... - - - - - - - - - .
4427 1.40 1.1 28.0 1.0 0.8 1.9 - - - - ... - - - - - - - - - -
4487 1.80 1.5 7.1 1.2 1.1 1.1 - - - - ... - - - - - - - - - -
4539 200.00 3.7 4.5 2.0 1.3 NaN - - - - ... . . . . . . . . . .
4663 200.00 7.1 0.8 1.3 1.9 NaN - - - - ... - - - - - - - - - -
4689 2.55 5.2 6.7 5.4 5.3 0.9 - - - - ... - - - - - - - - - -
4697 200.00 5.9 1.9 1.3 1.7 NaN - - - - ... - - - - - - - - - -
5071 200.00 6.3 5.2 1.4 1.9 0.9 - - - - ... - - - - - - - - - -
5222 200.00 5.7 0.3 1.2 2.0 NaN - - - - ... - - - - - - - - - -
5280 6.20 4.5 1000.0 2.4 1.6 5.0 - - - - ... - - - - - - - - - -
5445 3.30 4.9 177.0 2.9 1.6 2.2 - - - - ... - - - - - - - - - -
5463 3.30 6.5 40.0 9.5 9.2 1.2 - - - - ... - - - - - - - - - -
5465 200.00 7.0 77.0 2.7 2.1 NaN - - - - ... - - - - - - - - - -
5641 200.00 4.2 1.5 1.2 1.3 0.9 - - - - ... - - - - - - - - - -
5708 200.00 12.0 1000.0 7.2 2.4 3.7 - - - - ... - - - - - - - - - -
6029 7.30 18.0 2400.0 16.0 4.8 NaN - - - - ... - - - - - - - - - -
6485 1.00 0.6 0.2 0.4 0.7 0.4 - - - - ... - - - - - - - - - -
6519 200.00 2.7 0.4 0.8 1.3 NaN - - - - ... - - - LH - - - - - -
6540 1.10 1.0 0.3 0.8 1.0 0.7 - - - - ... - - - - - - - - - -
6569 3.90 5.3 216.5 4.2 1.8 NaN - - - - ... - - - - - - - - - -
6709 0.90 0.8 0.5 0.8 1.0 0.7 - - - - ... - - - - - - - - - -
6796 7.10 7.0 259.0 4.4 2.4 NaN - - - - ... - - - - - - - - - -
6820 4.80 1.6 24.0 1.8 1.4 NaN - - - - ... - - - - - - E - - -
6859 1.40 2.3 42.0 2.0 1.0 NaN - - - - ... - - - - - - - - - -
6876 200.00 9.6 35.4 3.3 2.1 NaN - - - - ... - - - - - - - - - -
7328 3.60 1.7 7.2 1.5 1.1 NaN - - - - ... - - - - - - - - - -
7347 200.00 9.9 100.5 3.9 2.1 NaN - - - - ... - - - - - - - - - -
7348 200.00 7.7 7.9 2.3 1.8 NaN - - - - ... - - - - - - - - - -
7350 200.00 4.6 0.6 1.0 1.6 NaN - - - - ... - - - - - - - - - -
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
259208 1.25 1.0 0.6 0.9 1.0 0.9 - - - - ... - - - - - - - - - -
259210 81.70 1.6 0.2 0.6 1.6 0.3 - - - - ... - - - - - - - - - -
259212 1.00 0.7 0.4 0.8 0.8 0.7 - - - - ... - - - - - - - - - -
259214 74.00 1.5 0.2 0.7 1.4 0.4 - - - - ... - - - - - - - - - -
259216 1.05 0.9 0.8 1.2 1.3 0.9 - - - - ... - - - - - - - - - -
259222 29.50 3.6 0.4 1.2 1.7 2.0 - - - - ... - - - - - - - - - -
259224 1.40 1.0 0.6 0.9 1.0 0.9 - - - - ... - - - - - - - - - -
259226 113.00 1.8 0.2 0.8 1.6 0.5 - - - - ... - - - - - - - N - -
259228 1.30 0.9 0.5 0.7 0.7 0.8 - - - - ... - - - - - - - - - -
259230 91.50 3.0 0.2 0.5 1.2 0.2 - - - - ... - - - - - - - - - -
259234 1.15 1.0 0.8 1.1 1.3 1.0 - - - - ... - - - - - - - - - -
259238 1.30 1.0 0.6 0.9 1.2 0.9 - - - - ... - - - - - - - - - -
259242 0.90 0.8 0.7 0.9 0.9 0.9 - - - - ... - - - - - - - - - -
259244 88.50 5.2 0.2 0.6 2.0 0.3 - - - - ... - - - - - - - - - -
259246 1.05 0.8 0.8 1.1 1.3 0.9 - - - - ... - - - - - - - - - -
259248 82.10 1.9 0.3 0.8 1.6 0.5 - - - - ... - - - - - - - - - -
259250 1.15 0.9 0.6 0.9 1.2 0.8 - - - - ... - - - - - - - - - -
259252 87.00 3.8 0.4 0.9 1.9 0.7 - - - - ... - - - - - - - - - -
259254 1.40 0.9 0.6 0.9 0.8 0.8 - - - - ... - - - - - - - - - -
259256 91.50 3.0 0.2 0.6 1.1 0.2 - - - - ... - - - - - - - - - -
259258 1.20 0.9 0.8 1.2 1.3 1.0 - - - - ... - - - - - - - - - -
259260 104.50 4.4 0.6 0.7 1.9 0.8 - - - - ... - - - - - - - - - -
259262 1.15 1.0 0.5 0.8 0.8 0.7 - - - - ... - - - - - - - - - -
259266 1.05 0.7 0.3 0.8 0.8 0.6 - - - - ... - - - - - - - - - -
259268 125.50 2.9 0.2 0.7 1.2 0.4 - - - - ... - - - - - - - - - -
270334 1.15 1.4 1.2 1.4 1.2 1.1 - - - - ... - - - - - - - - - -
270335 89.40 3.2 0.8 1.6 2.4 0.8 - - - - ... - - - - - - - - - -
270336 89.40 2.8 0.5 1.4 2.2 0.6 - - - - ... - - - - - - - - - -
270337 1.20 1.4 1.8 1.4 1.2 1.2 - - - - ... - - - - - - - - - -
270338 0.70 0.8 0.4 1.0 0.8 0.7 - - - - ... - - - - - - - - - -

1498 rows × 246 columns


In [9]:
# Set the drug data columns and the amino acid data columns
drug_cols = data.columns[0:6]
feat_cols = data.columns[6:]

In [10]:
# Read in the consensus data
consensus = SeqIO.read('hiv-rt-consensus.fasta', 'fasta')

consensus_map = {i+1:letter for i, letter in enumerate(str(consensus.seq))}

In [11]:
# Because there are '-' characters in the dataset, representing consensus sequence at each of the positions, 
# they need to be replaced with the actual consensus letter.

for i, col in enumerate(feat_cols):
    # Replace '-' with the consensus letter.
    data[col] = data[col].replace({'-':consensus_map[i+1]})
    
    # Replace '.' with np.nan
    data[col] = data[col].replace({'.':np.nan})
    
    # Replace 'X' with np.nan
    data[col] = data[col].replace({'X':np.nan})

In [12]:
# Drop any feat_cols that have np.nan inside them. We don't want low quality sequences.
data.dropna(inplace=True, subset=feat_cols)

In [13]:
data


Out[13]:
3TC ABC AZT D4T DDI TDF P1 P2 P3 P4 ... P231 P232 P233 P234 P235 P236 P237 P238 P239 P240
SeqID
4427 1.40 1.1 28.0 1.0 0.8 1.9 P I S P ... G Y E L H P D K W T
4487 1.80 1.5 7.1 1.2 1.1 1.1 P I S P ... G Y E L H P D K W T
4663 200.00 7.1 0.8 1.3 1.9 NaN P I S P ... G Y E L H P D K W T
4689 2.55 5.2 6.7 5.4 5.3 0.9 P I S P ... G Y E L H P D K W T
4697 200.00 5.9 1.9 1.3 1.7 NaN P I S P ... G Y E L H P D K W T
5071 200.00 6.3 5.2 1.4 1.9 0.9 P I S P ... G Y E L H P D K W T
5222 200.00 5.7 0.3 1.2 2.0 NaN P I S P ... G Y E L H P D K W T
5280 6.20 4.5 1000.0 2.4 1.6 5.0 P I S P ... G Y E L H P D K W T
5445 3.30 4.9 177.0 2.9 1.6 2.2 P I S P ... G Y E L H P D K W T
5463 3.30 6.5 40.0 9.5 9.2 1.2 P I S P ... G Y E L H P D K W T
5465 200.00 7.0 77.0 2.7 2.1 NaN P I S P ... G Y E L H P D K W T
5641 200.00 4.2 1.5 1.2 1.3 0.9 P I S P ... G Y E L H P D K W T
5708 200.00 12.0 1000.0 7.2 2.4 3.7 P I S P ... G Y E L H P D K W T
6029 7.30 18.0 2400.0 16.0 4.8 NaN P I S P ... G Y E L H P D K W T
6485 1.00 0.6 0.2 0.4 0.7 0.4 P I S P ... G Y E L H P D K W T
6519 200.00 2.7 0.4 0.8 1.3 NaN P I S P ... G Y E LH H P D K W T
6540 1.10 1.0 0.3 0.8 1.0 0.7 P I S P ... G Y E L H P D K W T
6569 3.90 5.3 216.5 4.2 1.8 NaN P I S P ... G Y E L H P D K W T
6709 0.90 0.8 0.5 0.8 1.0 0.7 P I S P ... G Y E L H P D K W T
6796 7.10 7.0 259.0 4.4 2.4 NaN P I S P ... G Y E L H P D K W T
6820 4.80 1.6 24.0 1.8 1.4 NaN P I S P ... G Y E L H P E K W T
6859 1.40 2.3 42.0 2.0 1.0 NaN P I S P ... G Y E L H P D K W T
6876 200.00 9.6 35.4 3.3 2.1 NaN P I S P ... G Y E L H P D K W T
7328 3.60 1.7 7.2 1.5 1.1 NaN P I S P ... G Y E L H P D K W T
7347 200.00 9.9 100.5 3.9 2.1 NaN P I S P ... G Y E L H P D K W T
7348 200.00 7.7 7.9 2.3 1.8 NaN P I S P ... G Y E L H P D K W T
7350 200.00 4.6 0.6 1.0 1.6 NaN P I S P ... G Y E L H P D K W T
7360 200.00 6.7 6.5 1.6 1.4 NaN P I S P ... G Y E L H P D K W T
7364 200.00 7.8 7.9 2.3 1.7 NaN P I S P ... G Y E L H P D K W T
7377 200.00 29.3 726.0 10.1 4.1 NaN P I S P ... G Y E L H P D K W T
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
259204 1.50 1.0 0.5 0.8 0.9 0.8 P I S P ... G Y E L H P D K W T
259206 121.00 1.6 0.5 0.7 1.3 0.5 P I S P ... G Y E L H P D K W T
259208 1.25 1.0 0.6 0.9 1.0 0.9 P I S P ... G Y E L H P D K W T
259210 81.70 1.6 0.2 0.6 1.6 0.3 P I S P ... G Y E L H P D K W T
259214 74.00 1.5 0.2 0.7 1.4 0.4 P I S P ... G Y E L H P D K W T
259216 1.05 0.9 0.8 1.2 1.3 0.9 P I S P ... G Y E L H P D K W T
259222 29.50 3.6 0.4 1.2 1.7 2.0 P I S P ... G Y E L H P D K W T
259224 1.40 1.0 0.6 0.9 1.0 0.9 P I S P ... G Y E L H P D K W T
259226 113.00 1.8 0.2 0.8 1.6 0.5 P I S P ... G Y E L H P D N W T
259228 1.30 0.9 0.5 0.7 0.7 0.8 P I S P ... G Y E L H P D K W T
259230 91.50 3.0 0.2 0.5 1.2 0.2 P I S P ... G Y E L H P D K W T
259234 1.15 1.0 0.8 1.1 1.3 1.0 P I S P ... G Y E L H P D K W T
259238 1.30 1.0 0.6 0.9 1.2 0.9 P I S P ... G Y E L H P D K W T
259242 0.90 0.8 0.7 0.9 0.9 0.9 P I S P ... G Y E L H P D K W T
259244 88.50 5.2 0.2 0.6 2.0 0.3 P I S P ... G Y E L H P D K W T
259246 1.05 0.8 0.8 1.1 1.3 0.9 P I S P ... G Y E L H P D K W T
259248 82.10 1.9 0.3 0.8 1.6 0.5 P I S P ... G Y E L H P D K W T
259250 1.15 0.9 0.6 0.9 1.2 0.8 P I S P ... G Y E L H P D K W T
259252 87.00 3.8 0.4 0.9 1.9 0.7 P I S P ... G Y E L H P D K W T
259254 1.40 0.9 0.6 0.9 0.8 0.8 P I S P ... G Y E L H P D K W T
259256 91.50 3.0 0.2 0.6 1.1 0.2 P I S P ... G Y E L H P D K W T
259258 1.20 0.9 0.8 1.2 1.3 1.0 P I S P ... G Y E L H P D K W T
259260 104.50 4.4 0.6 0.7 1.9 0.8 P I S P ... G Y E L H P D K W T
259266 1.05 0.7 0.3 0.8 0.8 0.6 P I S P ... G Y E L H P D K W T
259268 125.50 2.9 0.2 0.7 1.2 0.4 P I S P ... G Y E L H P D K W T
270334 1.15 1.4 1.2 1.4 1.2 1.1 P I S P ... G Y E L H P D K W T
270335 89.40 3.2 0.8 1.6 2.4 0.8 P I S P ... G Y E L H P D K W T
270336 89.40 2.8 0.5 1.4 2.2 0.6 P I S P ... G Y E L H P D K W T
270337 1.20 1.4 1.8 1.4 1.2 1.2 P I S P ... G Y E L H P D K W T
270338 0.70 0.8 0.4 1.0 0.8 0.7 P I S P ... G Y E L H P D K W T

1263 rows × 246 columns


In [14]:
# Drop any feat_cols that are completely conserved.

# The nonconserved_cols list will serve as a convenient selector for the X- data from the 
# original dataframe.
nonconserved_cols = []
for col in feat_cols:
    if len(pd.unique(data[col])) == 1:
        data.drop(col, axis=1, inplace=True)
        
    else:
        nonconserved_cols.append(col)

In [15]:
drug_cols


Out[15]:
Index(['3TC', 'ABC', 'AZT', 'D4T', 'DDI', 'TDF'], dtype='object')

In [136]:
def x_equals_y(y_test):
    """
    A function that returns a range from minimum to maximum of y_test.
    
    Used below in the plotting below.
    """
    floor = math.floor(np.min(y_test))
    ceil = math.ceil(np.max(y_test))
    x_eq_y = range(floor, ceil)
    return x_eq_y

TWOPLACES = Decimal(10) ** -2

In [152]:
colnum = 3

drug_df = pd.DataFrame()
drug_df[drug_cols[colnum]] = data[drug_cols[colnum]]
drug_df[nonconserved_cols] = data[nonconserved_cols]
for col in nonconserved_cols:
    drug_df[col] = drug_df[col].apply(lambda x: np.nan if len(x) > 1 else x)
drug_df.dropna(inplace=True)

drug_X = drug_df[nonconserved_cols]
drug_Y = drug_df[drug_cols[colnum]].apply(lambda x:np.log(x))
# drug_Y.values

In [153]:
from isoelectric_point import isoelectric_points
from molecular_weight import molecular_weights

# Standardize pI matrix. 7 is neutral
drug_X_pi = drug_X.replace(isoelectric_points)

# Standardize MW matrix.
drug_X_mw = drug_X.replace(molecular_weights)

# Binarize drug_X matrix.
from sklearn.preprocessing import LabelBinarizer
drug_X_bi = pd.DataFrame()
binarizers = dict()

for col in drug_X.columns:
    lb = LabelBinarizer()
    binarized_cols = lb.fit_transform(drug_X[col])
    # print(binarized_cols)
    if len(lb.classes_) == 2:
        # print(binarized_cols)
        drug_X_bi[col] = pd.Series(binarized_cols[:,0])
    else:
        for i, c in enumerate(lb.classes_):
            # print(col + c)
            # print(binarized_cols[:,i])

            drug_X_bi[col + '_' + c] = binarized_cols[:,i]
        
drug_X_bi


Out[153]:
P1 P2 P3_C P3_N P3_S P4_P P4_S P4_T P5_I P5_S ... P237_E P237_N P238_K P238_R P238_T P239 P240_A P240_E P240_K P240_T
0 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
1 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
2 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
3 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
4 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
5 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
6 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
7 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
8 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
9 1 0 0 0 1 1 0 0 1 0 ... 1 0 1 0 0 1 0 0 0 1
10 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
11 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
12 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
13 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
14 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
15 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
16 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
17 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
18 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
19 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
20 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
21 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
22 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
23 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
24 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
25 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
26 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
27 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
28 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
29 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
428 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
429 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
430 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
431 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
432 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
433 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
434 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
435 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
436 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
437 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
438 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
439 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
440 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
441 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
442 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
443 1 0 0 0 1 0 1 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
444 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
445 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
446 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
447 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
448 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
449 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
450 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
451 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
452 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
453 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
454 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
455 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
456 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1
457 1 0 0 0 1 1 0 0 1 0 ... 0 0 1 0 0 1 0 0 0 1

458 rows × 482 columns


In [154]:
fig = plt.figure(figsize=(3,3))
drug_Y.hist(grid=False)
plt.xlabel('Value')
plt.ylabel('Count')
plt.title('{0} Distribution'.format(drug_cols[colnum]))


Out[154]:
<matplotlib.text.Text at 0x7fa5c215d4e0>

In [155]:
# Here, let's try the Random Forest Regressor. This will be the baseline.
x_train, x_test, y_train, y_test = train_test_split(drug_X_bi, drug_Y)

rfr = RandomForestRegressor(n_estimators=500, n_jobs=-1, oob_score=True)

rfr.fit(x_train, y_train)
rfr_preds = rfr.predict(x_test)
print(rfr.score(x_test, y_test), mean_squared_error(rfr_preds, y_test))
rfr_mse = mean_squared_error(rfr_preds, y_test)
# print(rfr.oob_score_)
sps.pearsonr(rfr_preds, y_test)

plt.figure(figsize=(3,3))
plt.scatter(y_test, rfr_preds,)
plt.title('{0} Random Forest'.format(drug_cols[colnum]))
plt.xlabel('Actual')
plt.ylabel('Predicted')
plt.gca().set_aspect('equal', 'datalim')
plt.annotate(s='mse: {0}'.format(str(Decimal(rfr_mse).quantize(TWOPLACES))), xy=(1,0), xycoords='axes fraction', ha='right', va='bottom')
plt.plot(x_equals_y(y_test), x_equals_y(y_test), color='red')
plt.show()


0.786368305999 0.132429943561

In [156]:
plt.bar(range(len(rfr.feature_importances_)), rfr.feature_importances_)
plt.xlabel('Position')
plt.ylabel('Relative Importance')
plt.title('{0} Random Forest'.format(drug_cols[colnum]))


Out[156]:
<matplotlib.text.Text at 0x7fa5c20d3b38>

In [157]:
# Get back the importance of each feature.
feat_impt = [(p, i) for p, i in zip(drug_X_bi.columns, rfr.feature_importances_)]
sorted(feat_impt, key=lambda x:x[1], reverse=True)


Out[157]:
[('P151', 0.16251462938669017),
 ('P215_T', 0.16006592453195229),
 ('P62', 0.089598402308289846),
 ('P210_W', 0.04805732919349797),
 ('P69_#', 0.043176594447178349),
 ('P210_L', 0.03820719628644869),
 ('P116', 0.036923518809471458),
 ('P41_L', 0.027512755954099771),
 ('P77', 0.018855167959727394),
 ('P67_D', 0.018461063073661806),
 ('P70_K', 0.016167279408335643),
 ('P40_E', 0.010868828402363542),
 ('P215_D', 0.0097352909437584643),
 ('P40_F', 0.0095284336496718729),
 ('P118', 0.0093069566910762534),
 ('P69_T', 0.0091408832193107317),
 ('P184_V', 0.0089808668340186793),
 ('P65', 0.0078659771260242446),
 ('P67_N', 0.0078486143018055513),
 ('P70_N', 0.0076543349919679102),
 ('P70_d', 0.0072499899020392116),
 ('P215_Y', 0.0071962504233967191),
 ('P223_N', 0.0062562829205002248),
 ('P43_E', 0.0060866113487613007),
 ('P73', 0.0060846705519703387),
 ('P70_R', 0.0060199337114942904),
 ('P139_S', 0.0053925554834850821),
 ('P41_M', 0.0052268298420678544),
 ('P184_M', 0.0047365291124894888),
 ('P135_T', 0.0047257702295656111),
 ('P208', 0.0045941531295171948),
 ('P219_K', 0.0042685057849632061),
 ('P75_V', 0.0040200811033691917),
 ('P218', 0.0038741506470258355),
 ('P174_R', 0.0038138763719247435),
 ('P58_N', 0.0036469640236670266),
 ('P190_A', 0.0034880680895961361),
 ('P181_C', 0.0034809909691663351),
 ('P203_K', 0.0034169068828247013),
 ('P181_Y', 0.0034023562884997189),
 ('P74_V', 0.0033380187605316792),
 ('P138_K', 0.0032196758142861827),
 ('P214', 0.0030899144516060234),
 ('P74_L', 0.0030685216674645865),
 ('P207_E', 0.0027232206439434365),
 ('P39_T', 0.0026421469117063161),
 ('P75_T', 0.0025827401321966976),
 ('P135_V', 0.0025695161872389847),
 ('P200_T', 0.0025636478473387938),
 ('P139_T', 0.0024212663266235321),
 ('P207_Q', 0.002402825396636144),
 ('P64_H', 0.002383787045730263),
 ('P123_E', 0.0023502644301236194),
 ('P69_D', 0.0023465037879284579),
 ('P190_G', 0.0022155694228006011),
 ('P122_E', 0.002213964890623938),
 ('P211_R', 0.0021859602876278584),
 ('P102_Q', 0.0021678505224306203),
 ('P44_E', 0.0020674696042807359),
 ('P215_F', 0.0019449153773165285),
 ('P219_E', 0.0019272493402581862),
 ('P203_E', 0.0018550549534186885),
 ('P196_E', 0.0018383325223303162),
 ('P122_K', 0.0018347234940733837),
 ('P178_I', 0.001806222233384819),
 ('P123_D', 0.0017749513060090457),
 ('P69_N', 0.0017700664947200423),
 ('P135_I', 0.0017449729362553255),
 ('P75_I', 0.0017440799295997303),
 ('P103_K', 0.0017227293956739301),
 ('P219_R', 0.0016741067913017835),
 ('P43_K', 0.0016735102186183708),
 ('P102_K', 0.0016675568335350174),
 ('P228_H', 0.0016663261354250669),
 ('P196_G', 0.0016541932141252525),
 ('P200_A', 0.0016447160208774292),
 ('P162_S', 0.00159023188765284),
 ('P138_E', 0.001478400658828117),
 ('P177_E', 0.0014627745557714018),
 ('P219_Q', 0.0014482811153136784),
 ('P35_T', 0.0014381497441807549),
 ('P188_L', 0.0013285168302978377),
 ('P211_K', 0.0013093825552678343),
 ('P103_N', 0.001248617556734215),
 ('P179_I', 0.0012328551602509344),
 ('P162_C', 0.0011982045240013577),
 ('P228_L', 0.0011635428973818041),
 ('P178_L', 0.0011481025443574502),
 ('P238_K', 0.0011445712967038549),
 ('P35_V', 0.0011024788869958006),
 ('P146', 0.0010982585955854973),
 ('P64_K', 0.0010920025818014561),
 ('P177_D', 0.0010905129160233638),
 ('P178_F', 0.0010693452316058952),
 ('P39_A', 0.0010484497888503823),
 ('P20_K', 0.0010475814864030529),
 ('P101_E', 0.0010354309956481218),
 ('P101_K', 0.0010101778830088773),
 ('P188_Y', 0.000998000612165823),
 ('P228_R', 0.00097952916267022867),
 ('P68_G', 0.00096397487205400543),
 ('P108', 0.00095849356763606428),
 ('P115', 0.00094286952675548899),
 ('P162_A', 0.00093144483875862323),
 ('P238_T', 0.00090917464128133089),
 ('P60_I', 0.00090070518000247117),
 ('P67_E', 0.00089224907758878632),
 ('P68_S', 0.00086652663411294559),
 ('P179_V', 0.00086618235778399889),
 ('P83', 0.00085505255029670385),
 ('P103_S', 0.00085144885179626469),
 ('P121_Y', 0.00085126753806842744),
 ('P20_R', 0.00079379237648687486),
 ('P203_D', 0.00077172511141351081),
 ('P202', 0.00075929783491084622),
 ('P180', 0.00073909838686148088),
 ('P174_Q', 0.00073874043164542686),
 ('P200_I', 0.00073162255681505047),
 ('P123_N', 0.00070153937678211304),
 ('P178_M', 0.00069111021143699323),
 ('P60_V', 0.00068919471690468239),
 ('P207_N', 0.00067385270252314779),
 ('P142_I', 0.00065640298445170553),
 ('P227_L', 0.00065024723775872889),
 ('P35_I', 0.00064649900469493577),
 ('P44_D', 0.00064010540730884949),
 ('P90', 0.00063989615343746926),
 ('P106_V', 0.00063507207005211491),
 ('P98_G', 0.00062887490465580531),
 ('P172', 0.00061024702564071141),
 ('P28_E', 0.00059981013085939265),
 ('P98_A', 0.00057534786487980669),
 ('P35_M', 0.00056911894252802164),
 ('P142_V', 0.00055137585045650437),
 ('P166_R', 0.00054977243910318146),
 ('P237_D', 0.00054775854075780049),
 ('P179_F', 0.00052963131097299315),
 ('P49', 0.00052512110059578228),
 ('P16', 0.00049606463759768607),
 ('P138_G', 0.000486103946357874),
 ('P5_V', 0.00045940617868891722),
 ('P219_H', 0.00045381112961015966),
 ('P197_Q', 0.00044784706876806028),
 ('P227_F', 0.0004473313209620604),
 ('P5_I', 0.00042971435742582422),
 ('P211_N', 0.00042786537379451052),
 ('P215_S', 0.00041865725477986172),
 ('P89', 0.00041764304987800856),
 ('P223_K', 0.00040976773738635435),
 ('P68_T', 0.00039561065486796288),
 ('P43_Q', 0.00037804569690291247),
 ('P44_A', 0.00037484028070881067),
 ('P139_A', 0.00035590141937412467),
 ('P219_N', 0.00035154342032472273),
 ('P123_I', 0.00034953177682979503),
 ('P39_E', 0.00034730857148287779),
 ('P135_L', 0.00032111616800641442),
 ('P184_I', 0.00031939754644173097),
 ('P121_D', 0.00031732579512637157),
 ('P173_K', 0.00031615518762627149),
 ('P221', 0.00031455011193028858),
 ('P6_K', 0.00030830883055275345),
 ('P181_I', 0.00030233783360660249),
 ('P98_S', 0.00028876819567226215),
 ('P32_K', 0.00028124417542713646),
 ('P166_K', 0.0002808748111850798),
 ('P11_R', 0.00028085164193501294),
 ('P67_G', 0.00028084776657503395),
 ('P142_T', 0.00028032331721691682),
 ('P237_E', 0.00027972256099027565),
 ('P162_D', 0.00027099659813596829),
 ('P197_E', 0.00026770517690369783),
 ('P207_D', 0.00026727523712224109),
 ('P135_M', 0.00026651589804190277),
 ('P33_V', 0.00026081957296777335),
 ('P11_K', 0.0002544333961544424),
 ('P106_A', 0.00025252978521046244),
 ('P169_D', 0.00025055970776281003),
 ('P138_A', 0.00024359386375885614),
 ('P211_S', 0.00024350310232222971),
 ('P173_R', 0.00024258349397140316),
 ('P1', 0.00023332647291195526),
 ('P27', 0.00023226793237666573),
 ('P179_D', 0.00023137107928803926),
 ('P58_T', 0.00022655002519174334),
 ('P100', 0.00022014608755782165),
 ('P190_C', 0.00020714297387020066),
 ('P207_G', 0.00020647995815651133),
 ('P211_T', 0.00020404825050321954),
 ('P230', 0.00020287080602791622),
 ('P101_Q', 0.00020120962784956936),
 ('P157', 0.00020000058336736433),
 ('P190_S', 0.00019419345405297073),
 ('P74_I', 0.00018296967171673874),
 ('P48_S', 0.0001775163749034273),
 ('P104_R', 0.00017244319760841673),
 ('P111', 0.000167731611715598),
 ('P174_K', 0.0001671768480088083),
 ('P31_V', 0.00016695766584317567),
 ('P33_A', 0.00016554905591164114),
 ('P109', 0.00016410756193115599),
 ('P64_R', 0.00016334690794083678),
 ('P101_H', 0.00016073242777569233),
 ('P4_S', 0.000159582501610235),
 ('P173_Q', 0.00015518873743907723),
 ('P28_K', 0.00015296418244316532),
 ('P169_E', 0.00015235216899203938),
 ('P32_E', 0.00015230190426268028),
 ('P48_T', 0.00014931600475534391),
 ('P101_R', 0.00014790045863572254),
 ('P173_N', 0.00014770509449362994),
 ('P4_P', 0.0001469270609201563),
 ('P239', 0.00014672393892225505),
 ('P106_I', 0.00014256246536649794),
 ('P13', 0.0001413177624044851),
 ('P31_I', 0.00013858828937927102),
 ('P121_H', 0.00013747673450159318),
 ('P215_V', 0.000136070845889713),
 ('P37', 0.00013531876061213566),
 ('P207_K', 0.00013337564255884132),
 ('P200_R', 0.00013174523839097946),
 ('P223_Q', 0.00013096240494140797),
 ('P28_A', 0.00012548340756437777),
 ('P138_Q', 0.00011532785441760907),
 ('P207_A', 0.00011039518865872977),
 ('P197_K', 0.00010950928712528276),
 ('P6_E', 0.00010824725552461508),
 ('P43_N', 0.00010378416513680041),
 ('P139_R', 9.9636100065719714e-05),
 ('P3_S', 9.8314835546991982e-05),
 ('P47_V', 9.8145633287732452e-05),
 ('P104_K', 9.2512522314128646e-05),
 ('P207_H', 9.1557056511385934e-05),
 ('P35_R', 8.5729676593072237e-05),
 ('P224_D', 8.5043931640428925e-05),
 ('P224_E', 8.277277434848758e-05),
 ('P163', 8.2429008848087891e-05),
 ('P63', 8.0182907627834556e-05),
 ('P179_E', 7.90508869786849e-05),
 ('P39_K', 7.8996196163410209e-05),
 ('P5_S', 7.8763581406593054e-05),
 ('P123_S', 7.8306273276253831e-05),
 ('P122_P', 7.7440964987469778e-05),
 ('P6_D', 7.6734583529687543e-05),
 ('P165_T', 7.5866324719554409e-05),
 ('P225', 7.5306144414764273e-05),
 ('P85', 7.4260000014479368e-05),
 ('P200_E', 7.4010491294990887e-05),
 ('P195_L', 7.3414017088698944e-05),
 ('P32_T', 7.2079978204739242e-05),
 ('P195_I', 7.1311497389698543e-05),
 ('P197_L', 7.0086870127877187e-05),
 ('P211_Q', 6.8967116529355059e-05),
 ('P224_K', 6.8384307989135959e-05),
 ('P166_T', 6.793157761611871e-05),
 ('P69_S', 6.5746514933650235e-05),
 ('P101_P', 6.5001407055795941e-05),
 ('P132', 6.4888442329120034e-05),
 ('P198', 6.1937193164967217e-05),
 ('P35_K', 6.1207994804328268e-05),
 ('P238_R', 6.0920435106250203e-05),
 ('P223_E', 6.089911496557083e-05),
 ('P215_I', 6.0291401747136104e-05),
 ('P102_R', 6.0072874363770748e-05),
 ('P174_H', 5.8923496579779993e-05),
 ('P122_Q', 5.8823392940486941e-05),
 ('P82', 5.7365951036787005e-05),
 ('P174_L', 5.4329223948507411e-05),
 ('P88_C', 5.3129844946612595e-05),
 ('P43_A', 5.1403624745991304e-05),
 ('P237_N', 5.057886566871371e-05),
 ('P215_L', 5.0287989972268171e-05),
 ('P204_E', 4.895528866502371e-05),
 ('P50', 4.6616079559863745e-05),
 ('P170', 4.5277446250194471e-05),
 ('P64_Y', 4.4686356333231143e-05),
 ('P181_V', 4.4139566915812721e-05),
 ('P39_N', 4.3575327951873241e-05),
 ('P35_L', 4.3280729834555589e-05),
 ('P177_G', 4.2939158480663363e-05),
 ('P69_G', 4.2712037213690829e-05),
 ('P7', 4.2577869169305751e-05),
 ('P3_N', 4.253990404912848e-05),
 ('P3_C', 4.0793962533628899e-05),
 ('P80', 4.0738013528253687e-05),
 ('P195_M', 4.0549062820673448e-05),
 ('P103_R', 4.0535151027103161e-05),
 ('P162_Y', 4.0502921610321552e-05),
 ('P46', 3.938733629927549e-05),
 ('P240_T', 3.922657631975248e-05),
 ('P104_N', 3.8375956322611281e-05),
 ('P174_E', 3.7667347299986625e-05),
 ('P43_R', 3.6967727822854313e-05),
 ('P106_M', 3.6706555739788651e-05),
 ('P75_M', 3.620893936993806e-05),
 ('P240_E', 3.496513401654536e-05),
 ('P135_K', 3.4296684429088164e-05),
 ('P21_V', 3.4283499227464581e-05),
 ('P28_G', 3.1606039987253433e-05),
 ('P211_A', 3.1339152264936688e-05),
 ('P165_I', 3.1264606669730081e-05),
 ('P211_G', 3.0904875189993984e-05),
 ('P88_W', 3.0696233897704535e-05),
 ('P39_I', 3.0225356134305558e-05),
 ('P200_K', 2.96926806919556e-05),
 ('P101_I', 2.860147433881264e-05),
 ('P22', 2.773397225078223e-05),
 ('P158', 2.6961852352201984e-05),
 ('P21_I', 2.6890565857257042e-05),
 ('P210_S', 2.6380376338860684e-05),
 ('P21_A', 2.5903507060762843e-05),
 ('P188_H', 2.5870275556699503e-05),
 ('P34', 2.5354038673889142e-05),
 ('P47_I', 2.5133841977865804e-05),
 ('P69_K', 2.5063630726307757e-05),
 ('P67_~', 2.4946710951120842e-05),
 ('P212', 2.4892262810345553e-05),
 ('P204_D', 2.4279655242071229e-05),
 ('P173_T', 2.4235630046209914e-05),
 ('P39_L', 2.3979867291791524e-05),
 ('P173_S', 2.3888613063677351e-05),
 ('P35_E', 2.3840376582614944e-05),
 ('P177_N', 2.3821915553087907e-05),
 ('P165_A', 2.325313210508823e-05),
 ('P39_S', 2.3180720272361477e-05),
 ('P31_L', 2.258466535340127e-05),
 ('P159_L', 2.2250669326182636e-05),
 ('P228_Q', 2.0646478625955823e-05),
 ('P189', 2.0242772947484979e-05),
 ('P200_S', 1.9167348467636126e-05),
 ('P107', 1.8460370565011147e-05),
 ('P58_G', 1.841829731588697e-05),
 ('P197_P', 1.7876763272087301e-05),
 ('P35_Q', 1.6521125261421695e-05),
 ('P102_M', 1.6337589824524213e-05),
 ('P190_E', 1.6089060311500773e-05),
 ('P122_A', 1.5362032752082303e-05),
 ('P204_N', 1.5026234928358834e-05),
 ('P6_G', 1.499251790325744e-05),
 ('P14', 1.4951797644508774e-05),
 ('P195_T', 1.3342347447179275e-05),
 ('P240_K', 1.3024231131735054e-05),
 ('P96', 1.2531349596052805e-05),
 ('P176_P', 1.2260048801154113e-05),
 ('P75_L', 1.2059981516383563e-05),
 ('P145', 1.1937069946213749e-05),
 ('P20_T', 1.1243103452339689e-05),
 ('P88_G', 1.1148211982685336e-05),
 ('P204_Q', 1.0980630032656169e-05),
 ('P113', 1.0272818990464138e-05),
 ('P66', 1.0040123196221622e-05),
 ('P48_V', 9.889363680492432e-06),
 ('P215_C', 9.6387834023666912e-06),
 ('P9', 9.6176332577349102e-06),
 ('P215_N', 9.2117924818115869e-06),
 ('P40_D', 8.2212612156804012e-06),
 ('P159_I', 8.1435250139559716e-06),
 ('P159_V', 8.0165716589835687e-06),
 ('P228_F', 7.9687525674008383e-06),
 ('P210_F', 7.8842792091922328e-06),
 ('P36', 7.3820368195924407e-06),
 ('P128', 6.6159503856832758e-06),
 ('P127', 6.4648237411951987e-06),
 ('P161', 6.170217236359231e-06),
 ('P200_V', 5.9395016760378583e-06),
 ('P173_E', 5.6872943675120443e-06),
 ('P11_T', 5.5186160983299495e-06),
 ('P70_G', 5.2656507077404917e-06),
 ('P86', 4.9187649220681258e-06),
 ('P101_A', 4.794689076618831e-06),
 ('P176_T', 4.5354871983633148e-06),
 ('P176_S', 4.4820962149087409e-06),
 ('P69_A', 4.4233211338960185e-06),
 ('P33_G', 4.1612184324786338e-06),
 ('P19', 4.0936989626473446e-06),
 ('P227_Y', 4.0702759885375919e-06),
 ('P223_R', 3.9869755081948724e-06),
 ('P200_L', 3.8689378684052706e-06),
 ('P135_R', 3.8263114032032092e-06),
 ('P69_I', 3.8148496017361344e-06),
 ('P53', 3.5117395389113594e-06),
 ('P173_I', 3.3270858084319144e-06),
 ('P123_G', 3.3176536042403781e-06),
 ('P182', 3.1849964653397612e-06),
 ('P195_K', 3.1153300879020901e-06),
 ('P164', 3.1051135760467751e-06),
 ('P8', 3.0457801694138396e-06),
 ('P40_K', 2.9025830231359228e-06),
 ('P45', 2.6725090075384944e-06),
 ('P167', 2.3928120733466434e-06),
 ('P165_L', 2.1256630019541142e-06),
 ('P173_A', 1.9348268249657992e-06),
 ('P60_A', 1.8128465735366147e-06),
 ('P199', 1.7584412182626765e-06),
 ('P144', 1.5109722083712416e-06),
 ('P196_K', 1.4619230053096696e-06),
 ('P240_A', 1.4017245065337739e-06),
 ('P2', 1.3437155057732335e-06),
 ('P35_A', 1.2998876912630738e-06),
 ('P79', 7.0993410033283175e-07),
 ('P32_R', 6.0690878075020209e-07),
 ('P102_E', 5.6410767082531577e-07),
 ('P169_K', 1.7175941527700947e-07),
 ('P236', 1.073097201017753e-07),
 ('P226', 6.0113693215943616e-09),
 ('P4_T', 0.0),
 ('P10_V', 0.0),
 ('P15_G', 0.0),
 ('P18_G', 0.0),
 ('P23_Q', 0.0),
 ('P24_W', 0.0),
 ('P25_P', 0.0),
 ('P29_E', 0.0),
 ('P30_K', 0.0),
 ('P32_N', 0.0),
 ('P38_C', 0.0),
 ('P41_I', 0.0),
 ('P42_E', 0.0),
 ('P47_F', 0.0),
 ('P52_P', 0.0),
 ('P54', 0.0),
 ('P56_Y', 0.0),
 ('P59_P', 0.0),
 ('P64_N', 0.0),
 ('P68_K', 0.0),
 ('P70_E', 0.0),
 ('P72_R', 0.0),
 ('P75_A', 0.0),
 ('P84_T', 0.0),
 ('P87_F', 0.0),
 ('P91_Q', 0.0),
 ('P93_G', 0.0),
 ('P94_I', 0.0),
 ('P95_P', 0.0),
 ('P101_S', 0.0),
 ('P105_S', 0.0),
 ('P110_D', 0.0),
 ('P117_S', 0.0),
 ('P119_P', 0.0),
 ('P122_R', 0.0),
 ('P124_F', 0.0),
 ('P125_R', 0.0),
 ('P126_K', 0.0),
 ('P130_F', 0.0),
 ('P131_T', 0.0),
 ('P134_S', 0.0),
 ('P136_N', 0.0),
 ('P137_N', 0.0),
 ('P139_I', 0.0),
 ('P140_P', 0.0),
 ('P141_G', 0.0),
 ('P143_R', 0.0),
 ('P147_N', 0.0),
 ('P149_L', 0.0),
 ('P154_K', 0.0),
 ('P162_H', 0.0),
 ('P162_N', 0.0),
 ('P171', 0.0),
 ('P175_N', 0.0),
 ('P179_L', 0.0),
 ('P187_L', 0.0),
 ('P188_C', 0.0),
 ('P188_F', 0.0),
 ('P191', 0.0),
 ('P192', 0.0),
 ('P193_L', 0.0),
 ('P194_E', 0.0),
 ('P197_I', 0.0),
 ('P200_F', 0.0),
 ('P201_K', 0.0),
 ('P205_L', 0.0),
 ('P206_R', 0.0),
 ('P213_G', 0.0),
 ('P215_E', 0.0),
 ('P216_T', 0.0),
 ('P217_P', 0.0),
 ('P219_D', 0.0),
 ('P219_T', 0.0),
 ('P222_Q', 0.0),
 ('P232_Y', 0.0),
 ('P233_E', 0.0),
 ('P234', 0.0)]

In [158]:
# # Here, let's try a parameter grid search, to figure out what would be the best 
# from sklearn.grid_search import GridSearchCV
# import numpy as np

# param_grid = [{'n_estimators':[100, 500, 1000],
#                #'max_features':['auto', 'sqrt', 'log2'],
#                #'min_samples_leaf':np.arange(1,20,1),
#               }]

# x_train, x_test, y_train, y_test = train_test_split(fpv_X_bi, fpv_Y)


# rfr_gs = GridSearchCV(RandomForestRegressor(), param_grid=param_grid, n_jobs=-1)
# rfr_gs.fit(x_train, y_train)
# print(rfr_gs.best_estimator_)
# print(rfr_gs.best_params_)

In [ ]:
# Try Bayesian Ridge Regression
# x_train, x_test, y_train, y_test = train_test_split(drug_X_bi, drug_Y)

brr = lm.BayesianRidge()
brr.fit(x_train, y_train)
brr_preds = brr.predict(x_test)
print(brr.score(x_test, y_test), mean_squared_error(brr_preds, y_test))
print(sps.pearsonr(brr_preds, y_test))
brr_mse = mean_squared_error(brr_preds, y_test)

plt.figure(figsize=(3,3))
plt.scatter(y_test, brr_preds)
plt.xlabel('Actual')
plt.ylabel('Predicted')
plt.title('{0} Bayesian Ridge'.format(drug_cols[colnum]))
plt.gca().set_aspect('equal', 'datalim')
plt.annotate(s='mse: {0}'.format(str(Decimal(brr_mse).quantize(TWOPLACES))), xy=(1,0), xycoords='axes fraction', ha='right', va='bottom')
plt.plot(x_equals_y(y_test), x_equals_y(y_test), color='red')
plt.show()


0.77269618709 0.140905268087
(0.88401111341649008, 4.0613783272294827e-39)

In [ ]:
# Try ARD regression
ardr = lm.ARDRegression()
ardr.fit(x_train, y_train)
ardr_preds = ardr.predict(x_test)
ardr_mse = mean_squared_error(ardr_preds, y_test)

plt.figure(figsize=(3,3))
plt.scatter(y_test, ardr_preds)
plt.xlabel('Actual')
plt.ylabel('Predicted')
plt.title('{0} ARD Regression'.format(drug_cols[colnum]))
plt.gca().set_aspect('equal', 'datalim')
plt.annotate(s='mse: {0}'.format(str(Decimal(ardr_mse).quantize(TWOPLACES))), xy=(1,0), xycoords='axes fraction', ha='right', va='bottom')
plt.plot(x_equals_y(y_test), x_equals_y(y_test), color='red')
plt.show()

In [ ]:
# Try Gradient Boost
# x_train, x_test, y_train, y_test = train_test_split(drug_X_bi, drug_Y)

gbr = GradientBoostingRegressor()
gbr.fit(x_train, y_train)
gbr_preds = gbr.predict(x_test)
print(gbr.score(x_test, y_test), mean_squared_error(gbr_preds, y_test))
print(sps.pearsonr(gbr_preds, y_test))
gbr_mse = mean_squared_error(gbr_preds, y_test)

plt.figure(figsize=(3,3))
plt.scatter(y_test, gbr_preds)
plt.xlabel('Actual')
plt.ylabel('Predicted')
plt.title('{0} Grad. Boost'.format(drug_cols[colnum]))
plt.gca().set_aspect('equal', 'datalim')
plt.annotate(s='mse: {0}'.format(str(Decimal(gbr_mse).quantize(TWOPLACES))), xy=(1,0), xycoords='axes fraction', ha='right', va='bottom')
plt.plot(x_equals_y(y_test), x_equals_y(y_test), color='red')
plt.show()

In [ ]:
plt.bar(range(len(gbr.feature_importances_)), gbr.feature_importances_)

In [ ]:
# Try AdaBoost
# x_train, x_test, y_train, y_test = train_test_split(drug_X_bi, drug_Y)

abr = AdaBoostRegressor()
abr.fit(x_train, y_train)
abr_preds = abr.predict(x_test)
print(abr.score(x_test, y_test), mean_squared_error(abr_preds, y_test))
print(sps.pearsonr(abr_preds, y_test))
abr_mse = mean_squared_error(abr_preds, y_test)

plt.figure(figsize=(3,3))
plt.scatter(x=y_test, y=abr_preds)
plt.xlabel('Actual')
plt.ylabel('Predicted')
plt.title('{0} AdaBoost'.format(drug_cols[colnum]))
plt.gca().set_aspect('equal', 'datalim')
plt.annotate(s='mse: {0}'.format(str(Decimal(abr_mse).quantize(TWOPLACES))), xy=(1,0), xycoords='axes fraction', ha='right', va='bottom')
plt.plot(x_equals_y(y_test), x_equals_y(y_test), color='red')
plt.show()

In [ ]:
plt.bar(range(len(abr.feature_importances_)), abr.feature_importances_)

In [ ]:
# Try support vector regression
svr = SVR()
svr.fit(x_train, y_train)
svr_preds = svr.predict(x_test)

svr_mse = mean_squared_error(svr_preds, y_test)

plt.figure(figsize=(3,3))
plt.scatter(y_test, svr_preds, )
plt.xlabel('Actual')
plt.ylabel('Predicted')
plt.title('{0} SVR'.format(drug_cols[colnum]))
plt.gca().set_aspect('equal', 'datalim')
plt.annotate(s='mse: {0}'.format(str(Decimal(svr_mse).quantize(TWOPLACES))), xy=(1,0), xycoords='axes fraction', ha='right', va='bottom')
plt.plot(x_equals_y(y_test), x_equals_y(y_test), color='red')
plt.show()

In [ ]:
# Neural Network 1 Specification: Feed Forward ANN with 1 hidden layer.
# x_train, x_test, y_train, y_test = train_test_split(drug_X_bi, drug_Y)

x_train = x_train.astype(np.float32)
y_train = y_train.astype(np.float32)
x_test = x_test.astype(np.float32)
y_test = y_test.astype(np.float32)

net1 = NeuralNet(
    layers=[  # three layers: one hidden layer
        ('input', layers.InputLayer),
        ('hidden1', layers.DenseLayer),
        ('dropout1', layers.DropoutLayer),
        #('hidden2', layers.DenseLayer),
        #('dropout2', layers.DropoutLayer),
        ('nonlinear', layers.NonlinearityLayer),
        ('output', layers.DenseLayer),
        ],
    # layer parameters:
    input_shape=(None, x_train.shape[1]),  # 
    hidden1_num_units=math.ceil(x_train.shape[1] / 2),  # number of units in hidden layer
    hidden1_nonlinearity=nonlinearities.tanh,
    dropout1_p = 0.5,
    #hidden2_num_units=math.ceil(x_train.shape[1] / 2),
    #dropout2_p = 0.5,
    output_nonlinearity=None,  # output layer uses identity function
    output_num_units=1,  # 30 target values
    
    # optimization method:
    update=nesterov_momentum,
    update_learning_rate=0.01,
    update_momentum=0.95,

    regression=True,  # flag to indicate we're dealing with regression problem
    max_epochs=500,  # we want to train this many epochs
    verbose=1,
    )
net1.fit(x_train.values, y_train.values)

In [ ]:
nn1_preds = net1.predict(x_test)
nn1_mse = float(mean_squared_error(nn1_preds, y_test))

plt.figure(figsize=(3,3))
plt.scatter(y_test, nn1_preds)
plt.xlabel('Actual')
plt.ylabel('Predicted')
plt.title('{0} Neural Network'.format(drug_cols[colnum]))
plt.gca().set_aspect('equal', 'datalim')
plt.annotate(s='mse: {0}'.format(str(Decimal(nn1_mse).quantize(TWOPLACES))), xy=(1,0), xycoords='axes fraction', ha='right', va='bottom')
plt.plot(x_equals_y(y_test), x_equals_y(y_test), color='red')
plt.show()

In [ ]:
sps.pearsonr(nn1_preds, y_test.reshape(y_test.shape[0],1))

In [ ]:


In [ ]:


In [ ]: