In [ ]:
from utils_io import read_json, load_items

params = read_json('parameters.json')
RESIZE_X = params['resize']['x']
RESIZE_Y = params['resize']['y']
ITEM_FOLDER = params['item_folder']
items = load_items(ITEM_FOLDER)

hl = [] # histogram list

In [ ]:
views = ['top_01','top-side_01','top-side_02','bottom_01','bottom-side_01','bottom-side_02']
for item in items:
    for view in views:
        try:
            filename = ITEM_FOLDER + '/' + item + '/' + item + '_' + view + '_dc.json'
            dc = read_json(filename)
            hist = dc['hist']
            obj_cc = dc['cluster_centers']
            hl.append( (item, view, hist, obj_cc) )
        except IOError:
            pass

In [ ]:
import numpy as np

ivdm = np.zeros( (len(hl),len(hl)) ) # Item/view distance matrix

In [ ]:
from utils_color import calc_EMD2

for i, (it1, vi1, hi1, cc1) in enumerate(hl):
    for j, (it2, vi2, hi2, cc2) in enumerate(hl):
        if j>i:
            ivdm[i][j] = calc_EMD2(hi1, cc1, hi2, cc2)
            ivdm[j][i] = ivdm[i][j]

In [ ]:
from matplotlib import pyplot as plt
%matplotlib inline

In [ ]:
plt.imshow(ivdm,cmap='jet');

In [ ]:
n = len(items)
idm = np.ones( (n, n) ) * 1000 # item distance matrix

In [ ]:
n = len(hl)
for i in range(n):
    for j in range(i+1, n):
        iti = items.index(hl[i][0])
        itj = items.index(hl[j][0])
        if ivdm[i][j] < idm[iti][itj]:
            idm[iti][itj] = ivdm[i][j]
            idm[itj][iti] = idm[iti][itj]

In [ ]:
plt.imshow(idm,cmap='jet');

In [ ]:
def plot_distance(item, d):
    idx = items.index(item)
    plt.plot(idm[idx],'b-'); plt.plot(idm[idx],'bo'); plt.title(items[idx]); plt.plot([0,40],[d,d],'r-'); plt.show();
    for d, it in sorted([(dist, it) for it, dist in zip(items, idm[idx]) if dist < d]):
        print('%f %s' % (d, it))

In [ ]:
from ipywidgets import interact
interact(plot_distance, item=items, d=20);

In [ ]:
from sklearn.cluster import AffinityPropagation

In [ ]:
af = AffinityPropagation(affinity='precomputed', damping=0.5, verbose=True).fit(-idm)

In [ ]:
cluster_centers_indices = af.cluster_centers_indices_
labels = af.labels_

In [ ]:
for idx, kls in enumerate(cluster_centers_indices):
    print(items[kls])
    for it in [it for it,lb in zip(items,labels) if lb==idx]:
        if it != items[kls]:
            print('    %s' % it)
    print

In [ ]:
def print_cluster(item):
    label = labels[items.index(item)]
    cluster = [it for it,lb in zip(items,labels) if lb==label]
    for item in cluster:
        print(item)
        
interact(print_cluster, item=[items[i] for i in cluster_centers_indices]);

In [ ]:
from utils_io import imread_rgb
import cv2
def show_cluster(item, view):
    label = labels[items.index(item)]
    cluster = [it for it,lb in zip(items,labels) if lb==label]
    pos = 1
    for item in cluster:
        filename = ITEM_FOLDER + '/' + item + '/' + item + '_' + view + '.png'
        image = imread_rgb(filename)
        image = cv2.resize(image,(RESIZE_X,RESIZE_Y))
        plt.subplot(130+pos); plt.imshow(image); plt.axis('off');
        pos += 1
        if pos==4:
            plt.show();
            pos = 1

In [ ]:
from ipywidgets import interact
interact(show_cluster, item=[items[i] for i in cluster_centers_indices], view=views);

In [ ]: