Use gene sets from MSigDB to prune the number of genes/features.
In [12]:
import os
import json
import numpy as np
import pandas as pd
import tensorflow as tf
import keras
import matplotlib.pyplot as plt
# fix random seed for reproducibility
np.random.seed(42)
# See https://github.com/h5py/h5py/issues/712
os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
In [3]:
X = pd.read_hdf("data/tcga_target_gtex.h5", "expression")
X.head()
Out[3]:
In [4]:
Y = pd.read_hdf("data/tcga_target_gtex.h5", "labels")
Y.head()
Out[4]:
In [5]:
# Load gene sets from downloaded MSigDB gmt file
# KEGG to for now as its experimental vs. computational)
with open("data/c2.cp.kegg.v6.1.symbols.gmt") as f:
gene_sets = {line.strip().split("\t")[0]: line.strip().split("\t")[2:]
for line in f.readlines()}
print("Loaded {} gene sets".format(len(gene_sets)))
# Drop genes not in X - sort so order is the same as X_pruned.columns
gene_sets = {name:
sorted([gene for gene in genes if gene in X.columns.values])
for name, genes in gene_sets.items()}
# Find union of all gene's in all gene sets in order to filter our input rows
all_gene_set_genes = sorted(list(set().union(
*[gene_set for gene_set in gene_sets.values()])))
print("Subsetting to {} genes".format(len(all_gene_set_genes)))
# Prune X to only include genes in the gene sets
X_pruned = X.drop(labels=(set(X.columns)
- set(all_gene_set_genes)), axis=1, errors="ignore")
assert X_pruned["TP53"]["TCGA-ZP-A9D4-01"] == X["TP53"]["TCGA-ZP-A9D4-01"]
print("X_pruned shape", X_pruned.shape)
# Make sure the genes are the same and in the same order
assert len(all_gene_set_genes) == len(X_pruned.columns.values)
assert list(X_pruned.columns.values) == all_gene_set_genes
In [6]:
from sklearn.preprocessing import LabelEncoder
primary_site_encoder = LabelEncoder()
Y["primary_site_value"] = pd.Series(
primary_site_encoder.fit_transform(Y["primary_site"]), index=Y.index)
Y.describe(include="all", percentiles=[])
Out[6]:
In [7]:
# Create one hots for training
from keras.utils import np_utils
Y_primary_site_one_hot = np_utils.to_categorical(Y["primary_site_value"])
print(Y_primary_site_one_hot[0:20000:5000])
In [8]:
# Split into stratified training and test sets based primary site
from sklearn.model_selection import StratifiedShuffleSplit
split = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
for train_index, test_index in split.split(X_pruned.values, Y["primary_site_value"]):
X_train = X_pruned.values[train_index]
X_test = X_pruned.values[test_index]
y_train = Y_primary_site_one_hot[train_index]
y_test = Y_primary_site_one_hot[test_index]
primary_site_train = Y["primary_site_value"].values[train_index]
primary_site_test = Y["primary_site_value"].values[test_index]
print(X_train.shape, X_test.shape)
In [13]:
# Lets see how big each class is based on primary site
plt.hist(primary_site_train, alpha=0.5, label='Train')
plt.hist(primary_site_test, alpha=0.5, label='Test')
plt.legend(loc='upper right')
plt.title("Primary Site distribution between train and test")
plt.show()
In [103]:
%%time
from keras.models import Model, Sequential
from keras.layers import InputLayer, Dense, BatchNormalization, Dropout
from keras.callbacks import EarlyStopping
from keras import regularizers
classify = [
InputLayer(input_shape=(X_train.shape[1],)),
BatchNormalization(),
Dense(32, activation='relu'),
Dropout(0.5),
Dense(16, activity_regularizer=regularizers.l1(1e-5), activation='relu'),
Dropout(0.5),
Dense(y_train.shape[1], activation='sigmoid')
]
model = Sequential(classify)
print(model.summary())
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
callbacks = [EarlyStopping(monitor='acc', min_delta=0.05, patience=2, verbose=2, mode="max")]
model.fit(X_train, y_train, epochs=10, batch_size=128, shuffle="batch", callbacks=callbacks)
print(model.metrics_names, model.evaluate(X_test, y_test))
In [104]:
%%time
predictions = model.predict(X_test)
labels = primary_site_encoder.classes_.tolist()
In [105]:
# Let's eye ball the top three predictions against ground truth
[(labels[primary_site_test[i]], ", ".join(["{}({:0.2f})".format(labels[i], p[i])
for i in p.argsort()[-3:][::-1]]))
for i, p in enumerate(predictions)][0:-1:200]
Out[105]:
In [106]:
import sklearn.metrics
import matplotlib.ticker as ticker
confusion_matrix = sklearn.metrics.confusion_matrix(
primary_site_test, np.array([np.argmax(p) for p in predictions]))
fig = plt.figure(figsize=(12,12))
ax = fig.add_subplot(111)
cax = ax.matshow(confusion_matrix, cmap=plt.cm.gray)
ax.set_xticklabels(primary_site_encoder.classes_.tolist(), rotation=90)
ax.set_yticklabels(primary_site_encoder.classes_.tolist())
ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
ax.set_xlabel("Primary Site Confusion Matrix for Holdout Test Data")
plt.show()
In [107]:
# Show only where there are errors
row_sums = confusion_matrix.sum(axis=1, keepdims=True)
norm_conf_mx = confusion_matrix / row_sums
np.fill_diagonal(norm_conf_mx, 0)
fig = plt.figure(figsize=(12,12))
ax = fig.add_subplot(111)
cax = ax.matshow(norm_conf_mx, cmap=plt.cm.gray)
ax.set_xticklabels(primary_site_encoder.classes_.tolist(), rotation=90)
ax.set_yticklabels(primary_site_encoder.classes_.tolist())
ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
ax.set_xlabel("Primary Site Prediction Errors")
plt.show()
In [108]:
# Save the model for separate inference
with open("models/primary_site.params.json", "w") as f:
f.write(json.dumps({
"labels": primary_site_encoder.classes_.tolist(),
"genes": all_gene_set_genes}))
with open("models/primary_site.model.json", "w") as f:
f.write(model.to_json())
model.save_weights("models/primary_site.weights.h5")