In [1]:
# As usual, a bit of setup

import time, os, json
import numpy as np
import matplotlib.pyplot as plt
import pickle

from gradient_check import eval_numerical_gradient, eval_numerical_gradient_array
from layers import *
from rnn1 import *
from solver import *

%matplotlib inline
plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'

%load_ext autoreload
%autoreload 2

def rel_error(x, y):
  """ returns relative error """
  return np.max(np.abs(x - y) / (np.maximum(1e-8, np.abs(x) + np.abs(y))))

In [35]:
raw = pickle.load(open("stories.pck", "rb"))
print len(raw), "examples"
print max([len(x) for x in raw]), "supporting facts and questions at most"
print max([len(y.split(' ')) for y in x for x in raw]), "words per sentence at most"

_null, _start, _query, _end = "<NULL>", "<Start>", "<Query>", "<End>"

words = [_null, _start, _query, _end] + [q for ex in raw for sent in ex for w in sent.split(' ') for q in w.split(',')]
words = sorted(set(words))
word_to_idx = {w:i for i,w in enumerate(words)}

print len(words), "total words"

T = 31 # longest story, 35 words unfolding
T2= 5

data = []
for ex in raw:
    sLen = 0
    while ex[sLen].find(',')==-1:
        sLen+=1
    supports = word_to_idx[_null] * np.ones(T, dtype=int)
    queries = word_to_idx[_null] * np.ones((len(ex)-sLen, T2), dtype=int)
    
    pos=0
    for idx, sent in enumerate(ex):
        if idx<sLen:
            sent = [word_to_idx[_start]] + [word_to_idx[x] for x in sent.split(' ')]
            supports[pos:pos+len(sent)+1] = sent + [word_to_idx[_end]]
            pos += len(sent)
        else:
            sent = sent.split(',')[0]
            sent = [word_to_idx[_query]] + [word_to_idx[x] for x in sent.split(' ')]
            sent = sent + [word_to_idx[_null]]*(T2-len(sent)-1) + [word_to_idx[_end]]#null pad the sentence
            queries[idx-sLen, :] = sent
    
    answers = np.asarray([word_to_idx[x.split(',')[1]] for x in ex[sLen:]]).reshape(len(ex)-sLen, 1)
    for i in xrange(queries.shape[0]):
        data.append(np.hstack((supports, queries[i,:], answers[i,:])))

data = np.asarray(data)
data_train = data[:-1000,:]
data_test = data[-1000:,:]
print data_train.shape


1000 examples
14 supporting facts and questions at most
5 words per sentence at most
159 total words
(3275, 37)

In [15]:
model = SeqNN(word_to_idx, cell_type='rnn', hidden_dim=256, wordvec_dim=512)
solver = SeqNNSolver(model, data_train[:50],
           update_rule='adam',
           num_epochs=75,
           batch_size=25,
           optim_config={
             'learning_rate': 1e-3,
           },
           lr_decay=.995,
           verbose=True, print_every=10,
         )
solver.train()

# Plot the training losses
plt.plot(solver.loss_history)
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.title('Training loss history')
plt.show()


(Iteration 1 / 150) loss: 4.482038
(Iteration 11 / 150) loss: 2.433492
(Iteration 21 / 150) loss: 2.325440
(Iteration 31 / 150) loss: 1.823419
(Iteration 41 / 150) loss: 1.179898
(Iteration 51 / 150) loss: 0.688655
(Iteration 61 / 150) loss: 0.390618
(Iteration 71 / 150) loss: 0.201693
(Iteration 81 / 150) loss: 0.151589
(Iteration 91 / 150) loss: 0.098668
(Iteration 101 / 150) loss: 0.151300
(Iteration 111 / 150) loss: 0.069594
(Iteration 121 / 150) loss: 0.053710
(Iteration 131 / 150) loss: 0.037960
(Iteration 141 / 150) loss: 0.027944

In [17]:
minibatch = data_train[:50]
print "Train:",
answ=model.loss(minibatch, sample=True)
print (answ==minibatch[:,-1]).mean()

minibatch = data_test[:50]
print "Test:",
answ=model.loss(minibatch, sample=True)
print (answ==minibatch[:,-1]).mean()

print "\n".join(" ".join([words[x] for x in data_train[0,:-T2-1]]).split(_start))
i=0
while np.all(data_train[0,:-T2-1]==data_train[i,:-T2-1]):
  print " ".join([words[x] for x in data_train[i,-T2-1:]])
  i=i+1

print
print "Other answers:"
print "\n".join([words[x]+" "+words[y] for x,y in zip(model.loss(minibatch, sample=True), minibatch[:,-1])])


Train: 1.0
Test: 0.02

 arif tershane 'a gitti . 
 aygul bahce 'a gitti . 
 aylin sigara 'yi tershane 'a tasidi . 
 abbas bolum 'a gitti . <End> <NULL> <NULL> <NULL> <NULL> <NULL> <NULL> <NULL> <NULL> <NULL> <NULL> <NULL> <NULL> <NULL> <NULL>
<Query> aylin nerede ? <NULL> <NULL> <End> tershane
<Query> sigara nerede ? <NULL> <NULL> <End> tershane
<Query> aygul nerede ? <NULL> <NULL> <End> bahce
<Query> arif nerede ? <NULL> <NULL> <End> tershane
<Query> abbas nerede ? <NULL> <NULL> <End> bolum

Other answers:
okul amfi
servis mutfak
sira otobus
amfi amfi
sinif amfi
bahce kantin
tershane amfi
servis araba
tuvalet park
bahce okul
mutfak hastane
mutfak sinif
tuvalet sinif
tamirhane tershane
tuvalet sehpa
amfi tershane
mutfak sehpa
ev sehpa
ev bolum
servis dersane
tuvalet tamirhane
tuvalet dersane
tuvalet sehpa
tuvalet sandalye
sira sehpa
tuvalet sandalye
tuvalet kamyon
tuvalet hastane
hastane okul
tershane kamyon
bahce masa
hastane okul
tuvalet dersane
kamyon dersane
tuvalet sira
tershane masa
mutfak masa
otobus labaratuvar
bahce bolum
tershane bolum
bahce oda
bahce duvar
bahce bolum
bahce oda
tershane duvar
bahce amfi
bahce amfi
sinif amfi
tershane ev
bolum dersane

In [38]:
model = SeqNN(word_to_idx, cell_type='rnn', hidden_dim=512, wordvec_dim=256)
solver = SeqNNSolver(model, data_train,
           update_rule='adam',
           num_epochs=100,
           batch_size=50,
           optim_config={
             'learning_rate': 1e-3,
           },
           lr_decay=.995,
           verbose=True, print_every=10,
         )
solver.train()

# Plot the training losses
plt.plot(solver.loss_history)
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.title('Training loss history')
plt.show()


(Iteration 1 / 6500) loss: 5.081163
(Iteration 11 / 6500) loss: 3.408903
(Iteration 21 / 6500) loss: 3.553241
(Iteration 31 / 6500) loss: 3.691982
(Iteration 41 / 6500) loss: 3.472588
(Iteration 51 / 6500) loss: 3.408628
(Iteration 61 / 6500) loss: 3.459859
(Iteration 71 / 6500) loss: 3.322526
(Iteration 81 / 6500) loss: 3.423279
(Iteration 91 / 6500) loss: 3.698760
(Iteration 101 / 6500) loss: 3.478104
(Iteration 111 / 6500) loss: 3.472032
(Iteration 121 / 6500) loss: 3.394311
(Iteration 131 / 6500) loss: 3.434845
(Iteration 141 / 6500) loss: 3.482549
(Iteration 151 / 6500) loss: 3.291287
(Iteration 161 / 6500) loss: 3.233374
(Iteration 171 / 6500) loss: 3.203009
(Iteration 181 / 6500) loss: 3.214114
(Iteration 191 / 6500) loss: 3.434284
(Iteration 201 / 6500) loss: 3.508972
(Iteration 211 / 6500) loss: 3.209338
(Iteration 221 / 6500) loss: 3.244726
(Iteration 231 / 6500) loss: 3.227492
(Iteration 241 / 6500) loss: 3.182003
(Iteration 251 / 6500) loss: 3.369503
(Iteration 261 / 6500) loss: 3.506626
(Iteration 271 / 6500) loss: 3.431571
(Iteration 281 / 6500) loss: 3.446562
(Iteration 291 / 6500) loss: 3.337542
(Iteration 301 / 6500) loss: 3.290618
(Iteration 311 / 6500) loss: 3.149266
(Iteration 321 / 6500) loss: 3.281240
(Iteration 331 / 6500) loss: 3.463018
(Iteration 341 / 6500) loss: 3.117781
(Iteration 351 / 6500) loss: 3.183804
(Iteration 361 / 6500) loss: 3.283659
(Iteration 371 / 6500) loss: 3.235309
(Iteration 381 / 6500) loss: 3.470165
(Iteration 391 / 6500) loss: 3.317800
(Iteration 401 / 6500) loss: 3.189820
(Iteration 411 / 6500) loss: 3.370758
(Iteration 421 / 6500) loss: 3.225166
(Iteration 431 / 6500) loss: 3.257957
(Iteration 441 / 6500) loss: 3.196287
(Iteration 451 / 6500) loss: 3.267181
(Iteration 461 / 6500) loss: 3.336601
(Iteration 471 / 6500) loss: 3.095485
(Iteration 481 / 6500) loss: 3.070567
(Iteration 491 / 6500) loss: 3.214824
(Iteration 501 / 6500) loss: 3.453414
(Iteration 511 / 6500) loss: 3.225292
(Iteration 521 / 6500) loss: 3.365623
(Iteration 531 / 6500) loss: 3.489733
(Iteration 541 / 6500) loss: 3.298666
(Iteration 551 / 6500) loss: 3.395562
(Iteration 561 / 6500) loss: 3.144460
(Iteration 571 / 6500) loss: 3.245115
(Iteration 581 / 6500) loss: 3.112459
(Iteration 591 / 6500) loss: 3.527957
(Iteration 601 / 6500) loss: 3.316309
(Iteration 611 / 6500) loss: 3.179916
(Iteration 621 / 6500) loss: 3.377768
(Iteration 631 / 6500) loss: 3.300570
(Iteration 641 / 6500) loss: 3.362940
(Iteration 651 / 6500) loss: 3.300925
(Iteration 661 / 6500) loss: 3.191254
(Iteration 671 / 6500) loss: 3.232590
(Iteration 681 / 6500) loss: 3.170605
(Iteration 691 / 6500) loss: 3.368349
(Iteration 701 / 6500) loss: 3.095831
(Iteration 711 / 6500) loss: 3.229149
(Iteration 721 / 6500) loss: 3.227655
(Iteration 731 / 6500) loss: 3.125763
(Iteration 741 / 6500) loss: 3.047175
(Iteration 751 / 6500) loss: 3.278969
(Iteration 761 / 6500) loss: 3.206191
(Iteration 771 / 6500) loss: 3.374711
(Iteration 781 / 6500) loss: 3.284283
(Iteration 791 / 6500) loss: 3.187273
(Iteration 801 / 6500) loss: 3.158733
(Iteration 811 / 6500) loss: 3.247691
(Iteration 821 / 6500) loss: 3.060297
(Iteration 831 / 6500) loss: 3.237614
(Iteration 841 / 6500) loss: 3.207551
(Iteration 851 / 6500) loss: 3.001288
(Iteration 861 / 6500) loss: 3.294262
(Iteration 871 / 6500) loss: 3.206636
(Iteration 881 / 6500) loss: 2.954641
(Iteration 891 / 6500) loss: 3.124338
(Iteration 901 / 6500) loss: 3.259053
(Iteration 911 / 6500) loss: 3.422081
(Iteration 921 / 6500) loss: 3.139704
(Iteration 931 / 6500) loss: 3.135388
(Iteration 941 / 6500) loss: 3.298052
(Iteration 951 / 6500) loss: 3.207340
(Iteration 961 / 6500) loss: 3.266441
(Iteration 971 / 6500) loss: 3.207970
(Iteration 981 / 6500) loss: 3.360744
(Iteration 991 / 6500) loss: 3.258570
(Iteration 1001 / 6500) loss: 3.403770
(Iteration 1011 / 6500) loss: 3.406311
(Iteration 1021 / 6500) loss: 3.158614
(Iteration 1031 / 6500) loss: 3.341207
(Iteration 1041 / 6500) loss: 3.194422
(Iteration 1051 / 6500) loss: 3.251693
(Iteration 1061 / 6500) loss: 3.242279
(Iteration 1071 / 6500) loss: 3.231348
(Iteration 1081 / 6500) loss: 3.152963
(Iteration 1091 / 6500) loss: 3.182526
(Iteration 1101 / 6500) loss: 3.312180
(Iteration 1111 / 6500) loss: 3.297531
(Iteration 1121 / 6500) loss: 3.378451
(Iteration 1131 / 6500) loss: 3.284748
(Iteration 1141 / 6500) loss: 3.263523
(Iteration 1151 / 6500) loss: 3.290260
(Iteration 1161 / 6500) loss: 3.254647
(Iteration 1171 / 6500) loss: 3.266753
(Iteration 1181 / 6500) loss: 3.369875
(Iteration 1191 / 6500) loss: 3.240891
(Iteration 1201 / 6500) loss: 3.164631
(Iteration 1211 / 6500) loss: 3.252614
(Iteration 1221 / 6500) loss: 3.237964
(Iteration 1231 / 6500) loss: 3.132415
(Iteration 1241 / 6500) loss: 3.436526
(Iteration 1251 / 6500) loss: 3.295453
(Iteration 1261 / 6500) loss: 3.149812
(Iteration 1271 / 6500) loss: 2.958878
(Iteration 1281 / 6500) loss: 3.293267
(Iteration 1291 / 6500) loss: 3.117164
(Iteration 1301 / 6500) loss: 3.077472
(Iteration 1311 / 6500) loss: 3.392416
(Iteration 1321 / 6500) loss: 3.512175
(Iteration 1331 / 6500) loss: 3.254942
(Iteration 1341 / 6500) loss: 3.185214
(Iteration 1351 / 6500) loss: 3.207530
(Iteration 1361 / 6500) loss: 3.433467
(Iteration 1371 / 6500) loss: 3.376125
(Iteration 1381 / 6500) loss: 3.374799
(Iteration 1391 / 6500) loss: 3.277039
(Iteration 1401 / 6500) loss: 3.340922
(Iteration 1411 / 6500) loss: 3.157929
(Iteration 1421 / 6500) loss: 3.240678
(Iteration 1431 / 6500) loss: 3.258865
(Iteration 1441 / 6500) loss: 3.194071
(Iteration 1451 / 6500) loss: 3.124094
(Iteration 1461 / 6500) loss: 3.290638
(Iteration 1471 / 6500) loss: 3.464995
(Iteration 1481 / 6500) loss: 3.358133
(Iteration 1491 / 6500) loss: 3.128901
(Iteration 1501 / 6500) loss: 3.222432
(Iteration 1511 / 6500) loss: 3.409962
(Iteration 1521 / 6500) loss: 3.177048
(Iteration 1531 / 6500) loss: 3.226326
(Iteration 1541 / 6500) loss: 3.255045
(Iteration 1551 / 6500) loss: 3.209271
(Iteration 1561 / 6500) loss: 3.103720
(Iteration 1571 / 6500) loss: 3.451807
(Iteration 1581 / 6500) loss: 3.123331
(Iteration 1591 / 6500) loss: 3.122302
(Iteration 1601 / 6500) loss: 3.272734
(Iteration 1611 / 6500) loss: 3.331461
(Iteration 1621 / 6500) loss: 3.213937
(Iteration 1631 / 6500) loss: 3.255308
(Iteration 1641 / 6500) loss: 3.225739
(Iteration 1651 / 6500) loss: 3.320260
(Iteration 1661 / 6500) loss: 3.141707
(Iteration 1671 / 6500) loss: 3.357069
(Iteration 1681 / 6500) loss: 3.297128
(Iteration 1691 / 6500) loss: 3.273867
(Iteration 1701 / 6500) loss: 3.241858
(Iteration 1711 / 6500) loss: 3.212563
(Iteration 1721 / 6500) loss: 3.088705
(Iteration 1731 / 6500) loss: 3.340276
(Iteration 1741 / 6500) loss: 3.204524
(Iteration 1751 / 6500) loss: 3.092583
(Iteration 1761 / 6500) loss: 3.082846
(Iteration 1771 / 6500) loss: 3.333141
(Iteration 1781 / 6500) loss: 3.348559
(Iteration 1791 / 6500) loss: 3.306636
(Iteration 1801 / 6500) loss: 3.285388
(Iteration 1811 / 6500) loss: 3.149512
(Iteration 1821 / 6500) loss: 3.253929
(Iteration 1831 / 6500) loss: 3.280740
(Iteration 1841 / 6500) loss: 3.345027
(Iteration 1851 / 6500) loss: 3.173229
(Iteration 1861 / 6500) loss: 3.248663
(Iteration 1871 / 6500) loss: 2.983986
(Iteration 1881 / 6500) loss: 3.334197
(Iteration 1891 / 6500) loss: 3.237433
(Iteration 1901 / 6500) loss: 3.162352
(Iteration 1911 / 6500) loss: 3.301589
(Iteration 1921 / 6500) loss: 3.080243
(Iteration 1931 / 6500) loss: 3.257055
(Iteration 1941 / 6500) loss: 3.133094
(Iteration 1951 / 6500) loss: 3.133337
(Iteration 1961 / 6500) loss: 3.350655
(Iteration 1971 / 6500) loss: 3.176377
(Iteration 1981 / 6500) loss: 3.173264
(Iteration 1991 / 6500) loss: 3.085672
(Iteration 2001 / 6500) loss: 3.296470
(Iteration 2011 / 6500) loss: 3.243347
(Iteration 2021 / 6500) loss: 3.403524
(Iteration 2031 / 6500) loss: 3.264016
(Iteration 2041 / 6500) loss: 3.275004
(Iteration 2051 / 6500) loss: 3.344241
(Iteration 2061 / 6500) loss: 3.107344
(Iteration 2071 / 6500) loss: 3.287318
(Iteration 2081 / 6500) loss: 3.213448
(Iteration 2091 / 6500) loss: 3.056137
(Iteration 2101 / 6500) loss: 3.092863
(Iteration 2111 / 6500) loss: 2.898487
(Iteration 2121 / 6500) loss: 3.182106
(Iteration 2131 / 6500) loss: 3.099532
(Iteration 2141 / 6500) loss: 3.157031
(Iteration 2151 / 6500) loss: 3.359168
(Iteration 2161 / 6500) loss: 3.286603
(Iteration 2171 / 6500) loss: 3.235262
(Iteration 2181 / 6500) loss: 3.149535
(Iteration 2191 / 6500) loss: 3.216575
(Iteration 2201 / 6500) loss: 3.289324
(Iteration 2211 / 6500) loss: 3.181918
(Iteration 2221 / 6500) loss: 3.178326
(Iteration 2231 / 6500) loss: 3.257136
(Iteration 2241 / 6500) loss: 3.141605
(Iteration 2251 / 6500) loss: 2.995117
(Iteration 2261 / 6500) loss: 3.140910
(Iteration 2271 / 6500) loss: 3.338181
(Iteration 2281 / 6500) loss: 3.200074
(Iteration 2291 / 6500) loss: 3.124386
(Iteration 2301 / 6500) loss: 3.167158
(Iteration 2311 / 6500) loss: 3.281858
(Iteration 2321 / 6500) loss: 3.207462
(Iteration 2331 / 6500) loss: 3.244360
(Iteration 2341 / 6500) loss: 3.179955
(Iteration 2351 / 6500) loss: 3.109914
(Iteration 2361 / 6500) loss: 3.144919
(Iteration 2371 / 6500) loss: 3.243493
(Iteration 2381 / 6500) loss: 3.186051
(Iteration 2391 / 6500) loss: 3.380705
(Iteration 2401 / 6500) loss: 3.260077
(Iteration 2411 / 6500) loss: 3.298461
(Iteration 2421 / 6500) loss: 3.282838
(Iteration 2431 / 6500) loss: 3.270495
(Iteration 2441 / 6500) loss: 3.223551
(Iteration 2451 / 6500) loss: 2.972221
(Iteration 2461 / 6500) loss: 3.312611
(Iteration 2471 / 6500) loss: 3.457439
(Iteration 2481 / 6500) loss: 3.313928
(Iteration 2491 / 6500) loss: 3.314639
(Iteration 2501 / 6500) loss: 3.290813
(Iteration 2511 / 6500) loss: 3.163170
(Iteration 2521 / 6500) loss: 3.042027
(Iteration 2531 / 6500) loss: 3.252811
(Iteration 2541 / 6500) loss: 3.168526
(Iteration 2551 / 6500) loss: 3.105883
(Iteration 2561 / 6500) loss: 3.125408
(Iteration 2571 / 6500) loss: 3.103707
(Iteration 2581 / 6500) loss: 3.409317
(Iteration 2591 / 6500) loss: 3.126922
(Iteration 2601 / 6500) loss: 3.224691
(Iteration 2611 / 6500) loss: 3.205022
(Iteration 2621 / 6500) loss: 3.320655
(Iteration 2631 / 6500) loss: 3.226919
(Iteration 2641 / 6500) loss: 3.131138
(Iteration 2651 / 6500) loss: 3.304069
(Iteration 2661 / 6500) loss: 3.123339
(Iteration 2671 / 6500) loss: 3.076708
(Iteration 2681 / 6500) loss: 2.981643
(Iteration 2691 / 6500) loss: 3.333719
(Iteration 2701 / 6500) loss: 3.039862
(Iteration 2711 / 6500) loss: 3.418711
(Iteration 2721 / 6500) loss: 3.291615
(Iteration 2731 / 6500) loss: 3.121584
(Iteration 2741 / 6500) loss: 3.238757
(Iteration 2751 / 6500) loss: 3.261269
(Iteration 2761 / 6500) loss: 3.366617
(Iteration 2771 / 6500) loss: 3.213320
(Iteration 2781 / 6500) loss: 3.196143
(Iteration 2791 / 6500) loss: 3.152910
(Iteration 2801 / 6500) loss: 3.318454
(Iteration 2811 / 6500) loss: 3.010117
(Iteration 2821 / 6500) loss: 3.068359
(Iteration 2831 / 6500) loss: 3.054662
(Iteration 2841 / 6500) loss: 3.210464
(Iteration 2851 / 6500) loss: 3.208020
(Iteration 2861 / 6500) loss: 3.167940
(Iteration 2871 / 6500) loss: 3.252401
(Iteration 2881 / 6500) loss: 3.238251
(Iteration 2891 / 6500) loss: 3.249012
(Iteration 2901 / 6500) loss: 3.268276
(Iteration 2911 / 6500) loss: 3.202016
(Iteration 2921 / 6500) loss: 3.358900
(Iteration 2931 / 6500) loss: 3.280429
(Iteration 2941 / 6500) loss: 3.256697
(Iteration 2951 / 6500) loss: 3.314723
(Iteration 2961 / 6500) loss: 3.200151
(Iteration 2971 / 6500) loss: 3.140251
(Iteration 2981 / 6500) loss: 3.218271
(Iteration 2991 / 6500) loss: 3.121976
(Iteration 3001 / 6500) loss: 3.175098
(Iteration 3011 / 6500) loss: 3.084890
(Iteration 3021 / 6500) loss: 3.373078
(Iteration 3031 / 6500) loss: 3.105694
(Iteration 3041 / 6500) loss: 3.259424
(Iteration 3051 / 6500) loss: 3.122558
(Iteration 3061 / 6500) loss: 3.103528
(Iteration 3071 / 6500) loss: 2.988836
(Iteration 3081 / 6500) loss: 3.099552
(Iteration 3091 / 6500) loss: 3.119796
(Iteration 3101 / 6500) loss: 3.149021
(Iteration 3111 / 6500) loss: 3.259957
(Iteration 3121 / 6500) loss: 3.205905
(Iteration 3131 / 6500) loss: 3.124528
(Iteration 3141 / 6500) loss: 3.102661
(Iteration 3151 / 6500) loss: 3.033034
(Iteration 3161 / 6500) loss: 2.989570
(Iteration 3171 / 6500) loss: 3.536449
(Iteration 3181 / 6500) loss: 3.404926
(Iteration 3191 / 6500) loss: 3.135219
(Iteration 3201 / 6500) loss: 3.037280
(Iteration 3211 / 6500) loss: 3.309605
(Iteration 3221 / 6500) loss: 3.226583
(Iteration 3231 / 6500) loss: 3.141690
(Iteration 3241 / 6500) loss: 3.199807
(Iteration 3251 / 6500) loss: 3.060846
(Iteration 3261 / 6500) loss: 3.169383
(Iteration 3271 / 6500) loss: 3.036749
(Iteration 3281 / 6500) loss: 3.101948
(Iteration 3291 / 6500) loss: 3.224894
(Iteration 3301 / 6500) loss: 2.965902
(Iteration 3311 / 6500) loss: 3.174678
(Iteration 3321 / 6500) loss: 3.135977
(Iteration 3331 / 6500) loss: 3.236188
(Iteration 3341 / 6500) loss: 3.051222
(Iteration 3351 / 6500) loss: 3.330589
(Iteration 3361 / 6500) loss: 3.314052
(Iteration 3371 / 6500) loss: 3.195367
(Iteration 3381 / 6500) loss: 3.382930
(Iteration 3391 / 6500) loss: 3.211231
(Iteration 3401 / 6500) loss: 3.262714
(Iteration 3411 / 6500) loss: 3.041757
(Iteration 3421 / 6500) loss: 3.285519
(Iteration 3431 / 6500) loss: 3.219639
(Iteration 3441 / 6500) loss: 3.204387
(Iteration 3451 / 6500) loss: 3.020248
(Iteration 3461 / 6500) loss: 3.222610
(Iteration 3471 / 6500) loss: 3.357396
(Iteration 3481 / 6500) loss: 3.288370
(Iteration 3491 / 6500) loss: 3.236477
(Iteration 3501 / 6500) loss: 3.153038
(Iteration 3511 / 6500) loss: 3.213330
(Iteration 3521 / 6500) loss: 3.380553
(Iteration 3531 / 6500) loss: 3.157968
(Iteration 3541 / 6500) loss: 3.259161
(Iteration 3551 / 6500) loss: 3.210481
(Iteration 3561 / 6500) loss: 3.158160
(Iteration 3571 / 6500) loss: 3.242752
(Iteration 3581 / 6500) loss: 3.113264
(Iteration 3591 / 6500) loss: 3.129353
(Iteration 3601 / 6500) loss: 3.208165
(Iteration 3611 / 6500) loss: 3.206379
(Iteration 3621 / 6500) loss: 3.064849
(Iteration 3631 / 6500) loss: 3.052793
(Iteration 3641 / 6500) loss: 3.147829
(Iteration 3651 / 6500) loss: 3.196326
(Iteration 3661 / 6500) loss: 3.231201
(Iteration 3671 / 6500) loss: 3.089174
(Iteration 3681 / 6500) loss: 3.428618
(Iteration 3691 / 6500) loss: 3.082861
(Iteration 3701 / 6500) loss: 3.260761
(Iteration 3711 / 6500) loss: 3.107771
(Iteration 3721 / 6500) loss: 3.155005
(Iteration 3731 / 6500) loss: 3.178942
(Iteration 3741 / 6500) loss: 3.119209
(Iteration 3751 / 6500) loss: 3.255609
(Iteration 3761 / 6500) loss: 3.300644
(Iteration 3771 / 6500) loss: 3.219170
(Iteration 3781 / 6500) loss: 3.073737
(Iteration 3791 / 6500) loss: 3.247109
(Iteration 3801 / 6500) loss: 3.004824
(Iteration 3811 / 6500) loss: 3.008551
(Iteration 3821 / 6500) loss: 3.291581
(Iteration 3831 / 6500) loss: 3.242744
(Iteration 3841 / 6500) loss: 3.182067
(Iteration 3851 / 6500) loss: 3.205007
(Iteration 3861 / 6500) loss: 3.129074
(Iteration 3871 / 6500) loss: 3.119228
(Iteration 3881 / 6500) loss: 3.232647
(Iteration 3891 / 6500) loss: 3.198757
(Iteration 3901 / 6500) loss: 3.178018
(Iteration 3911 / 6500) loss: 3.156205
(Iteration 3921 / 6500) loss: 3.114129
(Iteration 3931 / 6500) loss: 3.089466
(Iteration 3941 / 6500) loss: 3.093334
(Iteration 3951 / 6500) loss: 3.175014
(Iteration 3961 / 6500) loss: 3.158354
(Iteration 3971 / 6500) loss: 3.154164
(Iteration 3981 / 6500) loss: 3.110854
(Iteration 3991 / 6500) loss: 3.087659
(Iteration 4001 / 6500) loss: 3.253248
(Iteration 4011 / 6500) loss: 3.340842
(Iteration 4021 / 6500) loss: 3.099040
(Iteration 4031 / 6500) loss: 3.124088
(Iteration 4041 / 6500) loss: 3.098585
(Iteration 4051 / 6500) loss: 3.074766
(Iteration 4061 / 6500) loss: 3.251236
(Iteration 4071 / 6500) loss: 3.085186
(Iteration 4081 / 6500) loss: 3.036401
(Iteration 4091 / 6500) loss: 3.075170
(Iteration 4101 / 6500) loss: 3.022859
(Iteration 4111 / 6500) loss: 3.082636
(Iteration 4121 / 6500) loss: 3.232541
(Iteration 4131 / 6500) loss: 3.265728
(Iteration 4141 / 6500) loss: 3.033381
(Iteration 4151 / 6500) loss: 3.224923
(Iteration 4161 / 6500) loss: 2.975570
(Iteration 4171 / 6500) loss: 3.185651
(Iteration 4181 / 6500) loss: 3.128310
(Iteration 4191 / 6500) loss: 3.310016
(Iteration 4201 / 6500) loss: 3.279446
(Iteration 4211 / 6500) loss: 3.370289
(Iteration 4221 / 6500) loss: 3.523191
(Iteration 4231 / 6500) loss: 3.022225
(Iteration 4241 / 6500) loss: 3.225037
(Iteration 4251 / 6500) loss: 3.137309
(Iteration 4261 / 6500) loss: 3.096520
(Iteration 4271 / 6500) loss: 3.092523
(Iteration 4281 / 6500) loss: 3.134114
(Iteration 4291 / 6500) loss: 3.087708
(Iteration 4301 / 6500) loss: 3.248252
(Iteration 4311 / 6500) loss: 3.362140
(Iteration 4321 / 6500) loss: 3.262130
(Iteration 4331 / 6500) loss: 3.008355
(Iteration 4341 / 6500) loss: 2.983497
(Iteration 4351 / 6500) loss: 3.258147
(Iteration 4361 / 6500) loss: 3.093742
(Iteration 4371 / 6500) loss: 3.049689
(Iteration 4381 / 6500) loss: 3.185015
(Iteration 4391 / 6500) loss: 3.183446
(Iteration 4401 / 6500) loss: 3.219481
(Iteration 4411 / 6500) loss: 3.206947
(Iteration 4421 / 6500) loss: 3.307300
(Iteration 4431 / 6500) loss: 3.231768
(Iteration 4441 / 6500) loss: 3.127490
(Iteration 4451 / 6500) loss: 3.113276
(Iteration 4461 / 6500) loss: 3.081661
(Iteration 4471 / 6500) loss: 3.127054
(Iteration 4481 / 6500) loss: 3.153799
(Iteration 4491 / 6500) loss: 3.163819
(Iteration 4501 / 6500) loss: 3.204151
(Iteration 4511 / 6500) loss: 3.123404
(Iteration 4521 / 6500) loss: 3.314206
(Iteration 4531 / 6500) loss: 3.207819
(Iteration 4541 / 6500) loss: 3.169136
(Iteration 4551 / 6500) loss: 3.057458
(Iteration 4561 / 6500) loss: 3.159207
(Iteration 4571 / 6500) loss: 3.062155
(Iteration 4581 / 6500) loss: 3.212522
(Iteration 4591 / 6500) loss: 2.989435
(Iteration 4601 / 6500) loss: 3.023780
(Iteration 4611 / 6500) loss: 3.104466
(Iteration 4621 / 6500) loss: 2.961605
(Iteration 4631 / 6500) loss: 3.241505
(Iteration 4641 / 6500) loss: 3.123450
(Iteration 4651 / 6500) loss: 3.211448
(Iteration 4661 / 6500) loss: 3.198398
(Iteration 4671 / 6500) loss: 3.283188
(Iteration 4681 / 6500) loss: 3.168611
(Iteration 4691 / 6500) loss: 3.314898
(Iteration 4701 / 6500) loss: 3.099060
(Iteration 4711 / 6500) loss: 3.178217
(Iteration 4721 / 6500) loss: 3.132255
(Iteration 4731 / 6500) loss: 3.212854
(Iteration 4741 / 6500) loss: 2.974342
(Iteration 4751 / 6500) loss: 3.165208
(Iteration 4761 / 6500) loss: 2.937145
(Iteration 4771 / 6500) loss: 3.287653
(Iteration 4781 / 6500) loss: 2.990134
(Iteration 4791 / 6500) loss: 3.222145
(Iteration 4801 / 6500) loss: 2.952821
(Iteration 4811 / 6500) loss: 3.077408
(Iteration 4821 / 6500) loss: 3.233161
(Iteration 4831 / 6500) loss: 3.244088
(Iteration 4841 / 6500) loss: 3.120233
(Iteration 4851 / 6500) loss: 3.137705
(Iteration 4861 / 6500) loss: 3.056806
(Iteration 4871 / 6500) loss: 3.140650
(Iteration 4881 / 6500) loss: 3.174081
(Iteration 4891 / 6500) loss: 3.146333
(Iteration 4901 / 6500) loss: 3.101686
(Iteration 4911 / 6500) loss: 2.910747
(Iteration 4921 / 6500) loss: 3.050982
(Iteration 4931 / 6500) loss: 3.326048
(Iteration 4941 / 6500) loss: 3.302544
(Iteration 4951 / 6500) loss: 2.941517
(Iteration 4961 / 6500) loss: 3.192916
(Iteration 4971 / 6500) loss: 3.153873
(Iteration 4981 / 6500) loss: 3.223632
(Iteration 4991 / 6500) loss: 3.127348
(Iteration 5001 / 6500) loss: 3.077652
(Iteration 5011 / 6500) loss: 3.204094
(Iteration 5021 / 6500) loss: 3.286052
(Iteration 5031 / 6500) loss: 3.067675
(Iteration 5041 / 6500) loss: 3.133687
(Iteration 5051 / 6500) loss: 3.220713
(Iteration 5061 / 6500) loss: 3.261526
(Iteration 5071 / 6500) loss: 3.122312
(Iteration 5081 / 6500) loss: 3.243432
(Iteration 5091 / 6500) loss: 3.109125
(Iteration 5101 / 6500) loss: 3.294788
(Iteration 5111 / 6500) loss: 3.063448
(Iteration 5121 / 6500) loss: 3.219771
(Iteration 5131 / 6500) loss: 3.063831
(Iteration 5141 / 6500) loss: 3.316438
(Iteration 5151 / 6500) loss: 3.227706
(Iteration 5161 / 6500) loss: 3.070319
(Iteration 5171 / 6500) loss: 3.100753
(Iteration 5181 / 6500) loss: 3.177833
(Iteration 5191 / 6500) loss: 3.389536
(Iteration 5201 / 6500) loss: 3.057270
(Iteration 5211 / 6500) loss: 3.170821
(Iteration 5221 / 6500) loss: 3.227208
(Iteration 5231 / 6500) loss: 3.306525
(Iteration 5241 / 6500) loss: 3.143133
(Iteration 5251 / 6500) loss: 3.241541
(Iteration 5261 / 6500) loss: 3.097567
(Iteration 5271 / 6500) loss: 3.137881
(Iteration 5281 / 6500) loss: 3.379761
(Iteration 5291 / 6500) loss: 2.976743
(Iteration 5301 / 6500) loss: 3.211886
(Iteration 5311 / 6500) loss: 3.066688
(Iteration 5321 / 6500) loss: 3.180775
(Iteration 5331 / 6500) loss: 3.131103
(Iteration 5341 / 6500) loss: 3.175842
(Iteration 5351 / 6500) loss: 2.975848
(Iteration 5361 / 6500) loss: 3.031389
(Iteration 5371 / 6500) loss: 3.194212
(Iteration 5381 / 6500) loss: 3.042728
(Iteration 5391 / 6500) loss: 3.256925
(Iteration 5401 / 6500) loss: 3.109778
(Iteration 5411 / 6500) loss: 3.067532
(Iteration 5421 / 6500) loss: 3.208317
(Iteration 5431 / 6500) loss: 3.155585
(Iteration 5441 / 6500) loss: 3.238702
(Iteration 5451 / 6500) loss: 3.167611
(Iteration 5461 / 6500) loss: 3.139582
(Iteration 5471 / 6500) loss: 2.962633
(Iteration 5481 / 6500) loss: 3.203734
(Iteration 5491 / 6500) loss: 3.105920
(Iteration 5501 / 6500) loss: 2.953353
(Iteration 5511 / 6500) loss: 3.223272
(Iteration 5521 / 6500) loss: 3.050356
(Iteration 5531 / 6500) loss: 3.188413
(Iteration 5541 / 6500) loss: 3.182814
(Iteration 5551 / 6500) loss: 3.011836
(Iteration 5561 / 6500) loss: 3.181683
(Iteration 5571 / 6500) loss: 2.970350
(Iteration 5581 / 6500) loss: 3.200682
(Iteration 5591 / 6500) loss: 3.068029
(Iteration 5601 / 6500) loss: 3.157168
(Iteration 5611 / 6500) loss: 3.111010
(Iteration 5621 / 6500) loss: 3.198570
(Iteration 5631 / 6500) loss: 3.244223
(Iteration 5641 / 6500) loss: 3.173526
(Iteration 5651 / 6500) loss: 3.252762
(Iteration 5661 / 6500) loss: 3.272377
(Iteration 5671 / 6500) loss: 3.045999
(Iteration 5681 / 6500) loss: 3.198718
(Iteration 5691 / 6500) loss: 3.092830
(Iteration 5701 / 6500) loss: 2.976280
(Iteration 5711 / 6500) loss: 2.939376
(Iteration 5721 / 6500) loss: 3.035165
(Iteration 5731 / 6500) loss: 3.123749
(Iteration 5741 / 6500) loss: 3.061396
(Iteration 5751 / 6500) loss: 3.138943
(Iteration 5761 / 6500) loss: 3.337300
(Iteration 5771 / 6500) loss: 3.105203
(Iteration 5781 / 6500) loss: 3.121954
(Iteration 5791 / 6500) loss: 3.238953
(Iteration 5801 / 6500) loss: 3.192701
(Iteration 5811 / 6500) loss: 3.198984
(Iteration 5821 / 6500) loss: 3.172417
(Iteration 5831 / 6500) loss: 3.005950
(Iteration 5841 / 6500) loss: 3.266971
(Iteration 5851 / 6500) loss: 2.956843
(Iteration 5861 / 6500) loss: 3.013122
(Iteration 5871 / 6500) loss: 3.125925
(Iteration 5881 / 6500) loss: 3.085147
(Iteration 5891 / 6500) loss: 3.312744
(Iteration 5901 / 6500) loss: 3.075357
(Iteration 5911 / 6500) loss: 3.331591
(Iteration 5921 / 6500) loss: 3.111987
(Iteration 5931 / 6500) loss: 3.302929
(Iteration 5941 / 6500) loss: 3.147874
(Iteration 5951 / 6500) loss: 3.132140
(Iteration 5961 / 6500) loss: 3.036311
(Iteration 5971 / 6500) loss: 3.163517
(Iteration 5981 / 6500) loss: 3.148374
(Iteration 5991 / 6500) loss: 3.117929
(Iteration 6001 / 6500) loss: 2.969580
(Iteration 6011 / 6500) loss: 3.131476
(Iteration 6021 / 6500) loss: 3.252439
(Iteration 6031 / 6500) loss: 3.142022
(Iteration 6041 / 6500) loss: 3.123322
(Iteration 6051 / 6500) loss: 3.160108
(Iteration 6061 / 6500) loss: 3.105633
(Iteration 6071 / 6500) loss: 3.339191
(Iteration 6081 / 6500) loss: 3.246847
(Iteration 6091 / 6500) loss: 3.230234
(Iteration 6101 / 6500) loss: 3.206764
(Iteration 6111 / 6500) loss: 3.159972
(Iteration 6121 / 6500) loss: 3.080728
(Iteration 6131 / 6500) loss: 3.173167
(Iteration 6141 / 6500) loss: 3.055533
(Iteration 6151 / 6500) loss: 3.135834
(Iteration 6161 / 6500) loss: 3.062875
(Iteration 6171 / 6500) loss: 3.208627
(Iteration 6181 / 6500) loss: 3.200404
(Iteration 6191 / 6500) loss: 3.191258
(Iteration 6201 / 6500) loss: 3.250434
(Iteration 6211 / 6500) loss: 3.055741
(Iteration 6221 / 6500) loss: 3.182753
(Iteration 6231 / 6500) loss: 3.052222
(Iteration 6241 / 6500) loss: 3.142158
(Iteration 6251 / 6500) loss: 3.094199
(Iteration 6261 / 6500) loss: 2.960225
(Iteration 6271 / 6500) loss: 3.205226
(Iteration 6281 / 6500) loss: 3.198716
(Iteration 6291 / 6500) loss: 3.141877
(Iteration 6301 / 6500) loss: 3.036747
(Iteration 6311 / 6500) loss: 3.112796
(Iteration 6321 / 6500) loss: 3.102436
(Iteration 6331 / 6500) loss: 3.247133
(Iteration 6341 / 6500) loss: 3.029989
(Iteration 6351 / 6500) loss: 3.048305
(Iteration 6361 / 6500) loss: 3.017422
(Iteration 6371 / 6500) loss: 3.221219
(Iteration 6381 / 6500) loss: 3.100928
(Iteration 6391 / 6500) loss: 2.961815
(Iteration 6401 / 6500) loss: 3.241863
(Iteration 6411 / 6500) loss: 3.064771
(Iteration 6421 / 6500) loss: 3.108584
(Iteration 6431 / 6500) loss: 3.310657
(Iteration 6441 / 6500) loss: 3.079437
(Iteration 6451 / 6500) loss: 3.214549
(Iteration 6461 / 6500) loss: 3.073925
(Iteration 6471 / 6500) loss: 3.353747
(Iteration 6481 / 6500) loss: 2.986878
(Iteration 6491 / 6500) loss: 3.118944

In [19]:
print "Train:",
answ=model.loss(data_train[:1000,:], sample=True)
print (answ==data_train[:1000,-1]).mean()
minibatch = data_train[:10]
print "\n".join([words[x]+" "+words[y] for x,y in zip(model.loss(minibatch, sample=True), minibatch[:,-1])])

print
print "Test:",
answ=model.loss(data_test[:1000], sample=True)
print (answ==data_test[:1000,-1]).mean()
minibatch = data_test[:10]
print "\n".join([words[x]+" "+words[y] for x,y in zip(model.loss(minibatch, sample=True), minibatch[:,-1])])


Train: 0.709
tershane tershane
tershane tershane
tershane bahce
tershane tershane
tershane bolum
amfi sinif
mutfak servis
mutfak mutfak
mutfak mutfak
mutfak mutfak

Test: 0.047
bolum amfi
labaratuvar mutfak
araba otobus
araba amfi
oda amfi
amfi kantin
sira amfi
masa araba
kamyon park
kamyon okul

In [36]:
model = SeqNN(word_to_idx, cell_type='rnn', hidden_dim=256, wordvec_dim=256)
solver = SeqNNSolver(model, data_train[:50],
           update_rule='adam',
           num_epochs=75,
           batch_size=25,
           optim_config={
             'learning_rate': 1e-3,
           },
           lr_decay=.995,
           verbose=True, print_every=10,
         )
solver.train()

# Plot the training losses
plt.plot(solver.loss_history)
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.title('Training loss history')
plt.show()


(Iteration 1 / 150) loss: 5.059273
(Iteration 11 / 150) loss: 3.393621
(Iteration 21 / 150) loss: 2.451049
(Iteration 31 / 150) loss: 2.209528
(Iteration 41 / 150) loss: 1.664157
(Iteration 51 / 150) loss: 1.500286
(Iteration 61 / 150) loss: 0.786686
(Iteration 71 / 150) loss: 0.698852
(Iteration 81 / 150) loss: 0.318055
(Iteration 91 / 150) loss: 0.276624
(Iteration 101 / 150) loss: 0.135957
(Iteration 111 / 150) loss: 0.084413
(Iteration 121 / 150) loss: 0.076423
(Iteration 131 / 150) loss: 0.064472
(Iteration 141 / 150) loss: 0.053949

In [ ]: