In [11]:
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import os
from datetime import datetime
import seaborn as sns
%matplotlib inline

plt.rcParams['axes.labelsize'] = 14
plt.rcParams['xtick.labelsize'] = 12
plt.rcParams['ytick.labelsize'] = 12
plt.rcParams['figure.figsize'] = (10,6.180)    #golden ratio
# Where to save the figures
PROJECT_ROOT_DIR = "."
CHAPTER_ID = "classification"

def save_fig(fig_id, tight_layout=True):
    path = os.path.join(PROJECT_ROOT_DIR, "images", CHAPTER_ID, fig_id + ".png")
    print("Saving figure", fig_id)
    if tight_layout:
        plt.tight_layout()
    plt.savefig(path, format='png', dpi=300)

In [12]:
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import FeatureUnion
from sklearn.preprocessing import PolynomialFeatures
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble import VotingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.linear_model import SGDClassifier

from sklearn.model_selection import cross_val_score
from sklearn.model_selection import cross_val_predict
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score


# Create a class to select numerical or categorical columns 
# since Scikit-Learn doesn't handle DataFrames yet
class DataFrameSelector(BaseEstimator, TransformerMixin):
    def __init__(self, attribute_names):
        self.attribute_names = attribute_names
    def fit(self, X, y=None):
        return self
    def transform(self, X):
        return X[self.attribute_names].values
class RemoveFirstFrame(BaseEstimator, TransformerMixin):
    def __init__(self, frame):
        self.frame = frame
    def fit(self, X, y=None):
        return self
    def transform(self, X):
        return X.query(f"Step % {frame} != 1")
def choose_top_rw(data,n=5):
    return data.assign(chosen=pd.DataFrame.rank(data.Rw, method='first')<=n)
def choose_top_vtotal(data,n=5):
    return data.assign(chosen=pd.DataFrame.rank(data.VTotal, method='first')<=n)
def choose_top(data,col="Qw", n=5, ascending=False):
    return data.assign(chosen=pd.DataFrame.rank(data[col], ascending=ascending, method='first')<=n)

In [64]:
raw_test_data = pd.read_csv("/Users/weilu/Research/data/test_data/complete_data_mar27.csv", index_col=0)

In [66]:
import seaborn
fg = seaborn.FacetGrid(data=raw_test_data, hue='Name',  aspect=1.61, size=10)
fg.map(plt.plot, 'Step', 'Rw').add_legend()


Out[66]:
<seaborn.axisgrid.FacetGrid at 0x114638208>

In [65]:
raw_test_data


Out[65]:
Burial Chain Chi DSSP Frag_Mem GDT Helix Name P_AP QGO Qw Rama Rmsd Rw Step VTotal VwithoutGo Water isGood
0 -125.527566 198.348772 50.712241 -0.000001 -270.954728 64.0425 -31.903978 1MBA -3.510140 20.437024 0.612991 -608.466001 2.83047 -23936.084181 2 -815.660991 -836.098015 -44.796614 False
1 -127.621760 202.520695 48.958767 -0.000000 -264.937301 60.4475 -44.122252 1MBA -4.427336 22.635465 0.608468 -598.442065 3.11564 -24206.526672 3 -810.210414 -832.845879 -44.774626 False
2 -126.897613 201.171981 45.683736 -0.000016 -285.284868 59.9325 -44.705003 1MBA -2.790484 32.291579 0.563456 -601.607208 3.34374 -23746.829469 4 -831.611615 -863.903194 -49.473719 False
3 -129.227459 211.599927 43.332617 -0.000000 -284.484618 60.1050 -44.940302 1MBA -3.275868 31.026159 0.568731 -608.560594 3.31919 -23770.585063 5 -835.827804 -866.853963 -51.297665 False
4 -125.640142 210.098567 52.818004 -0.000000 -301.367085 57.8775 -47.194714 1MBA -4.095876 34.395024 0.552951 -616.317106 3.65728 -23788.221633 6 -845.665112 -880.060136 -48.361783 False
5 -126.596875 178.674935 37.575281 -0.000000 -313.910386 60.1050 -49.238866 1MBA -3.546136 32.651915 0.556017 -621.844923 3.31292 -23655.850471 7 -912.270330 -944.922245 -46.035274 False
6 -126.363437 170.111959 39.655774 -0.000002 -274.534180 61.3025 -50.380709 1MBA -2.877070 28.954400 0.584000 -632.719444 3.15679 -23915.867387 8 -896.728930 -925.683330 -48.576221 False
7 -126.587872 183.226188 42.971890 -0.000220 -318.007849 63.7000 -49.805937 1MBA -3.325704 26.661054 0.603713 -616.551470 3.01623 -24025.154290 9 -912.484393 -939.145447 -51.064472 False
8 -127.045921 198.521740 43.660693 -0.000174 -281.688080 61.3000 -38.269670 1MBA -3.149545 27.835810 0.576972 -605.484179 3.13825 -23730.122955 10 -833.630852 -861.466662 -48.011525 False
9 -127.821988 167.915369 38.463602 -0.000064 -306.039620 62.6725 -39.167075 1MBA -4.185415 35.135000 0.575112 -617.809782 3.16134 -23934.967672 11 -900.801360 -935.936360 -47.291386 False
10 -125.240649 181.325201 34.561878 -0.000238 -315.219654 58.3900 -43.768237 1MBA -5.029298 37.718750 0.553789 -607.918017 3.72640 -24099.479019 12 -888.005070 -925.723820 -44.434806 False
11 -124.226981 205.134813 42.344343 -0.000000 -293.497929 62.6750 -37.746544 1MBA -3.801658 34.764365 0.575847 -588.741736 3.35418 -23944.603696 13 -814.673618 -849.437983 -48.902291 False
12 -125.989391 194.906633 42.589569 -0.000001 -294.277982 60.2750 -38.457045 1MBA -4.508254 39.090859 0.570161 -592.773313 3.34147 -23790.997766 14 -828.026801 -867.117660 -48.607876 False
13 -125.741841 178.203643 39.134167 -0.000000 -273.580572 55.9925 -41.514275 1MBA -4.652959 44.909367 0.519456 -593.756051 3.76136 -23824.765782 15 -822.848990 -867.758357 -45.850469 False
14 -124.642489 181.659548 48.008885 -0.000009 -278.523992 59.9350 -42.209516 1MBA -3.738200 38.368045 0.570946 -584.533577 3.56355 -23764.935603 16 -813.520737 -851.888782 -47.909434 False
15 -125.530121 188.494522 42.314610 -0.000018 -285.343492 61.3025 -43.495398 1MBA -4.205698 37.842228 0.569005 -578.898001 3.44114 -23692.646370 17 -816.811202 -854.653430 -47.989833 False
16 -124.287663 175.296152 37.500142 -0.000003 -293.606272 64.0400 -37.590581 1MBA -4.365536 35.220046 0.573256 -589.743613 3.26999 -24305.005952 18 -855.607756 -890.827802 -54.030427 False
17 -125.075483 187.313933 40.738186 -0.000000 -292.340023 59.0775 -49.293628 1MBA -3.677689 39.559915 0.531779 -599.238728 3.74614 -23944.499620 19 -857.537677 -897.097592 -55.524159 False
18 -121.858498 194.738181 42.466727 -0.000000 -298.006799 62.3275 -52.797391 1MBA -4.728501 37.759331 0.555548 -625.505374 3.24677 -24285.544872 20 -880.317343 -918.076674 -52.385021 False
19 -123.204917 192.659707 43.133986 -0.000000 -296.960869 67.8100 -44.841098 1MBA -4.566400 30.655603 0.587209 -593.153120 3.07364 -24098.449630 21 -847.570628 -878.226231 -51.293521 False
20 -123.190282 204.618600 48.052902 -0.000000 -295.648436 63.3575 -46.407497 1MBA -5.822877 33.181558 0.565960 -596.292804 3.53937 -24256.890147 22 -829.000540 -862.182098 -47.491703 False
21 -124.595662 189.604372 42.608549 -0.000000 -283.012565 64.9000 -40.614604 1MBA -4.719388 29.526446 0.578371 -594.184739 3.22417 -24247.170465 23 -834.748620 -864.275066 -49.361027 False
22 -122.678348 181.559710 40.259295 -0.000010 -275.898318 69.1775 -44.072288 1MBA -5.743637 29.334476 0.632536 -611.109724 2.73339 -24644.193666 24 -862.625373 -891.959849 -54.276530 False
23 -124.564220 188.735953 36.267410 -0.000000 -290.339000 69.0050 -39.935562 1MBA -5.192057 29.401787 0.620164 -594.817525 2.71500 -24547.207577 25 -852.108705 -881.510492 -51.665491 False
24 -124.052734 208.189432 51.563421 -0.001324 -278.700496 67.6375 -53.513829 1MBA -4.817955 30.258235 0.622350 -603.410891 2.81950 -24350.112243 26 -829.211348 -859.469583 -54.725206 False
25 -124.536710 183.449538 40.945526 -0.000412 -299.523146 70.2075 -53.350348 1MBA -5.538342 25.594625 0.648444 -595.119148 2.63474 -24806.155768 27 -881.206776 -906.801401 -53.128357 False
26 -124.663733 216.545557 64.874880 -0.000000 -298.626663 65.0675 -41.794867 1MBA -5.131951 31.589724 0.591445 -597.854836 3.18823 -24383.187291 28 -803.787032 -835.376756 -48.725144 False
27 -125.313311 212.957320 55.534384 -0.000000 -298.701955 62.5000 -45.709347 1MBA -4.902316 31.390320 0.583252 -604.835334 3.23966 -24819.680897 29 -836.316249 -867.706569 -56.736011 False
28 -123.981910 184.565762 59.499322 -0.000000 -289.500332 68.8375 -54.294207 1MBA -5.214478 26.738521 0.639773 -613.413865 2.88495 -24742.065886 30 -869.817802 -896.556323 -54.216615 False
29 -124.067313 181.595380 55.654757 -0.000000 -275.983369 65.9225 -43.558057 1MBA -4.776085 33.474381 0.610935 -611.510716 3.02239 -23974.668244 31 -837.274764 -870.749145 -48.103742 False
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
15970 -91.625161 0.000000 0.000000 -11.919602 -184.592857 65.9700 -13.262848 T0251 -9.803505 56.200292 0.637768 -416.166714 3.32948 -18543.735553 1981 -713.373160 -769.573452 -42.202765 False
15971 -91.838054 0.000000 0.000000 -12.931464 -169.395263 63.6600 -13.135241 T0251 -10.555346 57.839760 0.586586 -433.345455 3.71267 -18806.586856 1982 -718.584087 -776.423847 -45.223024 False
15972 -91.437609 0.000000 0.000000 -12.268465 -178.217448 66.2025 -14.178784 T0251 -8.726733 55.337226 0.619550 -408.846264 3.40968 -18292.898968 1983 -702.237162 -757.574388 -43.899086 False
15973 -90.770878 0.000000 0.000000 -12.875443 -170.135792 66.2050 -11.894053 T0251 -9.199376 52.754278 0.629604 -422.369879 3.34732 -18486.265954 1984 -708.605858 -761.360136 -44.114716 False
15974 -92.095448 0.000000 0.000000 -12.317949 -171.162376 65.0450 -12.396215 T0251 -9.727326 52.296442 0.610305 -426.150941 3.29791 -18811.978836 1985 -720.084559 -772.381001 -48.530746 False
15975 -91.741652 0.000000 0.000000 -13.778394 -169.587986 64.1200 -11.415608 T0251 -9.470298 53.024501 0.610773 -413.416375 3.37226 -18375.427272 1986 -702.092985 -755.117486 -45.707172 False
15976 -89.808516 0.000000 0.000000 -14.150543 -168.499845 64.8150 -11.146861 T0251 -9.061044 54.955143 0.604477 -420.907855 3.42709 -18327.910269 1987 -702.353846 -757.308989 -43.734325 False
15977 -93.551472 0.000000 0.000000 -12.096543 -170.950329 66.6650 -11.357688 T0251 -9.434898 53.627463 0.645486 -422.623081 3.24725 -18625.428596 1988 -709.640245 -763.267708 -43.253698 False
15978 -93.275770 0.000000 0.000000 -11.312046 -160.942602 65.9725 -14.144596 T0251 -8.687470 50.687311 0.631807 -406.856309 3.18153 -18612.543824 1989 -691.108826 -741.796137 -46.577342 False
15979 -91.672222 0.000000 0.000000 -16.291054 -176.140083 65.7425 -14.607518 T0251 -10.930240 50.002463 0.618662 -412.185224 3.17333 -18749.123612 1990 -719.043700 -769.046163 -47.219822 False
15980 -92.684538 0.000000 0.000000 -12.000243 -169.392051 66.9000 -14.179498 T0251 -9.521931 54.304999 0.628868 -400.109164 3.04343 -18597.520155 1991 -689.110858 -743.415857 -45.528432 False
15981 -91.844435 0.000000 0.000000 -7.185468 -175.537399 64.8150 -16.855605 T0251 -9.541850 54.806245 0.613521 -407.732177 3.11175 -18384.509247 1992 -695.088812 -749.895057 -41.198123 False
15982 -93.604086 0.000000 0.000000 -9.096837 -173.055957 64.5850 -8.099020 T0251 -9.869841 51.216066 0.580552 -402.629671 3.36591 -18314.099901 1993 -686.587942 -737.804008 -41.448596 False
15983 -92.382859 0.000000 0.000000 -12.557272 -171.255735 64.3500 -12.012246 T0251 -9.809064 53.590184 0.578085 -415.315758 3.61281 -18286.540450 1994 -703.030931 -756.621115 -43.288180 False
15984 -91.572649 0.000000 0.000000 -11.387066 -173.889063 65.9700 -13.623254 T0251 -10.029134 51.003160 0.615589 -393.825531 3.51837 -18341.028348 1995 -682.435542 -733.438702 -39.112006 False
15985 -91.997823 0.000000 0.000000 -11.952323 -168.868828 67.3600 -12.763230 T0251 -11.184105 50.134098 0.624565 -400.715469 3.51635 -18289.442213 1996 -689.378736 -739.512834 -42.031057 False
15986 -92.580205 0.000000 0.000000 -12.564112 -174.420359 64.5850 -15.552759 T0251 -11.095399 51.954294 0.602795 -411.471148 3.55720 -18712.681771 1997 -708.521102 -760.475396 -42.791414 False
15987 -91.906959 0.000000 0.000000 -11.353787 -164.049919 64.8175 -13.034382 T0251 -10.440080 51.679293 0.618961 -413.541618 3.69826 -18482.524375 1998 -697.387374 -749.066667 -44.739923 False
15988 -92.174700 0.000000 0.000000 -13.870746 -169.978603 66.4375 -12.332784 T0251 -10.408912 52.870109 0.630862 -413.136788 3.76465 -18162.063254 1999 -699.070749 -751.940858 -40.038325 False
15989 -92.540159 0.000000 0.000000 -11.985480 -170.363053 67.3625 -14.368892 T0251 -9.328371 50.713208 0.627825 -412.905327 3.60437 -18345.451381 2000 -701.686860 -752.400068 -40.908787 False
15990 -93.095717 0.000000 0.000000 -13.098034 -170.599752 62.7325 -11.713871 T0251 -10.747843 50.215999 0.604300 -415.671737 3.60184 -18262.328080 2001 -704.374936 -754.590935 -39.663980 False
15991 -92.676088 0.000000 0.000000 -14.122128 -169.948793 66.2025 -14.229989 T0251 -11.293775 49.157165 0.631086 -414.410748 3.53470 -18422.168203 2002 -706.120250 -755.277415 -38.595893 False
15992 -93.958813 0.000000 0.000000 -14.912586 -167.463426 67.3625 -10.171771 T0251 -11.594634 51.465175 0.631761 -420.387811 3.46452 -18366.671484 2003 -706.468363 -757.933538 -39.444496 False
15993 -94.022752 0.000000 0.000000 -12.600719 -174.908114 66.9000 -17.473581 T0251 -10.337650 49.544524 0.630106 -428.279276 3.37273 -18813.711926 2004 -732.425108 -781.969632 -44.347541 False
15994 -92.992885 0.000000 0.000000 -14.020112 -176.323051 66.2050 -13.726338 T0251 -10.916914 49.111358 0.625945 -411.773704 3.52210 -18435.053224 2005 -711.686588 -760.797946 -41.044942 False
15995 -92.987056 0.000000 0.000000 -14.297675 -182.005906 68.9850 -17.013077 T0251 -10.174647 50.154168 0.636392 -416.846914 3.31385 -18782.711506 2006 -723.489913 -773.644081 -40.318805 False
15996 -92.419793 0.000000 0.000000 -11.580190 -176.139795 66.4350 -17.880439 T0251 -10.599407 49.903316 0.606045 -433.331408 3.67136 -18611.220785 2007 -731.276632 -781.179948 -39.228917 False
15997 -94.937787 0.000000 0.000000 -12.534783 -167.851668 66.9000 -14.921327 T0251 -9.964347 49.666925 0.625081 -419.639569 3.66594 -18521.880007 2008 -710.155956 -759.822881 -39.973399 False
15998 -92.122035 0.000000 0.000000 -8.943756 -176.753484 69.4450 -16.339722 T0251 -10.465368 49.254122 0.647274 -410.289604 3.30106 -18677.492763 2009 -706.011481 -755.265603 -40.351634 False
15999 -95.439265 0.000000 0.000000 -13.273688 -173.621326 66.4350 -12.888568 T0251 -11.383948 51.040138 0.624174 -410.635877 3.51677 -18524.682746 2010 -704.827677 -755.867815 -38.625143 False

16000 rows × 19 columns


In [14]:
raw_test_data["Name"].unique()


Out[14]:
array(['1MBA', 'T0792', 'T0815', 'T0766', 'T0784', 'T0803', 'T0833',
       'T0251'], dtype=object)

In [35]:
FEATURES = ['Rw',
#      'VTotal',
#      'QGO',
#      'VwithoutGo',
     'Burial',
     'Water',
     'Rama',
#      'DSSP',
#      'P_AP',
     'Helix',
#      'Frag_Mem'
               ]
n = 5
def my_transform(data, label, degree, FEATURES=FEATURES):

    # LABEL = "Qw"
    LABEL = label
    PolynomialDegree = degree

    num_attribs = FEATURES
    cat_attribs = [LABEL]
    num_pipeline = Pipeline([
            ('selector', DataFrameSelector(num_attribs)),
            ('std_scaler', StandardScaler()),
            ('poly', PolynomialFeatures(degree=PolynomialDegree, include_bias=False))
        ])
    cat_pipeline = Pipeline([
            ('selector', DataFrameSelector(cat_attribs))
        ])

    full_pipeline = FeatureUnion(transformer_list=[
            ("num_pipeline", num_pipeline),
            ("cat_pipeline", cat_pipeline),
        ])
    return full_pipeline.fit_transform(data)

def my_transform_predict(data, degree, FEATURES=FEATURES):

    # LABEL = "Qw"
    PolynomialDegree = degree

    num_attribs = FEATURES
    num_pipeline = Pipeline([
            ('selector', DataFrameSelector(num_attribs)),
            ('std_scaler', StandardScaler()),
            ('poly', PolynomialFeatures(degree=PolynomialDegree, include_bias=False))
        ])
    return num_pipeline.fit_transform(data)

In [36]:
g = sns.FacetGrid(raw_test_data, col="Name", hue="isGood", col_wrap=4)
g = g.map(plt.scatter, "Rw", "Qw")



In [37]:
g = sns.FacetGrid(raw_test_data, col="Name", hue="isGood", col_wrap=4)
g = g.map(plt.scatter, "VwithoutGo", "Qw")



In [38]:
g = sns.FacetGrid(raw_test_data, col="Name", hue="isGood", col_wrap=4)
g = g.map(plt.scatter, "QGO", "Qw")



In [39]:
raw_data_T0784 = raw_test_data.groupby("Name").get_group("T0784")
# raw_data_T0792 = raw_test_data.groupby("Name").get_group("T0792")
# raw_data = pd.concat([raw_data_T0784, raw_data_T0792])
# raw_data = raw_data_T0792
# raw_data = raw_test_data.groupby("Name").get_group("1mba")
raw_data = raw_data_T0784

In [40]:
# FEATURES = ["Rw", "VTotal", "QGO"]
# FEATURES = ["Rw", "VTotal", "QGO", "Burial", "Frag_Mem", "Water"]
# FEATURES = list(raw_test_data.columns[2:-3])
def train_and_test(raw_data, label="Qw", degree=1, p=0.1):
    # my_full_pipeline = Pipeline([
    # #         ('removeFirstFrame', RemoveFirstFrame(frame)),
    #         ('featureSelection', full_pipeline)
    # ])

    from sklearn.model_selection import StratifiedShuffleSplit

    split = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=142)
    for train_index, test_index in split.split(raw_data, raw_data["isGood"]):
        strat_train_set = raw_data.iloc[train_index]
        strat_test_set = raw_data.iloc[test_index]
    # strat_test_set[LABEL].value_counts() / len(strat_test_set)
    X_train = my_transform(strat_train_set, label, degree)
    X_test = my_transform(strat_test_set, label, degree)
    train_y = X_train[:,-1]
    train_set = X_train[:,:-1]
    test_y = X_test[:,-1]
    test_set = X_test[:,:-1]
    return (train_set, train_y, test_set, test_y)

In [54]:
label = "isGood"
degree = 1
p = 0.1
train_set, train_y, test_set, test_y = train_and_test(raw_data, label=label, degree=degree)
log_clf = LogisticRegression(random_state=140, penalty='l2')

# log_clf = LogisticRegression(random_state=14, class_weight={0:p, 1:(1-p)}, penalty='l1')
log_clf.fit(train_set, train_y)
y_pred = log_clf.predict(train_set)
# n = 100
prediction_list = []
for name, data in raw_test_data.groupby("Name"):
    print(name)
#     X = full_pipeline.fit_transform(data)
    X = my_transform(data, label, degree)
    eval_y = X[:,-1]
    eval_set = X[:,:-1]
    test= log_clf.predict_proba(eval_set)[:,1]
    one = data.assign(prediction=test)
    prediction_list.append(one)
#     prediction_list.append(pd.Series(test))
t = pd.concat(prediction_list)
# t = raw_test_data.assign(prediction=prediction.values)
best_by_prediction = t.groupby("Name").apply(choose_top, n=5, col="prediction").query("chosen==True")


1MBA
T0251
T0766
T0784
T0792
T0803
T0815
T0833

In [55]:
a = best_by_prediction.reset_index(drop=True)[["Name", "Qw", "GDT", "Rmsd", "prediction", "Step"]].groupby("Name").apply(lambda x: x.sort_values("prediction", ascending=False))

In [56]:
# best_by_prediction.to_csv("/Users/weilu/Research/davinci/structure_selector_apr09/selected.csv")
# a = pd.read_csv("/Users/weilu/Research/davinci/structure_selector_apr09/selected.csv")
# for i, row in a.iterrows():
#     print(i, row["Name"], row["Step"])

In [57]:


In [58]:


In [59]:
print(*(zip(FEATURES, log_clf.coef_[0])))


('Rw', -0.49825533989977439) ('Burial', 0.11535990415932022) ('Water', -0.096733273202118164) ('Rama', -0.16924406917860341) ('Helix', -0.16299557498004696)

In [60]:
n = 5
chosen_by_rw = raw_test_data.groupby("Name").apply(choose_top_rw, n)
chosen_by_vtotal = raw_test_data.groupby("Name").apply(choose_top_vtotal, n)
chosen_by_qgo = raw_test_data.groupby("Name").apply(choose_top, n=n, col="QGO", ascending=True)
top_rw = chosen_by_rw.query("chosen==True")
top_vtotal = chosen_by_vtotal.query("chosen==True")
top_qgo = chosen_by_qgo.query("chosen==True")

In [63]:
# T0784
label = "GDT"
best = raw_test_data.groupby("Name").apply(choose_top, n=n, col=label, ascending=True).query("chosen==True")
a2 = best.reset_index(drop=True)[["Name", label]].rename(index=str,columns={label:"Best"})
b2 = top_rw.reset_index(drop=True)[[label, "Name"]].rename(index=str,columns={label:"Rw"})
c2 = top_vtotal.reset_index(drop=True)[[label, "Name"]].rename(index=str,columns={label:"Awsem"})
d2 = best_by_prediction.reset_index(drop=True)[["Name", label]].rename(index=str,columns={label:"Selected"})
e2 = top_qgo.reset_index(drop=True)[[label, "Name"]].rename(index=str,columns={label:"QGo"})
# final2 = a2.merge(b2, on="Name").merge(c2, on="Name").merge(d2, on="Name").melt(id_vars="Name")
# final2 = pd.concat([a2, e2["QGo"], b2["Rw"], c2["Awsem"], d2["prediction"]], axis=1).melt(id_vars="Name")
final3 = pd.concat([a2, b2["Rw"], c2["Awsem"], d2["Selected"]], axis=1).melt(id_vars="Name", value_name=label, var_name=" ")
# sns.pointplot("Name","value", data=final2, hue="variable", hue_order=["prediction", "Awsem", "Rw", "best"])
# sns.stripplot("value", "Name", data=final2, hue="variable")
order = ["T0251", "T0833", "T0815", "T0803", "T0792", "T0784", "T0766", "1MBA"]
sns.pointplot("Name", label, data=final3, hue=" ", errwidth=0, order=order)
# plt.savefig("/Users/weilu/Desktop/fig6_GDT.png", dpi=300)
# plt.ylim([0.4,1])
final_gdt = final3



In [62]:
# T0784
label = "Qw"
best = raw_test_data.groupby("Name").apply(choose_top, n=n, col=label).query("chosen==True")
a2 = best.reset_index(drop=True)[["Name", label]].rename(index=str,columns={label:"Best"})
b2 = top_rw.reset_index(drop=True)[[label, "Name"]].rename(index=str,columns={label:"Rw"})
c2 = top_vtotal.reset_index(drop=True)[[label, "Name"]].rename(index=str,columns={label:"Awsem"})
d2 = best_by_prediction.reset_index(drop=True)[["Name", label]].rename(index=str,columns={label:"Selected"})
e2 = top_qgo.reset_index(drop=True)[[label, "Name"]].rename(index=str,columns={label:"QGo"})
# final2 = a2.merge(b2, on="Name").merge(c2, on="Name").merge(d2, on="Name").melt(id_vars="Name")
# final2 = pd.concat([a2, e2["QGo"], b2["Rw"], c2["Awsem"], d2["prediction"]], axis=1).melt(id_vars="Name")
final3 = pd.concat([a2, b2["Rw"], c2["Awsem"], d2["Selected"]], axis=1).melt(id_vars="Name", value_name=label, var_name=" ")
# sns.pointplot("Name","value", data=final2, hue="variable", hue_order=["prediction", "Awsem", "Rw", "best"])
# sns.stripplot("value", "Name", data=final2, hue="variable")
order = ["T0251", "T0833", "T0815", "T0803", "T0792", "T0784", "T0766", "1MBA"]
sns.pointplot("Name", label, data=final3, hue=" ", errwidth=0, order=order)
plt.savefig("/Users/weilu/Desktop/fig6_Qw.png", dpi=300)
# plt.ylim([0.4,1])
final_Qw = final3



In [16]:
print(*(zip(FEATURES, log_clf.coef_[0])))


('Rw', -0.19232275270600122) ('QGO', -1.6223111058603839) ('VwithoutGo', -0.28966535406693028)

In [25]:
n = 5
chosen_by_rw = raw_test_data.groupby("Name").apply(choose_top_rw, n)
chosen_by_vtotal = raw_test_data.groupby("Name").apply(choose_top_vtotal, n)
chosen_by_qgo = raw_test_data.groupby("Name").apply(choose_top, n=n, col="QGO", ascending=True)
top_rw = chosen_by_rw.query("chosen==True")
top_vtotal = chosen_by_vtotal.query("chosen==True")
top_qgo = chosen_by_qgo.query("chosen==True")

In [15]:
# T0784
label = "GDT"
best = raw_test_data.groupby("Name").apply(choose_top, n=n, col=label).query("chosen==True")
a2 = best.reset_index(drop=True)[["Name", label]].rename(index=str,columns={label:"Best"})
b2 = top_rw.reset_index(drop=True)[[label, "Name"]].rename(index=str,columns={label:"Rw"})
c2 = top_vtotal.reset_index(drop=True)[[label, "Name"]].rename(index=str,columns={label:"Awsem"})
d2 = best_by_prediction.reset_index(drop=True)[["Name", label]].rename(index=str,columns={label:"Selected"})
e2 = top_qgo.reset_index(drop=True)[[label, "Name"]].rename(index=str,columns={label:"QGo"})
# final2 = a2.merge(b2, on="Name").merge(c2, on="Name").merge(d2, on="Name").melt(id_vars="Name")
# final2 = pd.concat([a2, e2["QGo"], b2["Rw"], c2["Awsem"], d2["prediction"]], axis=1).melt(id_vars="Name")
final3 = pd.concat([a2, b2["Rw"], c2["Awsem"], d2["Selected"]], axis=1).melt(id_vars="Name", value_name=label, var_name=" ")
# sns.pointplot("Name","value", data=final2, hue="variable", hue_order=["prediction", "Awsem", "Rw", "best"])
# sns.stripplot("value", "Name", data=final2, hue="variable")
order = ["T0251", "T0833", "T0815", "T0803", "T0792", "T0784", "T0766", "1MBA"]
sns.pointplot("Name", label, data=final3, hue=" ", errwidth=0, order=order)
# plt.savefig("/Users/weilu/Desktop/fig6_GDT.png", dpi=300)
# plt.ylim([0.4,1])
final_gdt = final3



In [23]:
# T0784
label = "Qw"
best = raw_test_data.groupby("Name").apply(choose_top, n=n, col=label).query("chosen==True")
a2 = best.reset_index(drop=True)[["Name", label]].rename(index=str,columns={label:"Best"})
b2 = top_rw.reset_index(drop=True)[[label, "Name"]].rename(index=str,columns={label:"Rw"})
c2 = top_vtotal.reset_index(drop=True)[[label, "Name"]].rename(index=str,columns={label:"Awsem"})
d2 = best_by_prediction.reset_index(drop=True)[["Name", label]].rename(index=str,columns={label:"Selected"})
e2 = top_qgo.reset_index(drop=True)[[label, "Name"]].rename(index=str,columns={label:"QGo"})
# final2 = a2.merge(b2, on="Name").merge(c2, on="Name").merge(d2, on="Name").melt(id_vars="Name")
# final2 = pd.concat([a2, e2["QGo"], b2["Rw"], c2["Awsem"], d2["prediction"]], axis=1).melt(id_vars="Name")
final3 = pd.concat([a2, b2["Rw"], c2["Awsem"], d2["Selected"]], axis=1).melt(id_vars="Name", value_name=label, var_name=" ")
# sns.pointplot("Name","value", data=final2, hue="variable", hue_order=["prediction", "Awsem", "Rw", "best"])
# sns.stripplot("value", "Name", data=final2, hue="variable")
order = ["T0251", "T0833", "T0815", "T0803", "T0792", "T0784", "T0766", "1MBA"]
sns.pointplot("Name", label, data=final3, hue=" ", errwidth=0, order=order)
plt.savefig("/Users/weilu/Desktop/fig6_Qw.png", dpi=300)
# plt.ylim([0.4,1])
final_Qw = final3



In [24]:
# T0784
label = "Rmsd"
best = raw_test_data.groupby("Name").apply(choose_top, n=n, col=label, ascending=True).query("chosen==True")
a2 = best.reset_index(drop=True)[["Name", label]].rename(index=str,columns={label:"Best"})
b2 = top_rw.reset_index(drop=True)[[label, "Name"]].rename(index=str,columns={label:"Rw"})
c2 = top_vtotal.reset_index(drop=True)[[label, "Name"]].rename(index=str,columns={label:"Awsem"})
d2 = best_by_prediction.reset_index(drop=True)[["Name", label]].rename(index=str,columns={label:"Selected"})
e2 = top_qgo.reset_index(drop=True)[[label, "Name"]].rename(index=str,columns={label:"QGo"})
# final2 = a2.merge(b2, on="Name").merge(c2, on="Name").merge(d2, on="Name").melt(id_vars="Name")
# final2 = pd.concat([a2, e2["QGo"], b2["Rw"], c2["Awsem"], d2["prediction"]], axis=1).melt(id_vars="Name")
final3 = pd.concat([a2, b2["Rw"], c2["Awsem"], d2["Selected"]], axis=1).melt(id_vars="Name", value_name=label, var_name=" ")
# sns.pointplot("Name","value", data=final2, hue="variable", hue_order=["prediction", "Awsem", "Rw", "best"])
# sns.stripplot("value", "Name", data=final2, hue="variable")
order = ["T0251", "T0833", "T0815", "T0803", "T0792", "T0784", "T0766", "1MBA"]
sns.pointplot("Name", label, data=final3, hue=" ", errwidth=0, order=order)
plt.savefig("/Users/weilu/Desktop/fig6_Rmsd.png", dpi=300)
# plt.ylim([0.4,1])
final_Rmsd = final3



In [32]:
final = pd.concat([final_gdt, final_Qw["Qw"], final_Rmsd["Rmsd"]], axis=1)

In [64]:
f, (ax1, ax2, ax3) = plt.subplots(3, 1, sharex=True, figsize=(15,15))
sns.pointplot("Name", "GDT", data=final, hue=" ", errwidth=0, order=order, ax=ax1)
ax1.legend_.remove()
ax1.get_xaxis().set_visible(False)
sns.pointplot("Name", "Qw", data=final, hue=" ", errwidth=0, order=order, ax=ax2)
ax2.legend_.remove()
ax2.get_xaxis().set_visible(False)
sns.pointplot("Name", "Rmsd", data=final, hue=" ", errwidth=0, order=order, ax=ax3)
plt.legend(loc=9, bbox_to_anchor=(0.9, 2.8), ncol=1, fontsize='x-large')
plt.savefig("/Users/weilu/Desktop/fig6.png", dpi=300)



In [ ]: