A Basic EEG classifier with Inception

Goals

  1. Collect Corvo data from one user, perhaps from just one long session so the electrodes are in the same position.
  2. Break data into short epochs (2s?)
  3. Convert each epoch into a spectrogram (sounds like we can just convert it into a matrix instead of an actual jpeg)
  4. Label each epoch with its associated current Corvo performance score
  5. Feed to Inception (I have a big GPU that might help here) https://www.youtube.com/watch?v=cSKfRcEDGUs
  6. Test

In [44]:
import pandas as pd
import numpy as np
import scipy.stats as stats
import scipy.signal as signal
import matplotlib.pyplot as plt
import sklearn as sk
import tensorflow as tf
from tensorflow.contrib import learn

EPOCH_LENGTH = 440
VARIANCE_THRESHOLD = 550

In [45]:
# Data has been collected, let's import it

open_data = pd.read_csv("../Muse Data/DanoThursdayOpenRawEEG0.csv", header=0, index_col=False)
closed_data = pd.read_csv("../Muse Data/DanoThursdayClosedRawEEG1.csv", header=0, index_col=False)

In [46]:
# Drop difficulty, timestamp, and performance columns since we're not using them

open_data = open_data.drop(['Difficulty', 'Performance', 'Timestamp (ms)'], axis=1)
closed_data = closed_data.drop(['Difficulty', 'Performance', 'Timestamp (ms)'], axis=1)


# Prune rows from tail of datasets so that they are all divisible by 440 (the desired size of our epochs)

open_overflow = open_data.shape[0] % EPOCH_LENGTH
open_data = open_data[0:-open_overflow]
closed_overflow = closed_data.shape[0] % EPOCH_LENGTH
closed_data = closed_data[0:-closed_overflow]

In [47]:
# Split DataFrames into many different dataframes 440 samples long

split_open_data = np.stack(np.array_split(open_data, EPOCH_LENGTH), axis=1)
split_closed_data = np.stack(np.array_split(closed_data, EPOCH_LENGTH), axis=1)

# Transform data into a 3D pandas Panel ( n epochs x 4 channels x 440 samples )

open_panel = pd.Panel(split_open_data)
closed_panel = pd.Panel(split_closed_data)


open_panel.shape


Out[47]:
(202, 440, 4)

In [48]:
# Remove epochs with too much variance

def removeNoise(panel):
    for frameIndex in panel:
        for columnIndex in panel[frameIndex]:
            if np.var(panel[frameIndex][columnIndex]) > VARIANCE_THRESHOLD:
                print('variance ', np.var(panel[frameIndex][columnIndex]), ' at electrode ', columnIndex, ' frame ', frameIndex)
                panel = panel.drop(frameIndex)
                break
    return panel
        
closed_panel = removeNoise(closed_panel)
open_panel = removeNoise(open_panel)


variance  565.2019284170628  at electrode  3  frame  72
variance  585.4320009092443  at electrode  3  frame  73
variance  596.9470221979449  at electrode  3  frame  74
variance  601.5649475614356  at electrode  3  frame  75
variance  564.6084621180363  at electrode  3  frame  76
variance  580.7225537057597  at electrode  3  frame  77
variance  581.1037362989539  at electrode  3  frame  78
variance  575.4028816989534  at electrode  3  frame  79
variance  593.2626671243911  at electrode  3  frame  80
variance  603.3318585880803  at electrode  3  frame  81
variance  607.0889083912873  at electrode  3  frame  82
variance  618.329252343474  at electrode  3  frame  83
variance  602.9789655801228  at electrode  3  frame  84
variance  607.3852222287647  at electrode  3  frame  85
variance  621.6071542852401  at electrode  3  frame  86
variance  627.8743596727223  at electrode  3  frame  87
variance  620.4504121536488  at electrode  3  frame  88
variance  622.3241359324637  at electrode  3  frame  89
variance  607.3761511117598  at electrode  3  frame  90
variance  590.0657343364229  at electrode  3  frame  91
variance  575.8104529961632  at electrode  3  frame  92
variance  569.5319541378939  at electrode  3  frame  133
variance  714.4534612214848  at electrode  3  frame  134
variance  916.0835804070445  at electrode  3  frame  135
variance  1136.7916004543645  at electrode  3  frame  136
variance  1392.0835266392182  at electrode  3  frame  137
variance  1653.0708144364266  at electrode  3  frame  138
variance  1818.564342976062  at electrode  3  frame  139
variance  1843.4597929769334  at electrode  3  frame  140
variance  1861.0930380873722  at electrode  3  frame  141
variance  1869.3656032782656  at electrode  3  frame  142
variance  1853.3512380693373  at electrode  3  frame  143
variance  1891.4706854395013  at electrode  3  frame  144
variance  1924.0795759878329  at electrode  3  frame  145
variance  1900.477060667249  at electrode  3  frame  146
variance  1934.0630041634647  at electrode  3  frame  147
variance  1964.8055648018537  at electrode  3  frame  148
variance  1959.9487433954341  at electrode  3  frame  149
variance  1948.833969420409  at electrode  3  frame  150
variance  1961.7047243550517  at electrode  3  frame  151
variance  1994.5437752789437  at electrode  3  frame  152
variance  1995.162008942119  at electrode  3  frame  153
variance  1985.7908462217063  at electrode  3  frame  154
variance  1970.695207656704  at electrode  3  frame  155
variance  1968.7063967015808  at electrode  3  frame  156
variance  1944.007953471432  at electrode  3  frame  157
variance  1922.0496444992243  at electrode  3  frame  158
variance  1900.3542162183069  at electrode  3  frame  159
variance  1890.8265522696743  at electrode  3  frame  160
variance  1888.3099113259416  at electrode  3  frame  161
variance  1871.1975076702074  at electrode  3  frame  162
variance  1868.9206852558796  at electrode  3  frame  163
variance  1868.6352875083571  at electrode  3  frame  164
variance  1980.0285344545682  at electrode  3  frame  165
variance  1964.022681279929  at electrode  3  frame  166
variance  867.6490569963588  at electrode  3  frame  167
variance  839.2241737414919  at electrode  3  frame  168
variance  776.6074690551703  at electrode  3  frame  169
variance  695.1209540925781  at electrode  3  frame  170
variance  637.9379742863572  at electrode  3  frame  171
variance  589.5089857945231  at electrode  3  frame  172

In [51]:
plt.figure()
plt.subplot(2,2,1)
plt.specgram(open_panel[20][0], NFFT=256, Fs=220, noverlap=198)
plt.ylim(0,55)
plt.subplot(2,2,2)
plt.specgram(open_panel[20][1], NFFT=256, Fs=220, noverlap=198)
plt.ylim(0,55)
plt.subplot(2,2,3)
plt.specgram(open_panel[20][2], NFFT=256, Fs=220, noverlap=198)
plt.ylim(0,55)
plt.subplot(2,2,4)
plt.specgram(open_panel[20][3], NFFT=256, Fs=220, noverlap=198)
plt.ylim(0,55)
plt.show


/home/dano/anaconda3/lib/python3.5/site-packages/matplotlib/pyplot.py:516: RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (`matplotlib.pyplot.figure`) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam `figure.max_open_warning`).
  max_open_warning, RuntimeWarning)
Out[51]:
<function matplotlib.pyplot.show>

In [7]:
# Plot test spectograms of all 4 channels

def plotAndSave(frame, filename):
    plt.figure()
    plt.subplot(2,2,1)
    plt.specgram(frame[0], NFFT=256, Fs=220, noverlap=198)
    plt.ylim(0,55)
    plt.subplot(2,2,2)
    plt.specgram(frame[2], NFFT=256, Fs=220, noverlap=198)
    plt.ylim(0,55)
    plt.subplot(2,2,3)
    plt.specgram(frame[3], NFFT=256, Fs=220, noverlap=198)
    plt.ylim(0,55)
    plt.savefig('%s.jpg' % filename, pad_inches=0, bbox_inches='tight')
    
for frameIndex in open_panel:
    plotAndSave(open_panel[frameIndex], 'open%s' % frameIndex)

for frameIndex in closed_panel:
    plotAndSave(closed_panel[frameIndex], 'closed%s' % frameIndex)


/home/dano/anaconda3/lib/python3.5/site-packages/matplotlib/pyplot.py:516: RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (`matplotlib.pyplot.figure`) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam `figure.max_open_warning`).
  max_open_warning, RuntimeWarning)
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-7-d8dae9e5199a> in <module>()
     15 
     16 for frameIndex in open_panel:
---> 17     plotAndSave(open_panel[frameIndex], 'open%s' % frameIndex)
     18 
     19 for frameIndex in closed_panel:

<ipython-input-7-d8dae9e5199a> in plotAndSave(frame, filename)
     12     plt.specgram(frame[3], NFFT=256, Fs=220, noverlap=198)
     13     plt.ylim(0,55)
---> 14     plt.savefig('%s.jpg' % filename, pad_inches=0, bbox_inches='tight')
     15 
     16 for frameIndex in open_panel:

/home/dano/anaconda3/lib/python3.5/site-packages/matplotlib/pyplot.py in savefig(*args, **kwargs)
    686 def savefig(*args, **kwargs):
    687     fig = gcf()
--> 688     res = fig.savefig(*args, **kwargs)
    689     fig.canvas.draw_idle()   # need this if 'transparent=True' to reset colors
    690     return res

/home/dano/anaconda3/lib/python3.5/site-packages/matplotlib/figure.py in savefig(self, *args, **kwargs)
   1563             self.set_frameon(frameon)
   1564 
-> 1565         self.canvas.print_figure(*args, **kwargs)
   1566 
   1567         if frameon:

/home/dano/anaconda3/lib/python3.5/site-packages/matplotlib/backends/backend_qt5agg.py in print_figure(self, *args, **kwargs)
    195     def print_figure(self, *args, **kwargs):
    196         FigureCanvasAgg.print_figure(self, *args, **kwargs)
--> 197         self.draw()
    198 
    199 

/home/dano/anaconda3/lib/python3.5/site-packages/matplotlib/backends/backend_qt5agg.py in draw(self)
    156         # The Agg draw is done here; delaying causes problems with code that
    157         # uses the result of the draw() to update plot elements.
--> 158         FigureCanvasAgg.draw(self)
    159         self.update()
    160 

/home/dano/anaconda3/lib/python3.5/site-packages/matplotlib/backends/backend_agg.py in draw(self)
    472 
    473         try:
--> 474             self.figure.draw(self.renderer)
    475         finally:
    476             RendererAgg.lock.release()

/home/dano/anaconda3/lib/python3.5/site-packages/matplotlib/artist.py in draw_wrapper(artist, renderer, *args, **kwargs)
     59     def draw_wrapper(artist, renderer, *args, **kwargs):
     60         before(artist, renderer)
---> 61         draw(artist, renderer, *args, **kwargs)
     62         after(artist, renderer)
     63 

/home/dano/anaconda3/lib/python3.5/site-packages/matplotlib/figure.py in draw(self, renderer)
   1157         dsu.sort(key=itemgetter(0))
   1158         for zorder, a, func, args in dsu:
-> 1159             func(*args)
   1160 
   1161         renderer.close_group('figure')

/home/dano/anaconda3/lib/python3.5/site-packages/matplotlib/artist.py in draw_wrapper(artist, renderer, *args, **kwargs)
     59     def draw_wrapper(artist, renderer, *args, **kwargs):
     60         before(artist, renderer)
---> 61         draw(artist, renderer, *args, **kwargs)
     62         after(artist, renderer)
     63 

/home/dano/anaconda3/lib/python3.5/site-packages/matplotlib/axes/_base.py in draw(self, renderer, inframe)
   2322 
   2323         for zorder, a in dsu:
-> 2324             a.draw(renderer)
   2325 
   2326         renderer.close_group('axes')

/home/dano/anaconda3/lib/python3.5/site-packages/matplotlib/artist.py in draw_wrapper(artist, renderer, *args, **kwargs)
     59     def draw_wrapper(artist, renderer, *args, **kwargs):
     60         before(artist, renderer)
---> 61         draw(artist, renderer, *args, **kwargs)
     62         after(artist, renderer)
     63 

/home/dano/anaconda3/lib/python3.5/site-packages/matplotlib/axis.py in draw(self, renderer, *args, **kwargs)
   1109 
   1110         for tick in ticks_to_draw:
-> 1111             tick.draw(renderer)
   1112 
   1113         # scale up the axis label box to also find the neighbors, not

/home/dano/anaconda3/lib/python3.5/site-packages/matplotlib/artist.py in draw_wrapper(artist, renderer, *args, **kwargs)
     59     def draw_wrapper(artist, renderer, *args, **kwargs):
     60         before(artist, renderer)
---> 61         draw(artist, renderer, *args, **kwargs)
     62         after(artist, renderer)
     63 

/home/dano/anaconda3/lib/python3.5/site-packages/matplotlib/axis.py in draw(self, renderer)
    245                 self.gridline.draw(renderer)
    246             if self.tick1On:
--> 247                 self.tick1line.draw(renderer)
    248             if self.tick2On:
    249                 self.tick2line.draw(renderer)

/home/dano/anaconda3/lib/python3.5/site-packages/matplotlib/artist.py in draw_wrapper(artist, renderer, *args, **kwargs)
     59     def draw_wrapper(artist, renderer, *args, **kwargs):
     60         before(artist, renderer)
---> 61         draw(artist, renderer, *args, **kwargs)
     62         after(artist, renderer)
     63 

/home/dano/anaconda3/lib/python3.5/site-packages/matplotlib/lines.py in draw(self, renderer)
    791                 snap = marker.get_snap_threshold()
    792                 if type(snap) == float:
--> 793                     snap = renderer.points_to_pixels(self._markersize) >= snap
    794                 gc.set_snap(snap)
    795                 gc.set_joinstyle(marker.get_joinstyle())

/home/dano/anaconda3/lib/python3.5/site-packages/matplotlib/backends/backend_agg.py in points_to_pixels(self, points)
    299         if __debug__: verbose.report('RendererAgg.points_to_pixels',
    300                                      'debug-annoying')
--> 301         return points*self.dpi/72.0
    302 
    303     def tostring_rgb(self):

KeyboardInterrupt: 

In [231]:



Questions to answer before continuing:

  • Do these spectrograms look alright?
  • Shouldn't the axis go all the way to 2 if there are 440 samples at 220hz sampling rate?