In [1]:
import pandas as pd
import numpy as np
import nolearn
import matplotlib.pyplot as plt
import seaborn
import sklearn.linear_model as lm
import scipy.stats as sps
import math
from Bio import SeqIO
from collections import Counter
from decimal import Decimal
from lasagne import layers, nonlinearities
from lasagne.updates import nesterov_momentum
from lasagne import layers
from nolearn.lasagne import NeuralNet
from sklearn.ensemble import RandomForestRegressor, AdaBoostRegressor, GradientBoostingRegressor
from sklearn.cross_validation import train_test_split
from sklearn.metrics import r2_score, mean_squared_error
from sklearn.svm import SVR
seaborn.set_style('white')
seaborn.set_context('poster')
%matplotlib inline
Using gpu device 0: Quadro 2000
In [7]:
# Read in the protease inhibitor data
widths = [8]
widths.extend([4]*8)
widths.extend([4]*99)
data = pd.read_csv('hiv-nrt-data.csv', index_col='SeqID')
data
Out[7]:
3TC
ABC
AZT
D4T
DDI
TDF
P1
P2
P3
P4
...
P231
P232
P233
P234
P235
P236
P237
P238
P239
P240
SeqID
2997
200.00
4.3
3.9
1.4
1.2
NaN
-
-
-
-
...
-
-
-
-
-
-
-
-
-
-
4388
200.00
8.1
4.7
1.8
1.7
NaN
-
-
-
-
...
-
-
-
-
-
-
-
-
-
.
4427
1.40
1.1
28.0
1.0
0.8
1.9
-
-
-
-
...
-
-
-
-
-
-
-
-
-
-
4487
1.80
1.5
7.1
1.2
1.1
1.1
-
-
-
-
...
-
-
-
-
-
-
-
-
-
-
4539
200.00
3.7
4.5
2.0
1.3
NaN
-
-
-
-
...
.
.
.
.
.
.
.
.
.
.
4663
200.00
7.1
0.8
1.3
1.9
NaN
-
-
-
-
...
-
-
-
-
-
-
-
-
-
-
4689
2.55
5.2
6.7
5.4
5.3
0.9
-
-
-
-
...
-
-
-
-
-
-
-
-
-
-
4697
200.00
5.9
1.9
1.3
1.7
NaN
-
-
-
-
...
-
-
-
-
-
-
-
-
-
-
5071
200.00
6.3
5.2
1.4
1.9
0.9
-
-
-
-
...
-
-
-
-
-
-
-
-
-
-
5222
200.00
5.7
0.3
1.2
2.0
NaN
-
-
-
-
...
-
-
-
-
-
-
-
-
-
-
5280
6.20
4.5
1000.0
2.4
1.6
5.0
-
-
-
-
...
-
-
-
-
-
-
-
-
-
-
5445
3.30
4.9
177.0
2.9
1.6
2.2
-
-
-
-
...
-
-
-
-
-
-
-
-
-
-
5463
3.30
6.5
40.0
9.5
9.2
1.2
-
-
-
-
...
-
-
-
-
-
-
-
-
-
-
5465
200.00
7.0
77.0
2.7
2.1
NaN
-
-
-
-
...
-
-
-
-
-
-
-
-
-
-
5641
200.00
4.2
1.5
1.2
1.3
0.9
-
-
-
-
...
-
-
-
-
-
-
-
-
-
-
5708
200.00
12.0
1000.0
7.2
2.4
3.7
-
-
-
-
...
-
-
-
-
-
-
-
-
-
-
6029
7.30
18.0
2400.0
16.0
4.8
NaN
-
-
-
-
...
-
-
-
-
-
-
-
-
-
-
6485
1.00
0.6
0.2
0.4
0.7
0.4
-
-
-
-
...
-
-
-
-
-
-
-
-
-
-
6519
200.00
2.7
0.4
0.8
1.3
NaN
-
-
-
-
...
-
-
-
LH
-
-
-
-
-
-
6540
1.10
1.0
0.3
0.8
1.0
0.7
-
-
-
-
...
-
-
-
-
-
-
-
-
-
-
6569
3.90
5.3
216.5
4.2
1.8
NaN
-
-
-
-
...
-
-
-
-
-
-
-
-
-
-
6709
0.90
0.8
0.5
0.8
1.0
0.7
-
-
-
-
...
-
-
-
-
-
-
-
-
-
-
6796
7.10
7.0
259.0
4.4
2.4
NaN
-
-
-
-
...
-
-
-
-
-
-
-
-
-
-
6820
4.80
1.6
24.0
1.8
1.4
NaN
-
-
-
-
...
-
-
-
-
-
-
E
-
-
-
6859
1.40
2.3
42.0
2.0
1.0
NaN
-
-
-
-
...
-
-
-
-
-
-
-
-
-
-
6876
200.00
9.6
35.4
3.3
2.1
NaN
-
-
-
-
...
-
-
-
-
-
-
-
-
-
-
7328
3.60
1.7
7.2
1.5
1.1
NaN
-
-
-
-
...
-
-
-
-
-
-
-
-
-
-
7347
200.00
9.9
100.5
3.9
2.1
NaN
-
-
-
-
...
-
-
-
-
-
-
-
-
-
-
7348
200.00
7.7
7.9
2.3
1.8
NaN
-
-
-
-
...
-
-
-
-
-
-
-
-
-
-
7350
200.00
4.6
0.6
1.0
1.6
NaN
-
-
-
-
...
-
-
-
-
-
-
-
-
-
-
...
...
...
...
...
...
...
...
...
...
...
...
...
...
...
...
...
...
...
...
...
...
259208
1.25
1.0
0.6
0.9
1.0
0.9
-
-
-
-
...
-
-
-
-
-
-
-
-
-
-
259210
81.70
1.6
0.2
0.6
1.6
0.3
-
-
-
-
...
-
-
-
-
-
-
-
-
-
-
259212
1.00
0.7
0.4
0.8
0.8
0.7
-
-
-
-
...
-
-
-
-
-
-
-
-
-
-
259214
74.00
1.5
0.2
0.7
1.4
0.4
-
-
-
-
...
-
-
-
-
-
-
-
-
-
-
259216
1.05
0.9
0.8
1.2
1.3
0.9
-
-
-
-
...
-
-
-
-
-
-
-
-
-
-
259222
29.50
3.6
0.4
1.2
1.7
2.0
-
-
-
-
...
-
-
-
-
-
-
-
-
-
-
259224
1.40
1.0
0.6
0.9
1.0
0.9
-
-
-
-
...
-
-
-
-
-
-
-
-
-
-
259226
113.00
1.8
0.2
0.8
1.6
0.5
-
-
-
-
...
-
-
-
-
-
-
-
N
-
-
259228
1.30
0.9
0.5
0.7
0.7
0.8
-
-
-
-
...
-
-
-
-
-
-
-
-
-
-
259230
91.50
3.0
0.2
0.5
1.2
0.2
-
-
-
-
...
-
-
-
-
-
-
-
-
-
-
259234
1.15
1.0
0.8
1.1
1.3
1.0
-
-
-
-
...
-
-
-
-
-
-
-
-
-
-
259238
1.30
1.0
0.6
0.9
1.2
0.9
-
-
-
-
...
-
-
-
-
-
-
-
-
-
-
259242
0.90
0.8
0.7
0.9
0.9
0.9
-
-
-
-
...
-
-
-
-
-
-
-
-
-
-
259244
88.50
5.2
0.2
0.6
2.0
0.3
-
-
-
-
...
-
-
-
-
-
-
-
-
-
-
259246
1.05
0.8
0.8
1.1
1.3
0.9
-
-
-
-
...
-
-
-
-
-
-
-
-
-
-
259248
82.10
1.9
0.3
0.8
1.6
0.5
-
-
-
-
...
-
-
-
-
-
-
-
-
-
-
259250
1.15
0.9
0.6
0.9
1.2
0.8
-
-
-
-
...
-
-
-
-
-
-
-
-
-
-
259252
87.00
3.8
0.4
0.9
1.9
0.7
-
-
-
-
...
-
-
-
-
-
-
-
-
-
-
259254
1.40
0.9
0.6
0.9
0.8
0.8
-
-
-
-
...
-
-
-
-
-
-
-
-
-
-
259256
91.50
3.0
0.2
0.6
1.1
0.2
-
-
-
-
...
-
-
-
-
-
-
-
-
-
-
259258
1.20
0.9
0.8
1.2
1.3
1.0
-
-
-
-
...
-
-
-
-
-
-
-
-
-
-
259260
104.50
4.4
0.6
0.7
1.9
0.8
-
-
-
-
...
-
-
-
-
-
-
-
-
-
-
259262
1.15
1.0
0.5
0.8
0.8
0.7
-
-
-
-
...
-
-
-
-
-
-
-
-
-
-
259266
1.05
0.7
0.3
0.8
0.8
0.6
-
-
-
-
...
-
-
-
-
-
-
-
-
-
-
259268
125.50
2.9
0.2
0.7
1.2
0.4
-
-
-
-
...
-
-
-
-
-
-
-
-
-
-
270334
1.15
1.4
1.2
1.4
1.2
1.1
-
-
-
-
...
-
-
-
-
-
-
-
-
-
-
270335
89.40
3.2
0.8
1.6
2.4
0.8
-
-
-
-
...
-
-
-
-
-
-
-
-
-
-
270336
89.40
2.8
0.5
1.4
2.2
0.6
-
-
-
-
...
-
-
-
-
-
-
-
-
-
-
270337
1.20
1.4
1.8
1.4
1.2
1.2
-
-
-
-
...
-
-
-
-
-
-
-
-
-
-
270338
0.70
0.8
0.4
1.0
0.8
0.7
-
-
-
-
...
-
-
-
-
-
-
-
-
-
-
1498 rows × 246 columns
In [9]:
# Set the drug data columns and the amino acid data columns
drug_cols = data.columns[0:6]
feat_cols = data.columns[6:]
In [10]:
# Read in the consensus data
consensus = SeqIO.read('hiv-rt-consensus.fasta', 'fasta')
consensus_map = {i+1:letter for i, letter in enumerate(str(consensus.seq))}
In [11]:
# Because there are '-' characters in the dataset, representing consensus sequence at each of the positions,
# they need to be replaced with the actual consensus letter.
for i, col in enumerate(feat_cols):
# Replace '-' with the consensus letter.
data[col] = data[col].replace({'-':consensus_map[i+1]})
# Replace '.' with np.nan
data[col] = data[col].replace({'.':np.nan})
# Replace 'X' with np.nan
data[col] = data[col].replace({'X':np.nan})
In [12]:
# Drop any feat_cols that have np.nan inside them. We don't want low quality sequences.
data.dropna(inplace=True, subset=feat_cols)
In [13]:
data
Out[13]:
3TC
ABC
AZT
D4T
DDI
TDF
P1
P2
P3
P4
...
P231
P232
P233
P234
P235
P236
P237
P238
P239
P240
SeqID
4427
1.40
1.1
28.0
1.0
0.8
1.9
P
I
S
P
...
G
Y
E
L
H
P
D
K
W
T
4487
1.80
1.5
7.1
1.2
1.1
1.1
P
I
S
P
...
G
Y
E
L
H
P
D
K
W
T
4663
200.00
7.1
0.8
1.3
1.9
NaN
P
I
S
P
...
G
Y
E
L
H
P
D
K
W
T
4689
2.55
5.2
6.7
5.4
5.3
0.9
P
I
S
P
...
G
Y
E
L
H
P
D
K
W
T
4697
200.00
5.9
1.9
1.3
1.7
NaN
P
I
S
P
...
G
Y
E
L
H
P
D
K
W
T
5071
200.00
6.3
5.2
1.4
1.9
0.9
P
I
S
P
...
G
Y
E
L
H
P
D
K
W
T
5222
200.00
5.7
0.3
1.2
2.0
NaN
P
I
S
P
...
G
Y
E
L
H
P
D
K
W
T
5280
6.20
4.5
1000.0
2.4
1.6
5.0
P
I
S
P
...
G
Y
E
L
H
P
D
K
W
T
5445
3.30
4.9
177.0
2.9
1.6
2.2
P
I
S
P
...
G
Y
E
L
H
P
D
K
W
T
5463
3.30
6.5
40.0
9.5
9.2
1.2
P
I
S
P
...
G
Y
E
L
H
P
D
K
W
T
5465
200.00
7.0
77.0
2.7
2.1
NaN
P
I
S
P
...
G
Y
E
L
H
P
D
K
W
T
5641
200.00
4.2
1.5
1.2
1.3
0.9
P
I
S
P
...
G
Y
E
L
H
P
D
K
W
T
5708
200.00
12.0
1000.0
7.2
2.4
3.7
P
I
S
P
...
G
Y
E
L
H
P
D
K
W
T
6029
7.30
18.0
2400.0
16.0
4.8
NaN
P
I
S
P
...
G
Y
E
L
H
P
D
K
W
T
6485
1.00
0.6
0.2
0.4
0.7
0.4
P
I
S
P
...
G
Y
E
L
H
P
D
K
W
T
6519
200.00
2.7
0.4
0.8
1.3
NaN
P
I
S
P
...
G
Y
E
LH
H
P
D
K
W
T
6540
1.10
1.0
0.3
0.8
1.0
0.7
P
I
S
P
...
G
Y
E
L
H
P
D
K
W
T
6569
3.90
5.3
216.5
4.2
1.8
NaN
P
I
S
P
...
G
Y
E
L
H
P
D
K
W
T
6709
0.90
0.8
0.5
0.8
1.0
0.7
P
I
S
P
...
G
Y
E
L
H
P
D
K
W
T
6796
7.10
7.0
259.0
4.4
2.4
NaN
P
I
S
P
...
G
Y
E
L
H
P
D
K
W
T
6820
4.80
1.6
24.0
1.8
1.4
NaN
P
I
S
P
...
G
Y
E
L
H
P
E
K
W
T
6859
1.40
2.3
42.0
2.0
1.0
NaN
P
I
S
P
...
G
Y
E
L
H
P
D
K
W
T
6876
200.00
9.6
35.4
3.3
2.1
NaN
P
I
S
P
...
G
Y
E
L
H
P
D
K
W
T
7328
3.60
1.7
7.2
1.5
1.1
NaN
P
I
S
P
...
G
Y
E
L
H
P
D
K
W
T
7347
200.00
9.9
100.5
3.9
2.1
NaN
P
I
S
P
...
G
Y
E
L
H
P
D
K
W
T
7348
200.00
7.7
7.9
2.3
1.8
NaN
P
I
S
P
...
G
Y
E
L
H
P
D
K
W
T
7350
200.00
4.6
0.6
1.0
1.6
NaN
P
I
S
P
...
G
Y
E
L
H
P
D
K
W
T
7360
200.00
6.7
6.5
1.6
1.4
NaN
P
I
S
P
...
G
Y
E
L
H
P
D
K
W
T
7364
200.00
7.8
7.9
2.3
1.7
NaN
P
I
S
P
...
G
Y
E
L
H
P
D
K
W
T
7377
200.00
29.3
726.0
10.1
4.1
NaN
P
I
S
P
...
G
Y
E
L
H
P
D
K
W
T
...
...
...
...
...
...
...
...
...
...
...
...
...
...
...
...
...
...
...
...
...
...
259204
1.50
1.0
0.5
0.8
0.9
0.8
P
I
S
P
...
G
Y
E
L
H
P
D
K
W
T
259206
121.00
1.6
0.5
0.7
1.3
0.5
P
I
S
P
...
G
Y
E
L
H
P
D
K
W
T
259208
1.25
1.0
0.6
0.9
1.0
0.9
P
I
S
P
...
G
Y
E
L
H
P
D
K
W
T
259210
81.70
1.6
0.2
0.6
1.6
0.3
P
I
S
P
...
G
Y
E
L
H
P
D
K
W
T
259214
74.00
1.5
0.2
0.7
1.4
0.4
P
I
S
P
...
G
Y
E
L
H
P
D
K
W
T
259216
1.05
0.9
0.8
1.2
1.3
0.9
P
I
S
P
...
G
Y
E
L
H
P
D
K
W
T
259222
29.50
3.6
0.4
1.2
1.7
2.0
P
I
S
P
...
G
Y
E
L
H
P
D
K
W
T
259224
1.40
1.0
0.6
0.9
1.0
0.9
P
I
S
P
...
G
Y
E
L
H
P
D
K
W
T
259226
113.00
1.8
0.2
0.8
1.6
0.5
P
I
S
P
...
G
Y
E
L
H
P
D
N
W
T
259228
1.30
0.9
0.5
0.7
0.7
0.8
P
I
S
P
...
G
Y
E
L
H
P
D
K
W
T
259230
91.50
3.0
0.2
0.5
1.2
0.2
P
I
S
P
...
G
Y
E
L
H
P
D
K
W
T
259234
1.15
1.0
0.8
1.1
1.3
1.0
P
I
S
P
...
G
Y
E
L
H
P
D
K
W
T
259238
1.30
1.0
0.6
0.9
1.2
0.9
P
I
S
P
...
G
Y
E
L
H
P
D
K
W
T
259242
0.90
0.8
0.7
0.9
0.9
0.9
P
I
S
P
...
G
Y
E
L
H
P
D
K
W
T
259244
88.50
5.2
0.2
0.6
2.0
0.3
P
I
S
P
...
G
Y
E
L
H
P
D
K
W
T
259246
1.05
0.8
0.8
1.1
1.3
0.9
P
I
S
P
...
G
Y
E
L
H
P
D
K
W
T
259248
82.10
1.9
0.3
0.8
1.6
0.5
P
I
S
P
...
G
Y
E
L
H
P
D
K
W
T
259250
1.15
0.9
0.6
0.9
1.2
0.8
P
I
S
P
...
G
Y
E
L
H
P
D
K
W
T
259252
87.00
3.8
0.4
0.9
1.9
0.7
P
I
S
P
...
G
Y
E
L
H
P
D
K
W
T
259254
1.40
0.9
0.6
0.9
0.8
0.8
P
I
S
P
...
G
Y
E
L
H
P
D
K
W
T
259256
91.50
3.0
0.2
0.6
1.1
0.2
P
I
S
P
...
G
Y
E
L
H
P
D
K
W
T
259258
1.20
0.9
0.8
1.2
1.3
1.0
P
I
S
P
...
G
Y
E
L
H
P
D
K
W
T
259260
104.50
4.4
0.6
0.7
1.9
0.8
P
I
S
P
...
G
Y
E
L
H
P
D
K
W
T
259266
1.05
0.7
0.3
0.8
0.8
0.6
P
I
S
P
...
G
Y
E
L
H
P
D
K
W
T
259268
125.50
2.9
0.2
0.7
1.2
0.4
P
I
S
P
...
G
Y
E
L
H
P
D
K
W
T
270334
1.15
1.4
1.2
1.4
1.2
1.1
P
I
S
P
...
G
Y
E
L
H
P
D
K
W
T
270335
89.40
3.2
0.8
1.6
2.4
0.8
P
I
S
P
...
G
Y
E
L
H
P
D
K
W
T
270336
89.40
2.8
0.5
1.4
2.2
0.6
P
I
S
P
...
G
Y
E
L
H
P
D
K
W
T
270337
1.20
1.4
1.8
1.4
1.2
1.2
P
I
S
P
...
G
Y
E
L
H
P
D
K
W
T
270338
0.70
0.8
0.4
1.0
0.8
0.7
P
I
S
P
...
G
Y
E
L
H
P
D
K
W
T
1263 rows × 246 columns
In [14]:
# Drop any feat_cols that are completely conserved.
# The nonconserved_cols list will serve as a convenient selector for the X- data from the
# original dataframe.
nonconserved_cols = []
for col in feat_cols:
if len(pd.unique(data[col])) == 1:
data.drop(col, axis=1, inplace=True)
else:
nonconserved_cols.append(col)
In [15]:
drug_cols
Out[15]:
Index(['3TC', 'ABC', 'AZT', 'D4T', 'DDI', 'TDF'], dtype='object')
In [136]:
def x_equals_y(y_test):
"""
A function that returns a range from minimum to maximum of y_test.
Used below in the plotting below.
"""
floor = math.floor(np.min(y_test))
ceil = math.ceil(np.max(y_test))
x_eq_y = range(floor, ceil)
return x_eq_y
TWOPLACES = Decimal(10) ** -2
In [152]:
colnum = 3
drug_df = pd.DataFrame()
drug_df[drug_cols[colnum]] = data[drug_cols[colnum]]
drug_df[nonconserved_cols] = data[nonconserved_cols]
for col in nonconserved_cols:
drug_df[col] = drug_df[col].apply(lambda x: np.nan if len(x) > 1 else x)
drug_df.dropna(inplace=True)
drug_X = drug_df[nonconserved_cols]
drug_Y = drug_df[drug_cols[colnum]].apply(lambda x:np.log(x))
# drug_Y.values
In [153]:
from isoelectric_point import isoelectric_points
from molecular_weight import molecular_weights
# Standardize pI matrix. 7 is neutral
drug_X_pi = drug_X.replace(isoelectric_points)
# Standardize MW matrix.
drug_X_mw = drug_X.replace(molecular_weights)
# Binarize drug_X matrix.
from sklearn.preprocessing import LabelBinarizer
drug_X_bi = pd.DataFrame()
binarizers = dict()
for col in drug_X.columns:
lb = LabelBinarizer()
binarized_cols = lb.fit_transform(drug_X[col])
# print(binarized_cols)
if len(lb.classes_) == 2:
# print(binarized_cols)
drug_X_bi[col] = pd.Series(binarized_cols[:,0])
else:
for i, c in enumerate(lb.classes_):
# print(col + c)
# print(binarized_cols[:,i])
drug_X_bi[col + '_' + c] = binarized_cols[:,i]
drug_X_bi
Out[153]:
P1
P2
P3_C
P3_N
P3_S
P4_P
P4_S
P4_T
P5_I
P5_S
...
P237_E
P237_N
P238_K
P238_R
P238_T
P239
P240_A
P240_E
P240_K
P240_T
0
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
1
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
2
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
3
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
4
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
5
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
6
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
7
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
8
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
9
1
0
0
0
1
1
0
0
1
0
...
1
0
1
0
0
1
0
0
0
1
10
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
11
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
12
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
13
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
14
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
15
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
16
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
17
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
18
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
19
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
20
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
21
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
22
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
23
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
24
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
25
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
26
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
27
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
28
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
29
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
...
...
...
...
...
...
...
...
...
...
...
...
...
...
...
...
...
...
...
...
...
...
428
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
429
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
430
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
431
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
432
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
433
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
434
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
435
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
436
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
437
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
438
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
439
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
440
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
441
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
442
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
443
1
0
0
0
1
0
1
0
1
0
...
0
0
1
0
0
1
0
0
0
1
444
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
445
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
446
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
447
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
448
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
449
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
450
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
451
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
452
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
453
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
454
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
455
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
456
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
457
1
0
0
0
1
1
0
0
1
0
...
0
0
1
0
0
1
0
0
0
1
458 rows × 482 columns
In [154]:
fig = plt.figure(figsize=(3,3))
drug_Y.hist(grid=False)
plt.xlabel('Value')
plt.ylabel('Count')
plt.title('{0} Distribution'.format(drug_cols[colnum]))
Out[154]:
<matplotlib.text.Text at 0x7fa5c215d4e0>
In [155]:
# Here, let's try the Random Forest Regressor. This will be the baseline.
x_train, x_test, y_train, y_test = train_test_split(drug_X_bi, drug_Y)
rfr = RandomForestRegressor(n_estimators=500, n_jobs=-1, oob_score=True)
rfr.fit(x_train, y_train)
rfr_preds = rfr.predict(x_test)
print(rfr.score(x_test, y_test), mean_squared_error(rfr_preds, y_test))
rfr_mse = mean_squared_error(rfr_preds, y_test)
# print(rfr.oob_score_)
sps.pearsonr(rfr_preds, y_test)
plt.figure(figsize=(3,3))
plt.scatter(y_test, rfr_preds,)
plt.title('{0} Random Forest'.format(drug_cols[colnum]))
plt.xlabel('Actual')
plt.ylabel('Predicted')
plt.gca().set_aspect('equal', 'datalim')
plt.annotate(s='mse: {0}'.format(str(Decimal(rfr_mse).quantize(TWOPLACES))), xy=(1,0), xycoords='axes fraction', ha='right', va='bottom')
plt.plot(x_equals_y(y_test), x_equals_y(y_test), color='red')
plt.show()
0.786368305999 0.132429943561
In [156]:
plt.bar(range(len(rfr.feature_importances_)), rfr.feature_importances_)
plt.xlabel('Position')
plt.ylabel('Relative Importance')
plt.title('{0} Random Forest'.format(drug_cols[colnum]))
Out[156]:
<matplotlib.text.Text at 0x7fa5c20d3b38>
In [157]:
# Get back the importance of each feature.
feat_impt = [(p, i) for p, i in zip(drug_X_bi.columns, rfr.feature_importances_)]
sorted(feat_impt, key=lambda x:x[1], reverse=True)
Out[157]:
[('P151', 0.16251462938669017),
('P215_T', 0.16006592453195229),
('P62', 0.089598402308289846),
('P210_W', 0.04805732919349797),
('P69_#', 0.043176594447178349),
('P210_L', 0.03820719628644869),
('P116', 0.036923518809471458),
('P41_L', 0.027512755954099771),
('P77', 0.018855167959727394),
('P67_D', 0.018461063073661806),
('P70_K', 0.016167279408335643),
('P40_E', 0.010868828402363542),
('P215_D', 0.0097352909437584643),
('P40_F', 0.0095284336496718729),
('P118', 0.0093069566910762534),
('P69_T', 0.0091408832193107317),
('P184_V', 0.0089808668340186793),
('P65', 0.0078659771260242446),
('P67_N', 0.0078486143018055513),
('P70_N', 0.0076543349919679102),
('P70_d', 0.0072499899020392116),
('P215_Y', 0.0071962504233967191),
('P223_N', 0.0062562829205002248),
('P43_E', 0.0060866113487613007),
('P73', 0.0060846705519703387),
('P70_R', 0.0060199337114942904),
('P139_S', 0.0053925554834850821),
('P41_M', 0.0052268298420678544),
('P184_M', 0.0047365291124894888),
('P135_T', 0.0047257702295656111),
('P208', 0.0045941531295171948),
('P219_K', 0.0042685057849632061),
('P75_V', 0.0040200811033691917),
('P218', 0.0038741506470258355),
('P174_R', 0.0038138763719247435),
('P58_N', 0.0036469640236670266),
('P190_A', 0.0034880680895961361),
('P181_C', 0.0034809909691663351),
('P203_K', 0.0034169068828247013),
('P181_Y', 0.0034023562884997189),
('P74_V', 0.0033380187605316792),
('P138_K', 0.0032196758142861827),
('P214', 0.0030899144516060234),
('P74_L', 0.0030685216674645865),
('P207_E', 0.0027232206439434365),
('P39_T', 0.0026421469117063161),
('P75_T', 0.0025827401321966976),
('P135_V', 0.0025695161872389847),
('P200_T', 0.0025636478473387938),
('P139_T', 0.0024212663266235321),
('P207_Q', 0.002402825396636144),
('P64_H', 0.002383787045730263),
('P123_E', 0.0023502644301236194),
('P69_D', 0.0023465037879284579),
('P190_G', 0.0022155694228006011),
('P122_E', 0.002213964890623938),
('P211_R', 0.0021859602876278584),
('P102_Q', 0.0021678505224306203),
('P44_E', 0.0020674696042807359),
('P215_F', 0.0019449153773165285),
('P219_E', 0.0019272493402581862),
('P203_E', 0.0018550549534186885),
('P196_E', 0.0018383325223303162),
('P122_K', 0.0018347234940733837),
('P178_I', 0.001806222233384819),
('P123_D', 0.0017749513060090457),
('P69_N', 0.0017700664947200423),
('P135_I', 0.0017449729362553255),
('P75_I', 0.0017440799295997303),
('P103_K', 0.0017227293956739301),
('P219_R', 0.0016741067913017835),
('P43_K', 0.0016735102186183708),
('P102_K', 0.0016675568335350174),
('P228_H', 0.0016663261354250669),
('P196_G', 0.0016541932141252525),
('P200_A', 0.0016447160208774292),
('P162_S', 0.00159023188765284),
('P138_E', 0.001478400658828117),
('P177_E', 0.0014627745557714018),
('P219_Q', 0.0014482811153136784),
('P35_T', 0.0014381497441807549),
('P188_L', 0.0013285168302978377),
('P211_K', 0.0013093825552678343),
('P103_N', 0.001248617556734215),
('P179_I', 0.0012328551602509344),
('P162_C', 0.0011982045240013577),
('P228_L', 0.0011635428973818041),
('P178_L', 0.0011481025443574502),
('P238_K', 0.0011445712967038549),
('P35_V', 0.0011024788869958006),
('P146', 0.0010982585955854973),
('P64_K', 0.0010920025818014561),
('P177_D', 0.0010905129160233638),
('P178_F', 0.0010693452316058952),
('P39_A', 0.0010484497888503823),
('P20_K', 0.0010475814864030529),
('P101_E', 0.0010354309956481218),
('P101_K', 0.0010101778830088773),
('P188_Y', 0.000998000612165823),
('P228_R', 0.00097952916267022867),
('P68_G', 0.00096397487205400543),
('P108', 0.00095849356763606428),
('P115', 0.00094286952675548899),
('P162_A', 0.00093144483875862323),
('P238_T', 0.00090917464128133089),
('P60_I', 0.00090070518000247117),
('P67_E', 0.00089224907758878632),
('P68_S', 0.00086652663411294559),
('P179_V', 0.00086618235778399889),
('P83', 0.00085505255029670385),
('P103_S', 0.00085144885179626469),
('P121_Y', 0.00085126753806842744),
('P20_R', 0.00079379237648687486),
('P203_D', 0.00077172511141351081),
('P202', 0.00075929783491084622),
('P180', 0.00073909838686148088),
('P174_Q', 0.00073874043164542686),
('P200_I', 0.00073162255681505047),
('P123_N', 0.00070153937678211304),
('P178_M', 0.00069111021143699323),
('P60_V', 0.00068919471690468239),
('P207_N', 0.00067385270252314779),
('P142_I', 0.00065640298445170553),
('P227_L', 0.00065024723775872889),
('P35_I', 0.00064649900469493577),
('P44_D', 0.00064010540730884949),
('P90', 0.00063989615343746926),
('P106_V', 0.00063507207005211491),
('P98_G', 0.00062887490465580531),
('P172', 0.00061024702564071141),
('P28_E', 0.00059981013085939265),
('P98_A', 0.00057534786487980669),
('P35_M', 0.00056911894252802164),
('P142_V', 0.00055137585045650437),
('P166_R', 0.00054977243910318146),
('P237_D', 0.00054775854075780049),
('P179_F', 0.00052963131097299315),
('P49', 0.00052512110059578228),
('P16', 0.00049606463759768607),
('P138_G', 0.000486103946357874),
('P5_V', 0.00045940617868891722),
('P219_H', 0.00045381112961015966),
('P197_Q', 0.00044784706876806028),
('P227_F', 0.0004473313209620604),
('P5_I', 0.00042971435742582422),
('P211_N', 0.00042786537379451052),
('P215_S', 0.00041865725477986172),
('P89', 0.00041764304987800856),
('P223_K', 0.00040976773738635435),
('P68_T', 0.00039561065486796288),
('P43_Q', 0.00037804569690291247),
('P44_A', 0.00037484028070881067),
('P139_A', 0.00035590141937412467),
('P219_N', 0.00035154342032472273),
('P123_I', 0.00034953177682979503),
('P39_E', 0.00034730857148287779),
('P135_L', 0.00032111616800641442),
('P184_I', 0.00031939754644173097),
('P121_D', 0.00031732579512637157),
('P173_K', 0.00031615518762627149),
('P221', 0.00031455011193028858),
('P6_K', 0.00030830883055275345),
('P181_I', 0.00030233783360660249),
('P98_S', 0.00028876819567226215),
('P32_K', 0.00028124417542713646),
('P166_K', 0.0002808748111850798),
('P11_R', 0.00028085164193501294),
('P67_G', 0.00028084776657503395),
('P142_T', 0.00028032331721691682),
('P237_E', 0.00027972256099027565),
('P162_D', 0.00027099659813596829),
('P197_E', 0.00026770517690369783),
('P207_D', 0.00026727523712224109),
('P135_M', 0.00026651589804190277),
('P33_V', 0.00026081957296777335),
('P11_K', 0.0002544333961544424),
('P106_A', 0.00025252978521046244),
('P169_D', 0.00025055970776281003),
('P138_A', 0.00024359386375885614),
('P211_S', 0.00024350310232222971),
('P173_R', 0.00024258349397140316),
('P1', 0.00023332647291195526),
('P27', 0.00023226793237666573),
('P179_D', 0.00023137107928803926),
('P58_T', 0.00022655002519174334),
('P100', 0.00022014608755782165),
('P190_C', 0.00020714297387020066),
('P207_G', 0.00020647995815651133),
('P211_T', 0.00020404825050321954),
('P230', 0.00020287080602791622),
('P101_Q', 0.00020120962784956936),
('P157', 0.00020000058336736433),
('P190_S', 0.00019419345405297073),
('P74_I', 0.00018296967171673874),
('P48_S', 0.0001775163749034273),
('P104_R', 0.00017244319760841673),
('P111', 0.000167731611715598),
('P174_K', 0.0001671768480088083),
('P31_V', 0.00016695766584317567),
('P33_A', 0.00016554905591164114),
('P109', 0.00016410756193115599),
('P64_R', 0.00016334690794083678),
('P101_H', 0.00016073242777569233),
('P4_S', 0.000159582501610235),
('P173_Q', 0.00015518873743907723),
('P28_K', 0.00015296418244316532),
('P169_E', 0.00015235216899203938),
('P32_E', 0.00015230190426268028),
('P48_T', 0.00014931600475534391),
('P101_R', 0.00014790045863572254),
('P173_N', 0.00014770509449362994),
('P4_P', 0.0001469270609201563),
('P239', 0.00014672393892225505),
('P106_I', 0.00014256246536649794),
('P13', 0.0001413177624044851),
('P31_I', 0.00013858828937927102),
('P121_H', 0.00013747673450159318),
('P215_V', 0.000136070845889713),
('P37', 0.00013531876061213566),
('P207_K', 0.00013337564255884132),
('P200_R', 0.00013174523839097946),
('P223_Q', 0.00013096240494140797),
('P28_A', 0.00012548340756437777),
('P138_Q', 0.00011532785441760907),
('P207_A', 0.00011039518865872977),
('P197_K', 0.00010950928712528276),
('P6_E', 0.00010824725552461508),
('P43_N', 0.00010378416513680041),
('P139_R', 9.9636100065719714e-05),
('P3_S', 9.8314835546991982e-05),
('P47_V', 9.8145633287732452e-05),
('P104_K', 9.2512522314128646e-05),
('P207_H', 9.1557056511385934e-05),
('P35_R', 8.5729676593072237e-05),
('P224_D', 8.5043931640428925e-05),
('P224_E', 8.277277434848758e-05),
('P163', 8.2429008848087891e-05),
('P63', 8.0182907627834556e-05),
('P179_E', 7.90508869786849e-05),
('P39_K', 7.8996196163410209e-05),
('P5_S', 7.8763581406593054e-05),
('P123_S', 7.8306273276253831e-05),
('P122_P', 7.7440964987469778e-05),
('P6_D', 7.6734583529687543e-05),
('P165_T', 7.5866324719554409e-05),
('P225', 7.5306144414764273e-05),
('P85', 7.4260000014479368e-05),
('P200_E', 7.4010491294990887e-05),
('P195_L', 7.3414017088698944e-05),
('P32_T', 7.2079978204739242e-05),
('P195_I', 7.1311497389698543e-05),
('P197_L', 7.0086870127877187e-05),
('P211_Q', 6.8967116529355059e-05),
('P224_K', 6.8384307989135959e-05),
('P166_T', 6.793157761611871e-05),
('P69_S', 6.5746514933650235e-05),
('P101_P', 6.5001407055795941e-05),
('P132', 6.4888442329120034e-05),
('P198', 6.1937193164967217e-05),
('P35_K', 6.1207994804328268e-05),
('P238_R', 6.0920435106250203e-05),
('P223_E', 6.089911496557083e-05),
('P215_I', 6.0291401747136104e-05),
('P102_R', 6.0072874363770748e-05),
('P174_H', 5.8923496579779993e-05),
('P122_Q', 5.8823392940486941e-05),
('P82', 5.7365951036787005e-05),
('P174_L', 5.4329223948507411e-05),
('P88_C', 5.3129844946612595e-05),
('P43_A', 5.1403624745991304e-05),
('P237_N', 5.057886566871371e-05),
('P215_L', 5.0287989972268171e-05),
('P204_E', 4.895528866502371e-05),
('P50', 4.6616079559863745e-05),
('P170', 4.5277446250194471e-05),
('P64_Y', 4.4686356333231143e-05),
('P181_V', 4.4139566915812721e-05),
('P39_N', 4.3575327951873241e-05),
('P35_L', 4.3280729834555589e-05),
('P177_G', 4.2939158480663363e-05),
('P69_G', 4.2712037213690829e-05),
('P7', 4.2577869169305751e-05),
('P3_N', 4.253990404912848e-05),
('P3_C', 4.0793962533628899e-05),
('P80', 4.0738013528253687e-05),
('P195_M', 4.0549062820673448e-05),
('P103_R', 4.0535151027103161e-05),
('P162_Y', 4.0502921610321552e-05),
('P46', 3.938733629927549e-05),
('P240_T', 3.922657631975248e-05),
('P104_N', 3.8375956322611281e-05),
('P174_E', 3.7667347299986625e-05),
('P43_R', 3.6967727822854313e-05),
('P106_M', 3.6706555739788651e-05),
('P75_M', 3.620893936993806e-05),
('P240_E', 3.496513401654536e-05),
('P135_K', 3.4296684429088164e-05),
('P21_V', 3.4283499227464581e-05),
('P28_G', 3.1606039987253433e-05),
('P211_A', 3.1339152264936688e-05),
('P165_I', 3.1264606669730081e-05),
('P211_G', 3.0904875189993984e-05),
('P88_W', 3.0696233897704535e-05),
('P39_I', 3.0225356134305558e-05),
('P200_K', 2.96926806919556e-05),
('P101_I', 2.860147433881264e-05),
('P22', 2.773397225078223e-05),
('P158', 2.6961852352201984e-05),
('P21_I', 2.6890565857257042e-05),
('P210_S', 2.6380376338860684e-05),
('P21_A', 2.5903507060762843e-05),
('P188_H', 2.5870275556699503e-05),
('P34', 2.5354038673889142e-05),
('P47_I', 2.5133841977865804e-05),
('P69_K', 2.5063630726307757e-05),
('P67_~', 2.4946710951120842e-05),
('P212', 2.4892262810345553e-05),
('P204_D', 2.4279655242071229e-05),
('P173_T', 2.4235630046209914e-05),
('P39_L', 2.3979867291791524e-05),
('P173_S', 2.3888613063677351e-05),
('P35_E', 2.3840376582614944e-05),
('P177_N', 2.3821915553087907e-05),
('P165_A', 2.325313210508823e-05),
('P39_S', 2.3180720272361477e-05),
('P31_L', 2.258466535340127e-05),
('P159_L', 2.2250669326182636e-05),
('P228_Q', 2.0646478625955823e-05),
('P189', 2.0242772947484979e-05),
('P200_S', 1.9167348467636126e-05),
('P107', 1.8460370565011147e-05),
('P58_G', 1.841829731588697e-05),
('P197_P', 1.7876763272087301e-05),
('P35_Q', 1.6521125261421695e-05),
('P102_M', 1.6337589824524213e-05),
('P190_E', 1.6089060311500773e-05),
('P122_A', 1.5362032752082303e-05),
('P204_N', 1.5026234928358834e-05),
('P6_G', 1.499251790325744e-05),
('P14', 1.4951797644508774e-05),
('P195_T', 1.3342347447179275e-05),
('P240_K', 1.3024231131735054e-05),
('P96', 1.2531349596052805e-05),
('P176_P', 1.2260048801154113e-05),
('P75_L', 1.2059981516383563e-05),
('P145', 1.1937069946213749e-05),
('P20_T', 1.1243103452339689e-05),
('P88_G', 1.1148211982685336e-05),
('P204_Q', 1.0980630032656169e-05),
('P113', 1.0272818990464138e-05),
('P66', 1.0040123196221622e-05),
('P48_V', 9.889363680492432e-06),
('P215_C', 9.6387834023666912e-06),
('P9', 9.6176332577349102e-06),
('P215_N', 9.2117924818115869e-06),
('P40_D', 8.2212612156804012e-06),
('P159_I', 8.1435250139559716e-06),
('P159_V', 8.0165716589835687e-06),
('P228_F', 7.9687525674008383e-06),
('P210_F', 7.8842792091922328e-06),
('P36', 7.3820368195924407e-06),
('P128', 6.6159503856832758e-06),
('P127', 6.4648237411951987e-06),
('P161', 6.170217236359231e-06),
('P200_V', 5.9395016760378583e-06),
('P173_E', 5.6872943675120443e-06),
('P11_T', 5.5186160983299495e-06),
('P70_G', 5.2656507077404917e-06),
('P86', 4.9187649220681258e-06),
('P101_A', 4.794689076618831e-06),
('P176_T', 4.5354871983633148e-06),
('P176_S', 4.4820962149087409e-06),
('P69_A', 4.4233211338960185e-06),
('P33_G', 4.1612184324786338e-06),
('P19', 4.0936989626473446e-06),
('P227_Y', 4.0702759885375919e-06),
('P223_R', 3.9869755081948724e-06),
('P200_L', 3.8689378684052706e-06),
('P135_R', 3.8263114032032092e-06),
('P69_I', 3.8148496017361344e-06),
('P53', 3.5117395389113594e-06),
('P173_I', 3.3270858084319144e-06),
('P123_G', 3.3176536042403781e-06),
('P182', 3.1849964653397612e-06),
('P195_K', 3.1153300879020901e-06),
('P164', 3.1051135760467751e-06),
('P8', 3.0457801694138396e-06),
('P40_K', 2.9025830231359228e-06),
('P45', 2.6725090075384944e-06),
('P167', 2.3928120733466434e-06),
('P165_L', 2.1256630019541142e-06),
('P173_A', 1.9348268249657992e-06),
('P60_A', 1.8128465735366147e-06),
('P199', 1.7584412182626765e-06),
('P144', 1.5109722083712416e-06),
('P196_K', 1.4619230053096696e-06),
('P240_A', 1.4017245065337739e-06),
('P2', 1.3437155057732335e-06),
('P35_A', 1.2998876912630738e-06),
('P79', 7.0993410033283175e-07),
('P32_R', 6.0690878075020209e-07),
('P102_E', 5.6410767082531577e-07),
('P169_K', 1.7175941527700947e-07),
('P236', 1.073097201017753e-07),
('P226', 6.0113693215943616e-09),
('P4_T', 0.0),
('P10_V', 0.0),
('P15_G', 0.0),
('P18_G', 0.0),
('P23_Q', 0.0),
('P24_W', 0.0),
('P25_P', 0.0),
('P29_E', 0.0),
('P30_K', 0.0),
('P32_N', 0.0),
('P38_C', 0.0),
('P41_I', 0.0),
('P42_E', 0.0),
('P47_F', 0.0),
('P52_P', 0.0),
('P54', 0.0),
('P56_Y', 0.0),
('P59_P', 0.0),
('P64_N', 0.0),
('P68_K', 0.0),
('P70_E', 0.0),
('P72_R', 0.0),
('P75_A', 0.0),
('P84_T', 0.0),
('P87_F', 0.0),
('P91_Q', 0.0),
('P93_G', 0.0),
('P94_I', 0.0),
('P95_P', 0.0),
('P101_S', 0.0),
('P105_S', 0.0),
('P110_D', 0.0),
('P117_S', 0.0),
('P119_P', 0.0),
('P122_R', 0.0),
('P124_F', 0.0),
('P125_R', 0.0),
('P126_K', 0.0),
('P130_F', 0.0),
('P131_T', 0.0),
('P134_S', 0.0),
('P136_N', 0.0),
('P137_N', 0.0),
('P139_I', 0.0),
('P140_P', 0.0),
('P141_G', 0.0),
('P143_R', 0.0),
('P147_N', 0.0),
('P149_L', 0.0),
('P154_K', 0.0),
('P162_H', 0.0),
('P162_N', 0.0),
('P171', 0.0),
('P175_N', 0.0),
('P179_L', 0.0),
('P187_L', 0.0),
('P188_C', 0.0),
('P188_F', 0.0),
('P191', 0.0),
('P192', 0.0),
('P193_L', 0.0),
('P194_E', 0.0),
('P197_I', 0.0),
('P200_F', 0.0),
('P201_K', 0.0),
('P205_L', 0.0),
('P206_R', 0.0),
('P213_G', 0.0),
('P215_E', 0.0),
('P216_T', 0.0),
('P217_P', 0.0),
('P219_D', 0.0),
('P219_T', 0.0),
('P222_Q', 0.0),
('P232_Y', 0.0),
('P233_E', 0.0),
('P234', 0.0)]
In [158]:
# # Here, let's try a parameter grid search, to figure out what would be the best
# from sklearn.grid_search import GridSearchCV
# import numpy as np
# param_grid = [{'n_estimators':[100, 500, 1000],
# #'max_features':['auto', 'sqrt', 'log2'],
# #'min_samples_leaf':np.arange(1,20,1),
# }]
# x_train, x_test, y_train, y_test = train_test_split(fpv_X_bi, fpv_Y)
# rfr_gs = GridSearchCV(RandomForestRegressor(), param_grid=param_grid, n_jobs=-1)
# rfr_gs.fit(x_train, y_train)
# print(rfr_gs.best_estimator_)
# print(rfr_gs.best_params_)
In [159]:
# Try Bayesian Ridge Regression
# x_train, x_test, y_train, y_test = train_test_split(drug_X_bi, drug_Y)
brr = lm.BayesianRidge()
brr.fit(x_train, y_train)
brr_preds = brr.predict(x_test)
print(brr.score(x_test, y_test), mean_squared_error(brr_preds, y_test))
print(sps.pearsonr(brr_preds, y_test))
brr_mse = mean_squared_error(brr_preds, y_test)
plt.figure(figsize=(3,3))
plt.scatter(y_test, brr_preds)
plt.xlabel('Actual')
plt.ylabel('Predicted')
plt.title('{0} Bayesian Ridge'.format(drug_cols[colnum]))
plt.gca().set_aspect('equal', 'datalim')
plt.annotate(s='mse: {0}'.format(str(Decimal(brr_mse).quantize(TWOPLACES))), xy=(1,0), xycoords='axes fraction', ha='right', va='bottom')
plt.plot(x_equals_y(y_test), x_equals_y(y_test), color='red')
plt.show()
0.77269618709 0.140905268087
(0.88401111341649008, 4.0613783272294827e-39)
In [160]:
# Try ARD regression
ardr = lm.ARDRegression()
ardr.fit(x_train, y_train)
ardr_preds = ardr.predict(x_test)
ardr_mse = mean_squared_error(ardr_preds, y_test)
plt.figure(figsize=(3,3))
plt.scatter(y_test, ardr_preds)
plt.xlabel('Actual')
plt.ylabel('Predicted')
plt.title('{0} ARD Regression'.format(drug_cols[colnum]))
plt.gca().set_aspect('equal', 'datalim')
plt.annotate(s='mse: {0}'.format(str(Decimal(ardr_mse).quantize(TWOPLACES))), xy=(1,0), xycoords='axes fraction', ha='right', va='bottom')
plt.plot(x_equals_y(y_test), x_equals_y(y_test), color='red')
plt.show()
In [161]:
# Try Gradient Boost
# x_train, x_test, y_train, y_test = train_test_split(drug_X_bi, drug_Y)
gbr = GradientBoostingRegressor()
gbr.fit(x_train, y_train)
gbr_preds = gbr.predict(x_test)
print(gbr.score(x_test, y_test), mean_squared_error(gbr_preds, y_test))
print(sps.pearsonr(gbr_preds, y_test))
gbr_mse = mean_squared_error(gbr_preds, y_test)
plt.figure(figsize=(3,3))
plt.scatter(y_test, gbr_preds)
plt.xlabel('Actual')
plt.ylabel('Predicted')
plt.title('{0} Grad. Boost'.format(drug_cols[colnum]))
plt.gca().set_aspect('equal', 'datalim')
plt.annotate(s='mse: {0}'.format(str(Decimal(gbr_mse).quantize(TWOPLACES))), xy=(1,0), xycoords='axes fraction', ha='right', va='bottom')
plt.plot(x_equals_y(y_test), x_equals_y(y_test), color='red')
plt.show()
0.800344680916 0.123766011227
(0.8991383071394452, 2.3372440847266272e-42)
In [162]:
plt.bar(range(len(gbr.feature_importances_)), gbr.feature_importances_)
Out[162]:
<Container object of 482 artists>
In [163]:
# Try AdaBoost
# x_train, x_test, y_train, y_test = train_test_split(drug_X_bi, drug_Y)
abr = AdaBoostRegressor()
abr.fit(x_train, y_train)
abr_preds = abr.predict(x_test)
print(abr.score(x_test, y_test), mean_squared_error(abr_preds, y_test))
print(sps.pearsonr(abr_preds, y_test))
abr_mse = mean_squared_error(abr_preds, y_test)
plt.figure(figsize=(3,3))
plt.scatter(x=y_test, y=abr_preds)
plt.xlabel('Actual')
plt.ylabel('Predicted')
plt.title('{0} AdaBoost'.format(drug_cols[colnum]))
plt.gca().set_aspect('equal', 'datalim')
plt.annotate(s='mse: {0}'.format(str(Decimal(abr_mse).quantize(TWOPLACES))), xy=(1,0), xycoords='axes fraction', ha='right', va='bottom')
plt.plot(x_equals_y(y_test), x_equals_y(y_test), color='red')
plt.show()
0.701945703317 0.184763379201
(0.87256175470337738, 5.9504271811245195e-37)
In [164]:
plt.bar(range(len(abr.feature_importances_)), abr.feature_importances_)
Out[164]:
<Container object of 482 artists>
In [165]:
# Try support vector regression
svr = SVR()
svr.fit(x_train, y_train)
svr_preds = svr.predict(x_test)
svr_mse = mean_squared_error(svr_preds, y_test)
plt.figure(figsize=(3,3))
plt.scatter(y_test, svr_preds, )
plt.xlabel('Actual')
plt.ylabel('Predicted')
plt.title('{0} SVR'.format(drug_cols[colnum]))
plt.gca().set_aspect('equal', 'datalim')
plt.annotate(s='mse: {0}'.format(str(Decimal(svr_mse).quantize(TWOPLACES))), xy=(1,0), xycoords='axes fraction', ha='right', va='bottom')
plt.plot(x_equals_y(y_test), x_equals_y(y_test), color='red')
plt.show()
In [166]:
# Neural Network 1 Specification: Feed Forward ANN with 1 hidden layer.
# x_train, x_test, y_train, y_test = train_test_split(drug_X_bi, drug_Y)
x_train = x_train.astype(np.float32)
y_train = y_train.astype(np.float32)
x_test = x_test.astype(np.float32)
y_test = y_test.astype(np.float32)
net1 = NeuralNet(
layers=[ # three layers: one hidden layer
('input', layers.InputLayer),
('hidden1', layers.DenseLayer),
('dropout1', layers.DropoutLayer),
#('hidden2', layers.DenseLayer),
#('dropout2', layers.DropoutLayer),
('nonlinear', layers.NonlinearityLayer),
('output', layers.DenseLayer),
],
# layer parameters:
input_shape=(None, x_train.shape[1]), #
hidden1_num_units=math.ceil(x_train.shape[1] / 2), # number of units in hidden layer
hidden1_nonlinearity=nonlinearities.tanh,
dropout1_p = 0.5,
#hidden2_num_units=math.ceil(x_train.shape[1] / 2),
#dropout2_p = 0.5,
output_nonlinearity=None, # output layer uses identity function
output_num_units=1, # 30 target values
# optimization method:
update=nesterov_momentum,
update_learning_rate=0.01,
update_momentum=0.95,
regression=True, # flag to indicate we're dealing with regression problem
max_epochs=500, # we want to train this many epochs
verbose=1,
)
net1.fit(x_train.values, y_train.values)
# Neural Network with 116645 learnable parameters
## Layer information
# name size
--- --------- ------
0 input 482
1 hidden1 241
2 dropout1 241
3 nonlinear 241
4 output 1
epoch train loss valid loss train/val dur
------- ------------ ------------ ----------- -----
1 1.19743 1.68213 0.71185 0.01s
2 1.95433 0.72601 2.69189 0.01s
3 0.94574 0.30878 3.06285 0.01s
4 0.66858 0.31704 2.10884 0.01s
5 0.53894 0.26701 2.01844 0.01s
6 0.50842 0.25242 2.01423 0.01s
7 0.40975 0.24942 1.64282 0.01s
8 0.50461 0.23500 2.14729 0.01s
9 0.31233 0.32754 0.95357 0.01s
10 0.38249 0.29446 1.29897 0.01s
11 0.29550 0.25208 1.17227 0.01s
12 0.24624 0.22847 1.07778 0.01s
13 0.23960 0.21291 1.12538 0.01s
14 0.20809 0.21113 0.98563 0.01s
15 0.26575 0.24637 1.07865 0.01s
16 0.24770 0.25892 0.95665 0.01s
17 0.24705 0.21733 1.13673 0.01s
18 0.21937 0.19826 1.10649 0.01s
19 0.16858 0.20001 0.84284 0.01s
20 0.22139 0.19664 1.12585 0.01s
21 0.20173 0.19444 1.03745 0.01s
22 0.30861 0.19095 1.61618 0.01s
23 0.17062 0.19223 0.88759 0.01s
24 0.22889 0.18154 1.26077 0.01s
25 0.15896 0.20431 0.77802 0.01s
26 0.16305 0.18917 0.86192 0.01s
27 0.17398 0.19096 0.91108 0.01s
28 0.15817 0.17914 0.88294 0.01s
29 0.15098 0.18068 0.83563 0.01s
30 0.12210 0.20242 0.60322 0.01s
31 0.15813 0.20358 0.77673 0.01s
32 0.15006 0.19258 0.77922 0.01s
33 0.13211 0.20734 0.63715 0.01s
34 0.17773 0.18803 0.94521 0.01s
35 0.12933 0.17848 0.72459 0.01s
36 0.14602 0.18083 0.80747 0.01s
37 0.13874 0.18274 0.75920 0.01s
38 0.11998 0.17328 0.69239 0.01s
39 0.12378 0.16272 0.76070 0.01s
40 0.14595 0.16904 0.86339 0.01s
41 0.15521 0.16610 0.93445 0.01s
42 0.13233 0.17003 0.77824 0.01s
43 0.16456 0.18901 0.87061 0.01s
44 0.32841 0.23504 1.39725 0.01s
45 0.20293 0.19713 1.02946 0.01s
46 0.15822 0.17463 0.90599 0.01s
47 0.19039 0.15962 1.19283 0.01s
48 0.14678 0.14900 0.98507 0.01s
49 0.24509 0.15656 1.56549 0.01s
50 0.15942 0.17256 0.92382 0.01s
51 0.11750 0.20041 0.58629 0.01s
52 0.11213 0.17061 0.65726 0.01s
53 0.12890 0.14601 0.88285 0.01s
54 0.11421 0.14886 0.76725 0.01s
55 0.15005 0.16531 0.90768 0.01s
56 0.19297 0.16276 1.18561 0.01s
57 0.11098 0.16149 0.68726 0.01s
58 0.11118 0.15150 0.73384 0.01s
59 0.19944 0.18011 1.10728 0.01s
60 0.11480 0.20267 0.56644 0.01s
61 0.13112 0.18593 0.70524 0.01s
62 0.10427 0.16075 0.64865 0.01s
63 0.11867 0.15690 0.75631 0.01s
64 0.08913 0.16179 0.55090 0.01s
65 0.12996 0.15242 0.85270 0.01s
66 0.09034 0.14482 0.62380 0.01s
67 0.09174 0.14283 0.64229 0.01s
68 0.11042 0.16092 0.68613 0.01s
69 0.07618 0.16674 0.45686 0.01s
70 0.07536 0.15406 0.48917 0.01s
71 0.09851 0.14473 0.68067 0.01s
72 0.10116 0.14914 0.67834 0.01s
73 0.08253 0.14524 0.56819 0.01s
74 0.11012 0.14710 0.74863 0.01s
75 0.09852 0.13707 0.71875 0.01s
76 0.09305 0.13483 0.69012 0.01s
77 0.10088 0.16733 0.60284 0.01s
78 0.09275 0.18086 0.51285 0.01s
79 0.08869 0.17111 0.51831 0.01s
80 0.16047 0.15100 1.06270 0.01s
81 0.11934 0.14352 0.83149 0.01s
82 0.19195 0.16509 1.16270 0.01s
83 0.10747 0.14800 0.72615 0.01s
84 0.06586 0.13695 0.48092 0.01s
85 0.11371 0.14434 0.78779 0.01s
86 0.08070 0.14007 0.57610 0.01s
87 0.10536 0.12969 0.81239 0.01s
88 0.07028 0.14674 0.47893 0.01s
89 0.18234 0.14641 1.24542 0.01s
90 0.08129 0.14214 0.57190 0.01s
91 0.11248 0.15048 0.74750 0.01s
92 0.08253 0.14805 0.55746 0.01s
93 0.09572 0.14202 0.67397 0.01s
94 0.19924 0.13006 1.53194 0.01s
95 0.11625 0.12150 0.95679 0.01s
96 0.06974 0.12586 0.55412 0.01s
97 0.09992 0.14316 0.69797 0.01s
98 0.11114 0.13600 0.81719 0.01s
99 0.08876 0.13129 0.67604 0.01s
100 0.13738 0.12904 1.06461 0.01s
101 0.10018 0.12606 0.79473 0.01s
102 0.12242 0.13659 0.89629 0.01s
103 0.10089 0.13986 0.72137 0.01s
104 0.09595 0.13154 0.72948 0.01s
105 0.12665 0.12320 1.02796 0.01s
106 0.12632 0.11770 1.07331 0.01s
107 0.09772 0.12210 0.80030 0.01s
108 0.10360 0.12519 0.82754 0.01s
109 0.13190 0.12772 1.03268 0.01s
110 0.08357 0.12930 0.64634 0.01s
111 0.08338 0.14676 0.56812 0.01s
112 0.18342 0.13090 1.40130 0.01s
113 0.08833 0.12613 0.70033 0.01s
114 0.10523 0.16371 0.64278 0.01s
115 0.28506 0.12587 2.26470 0.01s
116 0.08084 0.12635 0.63982 0.01s
117 0.12764 0.12970 0.98409 0.01s
118 0.07287 0.13624 0.53482 0.01s
119 0.17367 0.13381 1.29789 0.01s
120 0.12418 0.12860 0.96563 0.01s
121 0.07062 0.10958 0.64443 0.01s
122 0.14090 0.11285 1.24861 0.01s
123 0.15817 0.14505 1.09049 0.01s
124 0.09398 0.12590 0.74649 0.01s
125 0.11679 0.11381 1.02621 0.01s
126 0.12874 0.13595 0.94694 0.01s
127 0.06756 0.13883 0.48666 0.01s
128 0.15036 0.14205 1.05849 0.01s
129 0.09346 0.12188 0.76682 0.01s
130 0.11129 0.11641 0.95598 0.01s
131 0.10106 0.11543 0.87547 0.01s
132 0.08539 0.12200 0.69989 0.01s
133 0.05179 0.12845 0.40314 0.01s
134 0.11935 0.13800 0.86488 0.01s
135 0.09000 0.14137 0.63662 0.01s
136 0.11975 0.15208 0.78741 0.01s
137 0.08219 0.13515 0.60816 0.01s
138 0.06151 0.13946 0.44104 0.01s
139 0.08918 0.14164 0.62963 0.01s
140 0.08815 0.13352 0.66018 0.01s
141 0.10486 0.12188 0.86042 0.01s
142 0.08883 0.13003 0.68310 0.01s
143 0.06719 0.13224 0.50808 0.01s
144 0.08657 0.12857 0.67328 0.01s
145 0.08456 0.11816 0.71562 0.01s
146 0.12163 0.11974 1.01582 0.01s
147 0.05964 0.12404 0.48080 0.01s
148 0.12442 0.10907 1.14071 0.01s
149 0.07411 0.11275 0.65728 0.01s
150 0.07070 0.11459 0.61698 0.01s
151 0.08252 0.13039 0.63285 0.01s
152 0.22600 0.10907 2.07195 0.01s
153 0.08475 0.10037 0.84439 0.01s
154 0.13983 0.10329 1.35380 0.01s
155 0.21503 0.11749 1.83023 0.01s
156 0.09434 0.14672 0.64297 0.01s
157 0.08061 0.13848 0.58210 0.01s
158 0.11020 0.14070 0.78323 0.01s
159 0.07640 0.15349 0.49777 0.01s
160 0.10059 0.14362 0.70033 0.01s
161 0.17161 0.13071 1.31297 0.01s
162 0.09409 0.12029 0.78218 0.01s
163 0.15377 0.11169 1.37676 0.01s
164 0.13956 0.10527 1.32576 0.01s
165 0.08953 0.11600 0.77181 0.01s
166 0.14288 0.12357 1.15634 0.01s
167 0.12022 0.11994 1.00238 0.01s
168 0.07633 0.12713 0.60041 0.01s
169 0.09462 0.12695 0.74531 0.01s
170 0.20254 0.12086 1.67580 0.01s
171 0.15304 0.11058 1.38389 0.01s
172 0.11511 0.11520 0.99919 0.01s
173 0.07404 0.12196 0.60705 0.01s
174 0.16150 0.15179 1.06398 0.01s
175 0.08393 0.14864 0.56463 0.01s
176 0.11697 0.13070 0.89489 0.01s
177 0.16888 0.11628 1.45238 0.01s
178 0.08649 0.11871 0.72856 0.01s
179 0.10527 0.13650 0.77119 0.01s
180 0.13215 0.15267 0.86561 0.01s
181 0.11524 0.14075 0.81879 0.01s
182 0.12995 0.12186 1.06637 0.01s
183 0.13974 0.11200 1.24773 0.01s
184 0.11034 0.10854 1.01658 0.01s
185 0.08954 0.12519 0.71520 0.01s
186 0.09614 0.13197 0.72849 0.01s
187 0.11542 0.13394 0.86171 0.01s
188 0.18394 0.12831 1.43355 0.01s
189 0.09980 0.13187 0.75678 0.01s
190 0.08412 0.12378 0.67955 0.01s
191 0.10366 0.11898 0.87120 0.01s
192 0.11733 0.11964 0.98062 0.01s
193 0.21819 0.11653 1.87243 0.01s
194 0.08570 0.11871 0.72198 0.01s
195 0.27932 0.11683 2.39073 0.01s
196 0.12238 0.11941 1.02483 0.01s
197 0.10302 0.12086 0.85238 0.01s
198 0.23725 0.12547 1.89097 0.01s
199 0.11024 0.11744 0.93872 0.01s
200 0.10038 0.11068 0.90688 0.01s
201 0.09320 0.11110 0.83889 0.01s
202 0.16984 0.12371 1.37288 0.01s
203 0.13055 0.11477 1.13744 0.01s
204 0.11040 0.11653 0.94740 0.01s
205 0.14614 0.10772 1.35668 0.01s
206 0.13186 0.10909 1.20868 0.01s
207 0.09123 0.11366 0.80265 0.01s
208 0.08160 0.13464 0.60607 0.01s
209 0.08654 0.13987 0.61867 0.01s
210 0.08025 0.13004 0.61715 0.01s
211 0.09097 0.12227 0.74403 0.01s
212 0.12589 0.13893 0.90616 0.01s
213 0.14856 0.14380 1.03314 0.01s
214 0.12276 0.12380 0.99156 0.01s
215 0.11822 0.11806 1.00135 0.01s
216 0.10319 0.11564 0.89234 0.01s
217 0.11402 0.11674 0.97668 0.01s
218 0.27109 0.11497 2.35795 0.01s
219 0.09970 0.12409 0.80348 0.01s
220 0.14090 0.11524 1.22273 0.01s
221 0.12857 0.12957 0.99222 0.01s
222 0.07628 0.12505 0.61001 0.01s
223 0.17852 0.11901 1.49997 0.01s
224 0.11503 0.10621 1.08311 0.01s
225 0.19264 0.11913 1.61707 0.01s
226 0.07438 0.13171 0.56471 0.01s
227 0.16558 0.13311 1.24386 0.01s
228 0.11467 0.13764 0.83310 0.01s
229 0.09415 0.12420 0.75808 0.01s
230 0.22680 0.12710 1.78445 0.01s
231 0.34571 0.11241 3.07541 0.01s
232 0.10193 0.11659 0.87422 0.01s
233 0.09103 0.11723 0.77646 0.01s
234 0.10405 0.12078 0.86146 0.01s
235 0.10594 0.11446 0.92553 0.01s
236 0.18869 0.11008 1.71401 0.01s
237 0.11835 0.12989 0.91113 0.01s
238 0.13303 0.12998 1.02346 0.01s
239 0.10228 0.11656 0.87746 0.01s
240 0.12744 0.11222 1.13569 0.01s
241 0.12617 0.10443 1.20820 0.01s
242 0.11751 0.11067 1.06178 0.01s
243 0.11329 0.11561 0.97993 0.01s
244 0.07801 0.11557 0.67497 0.01s
245 0.07290 0.11985 0.60827 0.01s
246 0.09177 0.12275 0.74764 0.01s
247 0.16516 0.11076 1.49116 0.01s
248 0.14838 0.10435 1.42195 0.01s
249 0.14063 0.12676 1.10937 0.01s
250 0.08579 0.12798 0.67035 0.01s
251 0.10262 0.12816 0.80075 0.01s
252 0.22056 0.12011 1.83630 0.01s
253 0.11093 0.11132 0.99655 0.01s
254 0.08826 0.11225 0.78626 0.01s
255 0.12252 0.11137 1.10011 0.01s
256 0.46803 0.10848 4.31445 0.01s
257 0.14076 0.10959 1.28442 0.01s
258 0.14253 0.10520 1.35483 0.01s
259 0.11187 0.11667 0.95887 0.01s
260 0.20481 0.10591 1.93380 0.01s
261 0.10244 0.10637 0.96308 0.01s
262 0.16964 0.10803 1.57036 0.01s
263 0.29071 0.11017 2.63867 0.01s
264 0.07704 0.11463 0.67203 0.01s
265 0.06943 0.12622 0.55009 0.01s
266 0.07959 0.12558 0.63378 0.01s
267 0.20754 0.12320 1.68458 0.01s
268 0.12011 0.11356 1.05770 0.01s
269 0.12476 0.11536 1.08154 0.01s
270 0.17503 0.11010 1.58978 0.01s
271 0.10125 0.10330 0.98007 0.01s
272 0.14162 0.11844 1.19563 0.01s
273 0.16627 0.11906 1.39654 0.01s
274 0.12680 0.11905 1.06513 0.01s
275 0.12465 0.11530 1.08104 0.01s
276 0.11628 0.10806 1.07604 0.01s
277 0.07748 0.10359 0.74797 0.01s
278 0.15013 0.10246 1.46517 0.01s
279 0.10133 0.10362 0.97792 0.01s
280 0.11667 0.11448 1.01914 0.01s
281 0.09241 0.11514 0.80261 0.01s
282 0.24398 0.11741 2.07809 0.01s
283 0.11680 0.11378 1.02655 0.01s
284 0.08042 0.11638 0.69101 0.01s
285 0.14535 0.10724 1.35532 0.01s
286 0.24521 0.10983 2.23260 0.01s
287 0.09396 0.11252 0.83508 0.01s
288 0.08388 0.12504 0.67082 0.01s
289 0.09342 0.13650 0.68435 0.01s
290 0.08754 0.12773 0.68532 0.01s
291 0.18065 0.12100 1.49297 0.01s
292 0.09112 0.11719 0.77754 0.01s
293 0.12370 0.11364 1.08859 0.01s
294 0.09218 0.10652 0.86536 0.01s
295 0.10706 0.10734 0.99744 0.01s
296 0.07818 0.11892 0.65741 0.01s
297 0.12493 0.12213 1.02295 0.01s
298 0.24867 0.12240 2.03163 0.01s
299 0.24085 0.12847 1.87484 0.01s
300 0.22015 0.12969 1.69757 0.01s
301 0.27346 0.12789 2.13818 0.01s
302 0.12892 0.12044 1.07040 0.01s
303 0.10835 0.11448 0.94647 0.01s
304 0.12714 0.11780 1.07929 0.01s
305 0.17920 0.11530 1.55419 0.01s
306 0.11329 0.12292 0.92161 0.01s
307 0.07553 0.11891 0.63522 0.01s
308 0.13481 0.11160 1.20799 0.01s
309 0.10359 0.10694 0.96872 0.01s
310 0.07985 0.10992 0.72646 0.01s
311 0.11029 0.10577 1.04277 0.01s
312 0.16066 0.10568 1.52025 0.01s
313 0.11485 0.10818 1.06166 0.01s
314 0.07152 0.11596 0.61681 0.01s
315 0.18642 0.12340 1.51066 0.01s
316 0.25431 0.11513 2.20886 0.01s
317 0.07030 0.11237 0.62558 0.01s
318 0.08992 0.12206 0.73670 0.01s
319 0.10449 0.12094 0.86394 0.01s
320 0.11660 0.10877 1.07195 0.01s
321 0.08406 0.10463 0.80346 0.01s
322 0.08647 0.11436 0.75617 0.01s
323 0.12884 0.11228 1.14749 0.01s
324 0.08573 0.10950 0.78296 0.01s
325 0.08238 0.11099 0.74220 0.01s
326 0.07104 0.10458 0.67925 0.01s
327 0.13142 0.10535 1.24749 0.01s
328 0.12778 0.11225 1.13832 0.01s
329 0.16600 0.11924 1.39215 0.01s
330 0.11120 0.12677 0.87719 0.01s
331 0.10203 0.12839 0.79470 0.01s
332 0.22979 0.11863 1.93708 0.01s
333 0.19929 0.10867 1.83399 0.01s
334 0.06854 0.10215 0.67095 0.01s
335 0.10265 0.10456 0.98169 0.01s
336 0.14792 0.10731 1.37846 0.01s
337 0.20062 0.12097 1.65844 0.01s
338 0.09548 0.12428 0.76824 0.01s
339 0.11293 0.11852 0.95284 0.01s
340 0.13741 0.11049 1.24371 0.01s
341 0.10877 0.11369 0.95672 0.01s
342 0.10017 0.11329 0.88421 0.01s
343 0.12561 0.11486 1.09352 0.01s
344 0.16905 0.11363 1.48779 0.01s
345 0.15755 0.10594 1.48713 0.01s
346 0.09806 0.10779 0.90980 0.01s
347 0.16437 0.11106 1.47998 0.01s
348 0.11702 0.11615 1.00744 0.01s
349 0.31535 0.12388 2.54559 0.01s
350 0.08488 0.12407 0.68414 0.01s
351 0.10687 0.11902 0.89794 0.01s
352 0.09052 0.11557 0.78328 0.01s
353 0.09330 0.11175 0.83487 0.01s
354 0.27415 0.10993 2.49389 0.01s
355 0.12707 0.12385 1.02607 0.01s
356 0.19793 0.12569 1.57478 0.01s
357 0.34713 0.12064 2.87739 0.01s
358 0.10644 0.12088 0.88053 0.01s
359 0.09510 0.10983 0.86587 0.01s
360 0.16441 0.10088 1.62978 0.01s
361 0.11593 0.11305 1.02554 0.01s
362 0.11494 0.12057 0.95329 0.01s
363 0.12216 0.11234 1.08741 0.01s
364 0.13576 0.12015 1.12992 0.01s
365 0.14869 0.11554 1.28688 0.01s
366 0.10738 0.12143 0.88432 0.01s
367 0.16793 0.10595 1.58495 0.01s
368 0.10210 0.10116 1.00929 0.01s
369 0.08089 0.10337 0.78250 0.01s
370 0.11187 0.10552 1.06021 0.01s
371 0.14882 0.10535 1.41259 0.01s
372 0.09565 0.10372 0.92219 0.01s
373 0.11587 0.10403 1.11386 0.01s
374 0.12957 0.10684 1.21273 0.01s
375 0.12155 0.11439 1.06260 0.01s
376 0.13024 0.11501 1.13242 0.01s
377 0.10404 0.10944 0.95068 0.01s
378 0.09425 0.10831 0.87016 0.01s
379 0.12229 0.11896 1.02794 0.01s
380 0.27788 0.11709 2.37327 0.01s
381 0.13730 0.10770 1.27481 0.01s
382 0.06769 0.10408 0.65034 0.01s
383 0.09990 0.10043 0.99476 0.01s
384 0.12530 0.10382 1.20684 0.01s
385 0.21650 0.11164 1.93926 0.01s
386 0.10334 0.12068 0.85628 0.01s
387 0.15810 0.11898 1.32879 0.01s
388 0.18800 0.10597 1.77414 0.01s
389 0.11202 0.11016 1.01683 0.01s
390 0.10055 0.11460 0.87738 0.01s
391 0.09213 0.11409 0.80750 0.01s
392 0.08892 0.11820 0.75232 0.01s
393 0.10874 0.11569 0.93992 0.01s
394 0.16282 0.11167 1.45812 0.01s
395 0.12754 0.11284 1.13032 0.01s
396 0.13915 0.12221 1.13862 0.01s
397 0.24500 0.12010 2.03998 0.01s
398 0.20603 0.11635 1.77081 0.01s
399 0.16430 0.11840 1.38771 0.01s
400 0.09977 0.11417 0.87387 0.01s
401 0.12946 0.11064 1.17013 0.01s
402 0.23283 0.10917 2.13271 0.01s
403 0.16833 0.11238 1.49794 0.01s
404 0.11248 0.10762 1.04519 0.01s
405 0.12317 0.10358 1.18914 0.01s
406 0.33181 0.10045 3.30333 0.01s
407 0.22352 0.10111 2.21060 0.01s
408 0.20221 0.10430 1.93880 0.01s
409 0.24501 0.10388 2.35846 0.01s
410 0.11376 0.10717 1.06150 0.01s
411 0.09118 0.11045 0.82554 0.01s
412 0.17486 0.12949 1.35032 0.01s
413 0.22003 0.11184 1.96732 0.01s
414 0.12543 0.11305 1.10949 0.01s
415 0.09908 0.11581 0.85553 0.01s
416 0.07564 0.11651 0.64925 0.01s
417 0.11281 0.10889 1.03600 0.01s
418 0.09717 0.11091 0.87610 0.01s
419 0.14669 0.11229 1.30634 0.01s
420 0.28603 0.11228 2.54747 0.01s
421 0.09208 0.11689 0.78778 0.01s
422 0.11471 0.11832 0.96951 0.01s
423 0.13615 0.10567 1.28846 0.01s
424 0.10088 0.10536 0.95740 0.01s
425 0.12328 0.10732 1.14873 0.01s
426 0.22350 0.12054 1.85416 0.01s
427 0.09210 0.11415 0.80685 0.01s
428 0.15362 0.11322 1.35680 0.01s
429 0.11331 0.10815 1.04770 0.01s
430 0.12148 0.10818 1.12301 0.01s
431 0.09257 0.10700 0.86507 0.01s
432 0.09705 0.10703 0.90672 0.01s
433 0.07938 0.10405 0.76288 0.01s
434 0.22082 0.10631 2.07709 0.01s
435 0.10758 0.10861 0.99053 0.01s
436 0.10711 0.10276 1.04230 0.01s
437 0.14319 0.10121 1.41477 0.01s
438 0.43683 0.09328 4.68290 0.01s
439 0.11782 0.10140 1.16196 0.01s
440 0.21246 0.11188 1.89904 0.01s
441 0.19481 0.11114 1.75289 0.01s
442 0.06085 0.10605 0.57379 0.01s
443 0.12157 0.10399 1.16899 0.01s
444 0.13614 0.09945 1.36895 0.01s
445 0.13134 0.10366 1.26697 0.01s
446 0.15359 0.10963 1.40101 0.01s
447 0.09094 0.10463 0.86914 0.01s
448 0.09409 0.12492 0.75320 0.01s
449 0.12303 0.11162 1.10221 0.01s
450 0.14937 0.10294 1.45100 0.01s
451 0.10203 0.09828 1.03812 0.01s
452 0.13967 0.09738 1.43420 0.01s
453 0.14812 0.10017 1.47873 0.01s
454 0.11346 0.11001 1.03138 0.01s
455 0.14265 0.11476 1.24295 0.01s
456 0.12859 0.11210 1.14717 0.01s
457 0.11018 0.11003 1.00134 0.01s
458 0.13159 0.10225 1.28691 0.01s
459 0.09561 0.10233 0.93435 0.01s
460 0.09103 0.10507 0.86632 0.01s
461 0.11055 0.09492 1.16473 0.01s
462 0.12157 0.09544 1.27374 0.01s
463 0.07093 0.10482 0.67671 0.01s
464 0.18711 0.11349 1.64871 0.01s
465 0.14038 0.10664 1.31639 0.01s
466 0.12566 0.10897 1.15320 0.01s
467 0.19319 0.11205 1.72406 0.01s
468 0.09147 0.10147 0.90147 0.01s
469 0.14029 0.11045 1.27019 0.01s
470 0.22488 0.11149 2.01710 0.01s
471 0.10785 0.10846 0.99440 0.01s
472 0.10929 0.10316 1.05941 0.01s
473 0.17153 0.10184 1.68436 0.01s
474 0.14045 0.10790 1.30172 0.01s
475 0.19863 0.12198 1.62838 0.01s
476 0.13273 0.11623 1.14202 0.01s
477 0.23120 0.10988 2.10402 0.01s
478 0.12489 0.10002 1.24865 0.01s
479 0.17528 0.09908 1.76904 0.01s
480 0.07944 0.10917 0.72769 0.01s
481 0.09229 0.11801 0.78204 0.01s
482 0.08539 0.11923 0.71624 0.01s
483 0.12025 0.11856 1.01427 0.01s
484 0.08485 0.11523 0.73631 0.01s
485 0.29845 0.10880 2.74312 0.01s
486 0.10727 0.10669 1.00549 0.01s
487 0.11725 0.11029 1.06315 0.01s
488 0.13444 0.11273 1.19267 0.01s
489 0.10725 0.12235 0.87660 0.01s
490 0.09939 0.11961 0.83092 0.01s
491 0.06879 0.11396 0.60364 0.01s
492 0.10803 0.11315 0.95476 0.01s
493 0.11545 0.11353 1.01689 0.01s
494 0.12147 0.11532 1.05333 0.01s
495 0.18089 0.12079 1.49757 0.01s
496 0.23218 0.11378 2.04061 0.01s
497 0.09412 0.11941 0.78820 0.01s
498 0.10273 0.10618 0.96755 0.01s
499 0.14538 0.11019 1.31939 0.01s
500 0.13773 0.11806 1.16664 0.01s
Out[166]:
NeuralNet(X_tensor_type=None,
batch_iterator_test=<nolearn.lasagne.base.BatchIterator object at 0x7fa67253e358>,
batch_iterator_train=<nolearn.lasagne.base.BatchIterator object at 0x7fa67253e550>,
custom_score=None, dropout1_p=0.5,
hidden1_nonlinearity=<function tanh at 0x7fa62838ff28>,
hidden1_num_units=241, input_shape=(None, 482),
layers=[('input', <class 'lasagne.layers.input.InputLayer'>), ('hidden1', <class 'lasagne.layers.dense.DenseLayer'>), ('dropout1', <class 'lasagne.layers.noise.DropoutLayer'>), ('nonlinear', <class 'lasagne.layers.dense.NonlinearityLayer'>), ('output', <class 'lasagne.layers.dense.DenseLayer'>)],
loss=None, max_epochs=500, more_params={},
objective=<function objective at 0x7fa672540510>,
objective_loss_function=<function squared_error at 0x7fa6280fc840>,
on_epoch_finished=[<nolearn.lasagne.handlers.PrintLog object at 0x7fa5c2153b00>],
on_training_finished=[],
on_training_started=[<nolearn.lasagne.handlers.PrintLayerInfo object at 0x7fa5c2153710>],
output_nonlinearity=None, output_num_units=1, regression=True,
train_split=<nolearn.lasagne.base.TrainSplit object at 0x7fa67253e7b8>,
update=<function nesterov_momentum at 0x7fa6281020d0>,
update_learning_rate=0.01, update_momentum=0.95,
use_label_encoder=False, verbose=1,
y_tensor_type=TensorType(float32, matrix))
In [167]:
nn1_preds = net1.predict(x_test)
nn1_mse = float(mean_squared_error(nn1_preds, y_test))
plt.figure(figsize=(3,3))
plt.scatter(y_test, nn1_preds)
plt.xlabel('Actual')
plt.ylabel('Predicted')
plt.title('{0} Neural Network'.format(drug_cols[colnum]))
plt.gca().set_aspect('equal', 'datalim')
plt.annotate(s='mse: {0}'.format(str(Decimal(nn1_mse).quantize(TWOPLACES))), xy=(1,0), xycoords='axes fraction', ha='right', va='bottom')
plt.plot(x_equals_y(y_test), x_equals_y(y_test), color='red')
plt.show()
In [168]:
sps.pearsonr(nn1_preds, y_test.reshape(y_test.shape[0],1))
Out[168]:
(array([ 0.87914544], dtype=float32), array([ 3.59665475e-38], dtype=float32))
In [ ]:
In [ ]:
In [ ]:
Content source: ericmjl/hiv-resistance-prediction
Similar notebooks: