MNIST classification


In [1]:
require 'nn';
require 'rnn';
matio = require 'matio'

In [2]:
data = matio.load('ex4data1.mat')
trainset = {}
trainset.data = data.X
trainset.label = data.y[{ {}, 1 }]

In [4]:
setmetatable(trainset,
    {__index = function(t,i)
                return {t.data[i], t.label[i]}
        end}
);
 
function trainset:size()
    return self.data:size(1)
end

Data Normalization


In [6]:
mean = {}
stdv = {}
for i=1,400 do
    mean[i] = trainset.data[{ {},{i} }]:mean()
    stdv[i] = trainset.data[{ {}, {i} }]:std()
    --print(i .. 'th mean: ' .. mean[i])
    --print(i .. 'th std dev: ' .. stdv[i])
    trainset.data[{ {},{i} }]:add(-mean[i])
    if stdv[i] ~= 0 then
        trainset.data[{ {},{i} }]:div(stdv[i])
    end
end

define model


In [7]:
batchSize = 5000
rho = 20
hiddenSize = 10
nIndex = 20
nClass = 10

In [8]:
rnn = nn.Sequential()
r = nn.Recurrent(
    hiddenSize, nn.Linear(nIndex, hiddenSize), --nn.Linear(nIndex, hiddenSize) = W_hx
    nn.Linear(hiddenSize, hiddenSize), nn.Tanh(), --nn.Linear(hiddenSize, hiddenSize) = W_hh
    rho
)
rnn:add(r)
rnn:add(nn.Linear(hiddenSize, nClass))
rnn:add(nn.LogSoftMax())
rnn = nn.Sequencer(rnn)

In [9]:
criterion = nn.SequencerCriterion(nn.ClassNLLCriterion())

train


In [10]:
lr = 0.1
i = 1

In [11]:
prev = 100
for epoch = 1,1e3 do
    local inputs, targets = {}, {}
    for step=1,rho do -- 1 ~ 20
        table.insert(inputs, trainset.data[{{},{(step-1)*rho+1, step*rho}}])
        table.insert(targets, trainset.label[{{}}])
    end

    local outputs = rnn:forward(inputs)
    local err = criterion:forward(outputs, targets)
    if epoch%10 == 1 then print(epoch, err/rho) end
    local gradOutputs = criterion:backward(outputs, targets)
    rnn:backward(inputs, gradOutputs)
    rnn:updateParameters(lr)
    rnn:zeroGradParameters()
end


Out[11]:
1	2.3391271816177	
Out[11]:
11	2.1955281583623	
Out[11]:
21	2.037658252251	
Out[11]:
31	1.9690607747917	
Out[11]:
41	1.9593234961249	
Out[11]:
51	1.9831929945857	
Out[11]:
61	1.8620890182392	
Out[11]:
71	1.9148388041667	
Out[11]:

Out[11]:
81	1.9557262654271	
Out[11]:
91	1.8021422117858	
Out[11]:
101	1.8188248065218	
Out[11]:
111	1.717234843504	
Out[11]:
121	1.7623135490069	
Out[11]:
131	1.79423168286	
Out[11]:
141	1.7726373670219	
Out[11]:
151	1.9834949072924	
Out[11]:
161	1.6997501949595	
Out[11]:
171	1.6492612479463	
Out[11]:
181	2.1208070556762	
Out[11]:
191	1.6821686974183	
Out[11]:
201	1.6318084080856	
Out[11]:
211	1.6194401872758	
Out[11]:
221	1.9092329085307	
Out[11]:
231	1.856933600125	
Out[11]:
241	1.8243470661334	
Out[11]:
251	2.031291997915	
Out[11]:
261	2.0166399091812	
Out[11]:
271	1.8939151789164	
Out[11]:
281	1.9844990549649	
Out[11]:
291	1.8686006093978	
Out[11]:
301	1.8129094266696	
Out[11]:
311	1.8059165071629	
Out[11]:
321	1.9269985363757	
Out[11]:
331	1.6999124079312	
Out[11]:
341	1.8645111286057	
Out[11]:
351	2.5128108530264	
Out[11]:
361	1.7570424130973	
Out[11]:
371	1.8349937668285	
Out[11]:
381	1.7687974118329	
Out[11]:
391	1.8403859580368	
Out[11]:
401	1.8246828940703	
Out[11]:
411	1.7423012112847	
Out[11]:
421	1.7294762827616	
Out[11]:
431	1.8694934743739	
Out[11]:
441	1.7682192822609	
Out[11]:
451	1.7605444764079	
Out[11]:
461	1.7484330770191	
Out[11]:
471	1.7650141805605	
Out[11]:
481	1.789059489947	
Out[11]:
491	1.6934223480175	
Out[11]:
501	1.6660960211912	
Out[11]:
511	1.7443063285868	
Out[11]:
521	1.7009674322367	
Out[11]:
531	1.6986856051417	
Out[11]:
541	1.6865454695006	
Out[11]:
551	1.7411362033103	
Out[11]:
561	1.6912126938625	
Out[11]:
571	1.6347865508689	
Out[11]:
581	1.7103972993512	
Out[11]:
591	1.8859219735329	
Out[11]:
601	1.7403357852309	
Out[11]:
611	1.7748609068793	
Out[11]:
621	1.7583254468828	
Out[11]:
631	1.8857419940146	
Out[11]:
641	1.7992770553966	
Out[11]:
651	2.2019843001456	
Out[11]:
661	2.1009637692277	
Out[11]:
671	2.0553492381859	
Out[11]:
681	2.1873517365684	
Out[11]:
691	2.0910587881434	
Out[11]:
701	2.0645586680785	
Out[11]:
711	2.0411537749646	
Out[11]:
721	2.0337288829184	
Out[11]:
731	2.0386315002356	
Out[11]:
741	2.0114692246773	
Out[11]:
751	2.0042521928585	
Out[11]:
761	2.076581891392	
Out[11]:
771	2.0004553883316	
Out[11]:
781	2.0174584730144	
Out[11]:
791	1.9738821894372	
Out[11]:
801	2.1543788260467	
Out[11]:
811	
Out[11]:
2.1238615624285	
Out[11]:
821	2.1007468713526	
Out[11]:
831	2.083947453716	
Out[11]:
841	2.0535403511206	
Out[11]:
851	1.9729857218067	
Out[11]:
861	1.957627588988	
Out[11]:
871	1.9522227621764	
Out[11]:
881	1.948584517834	
Out[11]:
891	1.9398604464103	
Out[11]:
901	1.927979918414	
Out[11]:
911	1.9267730202584	
Out[11]:
921	1.9137384341865	
Out[11]:
931	1.9048116611769	
Out[11]:
941	1.9156521368639	
Out[11]:
951	1.9037856805473	
Out[11]:
961	1.8979352848381	
Out[11]:
971	1.8926875283293	
Out[11]:
981	1.8950950975504	
Out[11]:
991	1.8761098743622	

test


In [12]:
correction = {}
for i=10,20 do
    correction[i] = 0
end
for i=1,trainset:size() do
    local answer = trainset.label[i]    
    local inputs = {}
    for step=1,rho do
        table.insert(inputs, trainset.data[{{i},{(step-1)*20+1, step*20}}])
    end
    local prediction = rnn:forward(inputs)
    for d=10,20 do
        guess = prediction[d][{1,{}}]
        local confidences, indices = torch.sort(guess, true)
        -- if i%100 == 1 then print(answer, guess, indices[1]) end
        if (answer == indices[1]) then
            correction[d] = correction[d] + 1
        end
    end
end
for i=10,20 do
    print(i, " = ", correction[i], 100*correction[i]/trainset:size() .. '%')
end


Out[12]:
10	
Out[12]:
 = 	1331	26.62%	
11	 = 	1111	22.22%	
12	 = 	1149	22.98%	
13	 = 	1383	27.66%	
14	 = 	1540	30.8%	
15	 = 	1642	32.84%	
16	 = 	1684	33.68%	
17	 = 	1664	33.28%	
18	 = 	1658	33.16%	
19	 = 	1692	33.84%	
20	 = 	1519	30.38%	

In [ ]: