In [51]:
%load_ext autoreload
%autoreload 2
import vislab._results
import os
import vislab.datasets
import vislab.dataset_viz
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
dirname = '/Users/sergeyk/work/aphrodite-writeup/figures/content_correlation2'
vislab.util.makedirs(dirname)
Out[51]:
In [53]:
print vislab.util.get_mongodb_client()['predict'].collection_names()
c = vislab.util.get_mongodb_client()['predict'][collection_name]
# if c.find({'features': 'noise'}).count() > 0:
# c.remove({'features': 'noise'})
pd.DataFrame([x for x in c.find()])
Out[53]:
In [63]:
def plot_gt_pred_correlation(
collection_name, pred_feat, dataset_name,
pred_prefix, multiclass, label_df, color_anchor=[-1, 1]):
# Load predictions.
results_df, preds_panel = vislab._results.load_pred_results(
collection_name, os.path.expanduser('~/work/vislab/data/shared/results'),
multiclass=multiclass, force=False)
# print preds_panel.items
# print preds_panel.minor_axis
# print preds_panel.major_axis
# Transform to workable form.
mc_pred_df = preds_panel.minor_xs(pred_feat)
mc_pred_df.columns = [
x.replace('clf {}_'.format(dataset_name), '')
for x in mc_pred_df.columns
]
mc_pred_df = mc_pred_df[[
x for x in mc_pred_df.columns if x.startswith(pred_prefix)
]].astype(float)
# Form correlation matrix: ground truth as rows.
cdf = label_df.join(mc_pred_df > 0).astype(bool).corr()
cdf = cdf.ix[label_df.columns][mc_pred_df.columns]
cdf.columns = [x.replace('class_', '') for x in cdf.columns]
cdf.columns = [x.replace('pred_meta', '') for x in cdf.columns]
cdf.index = [x.replace('style_', '') for x in cdf.index]
cdf.index = [x.replace('_', ' ') for x in cdf.index]
cdf.index = [x if x != 'Geometric Composition' else 'Geometric' for x in cdf.index]
cdf = cdf.T
# Plot.
fig = vislab.dataset_viz.plot_occurrence(
cdf, color_anchor=color_anchor, cmap=plt.cm.RdBu_r, font_size=16,
x_tick_rot=90)
return fig
In [64]:
label_df = vislab.datasets.flickr.get_df()[vislab.datasets.flickr.underscored_style_names]
collection_name = 'pascal_mc_on_flickr_mar23'
dataset_name = 'flickr'
pred_feat = 'caffe_fc7 False vw'
pred_prefix = 'pred_'
multiclass = True
fig = plot_gt_pred_correlation(
collection_name, pred_feat, dataset_name, pred_prefix,
multiclass, label_df, [-.3, .3])
fig.savefig(dirname + '/pascal_on_flickr.pdf', bbox_inches='tight')
In [3]:
collection_name = 'pascal_on_flickr_oct29'
pred_feat = 'decaf_fc6 False vw'
dataset_name = 'flickr'
pred_prefix = ''
multiclass= False
label_df = vislab.datasets.flickr.load_flickr_df()[vislab.datasets.flickr.underscored_style_names]
fig = plot_gt_pred_correlation(
collection_name, pred_feat, dataset_name, pred_prefix,
multiclass, label_df, [-.3, .3])
fig.savefig(dirname + '/pascal_on_flickr.pdf', bbox_inches='tight')
In [61]:
collection_name = 'pascal_mc_on_wikipaintings_mar23'
pred_feat = 'caffe_fc7 False vw'
dataset_name = 'wikipaintings'
pred_prefix = 'pred_'
multiclass= True
label_df = vislab.datasets.wikipaintings.get_style_df()
label_df = label_df[[x for x in label_df.columns if not x.startswith('_')]]
del label_df['image_url']
fig = plot_gt_pred_correlation(
collection_name, pred_feat, dataset_name, pred_prefix,
multiclass, label_df, [-.3, .3])
# fig.savefig(dirname + '/pascal_on_wp.pdf', bbox_inches='tight')
In [7]:
collection_name = 'wp_on_pascal_oct30'
pred_feat = 'decaf_fc6 False vw'
dataset_name = 'pascal'
pred_prefix = 'pred_'
multiclass= True
label_df = vislab.datasets.pascal.get_clf_df()
label_df = label_df[[x for x in label_df.columns if not x.startswith('_')]]
fig = plot_gt_pred_correlation(
collection_name, pred_feat, dataset_name, pred_prefix,
multiclass, label_df, [-.5, .5])
fig.savefig(dirname + '/wp_on_pascal.pdf', bbox_inches='tight')
In [5]:
collection_name = 'flickr_on_pascal_oct30'
pred_feat = 'decaf_fc6 False vw'
dataset_name = 'pascal'
pred_prefix = 'pred_'
multiclass= True
label_df = vislab.datasets.pascal.get_clf_df()
label_df = label_df[[x for x in label_df.columns if not x.startswith('_')]]
fig = plot_gt_pred_correlation(
collection_name, pred_feat, dataset_name, pred_prefix,
multiclass, label_df, [-.3, .3])
fig.savefig(dirname + '/flickr_on_pascal.pdf', bbox_inches='tight')