require 'nn' require 'optim' require 'csvigo'
In [ ]:
loaded = csvigo.load('example-logistic-regression.csv')
In [ ]:
brands = torch.Tensor(loaded.brand)
females = torch.Tensor(loaded.female)
ages = torch.Tensor(loaded.age)
dataset_inputs = torch.Tensor( (#brands)[1],2 )
dataset_inputs[{ {},1 }] = females
dataset_inputs[{ {},2 }] = ages
dataset_outputs = brands
numberOfBrands = torch.max(dataset_outputs) - torch.min(dataset_outputs) + 1
In [ ]:
linLayer = nn.Linear(2,3)
softMaxLayer = nn.LogSoftMax()
model = nn.Sequential()
model:add(linLayer)
model:add(softMaxLayer)
In [ ]:
criterion = nn.ClassNLLCriterion()
In [ ]:
x, dl_dx = model:getParameters()
feval = function(x_new)
if x ~= x_new then
x:copy(x_new)
end
_nidx_ = (_nidx_ or 0) + 1
if _nidx_ > (#dataset_inputs)[1] then _nidx_ = 1 end
local inputs = dataset_inputs[_nidx_]
local target = dataset_outputs[_nidx_]
dl_dx:zero()
local loss_x = criterion:forward(model:forward(inputs), target)
model:backward(inputs, criterion:backward(model.output, target))
return loss_x, dl_dx
end
sgd_params = {
learningRate = 1e-3,
learningRateDecay = 1e-4,
weightDecay = 0,
momentum = 0
}
In [ ]:
epochs = 1e2 -- number of times to cycle over our training data
print('')
print('============================================================')
print('Training with SGD')
print('')
for i = 1,epochs do
current_loss = 0
for i = 1,(#dataset_inputs)[1] do
_,fs = optim.sgd(feval,x,sgd_params)
current_loss = current_loss + fs[1]
end
current_loss = current_loss / (#dataset_inputs)[1]
print('epoch = ' .. i ..
' of ' .. epochs ..
' current loss = ' .. current_loss)
end
In [ ]:
model:reset()
-- next we re-define the closure that evaluates f and df/dx, so that
-- it estimates the true f, and true (exact) df/dx, over the entire
-- dataset. This is a full batch approach.
feval = function(x_new)
-- set x to x_new, if differnt
-- (in this simple example, x_new will typically always point to x,
-- so the copy is really useless)
if x ~= x_new then
x:copy(x_new)
end
-- reset gradients (gradients are always accumulated, to accomodate
-- batch methods)
dl_dx:zero()
-- and batch over the whole training dataset:
local loss_x = 0
for i = 1,(#dataset_inputs)[1] do
-- select a new training sample
_nidx_ = (_nidx_ or 0) + 1
if _nidx_ > (#dataset_inputs)[1] then _nidx_ = 1 end
local inputs = dataset_inputs[_nidx_]
local target = dataset_outputs[_nidx_]
-- evaluate the loss function and its derivative wrt x, for that sample
loss_x = loss_x + criterion:forward(model:forward(inputs), target)
model:backward(inputs, criterion:backward(model.output, target))
end
-- normalize with batch size
loss_x = loss_x / (#dataset_inputs)[1]
dl_dx = dl_dx:div( (#dataset_inputs)[1] )
-- return loss(x) and dloss/dx
return loss_x, dl_dx
end
In [ ]:
lbfgs_params = {
lineSearch = optim.lswolfe,
maxIter = epochs,
verbose = true
}
print('')
print('============================================================')
print('Training with L-BFGS')
print('')
_,fs = optim.lbfgs(feval,x,lbfgs_params)
-- fs contains all the evaluations of f, during optimization
print('history of L-BFGS evaluations:')
print(fs)
In [ ]:
print('')
print('============================================================')
print('Testing the model')
print('')
-- Now that the model is trained, one can test it by evaluating it
-- on new samples.
-- The model constructed and trained above computes the probabilities
-- of each class given the input values.
-- We want to compare our model's results with those from the text.
-- The input variables have narrow ranges, so we just compare all possible
-- input variables in the training data.
-- Determine actual frequency of the each female-age pair in the
-- training data
-- return index of largest value
function maxIndex(a,b,c)
if a >=b and a >= c then return 1
elseif b >= a and b >= c then return 2
else return 3 end
end
-- return predicted brand and probabilities of each brand
-- for the model in the text
-- The R code in the text computes the probabilities of choosing
-- brands 2 and 3 relative to the probability of choosing brand 1:
-- Prob(brand=2)/prob(brand=1) = exp(-11.77 + 0.52*female + 0.37*age)
-- Prob(brand=3)/prob(brand=1) = exp(-22.72 + 0.47*female + 0.69*age)
function predictText(age, female)
-- 1: calculate the "logit's"
-- The coefficients come from the text.
-- If you download the R script and run it, you may see slightly
-- different results.
local logit1 = 0
local logit2 = -11.774655 + 0.523814 * female + 0.368206 * age
local logit3 = -22.721396 + 0.465941 * female + 0.685908 * age
-- 2: calculate the unnormalized probabilities
local uprob1 = math.exp(logit1)
local uprob2 = math.exp(logit2)
local uprob3 = math.exp(logit3)
-- 3: normalize the probabilities
local z = uprob1 + uprob2 + uprob3
local prob1 = (1/z) * uprob1
local prob2 = (1/z) * uprob2
local prob3 = (1/z) * uprob3
return maxIndex(prob1, prob2, prob3), prob1, prob2, prob3
end
-- return predicted brand and the probabilities of each brand
-- for our model
function predictOur(age, female)
local input = torch.Tensor(2)
input[1] = female -- must be in same order as when the model was trained!
input[2] = age
local logProbs = model:forward(input)
--print('predictOur', age, female, input)
local probs = torch.exp(logProbs)
--print('logProbs', logProbs)
--print('probs', probs[1], probs[2], probs[3] )
local prob1, prob2, prob3 = probs[1], probs[2], probs[3]
return maxIndex(prob1, prob2, prob3), prob1, prob2, prob3
end
counts = {}
function makeKey(age, brand, female)
-- return a string containing the values
-- Note that returning a table will not work, because each
-- table is unique.
-- Because Lua interns the strings, a string with a given sequence
-- of characters is stored only once.
return string.format('%2d%1d%1f', age, brand, female)
end
for i = 1,(#brands)[1] do
local brand = brands[i]
local female = females[i]
local age = ages[i]
local key = makeKey (age, brand, female)
counts[key] = (counts[key] or 0) + 1
end
-- return probability of each brand conditioned on age and female
function actualProbabilities(age, female)
function countOf(age, brand, female)
return counts[makeKey(age, brand, female)] or 0
end
local count1 = countOf(age, 1, female)
local count2 = countOf(age, 2, female)
local count3 = countOf(age, 3, female)
local sumCounts = count1 + count2 + count3
if sumCounts == 0 then
return 0, 0, 0
else
return count1/sumCounts, count2/sumCounts, count3/sumCounts
end
end
print(' ')
print('summary of data')
summarizeData()
print(' ')
print('training variables')
for k,v in pairs(sgd_params) do
print(string.format('%20s %f', k, v))
end
print(string.format('%20s %f', 'epochs', epochs))
print(' ')
print('current loss', current_loss)
-- print the headers
print(' ')
lineFormat = '%-6s %-3s| %-17s | %-17s | %-17s | %-1s %-1s %-1s'
print(
string.format(lineFormat,
'', '',
'actual probs', 'text probs', 'our probs',
'best', '', ''))
choices = 'brnd1 brnd2 brnd3'
print(string.format(lineFormat,
'female', 'age',
choices, choices, choices,
'a', 't', 'o'))
-- print each row in the table
function formatFemale(female)
return string.format('%1d', female)
end
function formatAge(age)
return string.format('%2d', age)
end
function formatProbs(p1, p2, p3)
return string.format('%5.3f %5.3f %5.3f', p1, p2, p3)
end
function indexString(p1, p2, p3)
-- return index of highest probability or '-' if nearly all zeroes
if p1 < 0.001 and p2 < 0.001 and p3 < 0.001 then
return '-'
else
return string.format('%1d', maxIndex(p1, p2, p3))
end
end
-- print table rows and accumulate accuracy
for female = 0,1 do
for age = torch.min(ages),torch.max(ages) do
-- calculate the actual probabilities in the training data
local actual1, actual2, actual3 = actualProbabilities(age, female)
-- calculate the prediction and probabilities using the model in the text
local textBrand, textProb1, textProb2, textProb3 =
predictText(age, female)
-- calculate the probabilities using the model we just trained
--print("main", age, female)
local ourBrand, ourProb1, ourProb2, ourProb3 =
predictOur(age, female)
print(
string.format(lineFormat,
formatFemale(female),
formatAge(age),
formatProbs(actual1, actual2, actual3),
formatProbs(textProb1, textProb2, textProb3),
formatProbs(ourProb1, ourProb2, ourProb3),
indexString(actual1,actual2,actual3),
indexString(textProb1,textProb2,textProb3),
indexString(ourProb1,ourProb2,ourProb3))
)
end
end
In [ ]: