In [1]:
from scipy.io import loadmat
import os,sys,glob
import numpy as np
MAINFOLDER = '/data/ml2/rahul/fw-inference/RBM/'
TCNAME = 'rbm_20';N=40
TCNAME = 'rbm_21';N=42
#TCNAME = 'rbm_22';N=44
comparison_idx = range(2*N)[1::2]
EXACTDIR = MAINFOLDER+'/mat_out'
folder = {}
plotname = 'Exact MAP $\mathcal{M}_{\delta}$'
folder[plotname] = MAINFOLDER + '/resultsSpanning_MAPsolver/TRW/'
plotname = 'Approx MAP $\mathcal{M}_{\delta}$'
folder[plotname] = MAINFOLDER + '/resultsSpanningapprox_approxMAPsolver/TRW/'
plotname = 'TRBP'
folder[plotname] = MAINFOLDER+ '/TRWBP_opt/'
#Function to add two vectors
def add_vec(v1,v2):
assert len(v1.shape)==1 and len(v2.shape)==1,'v1 and v2 must be of dimension 1'
diff = np.abs(len(v2)-len(v1))
if len(v1)==len(v2):
return v1+v2
elif len(v1)<len(v2):
return np.append(v1,np.zeros(diff,)+v1[-1])+v2
elif len(v2)<len(v1):
return np.append(v2,np.zeros(diff,)+v2[-1])+v1
else:
assert False,'This should not happen'
def append_vec(v1,v2):
if v1 == None:
return v2
elif v2== None:
return v1
else:
return np.append(v1.ravel(),v2.ravel())
In [2]:
#Initialize L1 & LogZ
plot_data_l1 = {}
plot_data_logz={}
for plotname in folder:
plot_data_l1[plotname] = np.zeros(1,)
plot_data_logz[plotname] = np.zeros(1,)
for plotname in folder:
print "Processing : ",plotname
print "Looking for files: ",folder[plotname]+'/*-sp0.mat'
ctr = 1
#Handle TRBP
if plotname=='TRBP':
logz_err,l1_err = 0,0
files = []
for f in glob.glob(folder[plotname]+TCNAME+'*.mat'):
files.append(f)
if len(files)>1:
assert False,'Should only be one TRWBP file, investigate'
f = files[0]
mat = loadmat(f)
exact = loadmat(EXACTDIR+'/'+TCNAME+'.mat')
l1_err += np.abs(mat['trwbp_opt_marg'].ravel()[comparison_idx]-exact['exact_node_marginals'].ravel()[comparison_idx]).mean()
logz_err+= np.abs(mat['trwbp_opt_logz'].ravel()[0]-exact['log_partition'].ravel()[0])
plot_data_l1[plotname] = np.ones(10,)*l1_err
plot_data_logz[plotname] = np.ones(10,)*logz_err
continue
#Handle FW
logz_err = []
l1_err = []
for rhoit in range(10):
f_rho = folder[plotname] + TCNAME +'-sp'+str(rhoit)+'.mat'
print f_rho
exact = loadmat(EXACTDIR+'/'+TCNAME+'.mat')
mat = loadmat(f_rho)
#L1 marginals
marginals = mat['IterSet'][:,comparison_idx]
numIts = marginals.shape[0]
dim = marginals.shape[1]
exact_marginals = exact['exact_node_marginals'].ravel()[comparison_idx].reshape(1,dim).repeat(numIts,0)
print np.abs(marginals-exact_marginals).mean(1)[-1]
l1_err.append(np.abs(marginals-exact_marginals).mean(1)[-1])
#Log Z
if 'gap_full' not in mat:
dual_gap = mat['Dual_Gap']
else:
dual_gap = mat['gap_full']
if 'Approx' in plotname:
dual_gap *= 0
logz_err.append(np.abs((-1*mat['Obj_Val']+dual_gap).ravel()-exact['log_partition'][0][0])[-1])
plot_data_l1[plotname] = np.array(l1_err)
plot_data_logz[plotname] = np.array(logz_err)
In [3]:
import matplotlib as mpl
mpl.rcParams['lines.linewidth']=4.5
mpl.rcParams['lines.markersize']=12
mpl.rcParams['text.usetex']=True
mpl.rcParams['text.latex.unicode']=True
mpl.rcParams['font.family'] = 'serif'
mpl.rcParams['font.serif'] = 'Times New Roman'
mpl.rcParams['text.latex.preamble']= '\usepackage{amsfonts}'
mpl.rcParams['font.size'] = 40
mpl.rcParams['axes.labelsize']=40
mpl.rcParams['legend.fontsize']=33
colors = ['r','b','g','k','c']
markers= ['*','v','<','>','|']
figure(1,figsize=(10,10))
for idx,plottype in enumerate(folder):
L = plot_data_l1[plottype].ravel().shape[0]
plot(np.arange(1,L+1),plot_data_l1[plottype],label=plottype,color=colors[idx],marker = markers[idx])
xlabel('Updates to $\\rho$')
ylim([0,0.5])
ylabel('Error in Marginals $(\zeta_{\mu})$')
legend(bbox_to_anchor=(0.9, 0.9),bbox_transform=plt.gcf().transFigure)
f2 = figure(2,figsize=(10,10))
for idx,plottype in enumerate(folder):
L = plot_data_logz[plottype].ravel().shape[0]
plot(np.arange(1,L+1),plot_data_logz[plottype],label=plottype,color=colors[idx],marker = markers[idx])
xlabel('Updates to $\\rho$')
yscale('log')
ylabel('Error in LogZ $(\zeta_{\mathrm{\log Z}})$')
ylim([1,10**4])
legend()#bbox_to_anchor=(0.9, 0.6),bbox_transform=plt.gcf().transFigure)
if os.path.exists('./plots/'+TCNAME+'-l1.pdf'):
os.remove('./plots/'+TCNAME+'-l1.pdf')
if os.path.exists('./plots/'+TCNAME+'-logz.pdf'):
os.remove('./plots/'+TCNAME+'-logz.pdf')
figure(1)
plt.savefig('./plots/'+TCNAME+'-l1.pdf',bbox_inches='tight')
figure(2)
plt.savefig('./plots/'+TCNAME+'-logz.pdf',bbox_inches='tight')