In [ ]:
import glob
import matplotlib
import matplotlib.pyplot as plt
# import seaborn as sns
%matplotlib inline
In [ ]:
import utils
import metrics_utils
In [ ]:
matplotlib.rcParams.update({'font.size': 15})
In [ ]:
is_save = True
figsize = metrics_utils.get_figsize(is_save)
In [ ]:
## Define what to plot
criterion = ['l2', 'mean']
retrieve_list = [['l2', 'mean'], ['l2', 'std']]
legend_base_regexs = [
('Lasso' , '../estimated/mnist/full-input/gaussian/0.1/', '/lasso/*'),
('VAE' , '../estimated/mnist/full-input/gaussian/0.1/', '/vae/0.0_1.0_0.0_adam_0.01_0.9_False_1000_10'),
('VAE+Reg', '../estimated/mnist/full-input/gaussian/0.1/', '/vae/0.0_1.0_0.1_adam_0.01_0.9_False_1000_10')
]
save_path = '../results/mnist_reconstr_l2.pdf'
## Plot
plt.figure(figsize=figsize)
legends = []
for legend, base, regex in legend_base_regexs:
metrics_utils.plot(base, regex, criterion, retrieve_list)
legends.append(legend)
## Prettify
# axis
plt.gca().set_ylim(bottom=0)
plt.gca().set_xscale("log", nonposx='clip')
plt.gca().set_xlim([9, 800])
# labels, ticks, titles
ticks = [10, 25, 50, 100, 200, 300, 400, 500, 750]
labels = [10, 25, 50, 100, 200, 300, 400, 500, 750]
plt.xticks(ticks, labels, rotation=90)
plt.ylabel('Reconstruction error (per pixel)')
plt.xlabel('Number of measurements')
# Legends
plt.legend(legends, fontsize=12.5)
# Saving
utils.save_plot(is_save, save_path)
In [ ]:
## Define what to plot
criterion = ['l2', 'mean']
retrieve_list = [['l2', 'mean'], ['l2', 'std']]
legend_base_regexs = [
('Lasso' , '../estimated/mnist/full-input/gaussian/0.1/', '/lasso/*'),
('VAE' , '../estimated/mnist/full-input/gaussian/0.1/', '/vae/0.0_1.0_0.0_adam_0.01_0.9_False_1000_10'),
('VAE+Reg' , '../estimated/mnist/full-input/gaussian/0.1/', '/vae/0.0_1.0_0.1_adam_0.01_0.9_False_1000_10'),
('Fixed A' , '../estimated/mnist/full-input/fixed/0.1/' , '/learned/50-200'),
('Learned A', '../estimated/mnist/full-input/learned/0.1/' , '/learned/50-200')
]
save_path = '../results/mnist_reconstr_l2_all.pdf'
## Plot
plt.figure(figsize=figsize)
legends = []
for legend, base, regex in legend_base_regexs:
metrics_utils.plot(base, regex, criterion, retrieve_list)
legends.append(legend)
## Prettify
# axis
plt.gca().set_ylim(bottom=0)
plt.gca().set_xscale("log", nonposx='clip')
plt.gca().set_xlim([9, 800])
# labels, ticks, titles
ticks = [10, 25, 50, 100, 200, 300, 400, 500, 750]
labels = [10, 25, 50, 100, 200, 300, 400, 500, 750]
plt.xticks(ticks, labels, rotation=90)
plt.ylabel('Reconstruction error (per pixel)')
plt.xlabel('Number of measurements')
# Legends
plt.legend(legends, fontsize=10.0)
# Saving
utils.save_plot(is_save, save_path)
In [ ]:
## Define what to plot
criterion = ['l2', 'mean']
retrieve_list = [['l2', 'mean'], ['l2', 'std']]
legend_base_regexs = [
('From test set' , '../estimated/mnist/full-input/gaussian/0.1/', '/vae/0.0_1.0_0.1_adam_0.01_0.9_False_1000_10'),
('From generator' , '../estimated/mnist/gen-span/gaussian/0.1/' , '/vae/0.0_1.0_0.1_adam_0.01_0.9_False_1000_10'),
]
save_path = '../results/mnist_gen_range.pdf'
## Plot
plt.figure(figsize=figsize)
legends = []
for legend, base, regex in legend_base_regexs:
metrics_utils.plot(base, regex, criterion, retrieve_list)
legends.append(legend)
## Prettify
# axis
plt.gca().set_ylim(bottom=0)
plt.gca().set_xscale("log", nonposx='clip')
plt.gca().set_xlim([9, 800])
# labels, ticks, titles
ticks = [10, 25, 50, 100, 200, 300, 400, 500, 750]
labels = [10, 25, 50, 100, 200, 300, 400, 500, 750]
plt.xticks(ticks, labels, rotation=90)
plt.ylabel('Reconstruction error (per pixel)')
plt.xlabel('Number of measurements')
# Legend
plt.legend(legends, fontsize=12.5)
# Saving
utils.save_plot(is_save, save_path)
In [ ]:
## Define what to plot
criterion = ['l2', 'mean']
retrieve_list = [['l2', 'mean'], ['l2', 'std']]
legend_base_regexs = [
('Lasso, m=500' , '../estimated/mnist/full-input/gaussian/', '/500/lasso/0.1'),
('VAE, m=100' , '../estimated/mnist/full-input/gaussian/', '/100/vae/0.0_1.0_0.1_adam_0.01_0.9_False_1000_2'),
('VAE, m=500' , '../estimated/mnist/full-input/gaussian/', '/500/vae/0.0_1.0_0.1_adam_0.01_0.9_False_1000_2'),
]
save_path = '../results/mnist_noise_l2.pdf'
## Plot
plt.figure(figsize=figsize)
legends = []
for legend, base, regex in legend_base_regexs:
metrics_utils.plot(base, regex, criterion, retrieve_list)
legends.append(legend)
## Prettify
# axis
plt.gca().set_ylim(bottom=0)
plt.gca().set_xscale("log", nonposx='clip')
# labels, ticks, titles
plt.ylabel('Reconstruction error (per pixel)')
plt.xlabel('Standard deviation of noise')
# Legend
plt.legend(legends, fontsize=12.5, loc=2)
# Saving
utils.save_plot(is_save, save_path)
In [ ]:
## Define what to plot
criterion = ['l2', 'mean']
retrieve_list = [['l2', 'mean'], ['l2', 'std']]
base, tail = '../estimated/mnist/full-input/gaussian/0.1/', '/vae/'
other_hparams = set([a.split('/')[-1] for a in glob.glob(base + '*' + tail + '*')])
legend_base_regexs = []
for hparam in other_hparams:
legend_base_regexs.append((hparam, base, tail + hparam))
## Plot
plt.figure(figsize=figsize)
legends = []
for legend, base, regex in legend_base_regexs:
metrics_utils.plot(base, regex, criterion, retrieve_list)
legends.append(legend)
## Prettify
# axis
plt.gca().set_ylim(bottom=0)
plt.gca().set_xscale("log", nonposx='clip')
plt.gca().set_xlim([9.5, 510])
# labels, ticks, titles
ticks = [10, 25, 50, 100, 200, 300, 400, 500]
labels = [10, 25, 50, 100, 200, 300, 400, 500]
plt.xticks(ticks, labels, rotation=90)
plt.ylabel('Reconstruction error (per pixel)')
plt.xlabel('Number of measurements')
# Legend
plt.legend(legends, fontsize=8)
In [ ]:
## Define what to plot
criterion = ['l2', 'mean']
retrieve_list = [['l2', 'mean'], ['l2', 'std']]
legend_base_regexs = [
('Lasso (DCT)' , '../estimated/celebA/full-input/gaussian/0.01/', '/lasso-dct/*'),
('Lasso (Wavelet)' , '../estimated/celebA/full-input/gaussian/0.01/', '/lasso-wavelet/1e-05'),
('DCGAN' , '../estimated/celebA/full-input/gaussian/0.01/', '/dcgan/0.0_1.0_0.0_0.0_0.0_adam_0.1_0.9_False_500_2'),
('DCGAN+Reg' , '../estimated/celebA/full-input/gaussian/0.01/', '/dcgan/0.0_1.0_0.001_0.0_0.0_adam_0.1_0.9_False_500_10'),
]
save_path = '../results/celebA_reconstr_l2.pdf'
## Plot
figsize=[6, 4.05]
# figsize=[12, 8.1]
plt.figure(figsize=figsize)
legends = []
for legend, base, regex in legend_base_regexs:
metrics_utils.plot(base, regex, criterion, retrieve_list)
legends.append(legend)
## Prettify
# axis
plt.gca().set_ylim(bottom=0)
plt.gca().set_xscale("log", nonposx='clip')
plt.gca().set_xlim([19, 11000])
# labels, ticks, titles
ticks = [20, 50, 100, 200, 500, 1000, 2500, 5000, 7500, 10000]
labels = [20, 50, 100, 200, 500, 1000, 2500, 5000, 7500, 10000]
plt.xticks(ticks, labels, rotation=90)
plt.xticks(ticks, labels)
plt.ylabel('Reconstruction error (per pixel)')
plt.xlabel('Number of measurements')
# Legends
# plt.legend(legends, fontsize=12.5)
plt.legend(legends, fontsize=8)
# Saving
utils.save_plot(is_save, save_path)
In [ ]:
## Define what to plot
criterion = ['l2', 'mean']
retrieve_list = [['l2', 'mean'], ['l2', 'std']]
legend_base_regexs = [
('From test set' , '../estimated/celebA/full-input/gaussian/0.01/', '/dcgan/0.0_1.0_0.001_0.0_0.0_adam_0.1_0.9_False_500_10'),
('From generator', '../estimated/celebA/gen-span/gaussian/0.01/' , '/dcgan/0.0_1.0_0.001_0.0_0.0_adam_0.1_0.9_False_500_1'),
]
save_path = '../results/celebA_gen_range.pdf'
## Plot
plt.figure(figsize=figsize)
legends = []
for legend, base, regex in legend_base_regexs:
metrics_utils.plot(base, regex, criterion, retrieve_list)
legends.append(legend)
## Prettify
# axis
plt.gca().set_ylim(bottom=0)
plt.gca().set_xscale("log", nonposx='clip')
plt.gca().set_xlim([19, 2600])
# labels, ticks, titles
ticks = [20, 50, 100, 200, 500, 1000, 2500]
labels = [20, 50, 100, 200, 500, 1000, 2500]
plt.xticks(ticks, labels, rotation=90)
plt.xticks(ticks, labels)
plt.ylabel('Reconstruction error (per pixel)')
plt.xlabel('Number of measurements')
# Legends
plt.legend(legends, fontsize=12.5)
# Saving
utils.save_plot(is_save, save_path)
In [ ]:
## Define what to plot
criterion = ['l2', 'mean']
retrieve_list = [['l2', 'mean'], ['l2', 'std']]
legend_base_regexs = [
('Lasso (DCT), m=2500' , '../estimated/celebA/full-input/gaussian/', '/2500/lasso-dct/*'),
('Lasso (Wavelet), m=2500' , '../estimated/celebA/full-input/gaussian/', '/2500/lasso-wavelet/0.1'),
('DCGAN, m=2500' , '../estimated/celebA/full-input/gaussian/', '/2500/dcgan/0.0_1.0_0.001_0.0_0.0_adam_0.1_0.9_False_5000_1'),
]
save_path = '../results/celebA_noise_l2.pdf'
## Plot
plt.figure(figsize=figsize)
legends = []
for legend, base, regex in legend_base_regexs:
metrics_utils.plot(base, regex, criterion, retrieve_list)
legends.append(legend)
## Prettify
# axis
plt.gca().set_ylim(bottom=0)
plt.gca().set_xscale("log", nonposx='clip')
# labels, ticks, titles
plt.ylabel('Reconstruction error (per pixel)')
plt.xlabel('Standard deviation of noise')
# Legend
plt.legend(legends, fontsize=12.5, loc=2)
# Saving
utils.save_plot(is_save, save_path)
In [ ]:
## Define what to plot
criterion = ['l2', 'mean']
retrieve_list = [['l2', 'mean'], ['l2', 'std']]
legend_base_regexs = [
('k-sparse-wavelet, m=12288' , '../estimated/celebA/full-input/project/0.0/12288/k-sparse-wavelet/', '')
]
# save_path = '../results/celebA_noise_l2.pdf'
## Plot
plt.figure(figsize=[12, 8])
legends = []
for legend, base, regex in legend_base_regexs:
metrics_utils.plot(base, regex, criterion, retrieve_list)
legends.append(legend)
## Prettify
# axis
plt.gca().set_ylim(bottom=0)
plt.gca().set_xscale("log", nonposx='clip')
plt.gca().set_xlim([40, 13000])
# labels, ticks, titles
ticks = [50, 100, 200, 500, 1000, 2500, 5000, 7500, 10000, 12288]
labels = [50, 100, 200, 500, 1000, 2500, 5000, 7500, 10000, 12288]
plt.xticks(ticks, labels, rotation=90)
plt.xticks(ticks, labels)
plt.ylabel('Reconstruction error (per pixel)')
plt.xlabel('k')
# Legend
plt.legend(legends, fontsize=12.5, loc=1)
# # Saving
# utils.save_plot(is_save, save_path)
In [ ]:
## Define what to plot
criterion = ['l2', 'mean']
retrieve_list = [['l2', 'mean'], ['l2', 'std']]
base, tail = '../estimated/celebA/full-input/gaussian/0.01/', '/dcgan/'
other_hparams = set([a.split('/')[-1] for a in glob.glob(base + '*' + tail + '*')])
legend_base_regexs = []
for hparam in other_hparams:
legend_base_regexs.append((hparam, base, tail + hparam))
## Plot
plt.figure(figsize=figsize)
legends = []
for legend, base, regex in legend_base_regexs:
metrics_utils.plot(base, regex, criterion, retrieve_list)
legends.append(legend)
## Prettify
# axis
plt.gca().set_ylim(bottom=0)
plt.gca().set_xscale("log", nonposx='clip')
plt.gca().set_xlim([19, 11000])
# # labels, ticks, titles
ticks = [20, 50, 100, 200, 500, 1000, 2500, 5000, 7500, 10000]
labels = [20, 50, 100, 200, 500, 1000, 2500, 5000, 7500, 10000]
plt.xticks(ticks, labels, rotation=90)
plt.xticks(ticks, labels)
plt.ylabel('Reconstruction error (per pixel)')
plt.xlabel('Number of measurements')
# Legends
plt.legend(legends, fontsize=8)