metrics



In [ ]:
import glob
import matplotlib
import matplotlib.pyplot as plt
# import seaborn as sns
%matplotlib inline

In [ ]:
import utils
import metrics_utils

In [ ]:
matplotlib.rcParams.update({'font.size': 15})

In [ ]:
is_save = True
figsize = metrics_utils.get_figsize(is_save)

MNIST plots

Recontruction from gaussian measurements


In [ ]:
## Define what to plot
criterion = ['l2', 'mean']
retrieve_list = [['l2', 'mean'], ['l2', 'std']]
legend_base_regexs = [
    ('Lasso'  , '../estimated/mnist/full-input/gaussian/0.1/', '/lasso/*'),
    ('VAE'    , '../estimated/mnist/full-input/gaussian/0.1/', '/vae/0.0_1.0_0.0_adam_0.01_0.9_False_1000_10'),
    ('VAE+Reg', '../estimated/mnist/full-input/gaussian/0.1/', '/vae/0.0_1.0_0.1_adam_0.01_0.9_False_1000_10')
]
save_path = '../results/mnist_reconstr_l2.pdf'


## Plot
plt.figure(figsize=figsize)
legends = []
for legend, base, regex in legend_base_regexs:
    metrics_utils.plot(base, regex, criterion, retrieve_list)
    legends.append(legend)


## Prettify
# axis
plt.gca().set_ylim(bottom=0)
plt.gca().set_xscale("log", nonposx='clip')
plt.gca().set_xlim([9, 800])

# labels, ticks, titles
ticks = [10, 25, 50, 100, 200, 300, 400, 500, 750]
labels = [10, 25, 50, 100, 200, 300, 400, 500, 750]
plt.xticks(ticks, labels, rotation=90)
plt.ylabel('Reconstruction error (per pixel)')
plt.xlabel('Number of measurements')

# Legends
plt.legend(legends, fontsize=12.5)

# Saving
utils.save_plot(is_save, save_path)

Reconstruction : All models


In [ ]:
## Define what to plot
criterion = ['l2', 'mean']
retrieve_list = [['l2', 'mean'], ['l2', 'std']]
legend_base_regexs = [
    ('Lasso'    , '../estimated/mnist/full-input/gaussian/0.1/', '/lasso/*'),
    ('VAE'      , '../estimated/mnist/full-input/gaussian/0.1/', '/vae/0.0_1.0_0.0_adam_0.01_0.9_False_1000_10'),
    ('VAE+Reg'  , '../estimated/mnist/full-input/gaussian/0.1/', '/vae/0.0_1.0_0.1_adam_0.01_0.9_False_1000_10'),
    ('Fixed A'  , '../estimated/mnist/full-input/fixed/0.1/'   , '/learned/50-200'),
    ('Learned A', '../estimated/mnist/full-input/learned/0.1/' , '/learned/50-200')
    
]
save_path = '../results/mnist_reconstr_l2_all.pdf'


## Plot
plt.figure(figsize=figsize)
legends = []
for legend, base, regex in legend_base_regexs:
    metrics_utils.plot(base, regex, criterion, retrieve_list)
    legends.append(legend)


## Prettify
# axis
plt.gca().set_ylim(bottom=0)
plt.gca().set_xscale("log", nonposx='clip')
plt.gca().set_xlim([9, 800])

# labels, ticks, titles
ticks = [10, 25, 50, 100, 200, 300, 400, 500, 750]
labels = [10, 25, 50, 100, 200, 300, 400, 500, 750]
plt.xticks(ticks, labels, rotation=90)
plt.ylabel('Reconstruction error (per pixel)')
plt.xlabel('Number of measurements')

# Legends
plt.legend(legends, fontsize=10.0)

# Saving
utils.save_plot(is_save, save_path)

Sensing images from the range of the generator


In [ ]:
## Define what to plot
criterion = ['l2', 'mean']
retrieve_list = [['l2', 'mean'], ['l2', 'std']]
legend_base_regexs = [
    ('From test set'  , '../estimated/mnist/full-input/gaussian/0.1/', '/vae/0.0_1.0_0.1_adam_0.01_0.9_False_1000_10'),
    ('From generator' , '../estimated/mnist/gen-span/gaussian/0.1/'  , '/vae/0.0_1.0_0.1_adam_0.01_0.9_False_1000_10'),
]
save_path = '../results/mnist_gen_range.pdf'


## Plot
plt.figure(figsize=figsize)
legends = []
for legend, base, regex in legend_base_regexs:
    metrics_utils.plot(base, regex, criterion, retrieve_list)
    legends.append(legend)


## Prettify

# axis
plt.gca().set_ylim(bottom=0)
plt.gca().set_xscale("log", nonposx='clip')
plt.gca().set_xlim([9, 800])

# labels, ticks, titles
ticks = [10, 25, 50, 100, 200, 300, 400, 500, 750]
labels = [10, 25, 50, 100, 200, 300, 400, 500, 750]
plt.xticks(ticks, labels, rotation=90)
plt.ylabel('Reconstruction error (per pixel)')
plt.xlabel('Number of measurements')

# Legend
plt.legend(legends, fontsize=12.5)

# Saving
utils.save_plot(is_save, save_path)

Noise tolerance


In [ ]:
## Define what to plot
criterion = ['l2', 'mean']
retrieve_list = [['l2', 'mean'], ['l2', 'std']]
legend_base_regexs = [
    ('Lasso, m=500' , '../estimated/mnist/full-input/gaussian/', '/500/lasso/0.1'),
    ('VAE, m=100'   , '../estimated/mnist/full-input/gaussian/', '/100/vae/0.0_1.0_0.1_adam_0.01_0.9_False_1000_2'),
    ('VAE, m=500'   , '../estimated/mnist/full-input/gaussian/', '/500/vae/0.0_1.0_0.1_adam_0.01_0.9_False_1000_2'),
]
save_path = '../results/mnist_noise_l2.pdf'


## Plot
plt.figure(figsize=figsize)
legends = []
for legend, base, regex in legend_base_regexs:
    metrics_utils.plot(base, regex, criterion, retrieve_list)
    legends.append(legend)
    

## Prettify

# axis
plt.gca().set_ylim(bottom=0)
plt.gca().set_xscale("log", nonposx='clip')

# labels, ticks, titles
plt.ylabel('Reconstruction error (per pixel)')
plt.xlabel('Standard deviation of noise')

# Legend
plt.legend(legends, fontsize=12.5, loc=2)

# Saving
utils.save_plot(is_save, save_path)

Compare hyperparamter settings


In [ ]:
## Define what to plot
criterion = ['l2', 'mean']
retrieve_list = [['l2', 'mean'], ['l2', 'std']]
base, tail = '../estimated/mnist/full-input/gaussian/0.1/', '/vae/'
other_hparams = set([a.split('/')[-1] for a in glob.glob(base + '*' + tail + '*')])
legend_base_regexs = []
for hparam in other_hparams:
    legend_base_regexs.append((hparam, base, tail + hparam))

## Plot
plt.figure(figsize=figsize)
legends = []
for legend, base, regex in legend_base_regexs:
    metrics_utils.plot(base, regex, criterion, retrieve_list)
    legends.append(legend)


## Prettify

# axis
plt.gca().set_ylim(bottom=0)
plt.gca().set_xscale("log", nonposx='clip')
plt.gca().set_xlim([9.5, 510])

# labels, ticks, titles
ticks = [10, 25, 50, 100, 200, 300, 400, 500]
labels = [10, 25, 50, 100, 200, 300, 400, 500]
plt.xticks(ticks, labels, rotation=90)
plt.ylabel('Reconstruction error (per pixel)')
plt.xlabel('Number of measurements')

# Legend
plt.legend(legends, fontsize=8)

celebA plots

Recontruction from gaussian measurements


In [ ]:
## Define what to plot
criterion = ['l2', 'mean']
retrieve_list = [['l2', 'mean'], ['l2', 'std']]
legend_base_regexs = [
    ('Lasso (DCT)'     , '../estimated/celebA/full-input/gaussian/0.01/', '/lasso-dct/*'),
    ('Lasso (Wavelet)' , '../estimated/celebA/full-input/gaussian/0.01/', '/lasso-wavelet/1e-05'),
    ('DCGAN'           , '../estimated/celebA/full-input/gaussian/0.01/', '/dcgan/0.0_1.0_0.0_0.0_0.0_adam_0.1_0.9_False_500_2'),
    ('DCGAN+Reg'       , '../estimated/celebA/full-input/gaussian/0.01/', '/dcgan/0.0_1.0_0.001_0.0_0.0_adam_0.1_0.9_False_500_10'),
]
save_path = '../results/celebA_reconstr_l2.pdf'

## Plot
figsize=[6, 4.05]
# figsize=[12, 8.1]
plt.figure(figsize=figsize)
legends = []
for legend, base, regex in legend_base_regexs:
    metrics_utils.plot(base, regex, criterion, retrieve_list)
    legends.append(legend)
    

## Prettify
# axis
plt.gca().set_ylim(bottom=0)
plt.gca().set_xscale("log", nonposx='clip')
plt.gca().set_xlim([19, 11000])

# labels, ticks, titles
ticks = [20, 50, 100, 200, 500, 1000, 2500, 5000, 7500, 10000]
labels = [20, 50, 100, 200, 500, 1000, 2500, 5000, 7500, 10000]
plt.xticks(ticks, labels, rotation=90)
plt.xticks(ticks, labels)
plt.ylabel('Reconstruction error (per pixel)')
plt.xlabel('Number of measurements')

# Legends
# plt.legend(legends, fontsize=12.5)
plt.legend(legends, fontsize=8)

# Saving
utils.save_plot(is_save, save_path)

Sensing images from the span of the generator


In [ ]:
## Define what to plot
criterion = ['l2', 'mean']
retrieve_list = [['l2', 'mean'], ['l2', 'std']]
legend_base_regexs = [
    ('From test set' , '../estimated/celebA/full-input/gaussian/0.01/', '/dcgan/0.0_1.0_0.001_0.0_0.0_adam_0.1_0.9_False_500_10'),
    ('From generator', '../estimated/celebA/gen-span/gaussian/0.01/'  , '/dcgan/0.0_1.0_0.001_0.0_0.0_adam_0.1_0.9_False_500_1'),
]
save_path = '../results/celebA_gen_range.pdf'

    
## Plot
plt.figure(figsize=figsize)
legends = []
for legend, base, regex in legend_base_regexs:
    metrics_utils.plot(base, regex, criterion, retrieve_list)
    legends.append(legend)
    

## Prettify
# axis
plt.gca().set_ylim(bottom=0)
plt.gca().set_xscale("log", nonposx='clip')
plt.gca().set_xlim([19, 2600])

# labels, ticks, titles
ticks = [20, 50, 100, 200, 500, 1000, 2500]
labels = [20, 50, 100, 200, 500, 1000, 2500]
plt.xticks(ticks, labels, rotation=90)
plt.xticks(ticks, labels)
plt.ylabel('Reconstruction error (per pixel)')
plt.xlabel('Number of measurements')

# Legends
plt.legend(legends, fontsize=12.5)

# Saving
utils.save_plot(is_save, save_path)

Noise tolerance


In [ ]:
## Define what to plot
criterion = ['l2', 'mean']
retrieve_list = [['l2', 'mean'], ['l2', 'std']]
legend_base_regexs = [
    ('Lasso (DCT), m=2500'     , '../estimated/celebA/full-input/gaussian/', '/2500/lasso-dct/*'),
    ('Lasso (Wavelet), m=2500' , '../estimated/celebA/full-input/gaussian/', '/2500/lasso-wavelet/0.1'),
    ('DCGAN, m=2500' , '../estimated/celebA/full-input/gaussian/', '/2500/dcgan/0.0_1.0_0.001_0.0_0.0_adam_0.1_0.9_False_5000_1'),
]
save_path = '../results/celebA_noise_l2.pdf'

## Plot
plt.figure(figsize=figsize)
legends = []
for legend, base, regex in legend_base_regexs:
    metrics_utils.plot(base, regex, criterion, retrieve_list)
    legends.append(legend) 
    

## Prettify

# axis
plt.gca().set_ylim(bottom=0)
plt.gca().set_xscale("log", nonposx='clip')

# labels, ticks, titles
plt.ylabel('Reconstruction error (per pixel)')
plt.xlabel('Standard deviation of noise')

# Legend
plt.legend(legends, fontsize=12.5, loc=2)

# Saving
utils.save_plot(is_save, save_path)

k-sparse-wavelet


In [ ]:
## Define what to plot
criterion = ['l2', 'mean']
retrieve_list = [['l2', 'mean'], ['l2', 'std']]
legend_base_regexs = [
    ('k-sparse-wavelet, m=12288' , '../estimated/celebA/full-input/project/0.0/12288/k-sparse-wavelet/', '')
]
# save_path = '../results/celebA_noise_l2.pdf'

## Plot
plt.figure(figsize=[12, 8])
legends = []
for legend, base, regex in legend_base_regexs:
    metrics_utils.plot(base, regex, criterion, retrieve_list)
    legends.append(legend) 
    

## Prettify

# axis
plt.gca().set_ylim(bottom=0)
plt.gca().set_xscale("log", nonposx='clip')
plt.gca().set_xlim([40, 13000])

# labels, ticks, titles
ticks = [50, 100, 200, 500, 1000, 2500, 5000, 7500, 10000, 12288]
labels = [50, 100, 200, 500, 1000, 2500, 5000, 7500, 10000, 12288]
plt.xticks(ticks, labels, rotation=90)
plt.xticks(ticks, labels)
plt.ylabel('Reconstruction error (per pixel)')
plt.xlabel('k')

# Legend
plt.legend(legends, fontsize=12.5, loc=1)

# # Saving
# utils.save_plot(is_save, save_path)

Compare hyperparameter settings


In [ ]:
## Define what to plot
criterion = ['l2', 'mean']
retrieve_list = [['l2', 'mean'], ['l2', 'std']]
base, tail = '../estimated/celebA/full-input/gaussian/0.01/', '/dcgan/'
other_hparams = set([a.split('/')[-1] for a in glob.glob(base + '*' + tail + '*')])
legend_base_regexs = []
for hparam in other_hparams:
    legend_base_regexs.append((hparam, base, tail + hparam))


## Plot
plt.figure(figsize=figsize)
legends = []
for legend, base, regex in legend_base_regexs:
    metrics_utils.plot(base, regex, criterion, retrieve_list)
    legends.append(legend)
    

## Prettify
# axis
plt.gca().set_ylim(bottom=0)
plt.gca().set_xscale("log", nonposx='clip')
plt.gca().set_xlim([19, 11000])

# # labels, ticks, titles
ticks = [20, 50, 100, 200, 500, 1000, 2500, 5000, 7500, 10000]
labels = [20, 50, 100, 200, 500, 1000, 2500, 5000, 7500, 10000]
plt.xticks(ticks, labels, rotation=90)
plt.xticks(ticks, labels)
plt.ylabel('Reconstruction error (per pixel)')
plt.xlabel('Number of measurements')

# Legends
plt.legend(legends, fontsize=8)