Feature Extraction


In [1]:
import csv
import time
import re
from datetime import datetime
from decimal import Decimal
from matplotlib import *
import matplotlib.pyplot as plt
import matplotlib.pylab as pylab
pylab.rcParams['figure.figsize'] = 14, 6

import warnings
warnings.filterwarnings('ignore', 'DeprecationWarning')

# Import config params
DATA_URL = ''
SAVE_URL = ''
try:
   from dev_settings import *
except ImportError:
   pass
print DATA_URL
print SAVE_URL


/home/sagar/Dropbox/Academics/nc_EDM/eeg_project/data/all
/home/sagar/Dropbox/Academics/nc_EDM/eeg_project/edm_eeg.git/preprocess/all

Plotting Raw Signal


In [2]:
def plot_signal(x_ax, y_ax, label, ax=None):
    if ax==None:
        fig, ax = plt.subplots()
    ax.plot(x_ax, y_ax, label=label)
    ax.grid(True)
    fig.tight_layout()
    plt.legend(loc='upper left')
    plt.show()
    return ax

with open(SAVE_URL + "/raw_incremental_label.csv", 'r') as fi:
    fr = csv.reader(fi, delimiter='\t')
    next(fr)#header
    
    data = list(fr)
    time_x = [i[0] for i in data]
    signal_x = [i[1] for i in data]
    
    plot_signal(time_x, signal_x, 'Raw Signal - whole')
    plot_signal(time_x[0:1000], signal_x[0:1000], 'Raw Signal - 0-1000')
    plot_signal(time_x[315000:len(time_x)], signal_x[315000:len(signal_x)], 'Raw Signal - 315000 - end')


Filters and Power Spectrum

Designing Butterworth Filter


In [63]:
from scipy.signal import butter, lfilter, freqz

# http://wiki.scipy.org/Cookbook/ButterworthBandpass
def butter_bandpass(lowcut, highcut, fs, order=5):
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = butter(order, [low, high], btype='band')
    return b, a

def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):
    b, a = butter_bandpass(lowcut, highcut, fs, order=order)
    y = lfilter(b, a, data)
    return y


#From Bao Hong Tan's Thesis - p16
fs = 512.0
lowcut = 0.1
highcut = 20.0

# Plot the frequency response for a few different orders.
plt.figure(1)
plt.clf()
for order in [1,2,3,4,5, 6, 7, 8, 9]:
    b, a = butter_bandpass(lowcut, highcut, fs, order=order)
    w, h = freqz(b, a, worN=2000)
    plt.plot((fs * 0.5 / np.pi) * w, abs(h), label="order = %d" % order)
    plt.title("Sample frequency responses of the band filter")
    plt.xlabel('Frequency (Hz)')
    plt.ylabel('Gain')
    plt.grid(True)
    plt.legend(loc='best')


Filtering our Signal

The band pass filter is used to filter the signal


In [5]:
data_np = np.array([float(i[1]) for i in data])
data_filtered = butter_bandpass_filter(data_np, 0.1, 20, 512, 4)

fig, ax = plt.subplots()
ax.plot([i[1] for i in data[0:200]], label="Original Signal")
ax.plot(data_filtered[0:200], label="Filtered Signal")
plt.grid(True)
plt.legend(loc='best')
plt.title("data[0:200]")


Out[5]:
<matplotlib.text.Text at 0x8bd4590>

Power Spectrum

Computing Power in each frequency bin


In [253]:
import pyeeg

a = pyeeg.bin_power(data_filtered[0:1000], [0.5,4,7,12,30],512)

Using collective data

Get all subject details


In [3]:
from os import listdir
from os.path import isfile, join
import re

def get_subject_list():
    onlyfiles = [ f for f in listdir(SAVE_URL) if isfile(join(SAVE_URL,f)) ]
    pat = re.compile("[0-9]*\.[0-9]*\.labelled\.csv")
    temp_dat = [f.split('.')[0:2] for f in onlyfiles if pat.match(f)]
    sub_dict = {i[1]: i[0] for i in temp_dat}
    return sub_dict

def get_data(subj_list):
    subj_data = {}
    for s_id in subj_list.keys():
        s_time = subj_list[s_id]
        s_file = s_time + "." + s_id + ".labelled.csv"
        with open(join(SAVE_URL,s_file), 'rb') as fi:
            fr = csv.reader(fi,delimiter="\t")
            next(fr) #header
            s_data = list(fr)
            subj_data[int(s_id)] = s_data
    return subj_data

def get_counts(data):
    return [(i,len(data[i])) for i in data]

subj_list = get_subject_list()
subj_data = get_data(subj_list)

In [4]:
def plot_subject(s_comb, title=None):
    fig, ax = plt.subplots()
    x_ax = [int(i[0].split('.')[0]) for i in s_comb]
    
    sig_q = [int(i[1]) for i in s_comb]
    atten = [int(i[2]) for i in s_comb]
    medit = [int(i[3]) for i in s_comb]
    diffi = [int(i[4])*50 for i in s_comb]
    taskid= [int(i[5]) for i in s_comb]
    taskid_set = list(set(taskid))
    taskid_norm = [taskid_set.index(i) for i in taskid]
    
    ax.plot(x_ax,sig_q, label='Quality')
    ax.plot(x_ax, atten, label='Attention')
    ax.plot(x_ax, medit, label='Meditation')
    ax.plot(x_ax, diffi, label='Difficulty')
    #ax.plot(x_ax, taskid_norm, label='taskid')
    
    ax.grid(True)
    fig.tight_layout()
    plt.legend(loc='upper left')
    plt.title(title)
    plt.show()
    return

def plot_subjects(data, count):
    for i in range(count):
        s1 = subj_list.keys()[i]
        plot_subject(data[int(s1)], "Subject: "+s1)
    return

plot_subjects(subj_data, 4)


Data Cleaning

Looks like some of the subjects like 24 have all '0' attention and meditation values. Also, at several places, the quality value of signals is 'non-zero'. In Tan's thesis, they have discarded the entire trial if the quality was non zero:

A trial was considered good if the quality signal indicated a value of 0 for the duration of the entire trial, that is, the level of noise present in the trial was acceptable; otherwise, the trial was excluded from analyses in this thesis.

However, initial plan is to remove just the data for the time slots corresponding to non-zero quality


In [8]:
import pickle
from os.path import join

def clean_subj(s_data):
    s_data = [i for i in s_data if int(i[1])==0 and \
                int(i[2])>0 and \
                int(i[3])>0 and \
                int(i[4]) >-1]
    return s_data

def clean_all():
    cln_data = {}
    for s in subj_list.keys():
        cln_data[int(s)] = clean_subj(subj_data[int(s)])
    return cln_data

cln_data = clean_all()
pickle.dump(cln_data,open(join(SAVE_URL,"cln_data.pickle"),'wb'))


cnt1=get_counts(subj_data)
cnt2=get_counts(cln_data)


plot([i[0] for i in cnt1], [i[1] for i in cnt1], "-o", label="original")
plot([i[0] for i in cnt2], [i[1] for i in cnt2], "-o", label="cleaned")
plt.xlabel("Id")
plt.ylabel("Size")
plt.legend()
plt.grid()
plt.title("Comparing original and new size of data")

plot_subjects(cln_data, 4)


Decision Trees

Random Forest classifier

This classifier is often as the first technique to find out obvious and sometimes not-so-obvious relationships

Classifying for a single subject first


In [15]:
from sklearn.ensemble import RandomForestClassifier
import pandas as pd
import numpy as np

subj_id = 27

df = pd.DataFrame({'att':[int(i[2]) for i in cln_data[subj_id]],
                    'med':[int(i[3]) for i in cln_data[subj_id]],
                    'difficulty':[int(i[4]) for i in cln_data[subj_id]]})
df['is_train'] = np.random.uniform(0, 1, len(df)) <= .75
df.head()


/usr/lib/python2.7/dist-packages/pandas/core/config.py:570: DeprecationWarning: height has been deprecated.

  warnings.warn(d.msg, DeprecationWarning)
/usr/lib/python2.7/dist-packages/pandas/core/config.py:570: DeprecationWarning: height has been deprecated.

  warnings.warn(d.msg, DeprecationWarning)
Out[15]:
att difficulty med is_train
0 44 1 40 True
1 40 1 35 False
2 41 1 35 True
3 57 1 21 False
4 66 1 30 True

Divide the data into train and test. Fit the random forest model on train data


In [16]:
from sklearn.metrics import classification_report

train, test = df[df['is_train']==True], df[df['is_train']==False]
features = df.columns[[0,2]]

clf = RandomForestClassifier(n_jobs=2)
clf.fit(train[features], train['difficulty'])

preds = clf.predict(test[features])

print classification_report(test['difficulty'], preds)
pd.crosstab(test['difficulty'], preds, rownames=['actual'], colnames=['preds'])


             precision    recall  f1-score   support

          1       0.51      0.51      0.51        75
          2       0.44      0.47      0.46        60
          3       0.29      0.20      0.24        10

avg / total       0.47      0.47      0.47       145

/usr/lib/python2.7/dist-packages/pandas/core/config.py:570: DeprecationWarning: height has been deprecated.

  warnings.warn(d.msg, DeprecationWarning)
/usr/lib/python2.7/dist-packages/pandas/core/config.py:570: DeprecationWarning: height has been deprecated.

  warnings.warn(d.msg, DeprecationWarning)
Out[16]:
preds 1 2 3
actual
1 38 33 4
2 31 28 1
3 6 2 2

Very bad :(

Justis' idea of analysing the correlation between time taken to read a sentence with the difficulty

Correlation - Pearson's R value

Header of task.xls

['machine', 'SUBJECT', 'start_time', 'end_time', 'stim', 'block', 'pool', 'modality', 'TEXT', 'difficulty']

In [52]:
from scipy.stats.stats import pearsonr

def format_time(ti):
    """
    Converts format '2010-12-14 16:56:36.996' to Decimal
    """
    to = datetime.strptime(ti, '%Y-%m-%d %H:%M:%S.%f')
    #Decimal not losing precision
    to = Decimal(to.strftime('%s.%f'))
    return str(to)
    
def get_num_words():
    path_task_xls = DATA_URL + "/task.xls"
    
    with open(path_task_xls, 'rb') as fi:
        fr = csv.reader(fi, delimiter='\t')
        next(fr)#header

        data = list(fr)
        data_cols = zip(*data)
        
        l=len(data_cols[0])
        num_words_stim = [float(len(i.split())) for i in data_cols[4]]
        num_chars_stim = [float(len(i)) for i in data_cols[4]]
        difficulty = [float(i) for i in data_cols[-1]]
        time_diff = [float(Decimal(format_time(data_cols[3][i]))-\
                    Decimal(format_time(data_cols[2][i])))\
                    for i in xrange(l)]
        
        time_per_word = [time_diff[i]/num_words_stim[i] for i in range(l)]
        time_per_char = [time_diff[i]/num_chars_stim[i] for i in range(l)]
        
        sentence_idx=[i for i in xrange(l) if num_words_stim[i] > 1]
        
        print pearsonr(time_per_word, difficulty)
        print pearsonr(time_per_char, difficulty)

        print pearsonr([time_per_word[i] for i in sentence_idx], [difficulty[i] for i in sentence_idx])
        print pearsonr([time_per_char[i] for i in sentence_idx], [difficulty[i] for i in sentence_idx])

        tpa = [difficulty[i] for i in sentence_idx]
        hist(tpa)

get_num_words()


(0.39954572289607321, 4.7962634348368499e-152)
(0.52805066277565826, 6.1439118752796796e-284)
(0.1042932478927345, 1.4676898915448957e-07)
(-0.011311405571011842, 0.56964249492174535)

In [36]: