In [9]:
from keras import backend as K
from keras.callbacks import EarlyStopping
from keras.models import Sequential
from keras.layers import Activation, Conv1D, Dense, Dropout, Flatten, MaxPooling1D
from keras.wrappers.scikit_learn import KerasClassifier

from matplotlib import pyplot as plt
%matplotlib inline
import numpy as np

from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

from onset_detection.metrics import onset_metric
from onset_detection.read_data import read_data

In [2]:
active_datasets = {1, 2, 3, 4}
X_parts, y_parts, y_start_only_parts, ds_labels = read_data(active_datasets)


D:\Users\Michel\Documents\FH\module\8_IP6\git\onset_detection\read_data.py:149: UserWarning: No truth found for AR_Lick11_FN.wav, skipping file.
  warn('No truth found for ' + wav_file + ', skipping file.')
D:\Users\Michel\Documents\FH\module\8_IP6\git\onset_detection\read_data.py:149: UserWarning: No truth found for AR_Lick11_KN.wav, skipping file.
  warn('No truth found for ' + wav_file + ', skipping file.')
D:\Users\Michel\Documents\FH\module\8_IP6\git\onset_detection\read_data.py:149: UserWarning: No truth found for AR_Lick11_MN.wav, skipping file.
  warn('No truth found for ' + wav_file + ', skipping file.')
D:\Users\Michel\Documents\FH\module\8_IP6\git\onset_detection\read_data.py:151: UserWarning: Skipping non-wav file data\IDMT-SMT-GUITAR_V2\dataset2\audio\desktop.ini
  warn('Skipping non-wav file ' + path_to_wav)
D:\Users\Michel\Documents\FH\module\8_IP6\git\onset_detection\read_data.py:149: UserWarning: No truth found for FS_Lick11_FN.wav, skipping file.
  warn('No truth found for ' + wav_file + ', skipping file.')
D:\Users\Michel\Documents\FH\module\8_IP6\git\onset_detection\read_data.py:149: UserWarning: No truth found for FS_Lick11_KN.wav, skipping file.
  warn('No truth found for ' + wav_file + ', skipping file.')
D:\Users\Michel\Documents\FH\module\8_IP6\git\onset_detection\read_data.py:149: UserWarning: No truth found for FS_Lick11_MN.wav, skipping file.
  warn('No truth found for ' + wav_file + ', skipping file.')
D:\Users\Michel\Documents\FH\module\8_IP6\git\onset_detection\read_data.py:149: UserWarning: No truth found for LP_Lick11_FN.wav, skipping file.
  warn('No truth found for ' + wav_file + ', skipping file.')
D:\Users\Michel\Documents\FH\module\8_IP6\git\onset_detection\read_data.py:149: UserWarning: No truth found for LP_Lick11_KN.wav, skipping file.
  warn('No truth found for ' + wav_file + ', skipping file.')
D:\Users\Michel\Documents\FH\module\8_IP6\git\onset_detection\read_data.py:149: UserWarning: No truth found for LP_Lick11_MN.wav, skipping file.
  warn('No truth found for ' + wav_file + ', skipping file.')
D:\Users\Michel\Documents\FH\module\8_IP6\git\onset_detection\read_data.py:170: UserWarning: Skipping data\IDMT-SMT-GUITAR_V2\dataset4\Career SG\fast\country_folk\audio\country_1_150BPM.wav: no truth csv
  warn('Skipping ' + path_to_wav + ': no truth csv')
D:\Users\Michel\Documents\FH\module\8_IP6\git\onset_detection\read_data.py:170: UserWarning: Skipping data\IDMT-SMT-GUITAR_V2\dataset4\Career SG\fast\metal\audio\metal_3_135BPM.wav: no truth csv
  warn('Skipping ' + path_to_wav + ': no truth csv')
D:\Users\Michel\Documents\FH\module\8_IP6\git\onset_detection\read_data.py:170: UserWarning: Skipping data\IDMT-SMT-GUITAR_V2\dataset4\Career SG\fast\rock_blues\audio\rock_1_120BPM.wav: no truth csv
  warn('Skipping ' + path_to_wav + ': no truth csv')
D:\Users\Michel\Documents\FH\module\8_IP6\git\onset_detection\read_data.py:170: UserWarning: Skipping data\IDMT-SMT-GUITAR_V2\dataset4\Career SG\fast\rock_blues\audio\rock_2_115BPM.wav: no truth csv
  warn('Skipping ' + path_to_wav + ': no truth csv')
D:\Users\Michel\Documents\FH\module\8_IP6\git\onset_detection\read_data.py:170: UserWarning: Skipping data\IDMT-SMT-GUITAR_V2\dataset4\Ibanez 2820\slow\classical\audio\classical_8_100BPM.wav: no truth csv
  warn('Skipping ' + path_to_wav + ': no truth csv')
D:\Users\Michel\Documents\FH\module\8_IP6\git\onset_detection\read_data.py:170: UserWarning: Skipping data\IDMT-SMT-GUITAR_V2\dataset4\Ibanez 2820\slow\jazz\audio\jazz_1_160BPM.wav: no truth csv
  warn('Skipping ' + path_to_wav + ': no truth csv')
D:\Users\Michel\Documents\FH\module\8_IP6\git\onset_detection\read_data.py:170: UserWarning: Skipping data\IDMT-SMT-GUITAR_V2\dataset4\Ibanez 2820\slow\jazz\audio\jazz_2_170BPM.wav: no truth csv
  warn('Skipping ' + path_to_wav + ': no truth csv')
D:\Users\Michel\Documents\FH\module\8_IP6\git\onset_detection\read_data.py:170: UserWarning: Skipping data\IDMT-SMT-GUITAR_V2\dataset4\Ibanez 2820\slow\jazz\audio\jazz_3_120BPM.wav: no truth csv
  warn('Skipping ' + path_to_wav + ': no truth csv')
D:\Users\Michel\Documents\FH\module\8_IP6\git\onset_detection\read_data.py:170: UserWarning: Skipping data\IDMT-SMT-GUITAR_V2\dataset4\Ibanez 2820\slow\jazz\audio\jazz_4_70BPM.wav: no truth csv
  warn('Skipping ' + path_to_wav + ': no truth csv')
D:\Users\Michel\Documents\FH\module\8_IP6\git\onset_detection\read_data.py:170: UserWarning: Skipping data\IDMT-SMT-GUITAR_V2\dataset4\Ibanez 2820\slow\jazz\audio\jazz_5_80BPM.wav: no truth csv
  warn('Skipping ' + path_to_wav + ': no truth csv')
D:\Users\Michel\Documents\FH\module\8_IP6\git\onset_detection\read_data.py:170: UserWarning: Skipping data\IDMT-SMT-GUITAR_V2\dataset4\Ibanez 2820\slow\jazz\audio\jazz_6_150BPM.wav: no truth csv
  warn('Skipping ' + path_to_wav + ': no truth csv')
D:\Users\Michel\Documents\FH\module\8_IP6\git\onset_detection\read_data.py:170: UserWarning: Skipping data\IDMT-SMT-GUITAR_V2\dataset4\Ibanez 2820\slow\jazz\audio\jazz_7_140BPM.wav: no truth csv
  warn('Skipping ' + path_to_wav + ': no truth csv')
D:\Users\Michel\Documents\FH\module\8_IP6\git\onset_detection\read_data.py:170: UserWarning: Skipping data\IDMT-SMT-GUITAR_V2\dataset4\Ibanez 2820\slow\jazz\audio\jazz_8_110BPM.wav: no truth csv
  warn('Skipping ' + path_to_wav + ': no truth csv')
D:\Users\Michel\Documents\FH\module\8_IP6\git\onset_detection\read_data.py:170: UserWarning: Skipping data\IDMT-SMT-GUITAR_V2\dataset4\Ibanez 2820\slow\reggae_ska\audio\reggae_2_100BPM.wav: no truth csv
  warn('Skipping ' + path_to_wav + ': no truth csv')
D:\Users\Michel\Documents\FH\module\8_IP6\git\onset_detection\read_data.py:170: UserWarning: Skipping data\IDMT-SMT-GUITAR_V2\dataset4\Ibanez 2820\slow\rock_blues\audio\rock_5_100BPM.wav: no truth csv
  warn('Skipping ' + path_to_wav + ': no truth csv')
D:\Users\Michel\Documents\FH\module\8_IP6\git\onset_detection\read_data.py:14: UserWarning: Cannot handle stereo signal (data\IDMT-SMT-GUITAR_V2\dataset3\audio\pathetique_poly.wav), skipping file.
  warn('Cannot handle stereo signal (' + path_to_wav + '), skipping file.')

In [3]:
X_parts_train, X_parts_test, y_parts_train, y_parts_test, y_start_only_parts_train, y_start_only_parts_test, ds_labels_train, ds_labels_test = train_test_split(
    X_parts, y_parts, y_start_only_parts, ds_labels, test_size=0.2, random_state=42
)

In [4]:
X_train = np.concatenate(X_parts_train)
X_test = np.concatenate(X_parts_test)
y_train = np.concatenate(y_parts_train).ravel()
y_test = np.concatenate(y_parts_test).ravel()

y_start_only_train = np.concatenate(y_start_only_parts_train)
y_start_only_test = np.concatenate(y_start_only_parts_test)

ds_labels_flat_train = []
for y_part, ds_label in zip(y_parts_train, ds_labels_train):
    ds_labels_part = np.empty(len(y_part), dtype=np.int8)
    ds_labels_part.fill(ds_label)
    ds_labels_flat_train.append(ds_labels_part)
ds_labels_flat_train = np.concatenate(ds_labels_flat_train).ravel()

ds_labels_flat_test = []
for y_part, ds_label in zip(y_parts_test, ds_labels_test):
    ds_labels_part = np.empty(len(y_part), dtype=np.int8)
    ds_labels_part.fill(ds_label)
    ds_labels_flat_test.append(ds_labels_part)
ds_labels_flat_test = np.concatenate(ds_labels_flat_test).ravel()

print(X_train.shape)
print(y_train.shape)
print(y_start_only_train.shape)
print(ds_labels_flat_train.shape)
print(X_test.shape)
print(y_test.shape)
print(y_start_only_test.shape)
print(ds_labels_flat_test.shape)


(940179, 111)
(940179,)
(940179,)
(940179,)
(231747, 111)
(231747,)
(231747,)
(231747,)

In [5]:
ss = StandardScaler()
X_train = ss.fit_transform(X_train)
X_test = ss.transform(X_test)
print(X_train.mean())
print(X_train.std())
print(X_test.mean())
print(X_test.std())


5.48495118402e-19
1.0
2.3806596501e-05
1.03017544869

In [6]:
input_dim = X_train.shape[1]
X_train = X_train.reshape(X_train.shape[0], input_dim, 1)
X_test = X_test.reshape(X_test.shape[0], input_dim, 1)
input_shape = (input_dim, 1)
print(X_train.shape)
print(X_test.shape)
print(input_shape)


(940179, 111, 1)
(231747, 111, 1)
(111, 1)

In [10]:
def create_model(nb_filter=32, filter_length=8, padding='same', input_shape=(111, 1),
                 pool_size=2,
                 n_conv_layers=1, dropout=True):
    model = Sequential()
    model.add(Conv1D(nb_filter, filter_length, padding=padding, input_shape=input_shape))
    model.add(Activation('relu'))
    for i in range(0, n_conv_layers - 1):
        model.add(Convolution1D(nb_filter, filter_length, padding=padding))
        model.add(Activation('relu'))
    model.add(MaxPooling1D(pool_size=pool_size))
    if dropout:
        model.add(Dropout(0.25))

    model.add(Flatten())
    model.add(Dense(128))
    model.add(Activation('relu'))
    if dropout:
        model.add(Dropout(0.5))
    model.add(Dense(1))
    model.add(Activation('sigmoid'))

    model.compile(loss='binary_crossentropy',
                  optimizer='adam',)
    
    return model


def fit_predict(clf):
    clf.fit(X_train, y_train)
    y_train_predicted = clf.predict(X_train).ravel()
    y_test_predicted = clf.predict(X_test).ravel()
    
    model = clf.model
    for layer in model.layers:
        print(layer.get_config())
    print('TRAIN')
    print(classification_report(y_train, y_train_predicted))
    print(onset_metric(y_train, y_start_only_train, y_train_predicted))
    print('TEST')
    print(classification_report(y_test, y_test_predicted))
    print(onset_metric(y_test, y_start_only_test, y_test_predicted))
    print('')

In [11]:
clfs = [
    KerasClassifier(
        build_fn=create_model,
        batch_size=1024, nb_epoch=500,
        validation_data=(X_test, y_test),
        callbacks=[EarlyStopping(monitor='loss', patience=4)],
        nb_filter=256, input_shape=input_shape,
    ),
    KerasClassifier(
        build_fn=create_model,
        batch_size=1024, nb_epoch=500,
        validation_data=(X_test, y_test),
        callbacks=[EarlyStopping(monitor='loss', patience=4)],
        filter_length=16, input_shape=input_shape,
    ),
    KerasClassifier(
        build_fn=create_model,
        batch_size=1024, nb_epoch=500,
        validation_data=(X_test, y_test),
        callbacks=[EarlyStopping(monitor='loss', patience=4)],
        input_shape=input_shape, n_conv_layers=2,
    ),
    KerasClassifier(
        build_fn=create_model,
        batch_size=1024, nb_epoch=500,
        validation_data=(X_test, y_test),
        callbacks=[EarlyStopping(monitor='loss', patience=4)],
        input_shape=input_shape, dropout=False,
    ),
]

clfs2 = [
    KerasClassifier(
        build_fn=create_model,
        batch_size=1024, nb_epoch=500,
        validation_data=(X_test, y_test),
        callbacks=[EarlyStopping(monitor='loss', patience=4)],
        input_shape=input_shape, nb_filter=64, filter_length=16, n_conv_layers=2,
    ),
]

for clf in clfs2:
    fit_predict(clf)


Train on 940179 samples, validate on 231747 samples
Epoch 1/10
 94208/940179 [==>...........................] - ETA: 77s - loss: 0.2157
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-11-b14415fdb4f1> in <module>()
     41 
     42 for clf in clfs2:
---> 43     fit_predict(clf)

<ipython-input-10-6cc24a7d0cc1> in fit_predict(clf)
     27 
     28 def fit_predict(clf):
---> 29     clf.fit(X_train, y_train)
     30     y_train_predicted = clf.predict(X_train).ravel()
     31     y_test_predicted = clf.predict(X_test).ravel()

D:\ProgramFiles\Anaconda3_64\lib\site-packages\keras\wrappers\scikit_learn.py in fit(self, x, y, **kwargs)
    199             y = np.searchsorted(self.classes_, y)
    200         self.n_classes_ = len(self.classes_)
--> 201         return super(KerasClassifier, self).fit(x, y, **kwargs)
    202 
    203     def predict(self, x, **kwargs):

D:\ProgramFiles\Anaconda3_64\lib\site-packages\keras\wrappers\scikit_learn.py in fit(self, x, y, **kwargs)
    147         fit_args.update(kwargs)
    148 
--> 149         history = self.model.fit(x, y, **fit_args)
    150 
    151         return history

D:\ProgramFiles\Anaconda3_64\lib\site-packages\keras\models.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, **kwargs)
    843                               class_weight=class_weight,
    844                               sample_weight=sample_weight,
--> 845                               initial_epoch=initial_epoch)
    846 
    847     def evaluate(self, x, y, batch_size=32, verbose=1,

D:\ProgramFiles\Anaconda3_64\lib\site-packages\keras\engine\training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, **kwargs)
   1483                               val_f=val_f, val_ins=val_ins, shuffle=shuffle,
   1484                               callback_metrics=callback_metrics,
-> 1485                               initial_epoch=initial_epoch)
   1486 
   1487     def evaluate(self, x, y, batch_size=32, verbose=1, sample_weight=None):

D:\ProgramFiles\Anaconda3_64\lib\site-packages\keras\engine\training.py in _fit_loop(self, f, ins, out_labels, batch_size, epochs, verbose, callbacks, val_f, val_ins, shuffle, callback_metrics, initial_epoch)
   1138                 batch_logs['size'] = len(batch_ids)
   1139                 callbacks.on_batch_begin(batch_index, batch_logs)
-> 1140                 outs = f(ins_batch)
   1141                 if not isinstance(outs, list):
   1142                     outs = [outs]

D:\ProgramFiles\Anaconda3_64\lib\site-packages\keras\backend\theano_backend.py in __call__(self, inputs)
   1092     def __call__(self, inputs):
   1093         assert isinstance(inputs, (list, tuple))
-> 1094         return self.function(*inputs)
   1095 
   1096 

D:\ProgramFiles\Anaconda3_64\lib\site-packages\theano\compile\function_module.py in __call__(self, *args, **kwargs)
    882         try:
    883             outputs =\
--> 884                 self.fn() if output_subset is None else\
    885                 self.fn(output_subset=output_subset)
    886         except Exception:

D:\ProgramFiles\Anaconda3_64\lib\site-packages\theano\gof\op.py in rval(p, i, o, n)
    870             # default arguments are stored in the closure of `rval`
    871             def rval(p=p, i=node_input_storage, o=node_output_storage, n=node):
--> 872                 r = p(n, [x[0] for x in i], o)
    873                 for o in node.outputs:
    874                     compute_map[o][0] = True

KeyboardInterrupt: 

In [10]:
model = clfs2[0].model
for layer in model.layers:
    config = layer.get_config()
    if config['name'].startswith('convolution1d'):
        # squeeze = remove dimensions with shape 1
        filters = np.squeeze(layer.get_weights()[0]).T
        w = 8
        h = config['nb_filter'] // w
        fig, axarr = plt.subplots(h, w, figsize=(w*2, h*2), sharex=True, sharey=True)
        for i in range(h):
            for j in range(w):
                _ = axarr[i, j].plot(filters[i*w + j])
        break