In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
In [45]:
import imp
import pumpp
import pickle
import keras as K
import numpy as np
import pandas as pd
from tqdm import tqdm_notebook as tqdm
import json
import jams
import sklearn.metrics as skm
In [61]:
import matplotlib.pyplot as plt
%matplotlib nbagg
import seaborn as sns
sns.set_style('whitegrid')
In [6]:
with open('/home/bmcfee/git/chord_models/data/pump.pkl', 'rb') as fd:
pump = pickle.load(fd)
In [7]:
model_cr1 = imp.load_source('train_model', '/home/bmcfee/git/chord_models/code/train_model.py')
In [8]:
model_cr2 = imp.load_source('train_deep', '/home/bmcfee/git/chord_models/code/train_deep.py')
In [12]:
WORKING = '/home/bmcfee/working/chords/'
In [13]:
AUGMENTATION = True
WEIGHTED = False
In [36]:
model_cr1.make_output_path(WORKING, True, AUGMENTATION, WEIGHTED)
Out[36]:
In [ ]:
# Load model weights for fold -- needs model, split id, output path
# Load test index for fold -- needs split id
# Iterate over points -- needs split id
# accumulate confusions
In [16]:
# First, CR1
In [26]:
def load_datum(test_id):
return np.load(os.path.join(WORKING, 'pump', '{}.npz'.format(test_id)))
In [120]:
def confusion(model, test_id, pump, structured):
x = load_datum(test_id)
ytrue = x['chord_tag/chord'][0].squeeze()
if structured:
ypred = model.predict(x['cqt/mag'])[0][0].argmax(axis=1)
else:
ypred = model.predict(x['cqt/mag'])[0].argmax(axis=1)
n = min(len(ypred), len(ytrue))
return skm.confusion_matrix(ytrue[:n], ypred[:n],
labels=range(len(pump['chord_tag'].encoder.classes_)))
In [162]:
# Build the model
def confuse_split(split, structured, pump, model_class):
# Locate the weights
path = model_class.make_output_path(WORKING, structured, AUGMENTATION, WEIGHTED)
conf_file = os.path.join(path, 'fold{:02d}_confusion.pkl'.format(split))
if os.path.exists(conf_file):
C = pickle.load(open(conf_file, 'rb'))
return np.sum(C, axis=0)
model = model_class.construct_model(pump, structured)[0]
weight_path = os.path.join(path, 'fold{:02d}_weights.pkl'.format(split))
model.load_weights(weight_path)
# Load the test index
test_idx = pd.read_csv(os.path.join(WORKING, 'test{:02d}.csv'.format(split)), header=None)[0]
vocab = pump['chord_tag'].encoder.classes_
C = [confusion(model, test_id, pump, structured) for test_id in tqdm(test_idx)]
# Save the confusion matrices
with open(os.path.join(path, 'fold{:02d}_confusion.pkl'.format(split)), 'wb') as fd:
pickle.dump(C, fd)
Ctotal = np.sum(C, axis=0)
return Ctotal
In [163]:
# CR2
C_CR2 = []
for split in tqdm(range(5)):
C_CR2.append(confuse_split(split, False, pump, model_cr2))
In [176]:
# CR2S
C_CR2S = []
for split in tqdm(range(5)):
C_CR2S.append(confuse_split(split, True, pump, model_cr2))
In [356]:
# Can we aggregate down to root confusion and quality confusion?
def conf_to_qconf(C):
# Qualities = 14 + 2
Q = np.zeros((16, 16), dtype=C.dtype)
for i in range(0, 12):
Q[:14, :14] += C[i * 14: (i +1)* 14,
i * 14: (i+1) * 14]
Q[:14, 14:] += C[i * 14:(i+1) * 14, -2:]
Q[14:, :14] += C[-2:, i * 14:(i+1) * 14]
Q[14:, 14:] = C[-2:, -2:]
return Q
In [357]:
# What about comparisons across roots?
# For each root i, sum over all j != i
def conf_to_xconf(C):
# Qualities = 14 + 2
Q = np.zeros((16, 16), dtype=C.dtype)
for j in range(0, 12):
for i in range(0, 12):
if j == i:
continue
Q[:14, :14] += C[j * 14: j * 14 + 14,
i * 14: i * 14 + 14]
Q[:14, 14:] += C[j * 14:j * 14 + 14, -2:]
Q[14:, :14] += C[-2:, i * 14:i * 14 + 14]
Q[14:, 14:] = C[-2:, -2:]
return Q
In [358]:
# What about root confusions?
def conf_to_rconf(C):
Q = np.zeros((14, 14), dtype=C.dtype)
# C[i*12:i*12+14] == truth has root i
# C[:, j*12:j*12+14] == prediction has root j
for i in range(0, 12):
for j in range(0, 12):
Q[i, j] = np.sum(C[i*14:i*14+14, j*14:j*14+14])
Q[i, -2:] = np.sum(C[i*14:i*14+14, -2:], axis=0)
Q[-2:, i] = np.sum(C[-2:, i*14:i*14+14], axis=1)
Q[-2:, -2:] = C[-2:, -2:]
return Q
In [359]:
T_CR2 = np.sum(C_CR2, axis=0)
In [360]:
T_CR2.diagonal().sum() / T_CR2.sum()
Out[360]:
In [361]:
T_CR2s = np.sum(C_CR2S, axis=0)
In [362]:
T_CR2s.diagonal().sum() / T_CR2s.sum()
Out[362]:
In [363]:
T_CR2s.diagonal().sum() / T_CR2s.sum() - T_CR2.diagonal().sum() / T_CR2.sum()
Out[363]:
In [364]:
def norm(x, axis=1):
return x / x.sum(axis=axis, keepdims=True)
In [365]:
Q_CR2 = conf_to_qconf(np.sum(C_CR2, axis=0))
In [366]:
X_CR2 = conf_to_xconf(np.sum(C_CR2, axis=0))
In [367]:
R_CR2 = conf_to_rconf(np.sum(C_CR2, axis=0))
In [368]:
Q_CR2S = conf_to_qconf(np.sum(C_CR2S, axis=0))
In [369]:
X_CR2S = conf_to_xconf(np.sum(C_CR2S, axis=0))
In [370]:
R_CR2S = conf_to_rconf(np.sum(C_CR2S, axis=0))
In [371]:
vocab = pump['chord_tag'].encoder.inverse_transform(np.arange(170))
In [372]:
qvocab = [s[s.index(':')+1:] for s in vocab[:14]] + ['N', 'X']
In [373]:
list(enumerate(qvocab))
Out[373]:
In [411]:
idx = [8, 5, 2, 1, 9, 6, 10, 11, 7, 0, 3, 4, 12, 13, 14, 15]
In [412]:
qv = [qvocab[i] for i in idx]
In [413]:
qv
Out[413]:
In [414]:
SAVE = True
In [415]:
plt.figure(figsize=(6,5))
mesh = plt.pcolormesh(norm(Q_CR2S[idx][:, idx]), cmap='magma_r', vmin=0, vmax=1)
plt.grid('off')
# Major-minor
plt.axvline(2, color='w', alpha=0.95)
plt.axhline(2, color='w', alpha=0.95)
# Triads
plt.axvline(4, color='w', alpha=0.95)
plt.axhline(4, color='w', alpha=0.95)
# 6ths
plt.axvline(6, color='w', alpha=0.95)
plt.axhline(6, color='w', alpha=0.95)
# sus
plt.axvline(12, color='w', alpha=0.95)
plt.axhline(12, color='w', alpha=0.95)
# Partition NX
plt.axvline(14, color='w', alpha=0.95)
plt.axhline(14, color='w', alpha=0.95)
plt.xticks(np.arange(0.5, 16+0.5), qv, rotation=90)
plt.yticks(np.arange(0.5, 16+0.5), qv)
plt.gca().invert_yaxis()
plt.xlabel('Estimate')
plt.ylabel('Reference')
#plt.title('Quality confusions: CR2+S+A')
plt.colorbar()
plt.tight_layout()
mesh.set_rasterized(True)
if SAVE:
plt.savefig('../paper/figs/qualconf.pdf', transparent=True, pad_inches=0)
!convert ../paper/figs/qualconf.pdf ../paper/figs/qualconf.eps
In [416]:
root_vocab = ['A#', 'A', 'B', 'C#', 'C', 'D#', 'D', 'E', 'F#', 'F', 'G#', 'G', 'N', 'X']
In [417]:
list(enumerate(root_vocab))
Out[417]:
In [418]:
ridx = [4, 3, 6, 5, 7, 9, 8, 11, 10, 1, 0, 2, 12, 13]
In [419]:
len(ridx)
Out[419]:
In [420]:
rv = [root_vocab[i] for i in ridx]
In [422]:
plt.figure(figsize=(6,5))
mesh = plt.pcolormesh(norm(R_CR2S[ridx][:, ridx]), cmap='magma_r', vmin=0, vmax=1)
plt.grid('off')
plt.xticks(np.arange(0.5, 14+0.5), rv)
plt.yticks(np.arange(0.5, 14+0.5), rv)
plt.gca().invert_yaxis()
plt.xlabel('Estimate')
plt.ylabel('Reference')
#plt.title('Quality confusions: CR2+S+A')
plt.colorbar()
plt.tight_layout()
mesh.set_rasterized(True)
if SAVE:
plt.savefig('../paper/figs/rconf.pdf', transparent=True, pad_inches=0)
!convert ../paper/figs/rconf.pdf ../paper/figs/rconf.eps
In [423]:
plt.figure(figsize=(6,5))
delta = norm(Q_CR2S[idx][:, idx]) - norm(Q_CR2[idx][:, idx])
mesh = plt.pcolormesh(delta, cmap='RdBu_r', vmin=-np.max(np.abs(delta)), vmax=np.max(np.abs(delta)))
plt.grid('off')
# Major-minor
plt.axvline(2, color='w', alpha=0.95)
plt.axhline(2, color='w', alpha=0.95)
# Triads
plt.axvline(4, color='w', alpha=0.95)
plt.axhline(4, color='w', alpha=0.95)
# 6ths
plt.axvline(6, color='w', alpha=0.95)
plt.axhline(6, color='w', alpha=0.95)
# sus
plt.axvline(12, color='w', alpha=0.95)
plt.axhline(12, color='w', alpha=0.95)
# Partition NX
plt.axvline(14, color='w', alpha=0.95)
plt.axhline(14, color='w', alpha=0.95)
plt.xticks(np.arange(0.5, 16+0.5), qv, rotation=90)
plt.yticks(np.arange(0.5, 16+0.5), qv)
plt.gca().invert_yaxis()
plt.xlabel('Estimate')
plt.ylabel('Reference')
#plt.title('Structured - Unstructured')
plt.colorbar()
plt.tight_layout()
mesh.set_rasterized(True)
if SAVE:
plt.savefig('../paper/figs/confdelta.pdf', transparent=True, pad_inches=0)
!convert ../paper/figs/confdelta.pdf ../paper/figs/confdelta.eps
In [424]:
(2 * delta.diagonal().sum() - delta.sum())
Out[424]:
In [425]:
plt.figure(figsize=(6,5))
mesh = plt.pcolormesh(norm(X_CR2S[idx][:, idx]), cmap='magma_r', vmin=0, vmax=1)
plt.grid('off')
# Major-minor
plt.axvline(2, color='w', alpha=0.95)
plt.axhline(2, color='w', alpha=0.95)
# Triads
plt.axvline(4, color='w', alpha=0.95)
plt.axhline(4, color='w', alpha=0.95)
# 6ths
plt.axvline(6, color='w', alpha=0.95)
plt.axhline(6, color='w', alpha=0.95)
# sus
plt.axvline(12, color='w', alpha=0.95)
plt.axhline(12, color='w', alpha=0.95)
# Partition NX
plt.axvline(14, color='w', alpha=0.95)
plt.axhline(14, color='w', alpha=0.95)
plt.xticks(np.arange(0.5, 16+0.5), qv, rotation=90)
plt.yticks(np.arange(0.5, 16+0.5), qv)
plt.gca().invert_yaxis()
plt.xlabel('Predicted')
plt.ylabel('True')
#plt.title('Quality confusions: CR2+S+A')
plt.colorbar()
plt.tight_layout()
mesh.set_rasterized(True)
if SAVE:
plt.savefig('../paper/figs/xconf.pdf', transparent=True, pad_inches=0)
!convert ../paper/figs/xconf.pdf ../paper/figs/xconf.eps
In [ ]: