MNIST classification


In [1]:
require 'nn';
require 'rnn';
matio = require 'matio'

In [2]:
data = matio.load('ex4data1.mat')
trainset = {}
trainset.data = data.X
trainset.label = data.y[{ {}, 1 }]

In [3]:
setmetatable(trainset,
    {__index = function(t,i)
                return {t.data[i], t.label[i]}
        end}
);
 
function trainset:size()
    return self.data:size(1)
end

In [4]:
mean = {}
stdv = {}
for i=1,400 do
    mean[i] = trainset.data[{ {},{i} }]:mean()
    stdv[i] = trainset.data[{ {}, {i} }]:std()
    --print(i .. 'th mean: ' .. mean[i])
    --print(i .. 'th std dev: ' .. stdv[i])
    trainset.data[{ {},{i} }]:add(-mean[i])
    if stdv[i] ~= 0 then
        trainset.data[{ {},{i} }]:div(stdv[i])
    end
end

define model


In [5]:
batchSize = 5000
rho = 5
hiddenSize = 10
nIndex = 20
nClass = 10

In [6]:
rnn = nn.Sequential()
r = nn.Recurrent(
    hiddenSize, nn.Linear(nIndex, hiddenSize),
    nn.Linear(hiddenSize, hiddenSize), nn.Sigmoid(),
    rho
)
rnn:add(r)
rnn:add(nn.Linear(hiddenSize, nClass))
rnn:add(nn.LogSoftMax())
rnn = nn.Sequencer(rnn)

In [7]:
criterion = nn.SequencerCriterion(nn.ClassNLLCriterion())

train


In [17]:
lr = 0.1
i = 1

In [20]:
prev = 100
for epoch = 1,1e2 do
    local inputs, targets = {}, {}
    for step=1,rho do
        table.insert(inputs, trainset.data[{{},{i, i+nIndex-1}}])
        table.insert(targets, trainset.label[{{}}])
        i = i+20
        if i+nIndex-1 > 400 then 
            i = 1
            break
        end
    end
    local outputs = rnn:forward(inputs)
    local err = criterion:forward(outputs, targets)
    if epoch%1 == 0 then print(epoch, err/rho, i) end
    local gradOutputs = criterion:backward(outputs, targets)
    rnn:backward(inputs, gradOutputs)
    rnn:updateParameters(lr)
    rnn:zeroGradParameters()
    if prev<err and lr > 0.1 then 
        print("prev: ", prev, "cur: ", err, "lr", lr, "->", lr*0.5)
        lr = lr * 0.5 
    end
    prev = err
end


Out[20]:
1	2.2623325467748	101	
Out[20]:
2	1.8440018558042	201	
Out[20]:
3	1.9691901419904	301	
Out[20]:
4	2.2418026958367	1	
Out[20]:
5	2.2622659944138	101	
Out[20]:
6	1.8409273917969	201	
Out[20]:
7	1.9669050395324	301	
Out[20]:
8	2.2410413045918	1	
Out[20]:
9	2.2621974495532	101	
Out[20]:
10	1.8378874776133	201	
Out[20]:
11	1.9646372134661	301	
Out[20]:
12	2.2402718408027	1	
Out[20]:
13	
Out[20]:
2.2621272738148	101	
Out[20]:
14	1.8348818272594	201	
Out[20]:
15	1.9623867934585	301	
Out[20]:
16	2.2394946072619	1	
Out[20]:
17	2.2620557953525	101	
Out[20]:
18	1.8319100508671	
Out[20]:
201	
Out[20]:
19	1.9601539283713	
Out[20]:
301	
Out[20]:
20	2.2387096413671	1	
Out[20]:
21	2.2619832877235	
Out[20]:
101	
Out[20]:
22	1.8289716878628	201	
Out[20]:
23	
Out[20]:
1.9579387060812	301	
Out[20]:
24	2.237916888933	1	
Out[20]:
25	2.2619100778473	101	
Out[20]:
26	
Out[20]:
1.8260662899965	201	
Out[20]:
27	1.9557412031501	301	
Out[20]:
28	2.2371167834692	1	
Out[20]:
29	2.2618369518513	101	
Out[20]:
30	1.8231934987589	201	
Out[20]:
31	1.9535617116008	301	
Out[20]:
32	2.2363110523703	1	
Out[20]:
33	2.2617656481358	101	
Out[20]:
34	1.8203530525926	201	
Out[20]:
35	1.9514007098337	301	
Out[20]:
36	2.235502418027	1	
Out[20]:
37	2.2616965096519	101	
Out[20]:
38	1.8175447163854	201	
Out[20]:
39	1.9492577919316	301	
Out[20]:
40	2.2346916198613	1	
Out[20]:
41	2.2616255760604	101	
Out[20]:
42	
Out[20]:
1.8147681116829	201	
Out[20]:
43	1.9471320895165	301	
Out[20]:

Out[20]:
44	2.2338786256052	1	
Out[20]:
45	2.2615488420003	101	
Out[20]:
46	1.8120226263831	201	
Out[20]:
47	1.9450231445675	301	
Out[20]:
48	2.2330659384787	1	
Out[20]:
49	2.2614653209972	101	
Out[20]:
50	1.8093075875951	201	
Out[20]:
51	1.9429313708188	301	
Out[20]:
52	2.2322606024587	
Out[20]:
1	
Out[20]:
53	2.261384941503	101	
Out[20]:
54	1.8066223225857	201	
Out[20]:
55	1.9408589946083	301	
Out[20]:
56	2.2314756569954	1	
Out[20]:
57	2.2613202285551	101	
Out[20]:
58	1.8039663676047	201	
Out[20]:
59	1.9388063473393	301	
Out[20]:
60	2.2307159620443	1	
Out[20]:
61	2.2612681485664	101	
Out[20]:
62	1.8013396643819	201	
Out[20]:
63	1.9367724159201	301	
Out[20]:
64	2.2299798076258	1	
Out[20]:
65	
Out[20]:
2.2612226736151	101	
Out[20]:
66	1.7987420982266	201	
Out[20]:
67	1.9347564626203	301	
Out[20]:
68	2.2292661818862	1	
Out[20]:
69	2.261181676039	101	
Out[20]:
70	1.7961730346871	201	
Out[20]:
71	1.9327580344087	301	
Out[20]:
72	2.2285763906108	1	
Out[20]:
73	2.2611468033912	101	
Out[20]:
74	
Out[20]:
1.7936313127983	201	
Out[20]:
75	1.9307769820345	301	
Out[20]:
76	2.2279138645555	1	
Out[20]:
77	2.2611217301843	101	
Out[20]:
78	1.7911157233003	201	
Out[20]:
79	1.9288135521346	301	
Out[20]:
80	
Out[20]:
2.2272821027118	1	
Out[20]:
81	2.2611093570622	101	
Out[20]:
82	1.7886253974466	201	
Out[20]:
83	1.9268683541975	301	
Out[20]:
84	
Out[20]:
2.2266829991191	1	
Out[20]:
85	2.2611108686636	101	
Out[20]:
86	1.7861597089923	201	
Out[20]:
87	1.9249421927251	301	
Out[20]:
88	2.2261170413483	1	
Out[20]:
89	2.2611265847401	101	
Out[20]:
90	1.7837181009818	201	
Out[20]:
91	1.923035916227	301	
Out[20]:
92	2.2255838753852	1	
Out[20]:
93	2.2611565911132	101	
Out[20]:
94	1.7813000664733	201	
Out[20]:
95	1.9211503309159	301	
Out[20]:
96	2.2250825386842	1	
Out[20]:
97	2.2612008875501	101	
Out[20]:
98	1.7789051946369	201	
Out[20]:
99	
Out[20]:
1.9192861345837	301	
Out[20]:
100	2.2246115778632	
Out[20]:
1	

test


In [21]:
correction = {}
trainsize = 100 -- trainset:size()
for i=20,38 do
    correction[i] = 0
end
for i=1,trainsize do
    local answer = trainset.label[i]
    local inputs = {}
    for step=0,38 do
        table.insert(inputs, trainset.data[{{i},{step*10+1, step*10+20}}])
    end
    local prediction = rnn:forward(inputs)
    for d=20,38 do
        guess = prediction[d][{1,{}}]
        local confidences, indices = torch.max(guess)
        -- if i%100 == 1 then print(answer, guess, indices[1]) end
        if (answer == indices) then
            correction[d] = correction[d] + 1
        end
    end
end
for i=20,38 do
    print(i, " = ", correction[i], 100*correction[i]/trainsize .. '%')
end


Out[21]:
20	 = 	18	18%	
21	 = 	2	
Out[21]:
2%	
22	 = 	8	8%	
23	 = 	3	3%	
24	 = 	5	5%	
25	 = 	8	8%	
26	 = 	12	12%	
27	 = 	33	33%	
28	 = 	26	26%	
29	 = 	61	61%	
30	 = 	31	31%	
31	 = 	52	52%	
32	 = 	20	20%	
33	 = 	36	36%	
34	 = 	5	5%	
35	 = 	14	14%	
36	 = 	2	2%	
37	 = 	1	1%	
38	 = 	0	0%	

In [ ]: