In [6]:
import numpy as np
from scipy.io import loadmat
import os,glob
#Utility functions
def cleanLine(s):
return s.strip().replace('(','').replace(')','').replace('{','').replace('}','').replace('x','')
def parseExactFile(fname,uai_matf):
all_lines = [s.strip() for s in open(fname).readlines()]
fac_idx_start = [idx for idx,s in enumerate(all_lines) if "Exact factor marginals" in s]
fac_idx_end = [idx for idx,s in enumerate(all_lines) if "Exact log partition sum" in s]
assert len(fac_idx_start)==1 and len(fac_idx_end)==1,'More than one in exact file'
fac_idx_start = fac_idx_start[0]
fac_idx_end = fac_idx_end[0]
logZ = float(all_lines[fac_idx_end].split(':')[1].strip())
marginals = []
edgeList = []
for idx in range(fac_idx_start+1,fac_idx_end):
data = cleanLine(all_lines[idx]).split(',')
if len(data)==3:
marginals=marginals+([float(k) for k in data[1:]])
elif len(data)==6:
marginals=marginals+([float(k) for k in data[2:]])
edgeList.append([int(k) for k in data[:2]])
else:
assert False,'Beyond edges not supported'
#check that the order of edgeList is what is expected
mat = loadmat(uai_matf)
edgeListMat = (mat['List'][:,:2]-1).astype(int)
assert np.abs(edgeListMat-edgeList).sum()<1e-15,'Edge Lists do not match up. marginal vector may be shuffled'
return logZ,np.array(marginals)
logZ,marginals = parseExactFile('/data/ml2/rahul/fw-inference/Synthetic/exact_out/1.out','/data/ml2/rahul/fw-inference/Synthetic/uai_mat/1.mat')
print logZ,marginals.shape
#Collect perturbMAP results
def parsePerturbMAP(matf):
mat = loadmat(matf)
return mat['perturbMAP_log_partition_ub'].ravel()[0],mat['perturbMAP_node_marginals_ub'].ravel()
logZ,marginals = parsePerturbMAP('/data/ml2/rahul/fw-inference/Synthetic/mat_out/1.mat')
#Get index of trial
def getIdx(trial):
if trial>=1 and trial<=100:
return 0
elif trial>=101 and trial<=200:
return 1
elif trial>=201 and trial<=300:
return 2
elif trial>=301 and trial<=400:
return 3
elif trial>=401 and trial<=500:
return 4
elif trial>=501 and trial<=600:
return 5
elif trial>=601 and trial<=700:
return 6
elif trial>=701 and trial<=800:
return 7
elif trial>=801 and trial<=900:
return 8
else:
return -1#error
def removeIfExists(fname):
if os.path.exists(fname):
os.remove(fname)
def createIfAbsent(dirname):
if not os.path.exists(dirname):
os.mkdir(dirname)
In [7]:
#This looks at the plots under different metrics (for the synthetic instances). Particularly for LOCAL.
MAINFOLDER = '/data/ml2/rahul/fw-inference/Synthetic/'
EXACTFOLDER = MAINFOLDER.replace('_final','')+'/exact_out'
UAI_FOLDER = MAINFOLDER+'/uai_mat'
marg_exact = {}
logz_exact = {}
logz_pmap = {}
marg_pmap = {}
for tc in xrange(1,901):
fname = EXACTFOLDER+'/'+str(tc)+'.out'
fname_mat= UAI_FOLDER.replace('_final','')+'/'+str(tc)+'.mat'
logz,marginals = parseExactFile(fname,fname_mat)
marg_exact[tc] = marginals
logz_exact[tc] = logz
logz_pmap_val,marginals_pmap_val = parsePerturbMAP(fname_mat.replace('uai_mat','mat_out'))
logz_pmap[tc] = logz_pmap_val
marg_pmap[tc] = marginals_pmap_val
if tc%100==0:
print tc,logz
print "Compiled exact marginals and logz"
In [8]:
folder = {}
folder['$\mathbb{L}_{\delta}$'] =MAINFOLDER.replace('_final','')+'/'+'resultsSpanning_lpLOCAL'
folder['$\mathbb{L}_{\delta}$'+'$(\\rho_{\mathrm{opt}})$'] =MAINFOLDER.replace('_final','')+'/'+'resultsSpanning_lpLOCAL'
folder['$\mathcal{M}_{\delta}$'] = MAINFOLDER+'/'+'resultsSpanning_MAPsolver'
folder['$\mathcal{M}_{\delta}$'+'$(\\rho_{\mathrm{opt}})$'] =MAINFOLDER+'/'+'resultsSpanning_MAPsolver'
folder['perturbMAP'] =MAINFOLDER+'/'+'resultsSpanning_MAPsolver' #This is a bypass so that all trials are accounted resultsSpanning_MAPsolver
In [9]:
#Collect Data
X = np.array([0.5,1,2,3,4,5,6,7,8])
l1_ours = {}
logz_ours = {}
trialctr= {}
debug = {}
for plottype in folder:
l1_ours[plottype] = np.array([0.]*9)
logz_ours[plottype] = np.array([0.]*9)
trialctr[plottype] = np.array([0]*9)
debug[plottype] = np.array([0]*9)
N = 10
comparison_idx = range(2*N)[1::2]
#For ours, we need 4 plots
for plottype in folder:
ftype = 'TRW'
if 'opt' in plottype:
ftype += '/*-sp9.mat'
else:
ftype += '/*-sp0.mat'
print plottype,': looking for filetype: ',ftype
ctr = 1
for f in glob.glob(folder[plottype]+'/'+ftype):
if ctr%100==0:
print "Done: ",ctr
ctr+=1
trial = int(os.path.basename(f).split('-')[0])
idx = getIdx(trial)
#L1 in marginals & logz
l1_exact = marg_exact[trial]
logz_val = logz_exact[trial]
trialctr[plottype][idx] += 1
#Collecting results
if plottype=='perturbMAP':
logz_ours[plottype][idx] += np.abs(logz_pmap[trial]-logz_val)
l1_ours[plottype][idx] += np.abs(marg_pmap[trial][comparison_idx]-l1_exact[comparison_idx]).mean()
else:
mat = loadmat(f)
l1 = mat['IterSet'][-1,:]
l1_ours[plottype][idx] += np.abs(l1[comparison_idx]-l1_exact[comparison_idx]).mean()
#Logz
if 'gap_full' not in mat:
logz_ours[plottype][idx] += np.abs(-1*mat['Obj_Val'].ravel()[-1]+mat['Dual_Gap'].ravel()[-1]-logz_val)
else:
logz_ours[plottype][idx] += np.abs(-1*mat['Obj_Val'].ravel()[-1]+mat['gap_full'].ravel()[-1]-logz_val)
print "Processed: ",trialctr[plottype].sum()," trials"
print "Done collecting data"
for plottype in folder:
for idx in xrange(9):
l1_ours[plottype][idx]/=float(trialctr[plottype][idx])
logz_ours[plottype][idx]/=float(trialctr[plottype][idx])
print "Done normalizing data"
In [21]:
import matplotlib as mpl
mpl.rcParams['lines.linewidth']=4.5
mpl.rcParams['lines.markersize']=12
mpl.rcParams['text.usetex']=True
mpl.rcParams['text.latex.unicode']=True
mpl.rcParams['font.family'] = 'serif'
mpl.rcParams['font.serif'] = 'Times New Roman'
mpl.rcParams['text.latex.preamble']= '\usepackage{amsfonts}'
mpl.rcParams['font.size'] = 40
mpl.rcParams['axes.labelsize']=40
mpl.rcParams['legend.fontsize']=33
figure(1,figsize=(10,10))
colors = ['r','b','g','k','c']
markers= ['*','v','<','>','|']
for idx,plottype in enumerate(folder):
plot(X,l1_ours[plottype],label=plottype,color=colors[idx],marker = markers[idx])
xlabel('$\\theta$')
xticks(X,['0.5','1','2','3','4','5','6','7','8'])
ylim([0,0.8])
ylabel('Error in Marginals $(\zeta_{\mu})$')
legend(bbox_to_anchor=(0.9, 0.9),bbox_transform=plt.gcf().transFigure)
f2 = figure(2,figsize=(10,10))
for idx,plottype in enumerate(folder):
plot(X,logz_ours[plottype],label=plottype,color=colors[idx],marker = markers[idx])
xlabel('$\\theta$')
xticks(X,['0.5','1','2','3','4','5','6','7','8'])
yscale('log')
ylabel('Error in LogZ $(\zeta_{\mathrm{\log Z}})$')
legend(bbox_to_anchor=(0.9, 0.8),bbox_transform=plt.gcf().transFigure)
#Save to file
createIfAbsent('./plots')
removeIfExists('./plots/SyntheticL1VsTheta.pdf')
removeIfExists('./plots/SyntheticlogzVsTheta.pdf')
figure(1)
plt.savefig('./plots/SyntheticL1VsTheta.pdf',bbox_inches='tight')
figure(2)
plt.savefig('./plots/SyntheticlogzVsTheta.pdf',bbox_inches='tight')
In [ ]: