Getting the Dataset

This example uses the Data Set 2 from the BCI Competition 3. After downloading and copying it into a directory called data next to this script, you should be able to follow this example.


In [1]:
from __future__ import division

import numpy as np
import scipy as sp
from matplotlib import pyplot as plt
from matplotlib import ticker
import matplotlib as mpl

from wyrm import plot
plot.beautify()
from wyrm.types import Data
from wyrm import processing as proc
from wyrm.io import load_bcicomp3_ds2

In [2]:
TRAIN_A = 'data/BCI_Comp_III_Wads_2004/Subject_A_Train.mat'
TRAIN_B = 'data/BCI_Comp_III_Wads_2004/Subject_B_Train.mat'

TEST_A = 'data/BCI_Comp_III_Wads_2004/Subject_A_Test.mat'
TEST_B = 'data/BCI_Comp_III_Wads_2004/Subject_B_Test.mat'

TRUE_LABELS_A = 'WQXPLZCOMRKO97YFZDEZ1DPI9NNVGRQDJCUVRMEUOOOJD2UFYPOO6J7LDGYEGOA5VHNEHBTXOO1TDOILUEE5BFAEEXAW_K4R3MRU'
TRUE_LABELS_B = 'MERMIROOMUHJPXJOHUVLEORZP3GLOO7AUFDKEFTWEOOALZOP9ROCGZET1Y19EWX65QUYU7NAK_4YCJDVDNGQXODBEV2B5EFDIDNR'

MATRIX = ['abcdef',
          'ghijkl',
          'mnopqr',
          'stuvwx',
          'yz1234',
          '56789_']

MARKER_DEF_TRAIN = {'target': ['target'], 'nontarget': ['nontarget']}
MARKER_DEF_TEST = {'flashing': ['flashing']}

SEG_IVAL = [0, 700]

JUMPING_MEANS_IVALS_A = [150, 220], [200, 260], [310, 360], [550, 660] # 91%
JUMPING_MEANS_IVALS_B = [150, 250], [200, 280], [280, 380], [480, 610] # 91%

In [3]:
def preprocessing_simple(dat, MRK_DEF, *args, **kwargs):
    """Simple preprocessing that reaches 97% accuracy.
    """
    fs_n = dat.fs / 2
    b, a = proc.signal.butter(5, [10 / fs_n], btype='low')
    dat = proc.filtfilt(dat, b, a)
   
    dat = proc.subsample(dat, 20)
    epo = proc.segment_dat(dat, MRK_DEF, SEG_IVAL)
    fv = proc.create_feature_vectors(epo)
    return fv, epo

In [4]:
def preprocessing(dat, MRK_DEF, JUMPING_MEANS_IVALS):
    dat = proc.sort_channels(dat)
    
    fs_n = dat.fs / 2
    b, a = proc.signal.butter(5, [30 / fs_n], btype='low')
    dat = proc.lfilter(dat, b, a)
    b, a = proc.signal.butter(5, [.4 / fs_n], btype='high')
    dat = proc.lfilter(dat, b, a)
    
    dat = proc.subsample(dat, 60)
    epo = proc.segment_dat(dat, MRK_DEF, SEG_IVAL)
    
    fv = proc.jumping_means(epo, JUMPING_MEANS_IVALS)
    fv = proc.create_feature_vectors(fv)
    return fv, epo

In [5]:
epo = [None, None]
acc = 0
for subject in range(2):
    if subject == 0:
        training_set = TRAIN_A
        testing_set = TEST_A
        labels = TRUE_LABELS_A
        jumping_means_ivals = JUMPING_MEANS_IVALS_A
    else:
        training_set = TRAIN_B
        testing_set = TEST_B
        labels = TRUE_LABELS_B
        jumping_means_ivals = JUMPING_MEANS_IVALS_B
    
    # load the training set
    dat = load_bcicomp3_ds2(training_set)
    fv_train, epo[subject] = preprocessing(dat, MARKER_DEF_TRAIN, jumping_means_ivals)
    
    # train the lda
    cfy = proc.lda_train(fv_train)
    
    # load the testing set
    dat = load_bcicomp3_ds2(testing_set)
    fv_test, _ = preprocessing(dat, MARKER_DEF_TEST, jumping_means_ivals)
    
    # predict
    lda_out_prob = proc.lda_apply(fv_test, cfy)
    
    # unscramble the order of stimuli
    unscramble_idx = fv_test.stimulus_code.reshape(100, 15, 12).argsort()
    static_idx = np.indices(unscramble_idx.shape)
    lda_out_prob = lda_out_prob.reshape(100, 15, 12)
    lda_out_prob = lda_out_prob[static_idx[0], static_idx[1], unscramble_idx]
    
    #lda_out_prob = lda_out_prob[:, :5, :]
    
    # destil the result of the 15 runs
    #lda_out_prob = lda_out_prob.prod(axis=1)
    lda_out_prob = lda_out_prob.sum(axis=1)
        
    # 
    lda_out_prob = lda_out_prob.argsort()
    
    cols = lda_out_prob[lda_out_prob <= 5].reshape(100, -1)
    rows = lda_out_prob[lda_out_prob > 5].reshape(100, -1)
    text = ''
    for i in range(100):
        row = rows[i][-1]-6
        col = cols[i][-1]
        letter = MATRIX[row][col]
        text += letter
    print
    print 'Result for subject %d' % (subject+1)
    print 'Constructed labels: %s' % text.upper()
    print 'True labels       : %s' % labels
    a = np.array(list(text.upper()))
    b = np.array(list(labels))
    accuracy = np.count_nonzero(a == b) / len(a)
    print 'Accuracy: %.1f%%' % (accuracy * 100)
    acc += accuracy
print
print 'Overal accuracy: %.1f%%' % (100 * acc / 2)


WARNING:wyrm.processing:Subsampling led to loss of 2 samples, in an online setting consider using a BlockBuffer with a buffer size of a multiple of 4 samples.
WARNING:wyrm.processing:Subsampling led to loss of 2 samples, in an online setting consider using a BlockBuffer with a buffer size of a multiple of 4 samples.
Result for subject 1
Constructed labels: WQXPLZCOMRKOW7YFZDEZ1DPI9NN2GRKDJCUJRMEUOCOJD2UFYPOO6J7LDGYEGOA5VHNEKBW4OO1TDOILUEE5BFAEEXAW_K3R3MRU
True labels       : WQXPLZCOMRKO97YFZDEZ1DPI9NNVGRQDJCUVRMEUOOOJD2UFYPOO6J7LDGYEGOA5VHNEHBTXOO1TDOILUEE5BFAEEXAW_K4R3MRU
Accuracy: 91.0%

Result for subject 2
Constructed labels: MERMIROOMUZJPXJOHUVFBORZP3GLOO7AUFDKEFTWEOOALZOP9R1CGZE11Y19EWX65QUYU7NAK_1ACJDVDNGQXOJBEV2B5EFDIDTR
True labels       : MERMIROOMUHJPXJOHUVLEORZP3GLOO7AUFDKEFTWEOOALZOP9ROCGZET1Y19EWX65QUYU7NAK_4YCJDVDNGQXODBEV2B5EFDIDNR
Accuracy: 91.0%

Overal accuracy: 91.0%

Analysis of the data

The following part shows how to visualize interesting information of the data.


In [6]:
avgs = [None, None]
fig, axes = plt.subplots(2, 3, sharex=True, sharey=True, figsize=(9, 6))
for idx, file in enumerate([TRAIN_A, TRAIN_B]):
    avgs[idx] = proc.calculate_classwise_average(epo[idx])
    #avgs[idx] = proc.correct_for_baseline(avgs[idx], [0, 50])
    
    d = proc.select_channels(avgs[idx], ["fcz", "cz", "oz"])
    for i in range(3):
        axes[idx, i].plot(d.axes[-2], d.data[..., i].T)
        axes[idx, i].grid()

for i in range(3):        
    axes[0, i].set_title(d.axes[-1][i])
    
axes[1, 1].set_xlabel('time [ms]')
for i in range(2):
    axes[i, 0].set_ylabel(u'voltage [a.u.]')

for i in range(2):
    axes[i, 2].yaxis.set_label_position("right")
    axes[i, 2].set_ylabel('Subject %s' % 'AB'[i])

axes[0, -1].legend(d.class_names)
plt.tight_layout()

In [7]:
def plot_scalps(epo, ivals):
    # ratio scalp to colorbar width
    scale = 10
    dat = proc.jumping_means(epo, ivals)
    n_classes = epo.data.shape[0]
    n_ivals = len(ivals)
    for class_idx in range(n_classes):
        vmax = np.abs(dat.data).max()
        vmax = round(vmax)
        vmin = -vmax
        for ival_idx in range(n_ivals):
            ax = plt.subplot2grid((n_classes, scale*n_ivals+1), (class_idx, scale*ival_idx), colspan=scale)
            plot.ax_scalp(dat.data[class_idx, ival_idx, :], epo.axes[-1], vmin=vmin, vmax=vmax)
            if class_idx == 1:
                ax.text(0, -1.5, ivals[ival_idx], horizontalalignment='center')
            if ival_idx == 0:
                ax.text(-1.5, 0, ['nontarget', 'target'][class_idx], color='bm'[class_idx], rotation='vertical', verticalalignment='center')
    
    # colorbar
    ax = plt.subplot2grid((n_classes, scale*n_ivals+1), (0, scale*n_ivals), rowspan=n_classes)
    plot.ax_colorbar(vmin, vmax, label='voltage [a.u.]', ticks=[vmin, 0, vmax])

In [8]:
for subj_idx in range(2):
    fig = plt.figure(figsize=(11, 6))
    ivals = [JUMPING_MEANS_IVALS_A, JUMPING_MEANS_IVALS_B][subj_idx]
    plot_scalps(avgs[subj_idx], ivals)
    plt.tight_layout()
    fig.subplots_adjust(left=.06, bottom=.10, right=None, top=None, wspace=0, hspace=0)

In [9]:
fig, axes = plt.subplots(2, 1, sharex=True, sharey=True)
for i in range(2):
    r2 = proc.calculate_signed_r_square(epo[i])
    # switch the sign to make the plot more consistent with the timecourse. This is equivalent to reordering the classidices and calculating r2
    r2 *= -1
    
    max = np.max(np.abs(r2))
    im = axes[i].imshow(r2.T, aspect='auto', interpolation='None', vmin=-max, vmax=max)
    
    axes[i].set_ylabel('%s' % (epo[i].names[-1]))
    axes[i].grid()
    axes[i].set_title("Subject %s" % "AB"[i])
    cb = plt.colorbar(im, ax=axes[i])
    cb.set_label('[a.u.]')

axes[1].yaxis.set_major_formatter(ticker.IndexFormatter(epo[i].axes[-1]))
mask = map(lambda x: True if x.lower().endswith('z') else False, epo[i].axes[-1])
axes[1].yaxis.set_major_locator(ticker.FixedLocator(np.nonzero(mask)[0]))
axes[1].xaxis.set_major_formatter(ticker.IndexFormatter(['%d' % j for j in epo[i].axes[-2]]))
axes[1].xaxis.set_major_locator(ticker.MultipleLocator(6))
axes[1].set_xlabel('%s [%s]' % (epo[i].names[-2], epo[i].units[-2]))

plt.tight_layout()
plt.show()

In [9]: