In [1]:
require 'BKLDCriterion'
require 'KLDCriterion'
bkld = nn.BKLDCriterion()
kld = nn.KLDCriterion()
In [2]:
mu = torch.randn(3,4)
lv = torch.randn(3,4):pow(2):log()
pmu = torch.zeros(3,4)
plv = torch.zeros(3,4)
In [3]:
print(bkld:forward(mu, lv))
print(kld:forward({pmu, plv}, {mu, lv})) -- make sure you use the right one as "target"!
Out[3]:
In [4]:
mu = torch.randn(3,4)
lv = torch.randn(3,4):pow(2):log()
pmu = torch.randn(3,4)
plv = torch.randn(3,4):pow(2):log()
In [5]:
dpmu, dplv, dmu, dlv = unpack(kld:backward({pmu, plv}, {mu, lv}))
In [6]:
h = 1e-5
In [7]:
for i = 1,mu:size(1) do
for j = 1,mu:size(2) do
mu[{i, j}] = mu[{i, j}] + h
fph = kld:forward({pmu, plv}, {mu, lv})
mu[{i, j}] = mu[{i, j}] - h - h
fmp = kld:forward({pmu, plv}, {mu, lv})
mu[{i, j}] = mu[{i, j}] + h
print((fph - fmp)/2/h - dmu[{i, j}])
end
end
Out[7]:
In [8]:
for i = 1,lv:size(1) do
for j = 1,lv:size(2) do
lv[{i, j}] = lv[{i, j}] + h
fph = kld:forward({pmu, plv}, {mu, lv})
lv[{i, j}] = lv[{i, j}] - h - h
fmp = kld:forward({pmu, plv}, {mu, lv})
lv[{i, j}] = lv[{i, j}] + h
print((fph - fmp)/2/h - dlv[{i, j}])
end
end
Out[8]:
In [9]:
for i = 1,pmu:size(1) do
for j = 1,pmu:size(2) do
pmu[{i, j}] = pmu[{i, j}] + h
fph = kld:forward({pmu, plv}, {mu, lv})
pmu[{i, j}] = pmu[{i, j}] - h - h
fmp = kld:forward({pmu, plv}, {mu, lv})
pmu[{i, j}] = pmu[{i, j}] + h
print((fph - fmp)/2/h - dpmu[{i, j}])
end
end
Out[9]:
In [10]:
for i = 1,plv:size(1) do
for j = 1,plv:size(2) do
plv[{i, j}] = plv[{i, j}] + h
fph = kld:forward({pmu, plv}, {mu, lv})
plv[{i, j}] = plv[{i, j}] - h - h
fmp = kld:forward({pmu, plv}, {mu, lv})
plv[{i, j}] = plv[{i, j}] + h
print((fph - fmp)/2/h - dplv[{i, j}])
end
end
Out[10]:
In [ ]:
-- check my implementation is correct
-- create mu's
mu1 = torch.randn(3,1)
mu2 = torch.randn(3,1)
In [ ]:
-- create sigma
sig = torch.diag(torch.randn(3):pow(2))
sig_lin = torch.diag(sig):reshape(3,1)
In [ ]:
-- first approach
(mu1 - mu2):t() * (torch.inverse(sig)) * (mu1 - mu2)
In [ ]:
-- second approach
torch.Tensor(1):fill((mu1 - mu2):pow(2):cdiv(sig_lin):sum())
In [ ]:
x = torch.Tensor(1):fill(3)
In [ ]:
x:add(3):add(x)