In [1]:
require 'torch';
require 'nn';
require 'optim';

Prepair data


In [2]:
local matio = require 'matio'
data = matio.load('ex4data1.mat')

In [3]:
trainset = {}
trainset.data = data.X
trainset.label = data.y[{ {}, 1}]

In [4]:
trainset

In [5]:
setmetatable(trainset,
    {__index = function(t,i)
                return {t.data[i], t.label[i]}
        end}
);
 
function trainset:size()
    return self.data:size(1)
end

In [6]:
mean = {}
stdv = {}
for i=1,400 do
    mean[i] = trainset.data[{ {},{i} }]:mean()
    stdv[i] = trainset.data[{ {}, {i} }]:std()
    --print(i .. 'th mean: ' .. mean[i])
    --print(i .. 'th std dev: ' .. stdv[i])
    trainset.data[{ {},{i} }]:add(-mean[i])
    if stdv[i] ~= 0 then
        trainset.data[{ {},{i} }]:div(stdv[i])
    end
end

define model


In [7]:
n_train_data = trainset:size(1) -- number of training data
n_inputs = trainset:size(2)     -- number of cols = number of dims of input
n_outputs = 10   -- highest label = # of classes

In [8]:
net = nn.Sequential()
net:add(nn.Linear(400,25))
net:add(nn.Sigmoid())
net:add(nn.Linear(25,10))
net:add(nn.Sigmoid())
net:add(nn.LogSoftMax())

define loss function


In [9]:
opt = {
    optimization = 'sgd',
    batch_size = 5000,
    train_size = 5000,  -- set to 0 or 60000 to use all 60000 training data
    test_size = 0,      -- 0 means load all data
    epochs = 1e3,         -- **approximate** number of passes through the training data (see below for the `iterations` variable, which is calculated from this)
}         -- these options are used throughout

In [10]:
criterion = nn.ClassNLLCriterion()

In [11]:
parameters, gradParameters = net:getParameters()

In [12]:
counter = 0
feval = function(x)
  if x ~= parameters then
    parameters:copy(x)
  end

  gradParameters:zero()
    
  local batch_inputs = trainset.data[{{}, {}}]
  local batch_targets = trainset.label[{{}}]

  batch_outputs = net:forward(batch_inputs)
  batch_loss = criterion:forward(batch_outputs, batch_targets)
  dloss_doutput = criterion:backward(batch_outputs, batch_targets) 
  net:backward(batch_inputs, dloss_doutput)

  return batch_loss, gradParameters
end

train


In [13]:
optimState = {
    learningRate = 5,
    weightDecay = 0,
    momentum = 0,
    learningRateDecay = 1e-2
}
optimMethod = optim.sgd

In [14]:
losses = {}          -- training losses for each iteration/minibatch
epochs = opt.epochs  -- number of full passes over all the training data
iterations = epochs * math.ceil(n_train_data / opt.batch_size) -- integer number of minibatches to process

for i = 1, iterations do
  local _, minibatch_loss = optimMethod(feval, parameters, optimState)

  if i % 10 == 1 then -- don't print *every* iteration, this is enough to get the gist
      print(string.format("minibatches processed: %6s, loss = %6.6f", i, minibatch_loss[1]))
  end
  losses[#losses + 1] = minibatch_loss[1] -- append the new loss
end


Out[14]:
minibatches processed:      1, loss = 2.300654	
Out[14]:
minibatches processed:     11, loss = 2.027442	
Out[14]:
minibatches processed:     21, loss = 1.867546	
Out[14]:
minibatches processed:     31, loss = 1.787721	
Out[14]:
minibatches processed:     41, loss = 1.739668	
Out[14]:
minibatches processed:     51, loss = 1.707050	
Out[14]:
minibatches processed:     61, loss = 1.683614	
Out[14]:
minibatches processed:     71, loss = 1.666046	
Out[14]:
minibatches processed:     81, loss = 1.652381	
Out[14]:
minibatches processed:     91, loss = 1.641420	
Out[14]:
minibatches processed:    101, loss = 1.632400	
Out[14]:
minibatches processed:    111, loss = 1.624816	
Out[14]:
minibatches processed:    121, loss = 1.618327	
Out[14]:
minibatches processed:    131, loss = 1.612691	
Out[14]:
minibatches processed:    141, loss = 1.607736	
Out[14]:
minibatches processed:    151, loss = 1.603340	
Out[14]:
minibatches processed:    161, loss = 1.599405	
Out[14]:
minibatches processed:    171, loss = 1.595860	
Out[14]:
minibatches processed:    181, loss = 1.592646	
Out[14]:
minibatches processed:    191, loss = 1.589715	
Out[14]:
minibatches processed:    201, loss = 1.587031	
Out[14]:
minibatches processed:    211, loss = 1.584563	
Out[14]:
minibatches processed:    221, loss = 1.582286	
Out[14]:
minibatches processed:    231, loss = 1.580177	
Out[14]:
minibatches processed:    241, loss = 1.578218	
Out[14]:
minibatches processed:    251, loss = 1.576394	
Out[14]:
minibatches processed:    261, loss = 1.574689	
Out[14]:
minibatches processed:    271, loss = 1.573092	
Out[14]:
minibatches processed:    281, loss = 1.571592	
Out[14]:
minibatches processed:    291, loss = 1.570181	
Out[14]:
minibatches processed:    301, loss = 1.568850	
Out[14]:
minibatches processed:    311, loss = 1.567592	
Out[14]:
minibatches processed:    321, loss = 1.566401	
Out[14]:
minibatches processed:    331, loss = 1.565271	
Out[14]:
minibatches processed:    341, loss = 1.564198	
Out[14]:
minibatches processed:    351, loss = 1.563176	
Out[14]:
minibatches processed:    361, loss = 1.562201	
Out[14]:
minibatches processed:    371, loss = 1.561270	
Out[14]:
minibatches processed:    381, loss = 1.560380	
Out[14]:
minibatches processed:    391, loss = 1.559528	
Out[14]:
minibatches processed:    401, loss = 1.558711	
Out[14]:
minibatches processed:    411, loss = 1.557928	
Out[14]:
minibatches processed:    421, loss = 1.557176	
Out[14]:
minibatches processed:    431, loss = 1.556454	
Out[14]:
minibatches processed:    441, loss = 1.555760	
Out[14]:
minibatches processed:    451, loss = 1.555092	
Out[14]:
minibatches processed:    461, loss = 1.554449	
Out[14]:
minibatches processed:    471, loss = 1.553829	
Out[14]:
minibatches processed:    481, loss = 1.553231	
Out[14]:
minibatches processed:    491, loss = 1.552654	
Out[14]:
minibatches processed:    501, loss = 1.552096	
Out[14]:
minibatches processed:    511, loss = 1.551557	
Out[14]:
minibatches processed:    521, loss = 1.551035	
Out[14]:
minibatches processed:    531, loss = 1.550529	
Out[14]:
minibatches processed:    541, loss = 1.550040	
Out[14]:
minibatches processed:    551, loss = 1.549564	
Out[14]:
minibatches processed:    561, loss = 1.549103	
Out[14]:
minibatches processed:    571, loss = 1.548656	
Out[14]:
minibatches processed:    581, loss = 1.548221	
Out[14]:
minibatches processed:    591, loss = 1.547798	
Out[14]:
minibatches processed:    601, loss = 1.547387	
Out[14]:
minibatches processed:    611, loss = 1.546986	
Out[14]:
minibatches processed:    621, loss = 1.546597	
Out[14]:
minibatches processed:    631, loss = 1.546217	
Out[14]:
minibatches processed:    641, loss = 1.545848	
Out[14]:
minibatches processed:    651, loss = 1.545487	
Out[14]:
minibatches processed:    661, loss = 1.545136	
Out[14]:
minibatches processed:    671, loss = 1.544793	
Out[14]:
minibatches processed:    681, loss = 1.544458	
Out[14]:
minibatches processed:    691, loss = 1.544131	
Out[14]:
minibatches processed:    701, loss = 1.543812	
Out[14]:
minibatches processed:    711, loss = 1.543500	
Out[14]:
minibatches processed:    721, loss = 1.543195	
Out[14]:
minibatches processed:    731, loss = 1.542897	
Out[14]:
minibatches processed:    741, loss = 1.542605	
Out[14]:
minibatches processed:    751, loss = 1.542319	
Out[14]:
minibatches processed:    761, loss = 1.542040	
Out[14]:
minibatches processed:    771, loss = 1.541766	
Out[14]:
minibatches processed:    781, loss = 1.541498	
Out[14]:
minibatches processed:    791, loss = 1.541235	
Out[14]:
minibatches processed:    801, loss = 1.540978	
Out[14]:
minibatches processed:    811, loss = 1.540726	
Out[14]:
minibatches processed:    821, loss = 1.540478	
Out[14]:
minibatches processed:    831, loss = 1.540235	
Out[14]:
minibatches processed:    841, loss = 1.539997	
Out[14]:
minibatches processed:    851, loss = 1.539764	
Out[14]:
minibatches processed:    861, loss = 1.539534	
Out[14]:
minibatches processed:    871, loss = 1.539309	
Out[14]:
minibatches processed:    881, loss = 1.539088	
Out[14]:
minibatches processed:    891, loss = 1.538871	
Out[14]:
minibatches processed:    901, loss = 1.538657	
Out[14]:
minibatches processed:    911, loss = 1.538447	
Out[14]:
minibatches processed:    921, loss = 1.538241	
Out[14]:
minibatches processed:    931, loss = 1.538038	
Out[14]:
minibatches processed:    941, loss = 1.537839	
Out[14]:
minibatches processed:    951, loss = 1.537643	
Out[14]:
minibatches processed:    961, loss = 1.537450	
Out[14]:
minibatches processed:    971, loss = 1.537260	
Out[14]:
minibatches processed:    981, loss = 1.537074	
Out[14]:
minibatches processed:    991, loss = 1.536890	

test


In [16]:
correction = 0
for i=1001,1010 do
    local answer = trainset.label[i]
    local prediction = net:forward(trainset.data[i])
    print(answer, prediction)
    local confidences, indices = torch.sort(prediction, true)
    if (answer == indices[1]) then
        correction = correction + 1
    end
end
print(correction, 100*correction/trainset:size() .. '%')


Out[16]:
2	-2.4712
-1.4774
-2.4559
-2.4708
-2.4664
-2.4632
-2.4655
-2.4106
-2.4688
-2.4360
[torch.DoubleTensor of size 10]

Out[16]:
2	-2.5053
-1.5489
-2.4928
-2.5028
-2.5052
-2.5028
-2.1984
-2.5047
-2.5043
-2.2691
[torch.DoubleTensor of size 10]

Out[16]:
2	-2.5514
-1.5907
-1.9310
-2.5526
-2.5379
-2.5515
-2.5525
-2.3325
-2.5525
-2.4587
[torch.DoubleTensor of size 10]

Out[16]:
2	-2.5299
-1.5749
-2.4749
-2.5303
-2.5307
-2.5308
-1.9349
-2.4673
-2.5303
-2.5080
[torch.DoubleTensor of size 10]

Out[16]:
2	-2.4929
-1.5183
-2.2434
-2.5072
-2.5060
-2.5061
-2.5057
-2.2937
-2.5072
-2.4833
[torch.DoubleTensor of size 10]

Out[16]:
2	-2.4796
-1.4814
-2.3914
-2.4803
-2.4782
-2.4615
-2.4782
-2.4698
-2.4803
-2.3830
[torch.DoubleTensor of size 10]

Out[16]:
2	-2.4631
-1.4896
-2.4646
-2.4183
-2.4596
-2.3907
-2.4655
-2.4727
-2.4716
-2.4715
[torch.DoubleTensor of size 10]

Out[16]:
2	-2.5108
-1.5131
-2.1363
-2.5116
-2.4921
-2.5052
-2.5111
-2.4692
-2.5116
-2.4307
[torch.DoubleTensor of size 10]

Out[16]:
2	-2.5308
-1.5626
-2.3020
-2.5318
-2.5312
-2.5293
-2.0228
-2.5299
-2.5318
-2.5196
[torch.DoubleTensor of size 10]

Out[16]:
2	-2.4518
-1.5067
-2.4919
-2.5024
-2.5023
-2.1448
-2.4976
-2.5010
-2.5026
-2.4956
[torch.DoubleTensor of size 10]

10	0.2%	

In [ ]: