This example uses the Data Set 1 from the BCI Competition 3. After downloading and copying it into a directory called data next to this script, you should be able to follow this example.
In [1]:
from __future__ import division
import numpy as np
import scipy as sp
from scipy.io import loadmat
from matplotlib import pyplot as plt
import matplotlib as mpl
from wyrm import processing as proc
from wyrm.types import Data
from wyrm import plot
from wyrm.io import load_bcicomp3_ds1
plot.beautify()
In [2]:
DATA_DIR = 'data/BCI_COMP_III_Tuebingen/'
TRUE_LABELS = 'data/BCI_COMP_III_Tuebingen/true_labels.txt'
In [3]:
# load test and training data
dat_train, dat_test = load_bcicomp3_ds1(DATA_DIR)
# load true labels
true_labels = np.loadtxt(TRUE_LABELS).astype('int')
# map labels -1 -> 0
true_labels[true_labels == -1] = 0
In [4]:
def plot_csp_pattern(a):
# get symmetric min/max values for the color bar from first and last column of the pattern
maxv = np.max(np.abs(a[:, [0, -1]]))
minv = -maxv
im_args = {'interpolation' : 'None',
'vmin' : minv,
'vmax' : maxv
}
# plot
ax1 = plt.subplot2grid((1,11), (0,0), colspan=5)
ax2 = plt.subplot2grid((1,11), (0,5), colspan=5)
ax3 = plt.subplot2grid((1,11), (0,10))
ax1.imshow(a[:, 0].reshape(8, 8), **im_args)
ax1.set_title('Pinky')
ax = ax2.imshow(a[:, -1].reshape(8, 8), **im_args)
ax2.set_title('Tongue')
plt.colorbar(ax, cax=ax3)
plt.tight_layout()
In [5]:
def preprocess(data, filt=None):
dat = data.copy()
fs_n = dat.fs / 2
b, a = proc.signal.butter(5, [13 / fs_n], btype='low')
dat = proc.filtfilt(dat, b, a)
b, a = proc.signal.butter(5, [9 / fs_n], btype='high')
dat = proc.filtfilt(dat, b, a)
dat = proc.subsample(dat, 50)
if filt is None:
filt, pattern, _ = proc.calculate_csp(dat)
plot_csp_pattern(pattern)
dat = proc.apply_csp(dat, filt)
dat = proc.variance(dat)
dat = proc.logarithm(dat)
return dat, filt
In [6]:
fv_train, filt = preprocess(dat_train)
fv_test, _ = preprocess(dat_test, filt)
In [7]:
cfy = proc.lda_train(fv_train)
result = proc.lda_apply(fv_test, cfy)
result = (np.sign(result) + 1) / 2
print 'LDA Accuracy %.2f%%' % ((result == true_labels).sum() / len(result))
In [8]:
plt.show()
In [8]: