In [ ]:
LABEL2IDX = {'gene perturbation':2, 'drug perturbation':1, 'disease signature':0}
RUN_LABEL = 'disease signature'
DATA_PATH = '../../data/gesgnext'
import os
import numpy as np
import scipy as sp
import pandas as pd
from sklearn.neighbors import radius_neighbors_graph
from sklearn.metrics.pairwise import pairwise_distances as pdist
from bionlp.util import io
from bionlp.model import kallima
from bionlp import dstclc
# import matplotlib.pyplot as plt
# from bionlp.util import plot
# plot.MON = False
# plt.rcParams['axes.labelweight'] = 'bold'
In [ ]:
# GSM Vectors
gsm_X = io.read_df(os.path.join(DATA_PATH, 'gsm_X_%i.npz' % LABEL2IDX[RUN_LABEL]), with_idx=True, sparse_fmt='csr')
title_cols = [col for col in gsm_X.columns if col.startswith('title')]
gsm_X = gsm_X[title_cols]
In [ ]:
# Signatures
# sgn_df = pd.read_csv(os.path.join(DATA_PATH, '%s.csv' % RUN_LABEL.replace(' ', '_')))
sgn_df = pd.read_excel(os.path.join(DATA_PATH, '%s.xlsx' % RUN_LABEL.replace(' ', '_')))
In [ ]:
# sgn_df = sgn_df[sgn_df.geo_id == 'GSE48301']
useless_sgnids = []
for geo_id, sgn_ids in sgn_df.groupby('geo_id').groups.iteritems():
if (len(sgn_ids) == 1): continue
# Select the subset of signatures
sub_sgn_df = sgn_df.loc[sgn_ids]
# Extract all the GSM
gsms = list(set('|'.join(['|'.join(sub_sgn_df['ctrl_ids']), '|'.join(sub_sgn_df['pert_ids'])]).split('|')))
gsm_id_map = dict(zip(gsms, range(len(gsms))))
# Retrieve the data for each GSM
X = gsm_X.loc[gsms]
pw_dist = []
# Calculate the pairwise distance
dist_mt = pdist(X, metric='euclidean', n_jobs=-1)
# Process each signature of a GSE study
for idx, sgn in sub_sgn_df.iterrows():
# Extract the control and perturbation GSM indice
ctrl_ids, pert_ids = sgn['ctrl_ids'].split('|'), sgn['pert_ids'].split('|')
ctrl_idx, pert_idx = [gsm_id_map[x] for x in ctrl_ids], [gsm_id_map[x] for x in pert_ids]
# Obtain the distance matrix of those GSMs, Calculate the distance among clusters
# pw_dist.append(dist_mt[ctrl_idx,:][:,pert_idx].max())
# Use Ward's Method to measure the cluster distance
num_ctrl, num_pert = len(ctrl_ids), len(pert_ids)
pw_dist.append(1.0 * (num_ctrl * num_pert) / (num_ctrl + num_pert) * (np.linalg.norm(X.loc[ctrl_ids].mean(axis=0) - X.loc[pert_ids].mean(axis=0))))
# Find a cut value for filtering
# plot.plot_hist(np.array(pw_dist), 'Distance of Pairwise Sample Group', 'Number of Group Pairs', title='', style='ggplot', facecolor='skyblue', fmt='pdf', plot_cfg={'xlabel_fontsize':14,'ylabel_fontsize':14})
hist, bin_edges = np.histogram(pw_dist)
weird_val_idx = len(hist) - 1 - np.abs(hist[-1:0:-1] - hist[-2::-1]).argmax()
cut_val = (bin_edges[weird_val_idx] + bin_edges[weird_val_idx + 1]) / 2
cut_val = bin_edges[-2]
# Filter out the signatures
useless_sgnids.extend([sgn_id for sgn_id, dist in zip(sub_sgn_df['id'], pw_dist) if dist > cut_val])
In [ ]:
fltr_sgn_df = sgn_df.set_index('id').drop(useless_sgnids, axis=0)
fltr_sgn_df.to_csv('%s.csv' % RUN_LABEL.replace(' ', '_'), encoding='utf8')
fltr_sgn_df.to_excel('%s.xlsx' % RUN_LABEL.replace(' ', '_'), encoding='utf8')
print 'Filter out %i signatures!' % (sgn_df.shape[0] - fltr_sgn_df.shape[0])
In [ ]:
fltr_sgn_df = pd.read_csv('%s.csv' % RUN_LABEL.replace(' ', '_'), encoding='utf8').set_index('id')
sgn_dict, duplc_sgn = {}, []
for idx, ctrl_ids, pert_ids in zip(fltr_sgn_df.index, fltr_sgn_df['ctrl_ids'], fltr_sgn_df['pert_ids']):
if (sgn_dict.has_key((ctrl_ids, pert_ids))):
duplc_sgn.append(idx)
print 'Duplicate signature: %s' % idx
continue
else:
sgn_dict[(ctrl_ids, pert_ids)] = idx
final_sgn_df = fltr_sgn_df.drop(duplc_sgn, axis=0)
final_sgn_df.to_csv('%s.csv' % RUN_LABEL.replace(' ', '_'), encoding='utf8')
final_sgn_df.to_excel('%s.xlsx' % RUN_LABEL.replace(' ', '_'), encoding='utf8')
print 'Filter out %i signatures!' % (fltr_sgn_df.shape[0] - final_sgn_df.shape[0])