PSD Using mne features


In [ ]:
import os.path as op
import numpy as np
from scipy import stats
import mne

import matplotlib.pyplot as plt
import seaborn as sns

from mne.time_frequency import psd_welch, psd_multitaper

from jumeg.jumeg_utils import rescale_arr
from jumeg import get_jumeg_path
from jumeg.connectivity import plot_grouped_connectivity_circle

from mne_features.feature_extraction import extract_features
from mne_features.univariate import _freq_bands_helper

from rasmalai.network_utils import (make_count_based_stats,
                                    make_combined_average,
                                    get_con_metrics, get_coords_given_labels)

from glob import glob
import yaml
import pickle
import time
t_start = time.time()

sns.set_style('white')

In [2]:
epoch_fname = '/Users/psripad/kelsa/epochs_with_artifacts/207184_rest_EC_0.6-200bp_bcc,nr,ar,1-epo.fif'
epochs = mne.read_epochs(epoch_fname, preload=True)
data = epochs.get_data()


Reading /Users/psripad/kelsa/epochs_with_artifacts/207184_rest_EC_0.6-200bp_bcc,nr,ar,1-epo.fif ...
    Found the data of interest:
        t =    -499.87 ...     499.87 ms
        0 CTF compensation matrices available
430 matching events found
No baseline correction applied
Not setting metadata
0 projection items activated

In [3]:
n_colors = 7
color_list = plt.cm.Pastel1(np.linspace(0, 1, n_colors))
bands = ['delta', 'theta', 'alpha', 'beta1', 'beta2', 'gamma1', 'gamma2']
freqs = np.array([1., 4., 8., 13., 18., 30., 80., 200.])
selected_funcs = {'pow_freq_bands'}

In [4]:
def get_psd_fingerprint(epochs, freq_bands=None, normalize=True):
    '''Function to compute PSDs on epochs and return
       normalized power across the given frequency bands
       
       PSDs are computed using Welch's method and the sum
       of PSDs across frequencies (normalised) with the
       total power is returned.
       
       Uses the mne-features toolbox.
       
       Inputs
       ------
       epochs: mne.Epochs
           Instance of epochs.
    
       Output
       ------
       X_avg: ndarray
           Mean (normalized) sum of PSDs across given frequencies.
       
   '''
    from mne_features.feature_extraction import extract_features
    if not freq_bands:
        freq_bands = np.array([1., 4., 8., 13., 18., 30., 80., 200.])

    selected_funcs = {'pow_freq_bands'}
    X_new = extract_features(data, epochs.info['sfreq'], selected_funcs,
                             funcs_params={'pow_freq_bands__freq_bands': freqs,
                                           'pow_freq_bands__normalize': normalize,
                                           'pow_freq_bands__ratios': None,
                                           'pow_freq_bands__psd_method': 'welch',
                                           'pow_freq_bands__psd_params': {'welch_n_fft': 512}})
    X_new = X_new.reshape((len(epochs), -1, len(freqs) - 1))
    X_avg = X_new.mean(axis=(0, 1))
    return X_avg

In [20]:
def get_normalised_psds_welch(epochs, norm_freq=[80., 200.],
                              freq_bands=None, grand_average=False):
    '''Function to estimate PSDs and normalize
       the PSD values by given frequency band.
    
    Input
    -----
    epochs: mne.Epochs
        Epochs instance.

    norm_freq: list
        List giving frequency range to use for normalization.
    
    freq_bands: ndarray | list
        List of frequencies used to make average across bands.
        If this option is not None, then average normalized PSDs
        across the given bands are returned.
        (default: [1., 4., 8., 13., 18., 30., 80., 200.])
        
    grand_average: bool
        If True, then average across trials and channels are returned.
    
    Output
    ------
    norm_psds_welch: ndarray
        Normalized psds of shape (n_trials, n_channels, n_freqs)

    freqs_welch: ndarray
        Frequency points for which PSDs are estimated. 
    
    ''' 
    from mne.time_frequency import psd_welch
    from mne import pick_types
    from mne_features.univariate import _freq_bands_helper

    # estimate the PSDs
    picks = pick_types(epochs.info, meg=True, exclude='bads')
    psds_welch, freqs_welch = psd_welch(epochs, fmin=1., fmax=200., n_fft=512,
                                        picks=picks, n_jobs=4, average='mean')

    # get the mask for the freq range to be used for normalization
    norm_mask = np.logical_and(freqs_welch >= norm_freq[0], freqs_welch <= norm_freq[1])
    # compute mean PSDs across the freq range
    norm_mean = psds_welch[:, :, norm_mask].mean(axis=2)
    # get the normalized values
    norm_psds_welch = psds_welch / norm_mean[:, :, np.newaxis]
    
    if freq_bands is not None:
        psds_band_ = np.zeros((norm_psds_welch.shape[0],
                               norm_psds_welch.shape[1],
                               len(freq_bands) - 1))
        for i, (fmin, fmax) in enumerate(_freq_bands_helper(epochs.info['sfreq'], freq_bands)):
            mask = np.logical_and(freqs_welch >= fmin, freqs_welch <= fmax)
            psds_band_[:, :, i] = norm_psds_welch[:, :, mask].mean(axis=(2))
        if grand_average:
            # return the average across trials and channels and bands
            return psds_band_.mean(axis=(0, 1))
        else:
            # return the psds averaged across frequency bands only
            return psds_band_
    else:
        if grand_average:
            return norm_psds_welch.mean(axis=(0, 1)), freqs_welch
        else:
            # return the normalized psds and corresponding freqs
            return norm_psds_welch, freqs_welch

In [6]:
mypsds1, myfreqs = get_normalised_psds_welch(epochs, norm_freq=[80., 200.], freq_bands=None)
print(mypsds1.shape)


Effective window size : 0.755 (s)
[Parallel(n_jobs=4)]: Using backend LokyBackend with 4 concurrent workers.
[Parallel(n_jobs=4)]: Done   2 out of   4 | elapsed:   16.4s remaining:   16.4s
[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:   26.1s remaining:    0.0s
[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:   26.1s finished
(430, 247, 150)

In [21]:
mypsds2 = get_normalised_psds_welch(epochs, norm_freq=[80., 200.],
                                    freq_bands=np.array([1., 4., 8., 13., 18., 30., 80., 200.]))
print(mypsds2.shape)


Effective window size : 0.755 (s)
[Parallel(n_jobs=4)]: Using backend LokyBackend with 4 concurrent workers.
[Parallel(n_jobs=4)]: Done   2 out of   4 | elapsed:   12.7s remaining:   12.7s
[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:   34.1s remaining:    0.0s
[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:   34.1s finished
(430, 247, 7)

In [22]:
mypsds3 = get_normalised_psds_welch(epochs, norm_freq=[80., 200.],
                                    freq_bands=np.array([1., 4., 8., 13., 18., 30., 80., 200.]), 
                                    grand_average=True)
print(mypsds3.shape)


Effective window size : 0.755 (s)
[Parallel(n_jobs=4)]: Using backend LokyBackend with 4 concurrent workers.
[Parallel(n_jobs=4)]: Done   2 out of   4 | elapsed:   34.9s remaining:   34.9s
[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:  1.1min remaining:    0.0s
[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:  1.1min finished
(7,)

In [ ]:
picks = mne.pick_types(epochs.info, meg=True, exclude='bads')
psds_welch, freqs_welch = psd_welch(epochs, fmin=1., fmax=200., n_fft=512,
                                    picks=picks, n_jobs=4, average='mean')

In [ ]:
myavg = get_psd_fingerprint(epochs, freq_bands=None, normalize=False)
print(myavg)

In [ ]:
data.shape

So there are 430 epochs, 247 channels and 679 time points.

PSDs averaged across epochs and channels result in 7 data points, one for each frequency band.

Check PSDs using Welch's method

Learned that the extract_features returns the sum of PSDs over a given frequency range if it is not nromalized. This leads to a difference in the values between the mne-features function and the mne time_frequency function.


In [ ]:
# compute PSDs and plot using Welch's algorithm
picks = mne.pick_types(epochs.info, meg=True, exclude='bads')
psds_welch, freqs_welch = psd_welch(epochs, fmin=1., fmax=200., n_fft=256,
                                    picks=picks, n_jobs=4, average='mean')
psds_welch_avg = psds_welch.mean(axis=(0, 1))

In [ ]:
_freq_bands_helper?

In [ ]:
X_new = extract_features(data, epochs.info['sfreq'], selected_funcs,
                         funcs_params={'pow_freq_bands__freq_bands': freqs,
                                       'pow_freq_bands__normalize': False,
                                       'pow_freq_bands__ratios': None,
                                       'pow_freq_bands__psd_method': 'welch',
                                       'pow_freq_bands__psd_params': {'welch_n_fft': 256}})
X_new = X_new.reshape((430, -1, len(freqs) - 1))
X_avg = X_new.mean(axis=(0, 1))
print('Welch:', X_avg)

plt.figure(figsize=(10, 8))
plt.plot(freqs_welch, psds_welch_avg);
for i, (fmin, fmax) in enumerate(_freq_bands_helper(epochs.info['sfreq'], freqs)):
    mask = np.logical_and(freqs_welch >= fmin, freqs_welch <= fmax)
    plt.axhline(X_avg[i] / len(freqs_welch[mask]), 0, 1, label=bands[i], color=color_list[i]);

plt.legend();
plt.title('Welch');
plt.xlabel('freqs');
plt.ylabel('PSDs');

Compare PSDs using multitaper algorithm


In [ ]:
# compute PSDs and plot using multitaper algorithm
picks = mne.pick_types(epochs.info, meg=True, exclude='bads')
psds_multitaper, freqs_multi = psd_multitaper(epochs, fmin=1., fmax=200.,
                                              picks=picks, n_jobs=4,
                                              normalization='length')

In [ ]:
X_new = extract_features(data, epochs.info['sfreq'], selected_funcs,
                         funcs_params={'pow_freq_bands__freq_bands': freqs,
                                       'pow_freq_bands__normalize': False,
                                       'pow_freq_bands__ratios': None,
                                       'pow_freq_bands__psd_method': 'multitaper'})
X_new = X_new.reshape((430, -1, len(freqs) - 1))
X_avg = X_new.mean(axis=(0, 1))

plt.figure(figsize=(10, 8))
plt.plot(freqs_multi, psds_multitaper.mean(axis=(0, 1)));
for i, (fmin, fmax) in enumerate(_freq_bands_helper(epochs.info['sfreq'], freqs)):
    mask = np.logical_and(freqs_multi >= fmin, freqs_multi <= fmax)
    plt.axhline(X_avg[i] / len(freqs_multi[mask]), 0, 1,
                label=bands[i], color=color_list[i]);
plt.legend();
plt.title('multitaper');
plt.xlabel('freqs');
plt.ylabel('PSDs');

In [ ]:
plt.figure(figsize=(10, 8))

plt.plot(freqs_multi, psds_multitaper.mean(axis=(0, 1)), color='b', label='multitaper');
plt.plot(freqs_welch, psds_welch.mean(axis=(0, 1)), color='g', label='Welch');
plt.legend();
plt.title('Welch vs Multitaper');
plt.xlabel('freqs');
plt.ylabel('PSDs');

The normalization and averaging method used creates a large difference between PSD values between Welch and Multitaper.

End.