In [ ]:
from utils_io import read_json, load_items
params = read_json('parameters.json')
RESIZE_X = params['resize']['x']
RESIZE_Y = params['resize']['y']
ITEM_FOLDER = params['item_folder']
items = load_items(ITEM_FOLDER)
hl = [] # histogram list
In [ ]:
views = ['top_01','top-side_01','top-side_02','bottom_01','bottom-side_01','bottom-side_02']
for item in items:
for view in views:
try:
filename = ITEM_FOLDER + '/' + item + '/' + item + '_' + view + '_dc.json'
dc = read_json(filename)
hist = dc['hist']
obj_cc = dc['cluster_centers']
hl.append( (item, view, hist, obj_cc) )
except IOError:
pass
In [ ]:
import numpy as np
ivdm = np.zeros( (len(hl),len(hl)) ) # Item/view distance matrix
In [ ]:
from utils_color import calc_EMD2
for i, (it1, vi1, hi1, cc1) in enumerate(hl):
for j, (it2, vi2, hi2, cc2) in enumerate(hl):
if j>i:
ivdm[i][j] = calc_EMD2(hi1, cc1, hi2, cc2)
ivdm[j][i] = ivdm[i][j]
In [ ]:
from matplotlib import pyplot as plt
%matplotlib inline
In [ ]:
plt.imshow(ivdm,cmap='jet');
In [ ]:
n = len(items)
idm = np.ones( (n, n) ) * 1000 # item distance matrix
In [ ]:
n = len(hl)
for i in range(n):
for j in range(i+1, n):
iti = items.index(hl[i][0])
itj = items.index(hl[j][0])
if ivdm[i][j] < idm[iti][itj]:
idm[iti][itj] = ivdm[i][j]
idm[itj][iti] = idm[iti][itj]
In [ ]:
plt.imshow(idm,cmap='jet');
In [ ]:
def plot_distance(item, d):
idx = items.index(item)
plt.plot(idm[idx],'b-'); plt.plot(idm[idx],'bo'); plt.title(items[idx]); plt.plot([0,40],[d,d],'r-'); plt.show();
for d, it in sorted([(dist, it) for it, dist in zip(items, idm[idx]) if dist < d]):
print('%f %s' % (d, it))
In [ ]:
from ipywidgets import interact
interact(plot_distance, item=items, d=20);
In [ ]:
from sklearn.cluster import AffinityPropagation
In [ ]:
af = AffinityPropagation(affinity='precomputed', damping=0.5, verbose=True).fit(-idm)
In [ ]:
cluster_centers_indices = af.cluster_centers_indices_
labels = af.labels_
In [ ]:
for idx, kls in enumerate(cluster_centers_indices):
print(items[kls])
for it in [it for it,lb in zip(items,labels) if lb==idx]:
if it != items[kls]:
print(' %s' % it)
print
In [ ]:
def print_cluster(item):
label = labels[items.index(item)]
cluster = [it for it,lb in zip(items,labels) if lb==label]
for item in cluster:
print(item)
interact(print_cluster, item=[items[i] for i in cluster_centers_indices]);
In [ ]:
from utils_io import imread_rgb
import cv2
def show_cluster(item, view):
label = labels[items.index(item)]
cluster = [it for it,lb in zip(items,labels) if lb==label]
pos = 1
for item in cluster:
filename = ITEM_FOLDER + '/' + item + '/' + item + '_' + view + '.png'
image = imread_rgb(filename)
image = cv2.resize(image,(RESIZE_X,RESIZE_Y))
plt.subplot(130+pos); plt.imshow(image); plt.axis('off');
pos += 1
if pos==4:
plt.show();
pos = 1
In [ ]:
from ipywidgets import interact
interact(show_cluster, item=[items[i] for i in cluster_centers_indices], view=views);
In [ ]: