In [51]:
%load_ext autoreload
%autoreload 2
import vislab._results
import os
import vislab.datasets
import vislab.dataset_viz
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
dirname = '/Users/sergeyk/work/aphrodite-writeup/figures/content_correlation2'
vislab.util.makedirs(dirname)


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Out[51]:
'/Users/sergeyk/work/aphrodite-writeup/figures/content_correlation2'

Flickr


In [53]:
print vislab.util.get_mongodb_client()['predict'].collection_names()
c = vislab.util.get_mongodb_client()['predict'][collection_name]
# if c.find({'features': 'noise'}).count() > 0:
#     c.remove({'features': 'noise'})
pd.DataFrame([x for x in c.find()])


[u'system.indexes', u'default', u'behance_dec28', u'behance_illustration_jan15', u'flickr_mar23', u'flickr_on_pinterest_80k_mar23', u'pascal_mar23', u'pascal_mc_mar23', u'pascal_mc_on_flickr_mar23', u'pascal_mc_on_pinterest_80k_mar23', u'pascal_mc_on_wikipaintings_mar23', u'pinterest_80k_mar23', u'pinterest_80k_on_flickr_mar23', u'wikipaintings_mar23']
Out[53]:
_id data features num_test num_train num_val quadratic results_name score_test score_val task
0 532fdf6e9f00136077eae900 flickr_metaclass_ALL [caffe_fc7] 16000 48000 16000 False data_flickr_metaclass_ALL_features_['caffe_fc7... 0.2504 0.245003 clf

1 rows × 11 columns


In [63]:
def plot_gt_pred_correlation(
        collection_name, pred_feat, dataset_name,
        pred_prefix, multiclass, label_df, color_anchor=[-1, 1]):
    # Load predictions.
    results_df, preds_panel = vislab._results.load_pred_results(
        collection_name, os.path.expanduser('~/work/vislab/data/shared/results'),
        multiclass=multiclass, force=False)

#     print preds_panel.items
#     print preds_panel.minor_axis
#     print preds_panel.major_axis

    # Transform to workable form.
    mc_pred_df = preds_panel.minor_xs(pred_feat)
    mc_pred_df.columns = [
        x.replace('clf {}_'.format(dataset_name), '')
        for x in mc_pred_df.columns
    ]
    mc_pred_df = mc_pred_df[[
        x for x in mc_pred_df.columns if x.startswith(pred_prefix)
    ]].astype(float)

    # Form correlation matrix: ground truth as rows.
    cdf = label_df.join(mc_pred_df > 0).astype(bool).corr()
    cdf = cdf.ix[label_df.columns][mc_pred_df.columns]

    cdf.columns = [x.replace('class_', '') for x in cdf.columns]
    cdf.columns = [x.replace('pred_meta', '') for x in cdf.columns]
    cdf.index = [x.replace('style_', '') for x in cdf.index]
    cdf.index = [x.replace('_', ' ') for x in cdf.index]
    cdf.index = [x if x != 'Geometric Composition' else 'Geometric' for x in cdf.index]
    cdf = cdf.T

    # Plot.
    fig = vislab.dataset_viz.plot_occurrence(
        cdf, color_anchor=color_anchor, cmap=plt.cm.RdBu_r, font_size=16,
        x_tick_rot=90)
    return fig

In [64]:
label_df = vislab.datasets.flickr.get_df()[vislab.datasets.flickr.underscored_style_names]
collection_name = 'pascal_mc_on_flickr_mar23'
dataset_name = 'flickr'
pred_feat = 'caffe_fc7 False vw'
pred_prefix = 'pred_'
multiclass = True

fig = plot_gt_pred_correlation(
    collection_name, pred_feat, dataset_name, pred_prefix,
    multiclass, label_df, [-.3, .3])
fig.savefig(dirname + '/pascal_on_flickr.pdf', bbox_inches='tight')


Loaded from cache: 1 records

In [3]:
collection_name = 'pascal_on_flickr_oct29'
pred_feat = 'decaf_fc6 False vw'
dataset_name = 'flickr'
pred_prefix = ''
multiclass= False
label_df = vislab.datasets.flickr.load_flickr_df()[vislab.datasets.flickr.underscored_style_names]

fig = plot_gt_pred_correlation(
    collection_name, pred_feat, dataset_name, pred_prefix,
    multiclass, label_df, [-.3, .3])
fig.savefig(dirname + '/pascal_on_flickr.pdf', bbox_inches='tight')


Loaded from cache: 10 records

In [61]:
collection_name = 'pascal_mc_on_wikipaintings_mar23'
pred_feat = 'caffe_fc7 False vw'
dataset_name = 'wikipaintings'
pred_prefix = 'pred_'
multiclass= True
label_df = vislab.datasets.wikipaintings.get_style_df()
label_df = label_df[[x for x in label_df.columns if not x.startswith('_')]]
del label_df['image_url']

fig = plot_gt_pred_correlation(
    collection_name, pred_feat, dataset_name, pred_prefix,
    multiclass, label_df, [-.3, .3])
# fig.savefig(dirname + '/pascal_on_wp.pdf', bbox_inches='tight')


Loaded from cache: 1 records

In [7]:
collection_name = 'wp_on_pascal_oct30'
pred_feat = 'decaf_fc6 False vw'
dataset_name = 'pascal'
pred_prefix = 'pred_'
multiclass= True
label_df = vislab.datasets.pascal.get_clf_df()
label_df = label_df[[x for x in label_df.columns if not x.startswith('_')]]

fig = plot_gt_pred_correlation(
    collection_name, pred_feat, dataset_name, pred_prefix,
    multiclass, label_df, [-.5, .5])
fig.savefig(dirname + '/wp_on_pascal.pdf', bbox_inches='tight')


Loaded from cache: 1 records

In [5]:
collection_name = 'flickr_on_pascal_oct30'
pred_feat = 'decaf_fc6 False vw'
dataset_name = 'pascal'
pred_prefix = 'pred_'
multiclass= True
label_df = vislab.datasets.pascal.get_clf_df()
label_df = label_df[[x for x in label_df.columns if not x.startswith('_')]]

fig = plot_gt_pred_correlation(
    collection_name, pred_feat, dataset_name, pred_prefix,
    multiclass, label_df, [-.3, .3])
fig.savefig(dirname + '/flickr_on_pascal.pdf', bbox_inches='tight')


Loaded from cache: 1 records