In [1]:
%matplotlib inline
import time
import itertools as it
import numpy as np
import matplotlib.pylab as plt
import matplotlib.patches as patches
import operator as op
import scipy.cluster
from sklearn.preprocessing import StandardScaler
from multiprocessing import cpu_count

In [2]:
import nbutil
import golub_utils

# This data comes from:
# http://www.broadinstitute.org/mpr/publications/projects/Leukemia/data_set_ALL_AML_train.txt
# http://www.broadinstitute.org/mpr/publications/projects/Leukemia/table_ALL_AML_samples.txt
relation = np.array(golub_utils.GENE_DATA, dtype=np.float)
relation = relation.T # people x genes

In [3]:
print relation.shape


(38, 7129)

In [4]:
# take top 100 highest variance genes
topK = 100
inds = map(
    op.itemgetter(0), 
    sorted(enumerate(relation.var(axis=0)), key=op.itemgetter(1), reverse=True))[:topK]

In [5]:
small_relation = relation[:,inds]
# scale each column of gene expression to zero mean unit variance
small_relation = StandardScaler().fit_transform(small_relation)

In [6]:
print small_relation


[[ 0.28533844  0.79394618  0.14882763 ...,  0.72536091 -0.15434762
  -0.50003737]
 [-0.33250523  0.60202902  0.66767738 ...,  0.72863169 -0.11204041
   2.01110035]
 [ 1.12047391 -0.94067332 -0.59478749 ...,  2.17163823 -0.18294968
   3.14777404]
 ..., 
 [ 0.53961778 -0.46653759 -0.7292541  ..., -0.6450939  -0.33102491
  -0.1929413 ]
 [ 0.22654184 -0.26373302  0.30280172 ...,  0.72417154 -0.19754864
   0.20556664]
 [-1.10117496  0.99248118 -1.03204937 ..., -0.99952699  0.81424915
  -0.47434645]]

In [7]:
from microscopes.common.rng import rng
from microscopes.common.relation.dataview import numpy_dataview
from microscopes.models import nich as normal_inverse_chi_squared
from microscopes.irm.definition import model_definition
from microscopes.irm import model, runner, query
from microscopes.kernels import parallel
from microscopes.common.query import zmatrix_reorder, zmatrix_heuristic_block_ordering

In [8]:
defn = model_definition(domains=small_relation.shape, relations=[((0, 1), normal_inverse_chi_squared)])
views = [numpy_dataview(small_relation)]
prng = rng()
nchains = 50
latents = [model.initialize(defn, views, prng) for _ in xrange(nchains)]
kc = runner.default_kernel_config(defn)
runners = [runner.runner(defn, views, latent, kc) for latent in latents]

In [9]:
start = time.time()
r = parallel.runner(runners)
r.run(r=prng, niters=1000) # burnin
print "finished in", (time.time() - start), "seconds"


finished in 560.621087074 seconds

In [10]:
infers = r.get_latents()
zmat_people = query.zmatrix(0, infers)
zmat_genes = query.zmatrix(1, infers)

In [11]:
# order ALL/B, ALL/T, and then AML
inds_allb = [i for i, t in enumerate(golub_utils.CANCER_TYPES) if t == "ALL/B"]
inds_allt = [i for i, t in enumerate(golub_utils.CANCER_TYPES) if t == "ALL/T"]
inds_aml = [i for i, t in enumerate(golub_utils.CANCER_TYPES) if t == "AML"]
indices = inds_allb + inds_allt + inds_aml

with nbutil.figsize(10.0, 8.0) as _:
    plt.subplot(2, 2, 1)
    plt.imshow(zmatrix_reorder(zmat_people, indices), 
               cmap=plt.cm.binary, interpolation='nearest')
    nall = len(inds_allb) + len(inds_allt)
    rect1 = patches.Rectangle((-1, -1), 
        width=len(inds_allb)+0.5,
        height=len(inds_allb)+0.5, 
        alpha=0.5, 
        fc="red")
    rect2 = patches.Rectangle((len(inds_allb)-.5, len(inds_allb)-.5), 
        width=len(inds_allt),
        height=len(inds_allt), 
        alpha=0.5, 
        fc="green")
    rect3 = patches.Rectangle((nall-.5, nall-.5), 
        width=len(inds_aml)+0.5,
        height=len(inds_aml)+0.5, 
        alpha=0.5, 
        fc="blue")
    plt.gca().add_patch(rect1)
    plt.gca().add_patch(rect2)
    plt.gca().add_patch(rect3)
    
    plt.xlabel('people')
    plt.ylabel('people')
    plt.title('Z-matrix')

    plt.subplot(2, 2, 2)
    plt.imshow(zmat_genes, cmap=plt.cm.binary, interpolation='nearest')
    plt.xlabel('genes')
    plt.ylabel('genes')
    plt.title('Z-matrix')
    plt.tight_layout()


Red patch indicates these people have the "ALL/B" type of cancer

Green patch indicates "ALL/T" type of cancer

Blue patch indicates "AML" type of cancer


In [12]:
def subst(name):
    if name == "ALL/B": return "B"
    if name == "ALL/T": return "T"
    return name

ordering_people = zmatrix_heuristic_block_ordering(zmat_people)
ordering_genes = zmatrix_heuristic_block_ordering(zmat_genes)

people_labels = np.array(map(subst, golub_utils.CANCER_TYPES))[ordering_people]
genes_labels = np.array(golub_utils.GENE_IDENTS)[ordering_genes]

In [13]:
with nbutil.figsize(21., 16.) as _:
    plt.subplot(1, 2, 1)
    plt.imshow(zmatrix_reorder(zmat_people, ordering_people), 
               cmap=plt.cm.binary, interpolation='nearest')
    plt.xlabel("people (sorted)")
    plt.ylabel("people (sorted)")
    plt.xticks(range(zmat_people.shape[0]), people_labels, rotation='vertical')
    plt.yticks(range(zmat_people.shape[0]), people_labels)
    plt.title("clustered Z-matrix of people")
    
    plt.subplot(1, 2, 2)
    plt.imshow(zmatrix_reorder(zmat_genes, ordering_genes), 
               cmap=plt.cm.binary, interpolation='nearest')
    plt.xlabel("genes (sorted)")
    plt.ylabel("genes (sorted)")
    plt.title("clustered Z-matrix of genes")
    
    plt.tight_layout()



In [14]:
with nbutil.figsize(14., 10.) as _:
    plt.imshow(small_relation, cmap=plt.cm.BrBG, interpolation='nearest')
    plt.xlabel("genes")
    plt.ylabel("people")



In [15]:
z = small_relation.copy()
z = z[ordering_people]
z = z[:,ordering_genes]

z1 = z[:,:z.shape[1]/2]
z2 = z[:,z.shape[1]/2:]

with nbutil.figsize(20., 16.) as _:
    plt.subplot(2, 1, 1)
    plt.imshow(z1, cmap=plt.cm.BrBG, interpolation='nearest')
    plt.xlabel("genes (sorted)")
    plt.ylabel("people (sorted)")
    plt.xticks(range(z1.shape[1]), genes_labels[:z1.shape[1]], 
               size=8, rotation='vertical')
    plt.yticks(range(len(people_labels)), people_labels, size=8)
    
    plt.subplot(2, 1, 2)
    plt.imshow(z2, cmap=plt.cm.BrBG, interpolation='nearest')
    plt.xlabel("genes (sorted)")
    plt.ylabel("people (sorted)")
    plt.xticks(range(z2.shape[1]), genes_labels[z1.shape[1]:], 
               size=8, rotation='vertical')
    plt.yticks(range(len(people_labels)), people_labels, size=8)
    
    plt.tight_layout()


Dark brown is low values, dark green is high values.


In [ ]: