In [2]:
import numpy as np

from bokeh.plotting import HBox, VBox, figure, show, output_file, GridPlot
from bokeh.models.mappers import LinearColorMapper
from bokeh.models import BasicTicker, Grid 

from sklearn import metrics
from sklearn import preprocessing
from sklearn.datasets import fetch_olivetti_faces
from sklearn.utils.validation import check_random_state
from sklearn.ensemble import ExtraTreesClassifier, RandomForestClassifier
from sklearn.neighbors import KNeighborsRegressor
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import Ridge
from sklearn.cross_validation import train_test_split
from sklearn.covariance import GraphLassoCV, ledoit_wolf
from sklearn.grid_search import GridSearchCV
from scipy.spatial import distance
 
import sklearn
import OVFM.Model as md
import OVFM.FeatureMap as fm
import OVFM.Risk as rsk
import OVFM.LearningRate as lr
import OVFM.DataGeneration as dg
import OVFM.SGD as sgd
 
import time
import sys

In [3]:
def simplex_map( i ):
    if i > 2:
        return np.bmat( [ [ [ [ 1 ] ], np.repeat( -1.0 / ( i - 1 ), i - 1 ).reshape( ( 1, i - 1 ) ) ], [ np.zeros( ( i - 2, 1 ) ), simplex_map( i - 1 ) * np.sqrt( 1.0 - 1.0 / ( i - 1 ) ** 2 ) ] ] )
    elif i == 2:
        return np.array( [ [ 1, -1 ] ] )
    else:
        raise "invalid number of classes"

In [54]:
# Load the training and test data sets
train = np.genfromtxt('train.csv', delimiter=',',skip_header=1)
# test = np.genfromtxt('test.csv', delimiter=',',skip_header=1)

# Create numpy arrays for use with scikit-learn
train_X = train[:,1:-1].astype( float )
train_y = train[:,-1:]
# test_X = test[:,1:]

D = 1000
gamma = 0.1
C = 0.000
eta0 = 0.5

scaler  = preprocessing.StandardScaler( )
train_X = scaler.fit_transform( train_X )

In [55]:
rf = RandomForestClassifier( n_estimators = 500 )
X,X_,y,y_ = train_test_split( train_X, train_y.ravel( ), test_size = 0.33 )

rf.fit( X, y )
y_rf = rf.predict( X_ )

print metrics.classification_report( y_, y_rf )
print metrics.accuracy_score( y_, y_rf )


             precision    recall  f1-score   support

        1.0       0.80      0.76      0.78       713
        2.0       0.79      0.70      0.74       724
        3.0       0.84      0.81      0.82       745
        4.0       0.92      0.97      0.94       736
        5.0       0.90      0.95      0.92       702
        6.0       0.83      0.86      0.84       719
        7.0       0.93      0.98      0.95       651

avg / total       0.86      0.86      0.86      4990

0.858517034068

In [59]:
lb      = preprocessing.LabelBinarizer( neg_label = 0, pos_label = 1 )
train_y = np.dot( lb.fit_transform( train_y ).astype( float ), simplex_map( 7 ).T )

# gff = fm.GaussianFF( gamma, train_X.shape[ 1 ], D )
# Kex = gff.kernel_exact( train_X )
# Kap = gff.kernel_approx( train_X )
# fig, axes = plt.subplots( nrows=1, ncols=2, sharex=False, sharey=False )
# im = axes[ 0 ].imshow( Kex, origin = 'lower' )
# im.set_cmap( 'hot' )
# axes[ 0 ].set_title( 'Kernel exact' )
# im = axes[ 1 ].imshow( Kap, origin = 'lower' )
# im.set_cmap( 'hot' )
# axes[ 1 ].set_title( 'Kernel approximation' )
# plt.show( )
# print 'Kernel approximation MSE:', np.linalg.norm( Kex - Kap ) ** 2 / train_X.size

# M = np.cov( train_y.T )
# Dg = np.diag( np.diag( M ) + np.sum( M, axis = 0 ) )
# L = np.linalg.inv( Dg - M )

In [107]:
D = 1000
gamma = 0.2
C = 1e-5
eta0 = 1.0
L = simplex_map( 7 )

In [112]:
risk = rsk.Ridge( C )
lc = lr.Constant( 1. * eta0 )
lb = lr.Constant( 0.01 * eta0 )

X,X_,y,y_ = train_test_split( train_X, train_y, test_size = .33 )
model = md.Model( fm.DecomposableFF( gamma, train_X.shape[ 1 ], D, B = np.eye( 6 ) ) )
opt = sgd.SGD( risk, 5.0, lc, lb, 10, 10000 )
opt.fit( model, X, y )
y_rf = model( X_ )

S = simplex_map( 7 )
y_   = np.argmax( np.dot( y_, S ), axis = 1 )
y_rf = np.argmax( np.dot( y_rf, S ), axis = 1 )


0 0.165560177132 9e-06 0.00176383420738 0.100224615107 0.0107197433995
10000 0.0737598159906 0.0139015135305 0.0788687881154 152.008239447 0.576861488471
20000 0.066150321653 0.0168632148664 0.0849949975674 256.07962449 0.769727201819
30000 0.0623388881613 0.0221414430061 0.104866869848 344.238991882 0.886574765083
40000 0.0598341208272 0.0246087614811 0.11297575213 423.0117699 0.965790310925
50000 0.0580420555561 0.0290087151693 0.124583227819 493.630102144 1.04162124391
60000 0.0566791797539 0.0282307668542 0.118277156999 557.827264572 1.11995643438
70000 0.0555665105807 0.0298488371034 0.121501723271 613.678810319 1.16677086827
80000 0.0545180074145 0.0325503583001 0.129809861282 670.232574705 1.2355082644
90000 0.0540366848221 0.0327928618238 0.127940122863 723.764658689 1.28945724415
100000 0.0534907530678 0.035916335675 0.133796691952 771.551590539 1.3670423573
110000 0.0529840929645 0.0345868416652 0.138380367773 817.1410312 1.39878436952
120000 0.0527290738733 0.0335604997704 0.131300343061 862.125544025 1.43399371178
130000 0.052904169505 0.0360531276846 0.140215619658 903.771865239 1.50242524398
140000 0.0519534654636 0.0402506382663 0.147947155962 943.271757073 1.53611393869
150000 0.051634552902 0.0374849531631 0.142909194953 979.582029373 1.58025363145
160000 0.0516209296871 0.0392691101707 0.14338033064 1015.07158314 1.63565523772
170000 0.0512976122984 0.0389060223654 0.141543137538 1049.23690423 1.63061564675
180000 0.0511333445608 0.0372660598967 0.134194050073 1083.10272291 1.70187025672
190000 0.0507208618928 0.0397498336687 0.150598604592 1115.71400842 1.72083144112
200000 0.0504373482146 0.0393293859872 0.14505012344 1144.03381761 1.75488554924
210000 0.0504400094609 0.0424813724079 0.142371437342 1174.45153157 1.77378269886
220000 0.0503085787262 0.0387547064138 0.143389375342 1200.58337273 1.79897605599
230000 0.0499717409472 0.0408075213872 0.150951805883 1227.59353525 1.85527973723
240000 0.0500998716967 0.0397118209807 0.150239137614 1255.20639804 1.85473278737
250000 0.0501862490309 0.0409953573981 0.150621754605 1278.95406569 1.89757312416
260000 0.0496857775031 0.0403764124448 0.15071957634 1302.3255833 1.90758169513
270000 0.049617137791 0.0425562944891 0.149023792913 1325.17501749 1.93993936724
280000 0.0495673277101 0.0407120011837 0.1440591653 1348.56970713 1.96312785672
290000 0.0495167181518 0.0423426846291 0.152214647661 1370.05612964 1.96828889484
300000 0.0493332879504 0.0428705262136 0.151171350447 1389.15699471 2.02143744646
310000 0.0492386151512 0.0421194412968 0.149254976831 1410.47268345 2.02113727512
320000 0.0492484562802 0.0438742430774 0.158706616103 1429.78247452 2.04103923127
330000 0.0491280922943 0.0430958623357 0.152704792448 1447.49796898 2.04843720035
340000 0.0491751328643 0.0442316363591 0.156956847858 1466.71170381 2.07073592572
350000 0.0490627209001 0.0437592913414 0.151133338795 1483.64490313 2.09668762561
360000 0.0488706462465 0.0430091234502 0.149521407585 1501.27207033 2.07964122289
370000 0.0488549658629 0.0443705077448 0.152247964155 1517.72723735 2.1526398938
380000 0.0488889495979 0.0471325191959 0.162524575324 1531.96621407 2.15271322797
390000 0.0487674874284 0.0443771684346 0.162177185443 1548.1789345 2.17928437833
400000 0.0489425241496 0.0445607761712 0.155011162658 1564.67575694 2.18914064228
410000 0.0487830545274 0.0444018249998 0.155247883014 1574.452705 2.18890142486
420000 0.0486825163511 0.0445751727085 0.15224927582 1591.12831096 2.23500969834
430000 0.0487765357199 0.0434092343919 0.155400872716 1603.04949019 2.20192734154
440000 0.0487579392938 0.0442199271899 0.15190485771 1619.35092506 2.22909972563
450000 0.0486614561738 0.0459231464935 0.156836338101 1630.35658307 2.24561048221
460000 0.0484157792617 0.0453723736019 0.156787284477 1641.70166782 2.26036262986
470000 0.0483856416946 0.045743522683 0.159819428621 1653.87117309 2.28640658394
480000 0.0487349372864 0.0459591043547 0.159002067332 1668.34310391 2.28304483006
490000 0.0484850432507 0.0455738995514 0.159308065201 1678.33643251 2.24593028348
500000 0.0485319791549 0.0489882384049 0.160087077696 1689.95424427 2.30732971618

In [113]:
print metrics.classification_report( y_, y_rf )
print metrics.accuracy_score( y_, y_rf )


             precision    recall  f1-score   support

          0       0.67      0.63      0.65       685
          1       0.66      0.60      0.63       740
          2       0.76      0.67      0.71       719
          3       0.87      0.96      0.91       713
          4       0.80      0.89      0.84       701
          5       0.73      0.78      0.76       703
          6       0.92      0.92      0.92       729

avg / total       0.77      0.78      0.78      4990

0.77875751503

In [105]:
model.bias


Out[105]:
array([[ 0.11810472,  0.1321761 ,  0.01324193, -0.02279855, -0.05710079,
        -0.08284327]])

In [ ]: