In [2]:
cd ../backprop


/Users/darioml/src/fyp/backprop

In [14]:
%pylab inline
from nn_scipy_opti import NN_1HL
import numpy as np
from sklearn import cross_validation
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
from pprint import pprint
import scipy.io
import time


Populating the interactive namespace from numpy and matplotlib

In [4]:
def test_simple_backprop(data, labels, hidden_nodes, iterations, maxiter=200, plot=False):
    times = []
    accuracy = []
    
    for i in range(iterations):
        X_train, X_test, y_train, y_test = cross_validation.train_test_split(data, labels, test_size=0.2)
        nn = NN_1HL(maxiter=maxiter, hidden_layer_size=hidden_nodes)
        
        time_now = time.time()
        nn.fit(X_train, y_train)
        times.append( time.time() - time_now )
        
        accuracy.append(accuracy_score(y_test, nn.predict(X_test)))
    return np.mean(accuracy),np.mean(times),accuracy,times

In [21]:
data_file = scipy.io.loadmat('../data/mat/ball_with_speed.mat')

data = np.array(data_file['X'])
labels = np.array(data_file['Y'], 'uint8').T


labels = labels[0,:].flatten()
print




In [17]:
a,b,c,d = test_simple_backprop(data, labels, 20, 3, 400)

print a
print b
print c
print d


0.733716475096
26.0937476953
[0.70114942528735635, 0.74137931034482762, 0.75862068965517238]
[26.382867097854614, 26.194391012191772, 25.70398497581482]

In [24]:
data_1 = data/255

a,b,c,d = test_simple_backprop(data_1, labels, 20, 3, 400)

print a
print b
print c
print d


0.929118773946
9.13822809855
[0.93678160919540232, 0.91954022988505746, 0.93103448275862066]
[9.888365030288696, 8.147611141204834, 9.378708124160767]
scipy_optim.py:68: RuntimeWarning: divide by zero encountered in log
  costNegative = (1 - Y) * np.log(1 - h).T
scipy_optim.py:68: RuntimeWarning: invalid value encountered in multiply
  costNegative = (1 - Y) * np.log(1 - h).T

In [25]:
data_1 = data/255

a,b,c,d = test_simple_backprop(data_1, labels, 23, 4, 600)

print a
print b
print c
print d


0.916666666667
12.2243272662
[0.96551724137931039, 0.9022988505747126, 0.85632183908045978, 0.94252873563218387]
[20.849663019180298, 8.79947304725647, 7.358099937438965, 11.89007306098938]