In [1]:
import theano
import theano.tensor as th
import numpy as np
import math

In [2]:
import sys
sys.path.append('..')

In [3]:
import rnn.theano.model
import rnn.theano.lstm
import rnn.theano.crossent
import rnn.theano.solvers
import rnn.theano.softmax
from rnn.initializers import orthogonal
reload(rnn.theano.lstm)
reload(rnn.theano.model)
reload(rnn.theano.crossent)
reload(rnn.theano.solvers)
reload(rnn.theano.softmax)


Out[3]:
<module 'rnn.theano.softmax' from '../rnn/theano/softmax.pyc'>

Construct the model


In [4]:
n = 4

In [5]:
x = th.imatrix()
y = th.imatrix()

k,b = x.shape

lstm_units = 100
#s0 = th.matrix()
lstm = rnn.theano.lstm.LSTM(n, lstm_units)
y0 = th.zeros((b, lstm_units))
c0 = th.zeros((b, lstm_units))
lstm_out, _ = lstm.scanl(y0, c0, x)

wlin = np.random.randn(lstm_units+1, n).astype(theano.config.floatX)
wlin[0] = 0
orthogonal(wlin[1:])
wlin = theano.shared(wlin)
lin_out = th.dot(lstm_out.reshape((k*b,lstm_units)), wlin[1:]) + wlin[0]
lin_out = lin_out.reshape((k,b,n))

yh = rnn.theano.softmax.softmax(lin_out)
#print theano.printing.debugprint(yh)
err = th.sum(rnn.theano.crossent.crossent(yh, y))
acc = th.sum(th.eq(th.argmax(yh, axis=2), y))
count = y.size

solver = rnn.theano.solvers.RMSprop(0.01, decay=rnn.theano.solvers.GeomDecay(0.995))
model = rnn.theano.model.Model([lstm.weights, wlin], yh, x, y
                               , err, acc, count, solver=solver)

Generate data


In [6]:
class Minibatch(object):
    def __init__(self, data, size=256):
        self.data = data
        self.size = size
        
    def __len__(self):
        return int(math.ceil(len(self.data)/float(self.size)))
    
    def __iter__(self):
        n = len(self.data)
        p = np.random.permutation(n)
        for i in xrange(0, n, self.size):
            yield self.data[i:i+self.size]
            
class Loader(object):
    def __init__(self, data):
        self.data = data
        
    def __len__(self):
        if np.issubdtype(self.data.dtype, int):
            return self.data.shape[-1]
        return self.data.shape[-2]
    
    def __getitem__(self, i):
        return self.data[:-1,i], self.data[1:,i]

In [7]:
length, samples = 200, 100
batches = 20
D = np.random.randint(0, n, (length,samples)).astype(np.int32)
D = Minibatch(Loader(D), size=batches)

In [8]:
len(D)


Out[8]:
5

In [9]:
from rnn.progress import progress_bar

max_iters = 500
err_ = 0
acc_ = 0
n_ = 0
bar = progress_bar()
next(bar)
for iters, (err,acc,n) in model.fit(D):
    h = '{}/{}:'.format(int(iters), max_iters)
    p = iters % 1
    err_ += err
    acc_ += acc
    n_ += n
    bar.send((h, p, {'error':err_/n_, 'accurracy':float(acc_)/n_}))
    if iters >= max_iters:
        break
    if p == 0:
        err_, acc_, n_ = 0, 0, 0


0/500: accurracy=0.241959798995, error=1.38738941333 
    [########                                ]  20.00%, eta 00:00:51
1/500: accurracy=0.246633165829, error=1.39067024446 
2/500: accurracy=0.247135678392, error=1.3881992384 
3/500: accurracy=0.250854271357, error=1.38817646735 
4/500: accurracy=0.256633165829, error=1.38625875004 
5/500: accurracy=0.257989949749, error=1.38627093905 
6/500: accurracy=0.260150753769, error=1.38569260436 
7/500: accurracy=0.261557788945, error=1.38579801091 
8/500: accurracy=0.259748743719, error=1.38552693395 
9/500: accurracy=0.26216080402, error=1.38553130353 
10/500: accurracy=0.260904522613, error=1.3853778538 
11/500: accurracy=0.266783919598, error=1.38524255638 
12/500: accurracy=0.266030150754, error=1.38491690142 
13/500: accurracy=0.263969849246, error=1.38467380411 
14/500: accurracy=0.263969849246, error=1.38465702191 
15/500: accurracy=0.259396984925, error=1.3861265225 
16/500: accurracy=0.26351758794, error=1.38506250273 
17/500: accurracy=0.264371859296, error=1.38491095583 
18/500: accurracy=0.271608040201, error=1.38414124415 
19/500: accurracy=0.269698492462, error=1.38329674913 
20/500: accurracy=0.275628140704, error=1.38233919456 
21/500: accurracy=0.27527638191, error=1.38230733861 
22/500: accurracy=0.276834170854, error=1.38161591472 
23/500: accurracy=0.281055276382, error=1.38137250607 
24/500: accurracy=0.280201005025, error=1.38095624419 
25/500: accurracy=0.281507537688, error=1.38181258093 
26/500: accurracy=0.279899497487, error=1.38146324897 
27/500: accurracy=0.283115577889, error=1.38025573488 
28/500: accurracy=0.287688442211, error=1.37841607299 
29/500: accurracy=0.288894472362, error=1.37773296326 
30/500: accurracy=0.286934673367, error=1.37865704825 
31/500: accurracy=0.29256281407, error=1.3765199142 
32/500: accurracy=0.288040201005, error=1.37506439424 
33/500: accurracy=0.298743718593, error=1.37256153683 
34/500: accurracy=0.29783919598, error=1.37054170882 
35/500: accurracy=0.303819095477, error=1.3691495017 
36/500: accurracy=0.30783919598, error=1.36621397458 
37/500: accurracy=0.314371859296, error=1.36386268634 
38/500: accurracy=0.311809045226, error=1.36549436288 
39/500: accurracy=0.313417085427, error=1.36306328854 
40/500: accurracy=0.313718592965, error=1.36273996261 
41/500: accurracy=0.317487437186, error=1.35939745503 
42/500: accurracy=0.321859296482, error=1.35693995467 
43/500: accurracy=0.326884422111, error=1.35468197027 
44/500: accurracy=0.327135678392, error=1.35466289175 
45/500: accurracy=0.329095477387, error=1.35071170495 
46/500: accurracy=0.339949748744, error=1.3441932179 
47/500: accurracy=0.346432160804, error=1.33981559034 
48/500: accurracy=0.349396984925, error=1.33726443221 
49/500: accurracy=0.343316582915, error=1.34205192981 
50/500: accurracy=0.341959798995, error=1.34196725613 
51/500: accurracy=0.343115577889, error=1.33933881898 
52/500: accurracy=0.352261306533, error=1.33298521253 
53/500: accurracy=0.355728643216, error=1.32601389585 
54/500: accurracy=0.358140703518, error=1.32355764546 
55/500: accurracy=0.364773869347, error=1.32079895117 
56/500: accurracy=0.370201005025, error=1.31648010362 
57/500: accurracy=0.370150753769, error=1.31430347775 
58/500: accurracy=0.375075376884, error=1.30834918418 
59/500: accurracy=0.379296482412, error=1.30408548098 
60/500: accurracy=0.383819095477, error=1.30105005995 
61/500: accurracy=0.386532663317, error=1.29702062035 
62/500: accurracy=0.379798994975, error=1.30136951305 
63/500: accurracy=0.384020100503, error=1.2983791407 
64/500: accurracy=0.391608040201, error=1.29422274736 
65/500: accurracy=0.394070351759, error=1.28899245755 
66/500: accurracy=0.401859296482, error=1.28279274865 
67/500: accurracy=0.402010050251, error=1.27895638319 
68/500: accurracy=0.404170854271, error=1.28055965947 
69/500: accurracy=0.404522613065, error=1.27661091596 
70/500: accurracy=0.404974874372, error=1.27615874444 
71/500: accurracy=0.411306532663, error=1.26701089956 
72/500: accurracy=0.413015075377, error=1.26745401693 
73/500: accurracy=0.414874371859, error=1.26196908533 
74/500: accurracy=0.419748743719, error=1.25486083574 
75/500: accurracy=0.428090452261, error=1.25106964058 
76/500: accurracy=0.422964824121, error=1.24895261054 
77/500: accurracy=0.425979899497, error=1.2491127038 
78/500: accurracy=0.434221105528, error=1.23962340756 
79/500: accurracy=0.437587939698, error=1.23261949076 
80/500: accurracy=0.433216080402, error=1.23518392891 
81/500: accurracy=0.432010050251, error=1.23657396441 
82/500: accurracy=0.438442211055, error=1.23725483032 
83/500: accurracy=0.440100502513, error=1.23147775934 
84/500: accurracy=0.445728643216, error=1.22318470191 
85/500: accurracy=0.450904522613, error=1.21576941728 
86/500: accurracy=0.449748743719, error=1.21347395236 
87/500: accurracy=0.452964824121, error=1.21158980755 
88/500: accurracy=0.454271356784, error=1.2121083485 
89/500: accurracy=0.455025125628, error=1.21063305629 
90/500: accurracy=0.458743718593, error=1.2049251882 
91/500: accurracy=0.461105527638, error=1.19953460705 
92/500: accurracy=0.460603015075, error=1.19743272672 
93/500: accurracy=0.467537688442, error=1.1929887165 
94/500: accurracy=0.470100502513, error=1.19056479805 
95/500: accurracy=0.471608040201, error=1.18417807236 
96/500: accurracy=0.470954773869, error=1.18180110763 
97/500: accurracy=0.481055276382, error=1.17415143446 
98/500: accurracy=0.480653266332, error=1.16846568206 
99/500: accurracy=0.484271356784, error=1.16817771436 
100/500: accurracy=0.480904522613, error=1.1737953748 
101/500: accurracy=0.475075376884, error=1.17638042851 
102/500: accurracy=0.469346733668, error=1.18372359727 
103/500: accurracy=0.478291457286, error=1.18011133275 
104/500: accurracy=0.482261306533, error=1.16471270732 
105/500: accurracy=0.489095477387, error=1.15547578989 
106/500: accurracy=0.490050251256, error=1.15163757978 
107/500: accurracy=0.49527638191, error=1.14811009793 
108/500: accurracy=0.495477386935, error=1.14387553983 
109/500: accurracy=0.503869346734, error=1.13240715064 
110/500: accurracy=0.504221105528, error=1.13273878997 
111/500: accurracy=0.504371859296, error=1.13350734298 
112/500: accurracy=0.50432160804, error=1.13286952681 
113/500: accurracy=0.498793969849, error=1.13460266437 
114/500: accurracy=0.497587939698, error=1.13987789066 
115/500: accurracy=0.497035175879, error=1.14082686309 
116/500: accurracy=0.501155778894, error=1.13439405939 
117/500: accurracy=0.504271356784, error=1.12646039232 
118/500: accurracy=0.517638190955, error=1.11032715042 
119/500: accurracy=0.521005025126, error=1.10343986833 
120/500: accurracy=0.520502512563, error=1.10241393272 
121/500: accurracy=0.525778894472, error=1.09981969187 
122/500: accurracy=0.521055276382, error=1.10043961521 
123/500: accurracy=0.522713567839, error=1.09460155864 
124/500: accurracy=0.529296482412, error=1.0877048403 
125/500: accurracy=0.528944723618, error=1.0892938038 
126/500: accurracy=0.52783919598, error=1.0890202376 
127/500: accurracy=0.533165829146, error=1.0819935135 
128/500: accurracy=0.534221105528, error=1.07733578626 
129/500: accurracy=0.536231155779, error=1.07222459293 
130/500: accurracy=0.538944723618, error=1.0728979873 
131/500: accurracy=0.535577889447, error=1.08040968119 
132/500: accurracy=0.532613065327, error=1.0818984344 
133/500: accurracy=0.536934673367, error=1.07165434092 
134/500: accurracy=0.544170854271, error=1.06418691584 
135/500: accurracy=0.548190954774, error=1.05760321376 
136/500: accurracy=0.544371859296, error=1.06098653959 
137/500: accurracy=0.551507537688, error=1.0479989215 
138/500: accurracy=0.561457286432, error=1.03825571031 
139/500: accurracy=0.558291457286, error=1.0347525222 
140/500: accurracy=0.562613065327, error=1.03189024249 
141/500: accurracy=0.559698492462, error=1.03501679797 
142/500: accurracy=0.562964824121, error=1.0339459345 
143/500: accurracy=0.557135678392, error=1.04361943066 
144/500: accurracy=0.553115577889, error=1.04600226223 
145/500: accurracy=0.553115577889, error=1.04752833892 
146/500: accurracy=0.559095477387, error=1.03613083177 
147/500: accurracy=0.562010050251, error=1.02482883254 
148/500: accurracy=0.569095477387, error=1.01965593163 
149/500: accurracy=0.577889447236, error=1.00496756151 
150/500: accurracy=0.580452261307, error=0.999589810295 
151/500: accurracy=0.583065326633, error=0.996200877612 
152/500: accurracy=0.58216080402, error=0.997872415903 
153/500: accurracy=0.576231155779, error=1.00630325219 
154/500: accurracy=0.571959798995, error=1.01234418138 
155/500: accurracy=0.578090452261, error=1.00830708693 
156/500: accurracy=0.576834170854, error=1.00358298858 
157/500: accurracy=0.575829145729, error=1.00247188831 
158/500: accurracy=0.582010050251, error=0.996603369024 
159/500: accurracy=0.580201005025, error=0.992123782031 
160/500: accurracy=0.586683417085, error=0.984015608061 
161/500: accurracy=0.594472361809, error=0.976553899626 
162/500: accurracy=0.597135678392, error=0.972659677249 
163/500: accurracy=0.598040201005, error=0.967932214598 
164/500: accurracy=0.600251256281, error=0.965076395675 
165/500: accurracy=0.597939698492, error=0.966474761174 
166/500: accurracy=0.589899497487, error=0.97808037449 
167/500: accurracy=0.580452261307, error=0.99295250539 
168/500: accurracy=0.579346733668, error=0.995466496524 
169/500: accurracy=0.59040201005, error=0.980609079904 
170/500: accurracy=0.595778894472, error=0.969692415005 
171/500: accurracy=0.598291457286, error=0.962558029798 
172/500: accurracy=0.598542713568, error=0.958808763869 
173/500: accurracy=0.607286432161, error=0.951580897009 
174/500: accurracy=0.609246231156, error=0.94291820526 
175/500: accurracy=0.61432160804, error=0.937085827097 
176/500: accurracy=0.616582914573, error=0.932836807402 
177/500: accurracy=0.61959798995, error=0.932437522611 
178/500: accurracy=0.612211055276, error=0.937331578881 
179/500: accurracy=0.614020100503, error=0.936091980906 
180/500: accurracy=0.613718592965, error=0.934074230747 
181/500: accurracy=0.620502512563, error=0.927021941225 
182/500: accurracy=0.614371859296, error=0.929822699171 
183/500: accurracy=0.61391959799, error=0.932191244551 
184/500: accurracy=0.609296482412, error=0.934525715898 
185/500: accurracy=0.612261306533, error=0.928983699279 
186/500: accurracy=0.616783919598, error=0.92342378478 
187/500: accurracy=0.62432160804, error=0.913378646529 
188/500: accurracy=0.627035175879, error=0.910636608874 
189/500: accurracy=0.620703517588, error=0.912929084153 
190/500: accurracy=0.625527638191, error=0.909371599909 
191/500: accurracy=0.622311557789, error=0.915720640222 
192/500: accurracy=0.613768844221, error=0.920714563159 
193/500: accurracy=0.621507537688, error=0.913038160416 
194/500: accurracy=0.628492462312, error=0.901704778946 
195/500: accurracy=0.631708542714, error=0.894165263761 
196/500: accurracy=0.637336683417, error=0.88899899567 
197/500: accurracy=0.63824120603, error=0.887832007116 
198/500: accurracy=0.638291457286, error=0.884897917151 
199/500: accurracy=0.640954773869, error=0.882774059314 
200/500: accurracy=0.636884422111, error=0.889045918697 
201/500: accurracy=0.628140703518, error=0.899252409332 
202/500: accurracy=0.634070351759, error=0.895377135562 
203/500: accurracy=0.631557788945, error=0.892494909897 
204/500: accurracy=0.637487437186, error=0.887005894549 
205/500: accurracy=0.641608040201, error=0.876223060586 
206/500: accurracy=0.64527638191, error=0.870899068695 
207/500: accurracy=0.651507537688, error=0.865610852206 
208/500: accurracy=0.650753768844, error=0.86193901415 
209/500: accurracy=0.651708542714, error=0.860009622176 
210/500: accurracy=0.651055276382, error=0.860038252923 
211/500: accurracy=0.648743718593, error=0.85785523441 
212/500: accurracy=0.652211055276, error=0.85829200867 
213/500: accurracy=0.651155778894, error=0.861535810865 
214/500: accurracy=0.646331658291, error=0.871103874501 
215/500: accurracy=0.642010050251, error=0.873025150151 
216/500: accurracy=0.644070351759, error=0.870184104277 
217/500: accurracy=0.646180904523, error=0.868861428915 
218/500: accurracy=0.649748743719, error=0.858546341723 
219/500: accurracy=0.652914572864, error=0.851774016519 
220/500: accurracy=0.661306532663, error=0.840868524713 
221/500: accurracy=0.665226130653, error=0.834376750414 
222/500: accurracy=0.668743718593, error=0.827345145786 
223/500: accurracy=0.670804020101, error=0.820766316968 
224/500: accurracy=0.673567839196, error=0.818624350298 
225/500: accurracy=0.672864321608, error=0.817096939851 
226/500: accurracy=0.673165829146, error=0.820177946803 
227/500: accurracy=0.655628140704, error=0.843458049234 
228/500: accurracy=0.651206030151, error=0.858115941203 
229/500: accurracy=0.652512562814, error=0.855283172011 
230/500: accurracy=0.655427135678, error=0.839673977306 
231/500: accurracy=0.664673366834, error=0.830902311257 
232/500: accurracy=0.671105527638, error=0.818823224006 
233/500: accurracy=0.677587939698, error=0.807520307919 
234/500: accurracy=0.682211055276, error=0.800387663701 
235/500: accurracy=0.684221105528, error=0.798253731729 
236/500: accurracy=0.671809045226, error=0.81762808918 
237/500: accurracy=0.666432160804, error=0.827025729553 
238/500: accurracy=0.668040201005, error=0.823191390019 
239/500: accurracy=0.675879396985, error=0.81035602957 
240/500: accurracy=0.679648241206, error=0.802581345606 
241/500: accurracy=0.683417085427, error=0.802618761668 
242/500: accurracy=0.682663316583, error=0.797536929973 
243/500: accurracy=0.686281407035, error=0.794934038911 
244/500: accurracy=0.688391959799, error=0.788730354956 
245/500: accurracy=0.687989949749, error=0.787352507726 
246/500: accurracy=0.687788944724, error=0.786125218958 
247/500: accurracy=0.689396984925, error=0.78332557567 
248/500: accurracy=0.684673366834, error=0.789899235828 
249/500: accurracy=0.684170854271, error=0.793337834021 
250/500: accurracy=0.684371859296, error=0.791101048393 
251/500: accurracy=0.68783919598, error=0.784889808235 
252/500: accurracy=0.689296482412, error=0.782539225209 
253/500: accurracy=0.690050251256, error=0.777814702107 
254/500: accurracy=0.695829145729, error=0.77154892934 
255/500: accurracy=0.696030150754, error=0.766161532399 
256/500: accurracy=0.696582914573, error=0.767882680537 
257/500: accurracy=0.696984924623, error=0.770565450974 
258/500: accurracy=0.692311557789, error=0.777087945293 
259/500: accurracy=0.686984924623, error=0.781339598679 
260/500: accurracy=0.693618090452, error=0.772576292504 
261/500: accurracy=0.696984924623, error=0.769268190778 
262/500: accurracy=0.699447236181, error=0.766899050289 
263/500: accurracy=0.700301507538, error=0.762429084918 
264/500: accurracy=0.698944723618, error=0.758861379622 
265/500: accurracy=0.703065326633, error=0.755769963803 
266/500: accurracy=0.702412060302, error=0.75107008282 
267/500: accurracy=0.708492462312, error=0.746954657325 
268/500: accurracy=0.705829145729, error=0.747284427243 
269/500: accurracy=0.71040201005, error=0.740864567249 
270/500: accurracy=0.712713567839, error=0.737316842808 
271/500: accurracy=0.712713567839, error=0.742242695991 
272/500: accurracy=0.703969849246, error=0.751786878163 
273/500: accurracy=0.699145728643, error=0.761776944389 
274/500: accurracy=0.699698492462, error=0.759726344131 
275/500: accurracy=0.701005025126, error=0.756188423615 
276/500: accurracy=0.702412060302, error=0.75558296101 
277/500: accurracy=0.70959798995, error=0.74464379928 
278/500: accurracy=0.715728643216, error=0.736680729984 
279/500: accurracy=0.713969849246, error=0.730287868372 
280/500: accurracy=0.721105527638, error=0.72290675987 
281/500: accurracy=0.719296482412, error=0.72163101112 
282/500: accurracy=0.720150753769, error=0.722572604899 
283/500: accurracy=0.722211055276, error=0.719571825469 
284/500: accurracy=0.717537688442, error=0.721515748999 
285/500: accurracy=0.721306532663, error=0.723431853722 
286/500: accurracy=0.718140703518, error=0.72627970029 
287/500: accurracy=0.718040201005, error=0.725859393667 
288/500: accurracy=0.717537688442, error=0.721170506963 
289/500: accurracy=0.717336683417, error=0.723006636188 
290/500: accurracy=0.717939698492, error=0.726284458348 
291/500: accurracy=0.716030150754, error=0.725108315979 
292/500: accurracy=0.71743718593, error=0.71736972709 
293/500: accurracy=0.726834170854, error=0.711089331995 
294/500: accurracy=0.72472361809, error=0.707744595489 
295/500: accurracy=0.732060301508, error=0.704632429306 
296/500: accurracy=0.726281407035, error=0.702977285028 
297/500: accurracy=0.734824120603, error=0.697511194219 
298/500: accurracy=0.732110552764, error=0.696331760884 
299/500: accurracy=0.734422110553, error=0.697136936858 
300/500: accurracy=0.730251256281, error=0.699121766306 
301/500: accurracy=0.727135678392, error=0.703590233116 
302/500: accurracy=0.725326633166, error=0.707648327633 
303/500: accurracy=0.727487437186, error=0.705396981806 
304/500: accurracy=0.723567839196, error=0.705295446478 
305/500: accurracy=0.729396984925, error=0.700906260244 
306/500: accurracy=0.726281407035, error=0.704178682973 
307/500: accurracy=0.730954773869, error=0.69484403258 
308/500: accurracy=0.732864321608, error=0.693321550506 
309/500: accurracy=0.736030150754, error=0.688091951041 
310/500: accurracy=0.740301507538, error=0.68410689806 
311/500: accurracy=0.739145728643, error=0.683550664621 
312/500: accurracy=0.745376884422, error=0.676357499141 
313/500: accurracy=0.746683417085, error=0.672362718494 
314/500: accurracy=0.746834170854, error=0.671588622893 
315/500: accurracy=0.748994974874, error=0.6681163504 
316/500: accurracy=0.746683417085, error=0.669185171513 
317/500: accurracy=0.745075376884, error=0.669146221038 
318/500: accurracy=0.742914572864, error=0.672371842257 
319/500: accurracy=0.742613065327, error=0.679348966624 
320/500: accurracy=0.73216080402, error=0.690158760621 
321/500: accurracy=0.729447236181, error=0.696990991775 
322/500: accurracy=0.731206030151, error=0.695115214473 
323/500: accurracy=0.737788944724, error=0.684092437626 
324/500: accurracy=0.744120603015, error=0.671156681335 
325/500: accurracy=0.747638190955, error=0.665880317098 
326/500: accurracy=0.753467336683, error=0.660022335653 
327/500: accurracy=0.751306532663, error=0.660167657936 
328/500: accurracy=0.752311557789, error=0.658725956648 
329/500: accurracy=0.755326633166, error=0.657373372348 
330/500: accurracy=0.753216080402, error=0.657939337199 
331/500: accurracy=0.748894472362, error=0.661621234664 
332/500: accurracy=0.746381909548, error=0.664399870109 
333/500: accurracy=0.749346733668, error=0.661993339318 
334/500: accurracy=0.747185929648, error=0.666315759251 
335/500: accurracy=0.744120603015, error=0.668221819931 
336/500: accurracy=0.746683417085, error=0.665845468423 
337/500: accurracy=0.747085427136, error=0.663341512842 
338/500: accurracy=0.748793969849, error=0.658110249462 
339/500: accurracy=0.749497487437, error=0.657546559598 
340/500: accurracy=0.754221105528, error=0.651641160011 
341/500: accurracy=0.758291457286, error=0.644722287765 
342/500: accurracy=0.762814070352, error=0.638707233802 
343/500: accurracy=0.76391959799, error=0.636480015504 
344/500: accurracy=0.765577889447, error=0.633739495382 
345/500: accurracy=0.77040201005, error=0.631661974346 
346/500: accurracy=0.767487437186, error=0.635399443586 
347/500: accurracy=0.755879396985, error=0.644942787297 
348/500: accurracy=0.756783919598, error=0.64729135708 
349/500: accurracy=0.75608040201, error=0.643222367194 
350/500: accurracy=0.758190954774, error=0.641828879621 
351/500: accurracy=0.757587939698, error=0.643326660422 
352/500: accurracy=0.758693467337, error=0.638994720386 
353/500: accurracy=0.759296482412, error=0.637763083045 
354/500: accurracy=0.763417085427, error=0.636196010313 
355/500: accurracy=0.759447236181, error=0.634483770214 
356/500: accurracy=0.767487437186, error=0.628290752101 
357/500: accurracy=0.769849246231, error=0.627147994197 
358/500: accurracy=0.768442211055, error=0.626776010923 
359/500: accurracy=0.770201005025, error=0.625008051151 
360/500: accurracy=0.770150753769, error=0.624893862214 
361/500: accurracy=0.767989949749, error=0.620628232592 
362/500: accurracy=0.770954773869, error=0.618962335524 
363/500: accurracy=0.770502512563, error=0.620388539471 
364/500: accurracy=0.772462311558, error=0.618688651083 
365/500: accurracy=0.767386934673, error=0.622710907675 
366/500: accurracy=0.765376884422, error=0.626322969094 
367/500: accurracy=0.766733668342, error=0.624990929783 
368/500: accurracy=0.77040201005, error=0.622735971228 
369/500: accurracy=0.760904522613, error=0.628407126769 
370/500: accurracy=0.765477386935, error=0.630197679158 
371/500: accurracy=0.765879396985, error=0.623591523035 
372/500: accurracy=0.77648241206, error=0.613367608349 
373/500: accurracy=0.778844221106, error=0.608096936273 
374/500: accurracy=0.783216080402, error=0.603918808287 
375/500: accurracy=0.787236180905, error=0.601379694581 
376/500: accurracy=0.784522613065, error=0.599719823105 
377/500: accurracy=0.780653266332, error=0.602017334065 
378/500: accurracy=0.779798994975, error=0.60759754317 
379/500: accurracy=0.772110552764, error=0.612309992218 
380/500: accurracy=0.77391959799, error=0.612884324933 
381/500: accurracy=0.775879396985, error=0.61167273907 
382/500: accurracy=0.775477386935, error=0.606809109117 
383/500: accurracy=0.780703517588, error=0.600837305112 
384/500: accurracy=0.785025125628, error=0.595457442035 
385/500: accurracy=0.786130653266, error=0.590845001444 
386/500: accurracy=0.791055276382, error=0.589262852024 
387/500: accurracy=0.792512562814, error=0.586340984907 
388/500: accurracy=0.792462311558, error=0.585845270201 
389/500: accurracy=0.790653266332, error=0.587218882537 
390/500: accurracy=0.787638190955, error=0.58995371421 
391/500: accurracy=0.787386934673, error=0.593029807981 
392/500: accurracy=0.782110552764, error=0.598837042798 
393/500: accurracy=0.780150753769, error=0.597355925615 
394/500: accurracy=0.782060301508, error=0.59625741706 
395/500: accurracy=0.784522613065, error=0.595293915392 
396/500: accurracy=0.786582914573, error=0.588245193445 
397/500: accurracy=0.788190954774, error=0.586648161766 
398/500: accurracy=0.788341708543, error=0.586062291711 
399/500: accurracy=0.788693467337, error=0.585363467809 
400/500: accurracy=0.78743718593, error=0.585420903821 
401/500: accurracy=0.789095477387, error=0.583707996601 
402/500: accurracy=0.787185929648, error=0.583332188984 
403/500: accurracy=0.790502512563, error=0.584505752782 
404/500: accurracy=0.785527638191, error=0.58492496525 
405/500: accurracy=0.793115577889, error=0.581863297027 
406/500: accurracy=0.795577889447, error=0.576359083314 
407/500: accurracy=0.795929648241, error=0.573572065931 
408/500: accurracy=0.794974874372, error=0.574593220351 
409/500: accurracy=0.79648241206, error=0.575726168882 
410/500: accurracy=0.795025125628, error=0.574412542511 
411/500: accurracy=0.795226130653, error=0.571906793074 
412/500: accurracy=0.795376884422, error=0.572914148229 
413/500: accurracy=0.794070351759, error=0.573926593673 
414/500: accurracy=0.796984924623, error=0.570717681356 
415/500: accurracy=0.794824120603, error=0.568380627131 
416/500: accurracy=0.795376884422, error=0.568767633835 
417/500: accurracy=0.798592964824, error=0.569927212661 
418/500: accurracy=0.798090452261, error=0.56940793106 
419/500: accurracy=0.798894472362, error=0.568834797594 
420/500: accurracy=0.797135678392, error=0.56692466918 
421/500: accurracy=0.795527638191, error=0.565581788506 
422/500: accurracy=0.79824120603, error=0.564911563362 
423/500: accurracy=0.797989949749, error=0.563360935959 
424/500: accurracy=0.798291457286, error=0.563345081832 
425/500: accurracy=0.80391959799, error=0.562045839721 
426/500: accurracy=0.803969849246, error=0.557971554434 
427/500: accurracy=0.806130653266, error=0.55594087521 
428/500: accurracy=0.802110552764, error=0.558913134703 
429/500: accurracy=0.798793969849, error=0.562137665331 
430/500: accurracy=0.802211055276, error=0.560041330646 
431/500: accurracy=0.799648241206, error=0.560578666456 
432/500: accurracy=0.801306532663, error=0.559071781426 
433/500: accurracy=0.806030150754, error=0.555365109589 
434/500: accurracy=0.806130653266, error=0.552167464588 
435/500: accurracy=0.809648241206, error=0.547681315265 
436/500: accurracy=0.813165829146, error=0.544951534582 
437/500: accurracy=0.810050251256, error=0.546048836983 
438/500: accurracy=0.811105527638, error=0.545667062238 
439/500: accurracy=0.809497487437, error=0.546498880125 
440/500: accurracy=0.806582914573, error=0.549525072573 
441/500: accurracy=0.807336683417, error=0.550835757883 
442/500: accurracy=0.801457286432, error=0.556212088205 
443/500: accurracy=0.797085427136, error=0.560686892825 
444/500: accurracy=0.804120603015, error=0.555043447847 
445/500: accurracy=0.805577889447, error=0.550209623818 
446/500: accurracy=0.811055276382, error=0.544360823854 
447/500: accurracy=0.811809045226, error=0.540220209458 
448/500: accurracy=0.814773869347, error=0.539221598899 
449/500: accurracy=0.813618090452, error=0.539115897142 
450/500: accurracy=0.812814070352, error=0.539952259972 
450/500: accurracy=0.809547738693, error=0.545520301173 
    [################################        ]  80.00%, eta 00:00:00
451/500: accurracy=0.808994974874, error=0.543827718333 
451/500: accurracy=0.806595477387, error=0.54739207881 
    [################################        ]  80.00%, eta 00:00:00
452/500: accurracy=0.808291457286, error=0.544266104497 
452/500: accurracy=0.809924623116, error=0.546143317696 
    [################################        ]  80.00%, eta 00:00:00
453/500: accurracy=0.811557788945, error=0.542986190806 
454/500: accurracy=0.808693467337, error=0.543859105322 
455/500: accurracy=0.815879396985, error=0.536808570904 
456/500: accurracy=0.81648241206, error=0.533208131634 
457/500: accurracy=0.815125628141, error=0.5332328414 
458/500: accurracy=0.816834170854, error=0.531583214307 
459/500: accurracy=0.818944723618, error=0.530656972526 
460/500: accurracy=0.819195979899, error=0.531753062388 
461/500: accurracy=0.814422110553, error=0.534378886768 
462/500: accurracy=0.814824120603, error=0.53608812982 
463/500: accurracy=0.807688442211, error=0.538701912396 
464/500: accurracy=0.814522613065, error=0.538171805784 
465/500: accurracy=0.809798994975, error=0.538013010318 
466/500: accurracy=0.813467336683, error=0.537520458494 
467/500: accurracy=0.812864321608, error=0.536678347739 
468/500: accurracy=0.815929648241, error=0.533053937096 
469/500: accurracy=0.818894472362, error=0.529726231936 
470/500: accurracy=0.821557788945, error=0.525890785622 
471/500: accurracy=0.819145728643, error=0.523914946019 
472/500: accurracy=0.82, error=0.525421393415 
473/500: accurracy=0.821105527638, error=0.524985028876 
474/500: accurracy=0.82391959799, error=0.522182961705 
475/500: accurracy=0.824020100503, error=0.52006658912 
476/500: accurracy=0.823969849246, error=0.519238944702 
477/500: accurracy=0.823969849246, error=0.519543770026 
478/500: accurracy=0.824773869347, error=0.521020347865 
479/500: accurracy=0.816633165829, error=0.527997791315 
480/500: accurracy=0.816331658291, error=0.52584392501 
481/500: accurracy=0.814974874372, error=0.529753855074 
482/500: accurracy=0.816532663317, error=0.52705292146 
483/500: accurracy=0.819045226131, error=0.524838568956 
484/500: accurracy=0.818040201005, error=0.525621707167 
485/500: accurracy=0.819949748744, error=0.521336998959 
486/500: accurracy=0.827989949749, error=0.513980115964 
487/500: accurracy=0.829547738693, error=0.510540838072 
488/500: accurracy=0.830351758794, error=0.510462088017 
489/500: accurracy=0.828442211055, error=0.51096536258 
490/500: accurracy=0.828542713568, error=0.512088793507 
491/500: accurracy=0.828542713568, error=0.512627163575 
492/500: accurracy=0.826633165829, error=0.513171017145 
493/500: accurracy=0.82864321608, error=0.512170132534 
494/500: accurracy=0.826381909548, error=0.512776867457 
495/500: accurracy=0.828592964824, error=0.510052850827 
496/500: accurracy=0.828090452261, error=0.509873539773 
497/500: accurracy=0.830703517588, error=0.509270274762 
498/500: accurracy=0.828190954774, error=0.508573435046 
499/500: accurracy=0.831658291457, error=0.507237002467 
500/500: accurracy=0.829045226131, error=0.509276403663 

In [ ]:


In [ ]:


In [ ]: