This example uses the Data Set 2 from the BCI Competition 3. After downloading and copying it into a directory called data next to this script, you should be able to follow this example.
In [1]:
from __future__ import division
import numpy as np
import scipy as sp
from matplotlib import pyplot as plt
from matplotlib import ticker
import matplotlib as mpl
from wyrm import plot
plot.beautify()
from wyrm.types import Data
from wyrm import processing as proc
from wyrm.io import load_bcicomp3_ds2
In [2]:
TRAIN_A = 'data/BCI_Comp_III_Wads_2004/Subject_A_Train.mat'
TRAIN_B = 'data/BCI_Comp_III_Wads_2004/Subject_B_Train.mat'
TEST_A = 'data/BCI_Comp_III_Wads_2004/Subject_A_Test.mat'
TEST_B = 'data/BCI_Comp_III_Wads_2004/Subject_B_Test.mat'
TRUE_LABELS_A = 'WQXPLZCOMRKO97YFZDEZ1DPI9NNVGRQDJCUVRMEUOOOJD2UFYPOO6J7LDGYEGOA5VHNEHBTXOO1TDOILUEE5BFAEEXAW_K4R3MRU'
TRUE_LABELS_B = 'MERMIROOMUHJPXJOHUVLEORZP3GLOO7AUFDKEFTWEOOALZOP9ROCGZET1Y19EWX65QUYU7NAK_4YCJDVDNGQXODBEV2B5EFDIDNR'
MATRIX = ['abcdef',
'ghijkl',
'mnopqr',
'stuvwx',
'yz1234',
'56789_']
MARKER_DEF_TRAIN = {'target': ['target'], 'nontarget': ['nontarget']}
MARKER_DEF_TEST = {'flashing': ['flashing']}
SEG_IVAL = [0, 700]
JUMPING_MEANS_IVALS_A = [150, 220], [200, 260], [310, 360], [550, 660] # 91%
JUMPING_MEANS_IVALS_B = [150, 250], [200, 280], [280, 380], [480, 610] # 91%
In [3]:
def preprocessing_simple(dat, MRK_DEF, *args, **kwargs):
"""Simple preprocessing that reaches 97% accuracy.
"""
fs_n = dat.fs / 2
b, a = proc.signal.butter(5, [10 / fs_n], btype='low')
dat = proc.filtfilt(dat, b, a)
dat = proc.subsample(dat, 20)
epo = proc.segment_dat(dat, MRK_DEF, SEG_IVAL)
fv = proc.create_feature_vectors(epo)
return fv, epo
In [4]:
def preprocessing(dat, MRK_DEF, JUMPING_MEANS_IVALS):
dat = proc.sort_channels(dat)
fs_n = dat.fs / 2
b, a = proc.signal.butter(5, [30 / fs_n], btype='low')
dat = proc.lfilter(dat, b, a)
b, a = proc.signal.butter(5, [.4 / fs_n], btype='high')
dat = proc.lfilter(dat, b, a)
dat = proc.subsample(dat, 60)
epo = proc.segment_dat(dat, MRK_DEF, SEG_IVAL)
fv = proc.jumping_means(epo, JUMPING_MEANS_IVALS)
fv = proc.create_feature_vectors(fv)
return fv, epo
In [5]:
epo = [None, None]
acc = 0
for subject in range(2):
if subject == 0:
training_set = TRAIN_A
testing_set = TEST_A
labels = TRUE_LABELS_A
jumping_means_ivals = JUMPING_MEANS_IVALS_A
else:
training_set = TRAIN_B
testing_set = TEST_B
labels = TRUE_LABELS_B
jumping_means_ivals = JUMPING_MEANS_IVALS_B
# load the training set
dat = load_bcicomp3_ds2(training_set)
fv_train, epo[subject] = preprocessing(dat, MARKER_DEF_TRAIN, jumping_means_ivals)
# train the lda
cfy = proc.lda_train(fv_train)
# load the testing set
dat = load_bcicomp3_ds2(testing_set)
fv_test, _ = preprocessing(dat, MARKER_DEF_TEST, jumping_means_ivals)
# predict
lda_out_prob = proc.lda_apply(fv_test, cfy)
# unscramble the order of stimuli
unscramble_idx = fv_test.stimulus_code.reshape(100, 15, 12).argsort()
static_idx = np.indices(unscramble_idx.shape)
lda_out_prob = lda_out_prob.reshape(100, 15, 12)
lda_out_prob = lda_out_prob[static_idx[0], static_idx[1], unscramble_idx]
#lda_out_prob = lda_out_prob[:, :5, :]
# destil the result of the 15 runs
#lda_out_prob = lda_out_prob.prod(axis=1)
lda_out_prob = lda_out_prob.sum(axis=1)
#
lda_out_prob = lda_out_prob.argsort()
cols = lda_out_prob[lda_out_prob <= 5].reshape(100, -1)
rows = lda_out_prob[lda_out_prob > 5].reshape(100, -1)
text = ''
for i in range(100):
row = rows[i][-1]-6
col = cols[i][-1]
letter = MATRIX[row][col]
text += letter
print
print 'Result for subject %d' % (subject+1)
print 'Constructed labels: %s' % text.upper()
print 'True labels : %s' % labels
a = np.array(list(text.upper()))
b = np.array(list(labels))
accuracy = np.count_nonzero(a == b) / len(a)
print 'Accuracy: %.1f%%' % (accuracy * 100)
acc += accuracy
print
print 'Overal accuracy: %.1f%%' % (100 * acc / 2)
In [6]:
avgs = [None, None]
fig, axes = plt.subplots(2, 3, sharex=True, sharey=True, figsize=(9, 6))
for idx, file in enumerate([TRAIN_A, TRAIN_B]):
avgs[idx] = proc.calculate_classwise_average(epo[idx])
#avgs[idx] = proc.correct_for_baseline(avgs[idx], [0, 50])
d = proc.select_channels(avgs[idx], ["fcz", "cz", "oz"])
for i in range(3):
axes[idx, i].plot(d.axes[-2], d.data[..., i].T)
axes[idx, i].grid()
for i in range(3):
axes[0, i].set_title(d.axes[-1][i])
axes[1, 1].set_xlabel('time [ms]')
for i in range(2):
axes[i, 0].set_ylabel(u'voltage [a.u.]')
for i in range(2):
axes[i, 2].yaxis.set_label_position("right")
axes[i, 2].set_ylabel('Subject %s' % 'AB'[i])
axes[0, -1].legend(d.class_names)
plt.tight_layout()
In [7]:
def plot_scalps(epo, ivals):
# ratio scalp to colorbar width
scale = 10
dat = proc.jumping_means(epo, ivals)
n_classes = epo.data.shape[0]
n_ivals = len(ivals)
for class_idx in range(n_classes):
vmax = np.abs(dat.data).max()
vmax = round(vmax)
vmin = -vmax
for ival_idx in range(n_ivals):
ax = plt.subplot2grid((n_classes, scale*n_ivals+1), (class_idx, scale*ival_idx), colspan=scale)
plot.ax_scalp(dat.data[class_idx, ival_idx, :], epo.axes[-1], vmin=vmin, vmax=vmax)
if class_idx == 1:
ax.text(0, -1.5, ivals[ival_idx], horizontalalignment='center')
if ival_idx == 0:
ax.text(-1.5, 0, ['nontarget', 'target'][class_idx], color='bm'[class_idx], rotation='vertical', verticalalignment='center')
# colorbar
ax = plt.subplot2grid((n_classes, scale*n_ivals+1), (0, scale*n_ivals), rowspan=n_classes)
plot.ax_colorbar(vmin, vmax, label='voltage [a.u.]', ticks=[vmin, 0, vmax])
In [8]:
for subj_idx in range(2):
fig = plt.figure(figsize=(11, 6))
ivals = [JUMPING_MEANS_IVALS_A, JUMPING_MEANS_IVALS_B][subj_idx]
plot_scalps(avgs[subj_idx], ivals)
plt.tight_layout()
fig.subplots_adjust(left=.06, bottom=.10, right=None, top=None, wspace=0, hspace=0)
In [9]:
fig, axes = plt.subplots(2, 1, sharex=True, sharey=True)
for i in range(2):
r2 = proc.calculate_signed_r_square(epo[i])
# switch the sign to make the plot more consistent with the timecourse. This is equivalent to reordering the classidices and calculating r2
r2 *= -1
max = np.max(np.abs(r2))
im = axes[i].imshow(r2.T, aspect='auto', interpolation='None', vmin=-max, vmax=max)
axes[i].set_ylabel('%s' % (epo[i].names[-1]))
axes[i].grid()
axes[i].set_title("Subject %s" % "AB"[i])
cb = plt.colorbar(im, ax=axes[i])
cb.set_label('[a.u.]')
axes[1].yaxis.set_major_formatter(ticker.IndexFormatter(epo[i].axes[-1]))
mask = map(lambda x: True if x.lower().endswith('z') else False, epo[i].axes[-1])
axes[1].yaxis.set_major_locator(ticker.FixedLocator(np.nonzero(mask)[0]))
axes[1].xaxis.set_major_formatter(ticker.IndexFormatter(['%d' % j for j in epo[i].axes[-2]]))
axes[1].xaxis.set_major_locator(ticker.MultipleLocator(6))
axes[1].set_xlabel('%s [%s]' % (epo[i].names[-2], epo[i].units[-2]))
plt.tight_layout()
plt.show()
In [9]: