In [1]:
require 'nn';
require 'torch';

Prepair data


In [54]:
local matio = require 'matio'
data = matio.load('ex4data1.mat')

In [55]:
trainset = {}
trainset.data = data.X
trainset.label = data.y[{ {}, 1}]

In [56]:
trainset

In [57]:
setmetatable(trainset,
    {__index = function(t,i)
                return {t.data[i], t.label[i]}
        end}
);
 
function trainset:size()
    return self.data:size(1)
end


Out[57]:
{
  data : DoubleTensor - size: 5000x400
  label : DoubleTensor - size: 5000
}

In [58]:
mean = {}
stdv = {}
for i=1,400 do
    mean[i] = trainset.data[{ {},{i} }]:mean()
    stdv[i] = trainset.data[{ {}, {i} }]:std()
    --print(i .. 'th mean: ' .. mean[i])
    --print(i .. 'th std dev: ' .. stdv[i])
    trainset.data[{ {},{i} }]:add(-mean[i])
    if stdv[i] ~= 0 then
        trainset.data[{ {},{i} }]:div(stdv[i])
    end
end

define model


In [59]:
net = nn.Sequential()
net:add(nn.Linear(400,25))
net:add(nn.Sigmoid())
net:add(nn.Linear(25,10))
net:add(nn.Sigmoid())
net:add(nn.LogSoftMax())

define loss function


In [60]:
criterion = nn.ClassNLLCriterion()

let's train


In [61]:
trainer = nn.StochasticGradient(net, criterion)
trainer.learningRate = 0.001
trainer.maxIteration = 1e2

In [62]:
trainer:train(trainset)


Out[62]:
# StochasticGradient: training	
Out[62]:
# current error = 2.2919424613996	
Out[62]:
# current error = 2.2623607889839	
Out[62]:
# current error = 2.2320541713042	
Out[62]:
# current error = 2.2006777919079	
Out[62]:
# current error = 2.168539441064	
Out[62]:
# current error = 2.136266976882	
Out[62]:
# current error = 2.1045216386857	
Out[62]:
# current error = 2.0738772603513	
Out[62]:
# current error = 2.0447717893028	
Out[62]:
# current error = 2.0174868472812	
Out[62]:
# current error = 1.9921557315415	
Out[62]:
# current error = 1.9687918809252	
Out[62]:
# current error = 1.9473252182393	
Out[62]:
# current error = 1.9276363789713	
Out[62]:
# current error = 1.9095829761714	
Out[62]:
# current error = 1.8930167082486	
Out[62]:
# current error = 1.8777932931237	
Out[62]:
# current error = 1.8637776287851	
Out[62]:
# current error = 1.8508455714591	
Out[62]:
# current error = 1.8388839048255	
Out[62]:
# current error = 1.8277902435673	
Out[62]:
# current error = 1.8174728210088	
Out[62]:
# current error = 1.8078499691825	
Out[62]:
# current error = 1.7988494845912	
Out[62]:
# current error = 1.7904078209015	
Out[62]:
# current error = 1.7824692704477	
Out[62]:
# current error = 1.7749852483583	
Out[62]:
# current error = 1.7679136182669	
Out[62]:
# current error = 1.7612179979538	
Out[62]:
# current error = 1.7548670215456	
Out[62]:
# current error = 1.7488335768304	
Out[62]:
# current error = 1.7430940597637	
Out[62]:
# current error = 1.7376276846547	
Out[62]:
# current error = 1.7324158913333	
Out[62]:
# current error = 1.7274418922739	
Out[62]:
# current error = 1.7226903699712	
Out[62]:
# current error = 1.7181472968042	
Out[62]:
# current error = 1.7137998219794	
Out[62]:
# current error = 1.7096361358339	
Out[62]:
# current error = 1.7056452713351	
Out[62]:
# current error = 1.7018169467665	
Out[62]:
# current error = 1.6981415234959	
Out[62]:
# current error = 1.6946100062297	
Out[62]:
# current error = 1.6912140229238	
Out[62]:
# current error = 1.6879457858581	
Out[62]:
# current error = 1.6847980478246	
Out[62]:
# current error = 1.6817640592968	
Out[62]:
# current error = 1.6788375278242	
Out[62]:
# current error = 1.676012580153	
Out[62]:
# current error = 1.6732837276552	
Out[62]:
# current error = 1.6706458352558	
Out[62]:
# current error = 1.6680940935506	
Out[62]:
# current error = 1.6656239937012	
Out[62]:
# current error = 1.6632313047934	
Out[62]:
# current error = 1.6609120536788	
Out[62]:
# current error = 1.6586625073457	
Out[62]:
# current error = 1.6564791577077	
Out[62]:
# current error = 1.6543587085247	
Out[62]:
# current error = 1.6522980639652	
Out[62]:
# current error = 1.6502943181067	
Out[62]:
# current error = 1.6483447445889	
Out[62]:
# current error = 1.6464467855753	
Out[62]:
# current error = 1.6445980398386	
Out[62]:
# current error = 1.6427962509004	
Out[62]:
# current error = 1.6410392967841	
Out[62]:
# current error = 1.6393251817682	
Out[62]:
# current error = 1.6376520286444	
Out[62]:
# current error = 1.6360180698685	
Out[62]:
# current error = 1.6344216381803	
Out[62]:
# current error = 1.6328611589585	
Out[62]:
# current error = 1.6313351458557	
Out[62]:
# current error = 1.6298421989898	
Out[62]:
# current error = 1.6283810037427	
Out[62]:
# current error = 1.6269503286812	
Out[62]:
# current error = 1.6255490221372	
Out[62]:
# current error = 1.6241760076194	
Out[62]:
# current error = 1.6228302784771	
Out[62]:
# current error = 1.6215108921748	
Out[62]:
# current error = 1.6202169644662	
Out[62]:
# current error = 1.618947663716	
Out[62]:
# current error = 1.6177022055761	
Out[62]:
# current error = 1.6164798481998	
Out[62]:
# current error = 1.615279888089	
Out[62]:
# current error = 1.6141016566154	
Out[62]:
# current error = 1.6129445171429	
Out[62]:
# current error = 1.6118078626521	
Out[62]:
# current error = 1.6106911137369	
Out[62]:
# current error = 1.6095937168614	
Out[62]:
# current error = 1.6085151428036	
Out[62]:
# current error = 1.6074548852351	
Out[62]:
# current error = 1.6064124594164	
Out[62]:
# current error = 1.6053874010077	
Out[62]:
# current error = 1.6043792650534	
Out[62]:
# current error = 1.6033876251947	
Out[62]:
# current error = 1.6024120733069	
Out[62]:
# current error = 1.6014522198081	
Out[62]:
# current error = 1.6005076949361	
Out[62]:
# current error = 1.5995781510644	
Out[62]:
# current error = 1.5986632653485	
Out[62]:
# current error = 1.5977627407013	
# StochasticGradient: you have reached the maximum number of iterations	
# training error = 1.5977627407013	

test


In [63]:
correction = 0
for i=1,trainset:size() do
    local answer = trainset.label[i]
    local prediction = net:forward(trainset.data[i])
    local confidences, indices = torch.sort(prediction, true)
    if (answer == indices[1]) then
        correction = correction + 1
    end
end
print(correction, 100*correction/trainset:size() .. '%')


Out[63]:
4615	92.3%	

In [ ]: