Check Gaussian Criterion


In [1]:
require 'criteria/KLDCriterion'
require 'criteria/GaussianCriterion'
require 'utils/Sampler'

In [2]:
kld = nn.KLDCriterion
gauss = nn.GaussianCriterion
sampler = nn.Sampler()

In [3]:
mu = torch.randn(3,4)
lv = torch.randn(3,4):pow(2):log()
pmu = torch.zeros(3,4)
plv = torch.zeros(3,4)

In [4]:
kld:forward({pmu, plv}, {mu, lv})


Out[4]:
10.784102064081	

In [35]:
-- D(posterior || prior) = E[log posterior - log prior]
log_posterior = 0
log_prior = 0
N = 10
for i = 1,N do
    code = sampler({mu, lv})
    log_posterior = log_posterior + gauss:forward({mu, lv}, code)/N
    log_prior = log_prior + gauss:forward({pmu, plv}, code)/N
end
print(log_posterior - log_prior)


Out[35]:
10.790013963543	

In [6]:
zero_mu = torch.zeros(1)
zero_logv = torch.zeros(1)
code = torch.zeros(1)
print(gauss:forward({zero_mu, zero_logv}, code))


Out[6]:
-0.91893853320467	

Check Gradient


In [7]:
mu = torch.randn(3,4)
lv = torch.randn(3,4):pow(2):log()
pmu = torch.zeros(3,4)
plv = torch.zeros(3,4)
code = sampler({mu, lv})
mu = torch.randn(3,4)
lv = torch.randn(3,4):pow(2):log()
h = 1e-5
dmu, dlv = unpack(gauss:backward({mu, lv}, code))

In [8]:
for i = 1,mu:size(1) do
    for j = 1,mu:size(2) do
        mu[{i, j}] = mu[{i, j}] + h
        fph = gauss:forward({mu, lv}, code)
        mu[{i, j}] = mu[{i, j}] - h - h
        fmp = gauss:forward({mu, lv}, code)
        mu[{i, j}] = mu[{i, j}] + h
        print((fph - fmp)/2/h - dmu[{i, j}])
    end
end


Out[8]:
-1.0892229340698e-09	
1.0344436418563e-09	
1.1293965762604e-11	
-4.725304592057e-10	
6.4559291246269e-10	
-1.1422819357065e-09	
1.1262029087078e-09	
-1.5933434571735e-09	
3.1108386977508e-09	
-1.7533992036078e-09	
5.0998139045078e-10	
-7.0252781370073e-10	

In [9]:
for i = 1,lv:size(1) do
    for j = 1,lv:size(2) do
        lv[{i, j}] = lv[{i, j}] + h
        fph = gauss:forward({mu, lv}, code)
        lv[{i, j}] = lv[{i, j}] - h - h
        fmp = gauss:forward({mu, lv}, code)
        lv[{i, j}] = lv[{i, j}] + h
        print((fph - fmp)/2/h - dlv[{i, j}])
    end
end


Out[9]:
-4.5784531721438e-10	
2.8162148169031e-09	
-7.4037234676361e-10	
-8.7991369746021e-10	
-6.3426597307625e-11	
8.5557852469442e-10	
-1.2577909269673e-09	
8.821449126728e-11	
-4.4890757777694e-11	
-1.2077188138448e-09	
9.8273034154772e-10	
-6.3421907725569e-09