In [30]:
%matplotlib inline
from matplotlib import pylab as pl
import cPickle as pickle
import pandas as pd
import numpy as np
import os
import random

In [31]:
import sys
sys.path.append('..')

Read precomputed features

uncommoent the relevant pipeline in ../seizure_detection.py and run

cd ..
./doall data

or

./doall td
./doall tt

In [32]:
FEATURES = 'gen-8_allbands2-usf-w60-b0.2-b4-b8-b12-b30-b70'

In [33]:
nbands = 0
nwindows = 0
for p in FEATURES.split('-'):
    if p[0] == 'b':
        nbands += 1
    elif p[0] == 'w':
        nwindows = int(p[1:])

nbands -= 1
nbands, nwindows


Out[33]:
(5, 60)

In [34]:
NUNITS = 1

In [35]:
from common.data import CachedDataLoader
cached_data_loader = CachedDataLoader('../data-cache')

In [36]:
def read_data(target, data_type):
    fname = 'data_%s_%s_%s'%(data_type,target,FEATURES)
    print fname
    return cached_data_loader.load(fname,None)

Predict


In [37]:
def process(X, percentile=[0.1,0.5,0.9],nunits=NUNITS):
    N, Nf = X.shape
    print '# samples',N,'# power points', Nf
    print '# channels', Nf / (nbands*nwindows)
    
    newX = []
    for i in range(N):
        nw = nwindows//nunits
        windows = X[i,:].reshape((nunits,nw,-1))
        sorted_windows = np.sort(windows, axis=1)
        features = np.concatenate([sorted_windows[:,int(p*nw),:] for p in percentile], axis=-1)
        newX.append(features.ravel())
    newX = np.array(newX)

    return newX

In [40]:
from sklearn import preprocessing
from nolearn.dbn import DBN
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler

scale = StandardScaler()

min_max_scaler = preprocessing.MinMaxScaler() # scale features to be [0..1] which is DBN requirement

dbn = DBN(
    [-1, 300, -1], # first layer has size X.shape[1], hidden layer(s), last layer will have number of classes in y (2))
    learn_rates=0.3,
    learn_rate_decays=0.9,
    epochs=500,
    dropouts=[0.1,0.5],
    verbose=0,
    )

clf = Pipeline([('min_max_scaler', min_max_scaler), ('dbn', dbn)])

In [41]:
fpout = open('../submissions/141105-predict.3.csv','w')
print >>fpout,'clip,preictal'

In [27]:
for target in ['Dog_1', 'Dog_2', 'Dog_3', 'Dog_4', 'Dog_5', 'Patient_1', 'Patient_2']:
    pdata = read_data(target, 'preictal') # positive examples
    ndata = read_data(target, 'interictal') # negative examples
    X = np.concatenate((pdata.X, ndata.X))
    X = process(X)
    _, NF = X.shape
    
    X = scale.fit_transform(X)
    X = np.clip(X,-3,5)
    
    clf.set_params(dbn__layer_sizes=[NF,300,2]) # we need to reset each time because NF is different
    y = np.zeros(X.shape[0])
    y[:pdata.X.shape[0]] = 1
    # shuffle
    idxs=range(len(y))
    random.shuffle(idxs)
    X = X[idxs,:]
    y = y[idxs]
    # model
    clf.fit(X,y)
    # predict
    tdata = read_data(target, 'test') # test examples
    Xt = process(tdata.X)
    
    Xt = scale.transform(Xt)
    Xt = np.clip(Xt,-3,5)
    
    y_proba = clf.predict_proba(Xt)[:,1]
    # write results
    for i,p in enumerate(y_proba):
        print >>fpout,'%s_test_segment_%04d.mat,%.15f' % (target, i+1, p)


data_preictal_Dog_1_gen-8_allbands2-usf-w60-b0.2-b4-b8-b12-b30-b70
data_interictal_Dog_1_gen-8_allbands2-usf-w60-b0.2-b4-b8-b12-b30-b70
# samples 664 # power points 4800
# channels 16
data_test_Dog_1_gen-8_allbands2-usf-w60-b0.2-b4-b8-b12-b30-b70
# samples 502 # power points 4800
# channels 16
data_preictal_Dog_2_gen-8_allbands2-usf-w60-b0.2-b4-b8-b12-b30-b70
data_interictal_Dog_2_gen-8_allbands2-usf-w60-b0.2-b4-b8-b12-b30-b70
# samples 822 # power points 4800
# channels 16
data_test_Dog_2_gen-8_allbands2-usf-w60-b0.2-b4-b8-b12-b30-b70
# samples 1000 # power points 4800
# channels 16
data_preictal_Dog_3_gen-8_allbands2-usf-w60-b0.2-b4-b8-b12-b30-b70
data_interictal_Dog_3_gen-8_allbands2-usf-w60-b0.2-b4-b8-b12-b30-b70
# samples 1992 # power points 4800
# channels 16
data_test_Dog_3_gen-8_allbands2-usf-w60-b0.2-b4-b8-b12-b30-b70
# samples 907 # power points 4800
# channels 16
data_preictal_Dog_4_gen-8_allbands2-usf-w60-b0.2-b4-b8-b12-b30-b70
data_interictal_Dog_4_gen-8_allbands2-usf-w60-b0.2-b4-b8-b12-b30-b70
# samples 1541 # power points 4800
# channels 16
data_test_Dog_4_gen-8_allbands2-usf-w60-b0.2-b4-b8-b12-b30-b70
# samples 990 # power points 4800
# channels 16
data_preictal_Dog_5_gen-8_allbands2-usf-w60-b0.2-b4-b8-b12-b30-b70
data_interictal_Dog_5_gen-8_allbands2-usf-w60-b0.2-b4-b8-b12-b30-b70
# samples 680 # power points 4500
# channels 15
data_test_Dog_5_gen-8_allbands2-usf-w60-b0.2-b4-b8-b12-b30-b70
# samples 191 # power points 4500
# channels 15
data_preictal_Patient_1_gen-8_allbands2-usf-w60-b0.2-b4-b8-b12-b30-b70
data_interictal_Patient_1_gen-8_allbands2-usf-w60-b0.2-b4-b8-b12-b30-b70
# samples 188 # power points 4500
# channels 15
data_test_Patient_1_gen-8_allbands2-usf-w60-b0.2-b4-b8-b12-b30-b70
# samples 195 # power points 4500
# channels 15
data_preictal_Patient_2_gen-8_allbands2-usf-w60-b0.2-b4-b8-b12-b30-b70
data_interictal_Patient_2_gen-8_allbands2-usf-w60-b0.2-b4-b8-b12-b30-b70
# samples 180 # power points 7200
# channels 24
data_test_Patient_2_gen-8_allbands2-usf-w60-b0.2-b4-b8-b12-b30-b70
# samples 150 # power points 7200
# channels 24

In [28]:
fpout.close()

In [29]:
pl.hist(Xt.ravel(),bins=50)


Out[29]:
(array([  757.,    44.,    46.,    59.,    76.,    92.,   148.,   182.,
          212.,   267.,   375.,   475.,   632.,   795.,  1005.,  1272.,
         1586.,  2093.,  2300.,  2492.,  2541.,  2519.,  2542.,  2454.,
         2403.,  2202.,  2112.,  2118.,  2105.,  1982.,  1774.,  1722.,
         1599.,  1409.,  1284.,  1195.,   997.,   838.,   716.,   589.,
          558.,   491.,   358.,   317.,   274.,   208.,   193.,   174.,
          112.,  1306.]),
 array([-3.  , -2.84, -2.68, -2.52, -2.36, -2.2 , -2.04, -1.88, -1.72,
        -1.56, -1.4 , -1.24, -1.08, -0.92, -0.76, -0.6 , -0.44, -0.28,
        -0.12,  0.04,  0.2 ,  0.36,  0.52,  0.68,  0.84,  1.  ,  1.16,
         1.32,  1.48,  1.64,  1.8 ,  1.96,  2.12,  2.28,  2.44,  2.6 ,
         2.76,  2.92,  3.08,  3.24,  3.4 ,  3.56,  3.72,  3.88,  4.04,
         4.2 ,  4.36,  4.52,  4.68,  4.84,  5.  ]),
 <a list of 50 Patch objects>)

In [ ]: