Imports


In [ ]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import math
import os

import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import tensorflow as tf
import pyspark.sql.functions as F

from breastcancer import input_data

plt.rcParams['figure.figsize'] = (10, 6)

In [ ]:
# from pyspark.sql import SparkSession
# spark = (SparkSession.builder.appName("KerasResNet50").getOrCreate())

Settings


In [ ]:
size = 256
channels = 3
features = size * size * channels
classes = 3
p = 1
val_p = 1
use_caching = False
normalize_class_distribution = False
seed = 123

Read in train & val data


In [ ]:
# Read and sample from full DataFrames
# TODO: Pull filenames out and simply pass them in as arguments.
# NOTE: ***Currently hacked read_* with updated data filenames.***
train_df = input_data.read_train_data(spark, size, channels, p, normalize_class_distribution, seed)
val_df = input_data.read_val_data(spark, size, channels, val_p, normalize_class_distribution, seed)

In [ ]:
# # Save DataFrames (Optional)
# mode = "error"
# tr_sample_filename = os.path.join("data", "train_{}_sample_{}.parquet".format(p, size))
# val_sample_filename = os.path.join("data", "val_{}_sample_{}.parquet".format(val_p, size))
# train_df.write.mode(mode).save(tr_sample_filename, format="parquet")
# val_df.write.mode(mode).save(val_sample_filename, format="parquet")

In [ ]:
if use_caching:
  train_df.cache()
  val_df.cache()

In [ ]:
# Explore class distributions.
for df in [train_df, val_df]:
  df.select("tumor_score").groupBy("tumor_score").count().show()

In [ ]:
tc = train_df.count()
vc = val_df.count()
print(tc, vc)  # updated norm vs: 1801835 498183; original: 3560187 910918

In [ ]:
# Sanity check that there are no duplicates.
if p < 1:
  assert train_df.dropDuplicates().count() == tc
if val_p < 1:
  assert val_df.dropDuplicates().count() == vc

Normalize Staining


In [ ]:
def normalize_staining(x, beta=0.15, alpha=1, light_intensity=240):
  """
  Normalize the staining of H&E histology slides.
  
  This function normalizes the staining of H&E histoloy slides.
  
  References:
    - Macenko, Marc, et al. "A method for normalizing histology slides for
    quantitative analysis." Biomedical Imaging: From Nano to Macro, 2009.
    ISBI'09. IEEE International Symposium on. IEEE, 2009.
      - http://wwwx.cs.unc.edu/~mn/sites/default/files/macenko2009.pdf
    - https://github.com/mitkovetta/staining-normalization/blob/master/normalizeStaining.m
  """
  # Setup.
  x = np.asarray(x)
  h, w, c = x.shape
  x = x.reshape(-1, c).astype(np.float64)  # shape (H*W, C)
  
  # Reference stain vectors and stain saturations.  We will normalize all slides
  # to these references.  To create these, grab the stain vectors and stain
  # saturations from a desirable slide.
  ## Values in reference implementation for use with eigendecomposition approach.
  stain_ref = np.array([0.5626, 0.2159, 0.7201, 0.8012, 0.4062, 0.5581]).reshape(3,2)
  max_sat_ref = np.array([1.9705, 1.0308]).reshape(2,1)
  ## Values for use with SVD approach.  These were computed by (1) running the
  ## the eigendecomposition approach to normalize an image, (2) running the
  ## SVD approach on the normalized image, and (3) recording the stain vectors
  ## and max saturations for this (ideal) normalized image.
#   stain_ref = np.array([0.20730702, 0.56170196, 0.80308092, 0.72012455, 0.55864554, 0.4073224]).reshape(3,2)
#   max_sat_ref = np.array([0.99818645, 1.96029115]).reshape(2,1)
  
  # Convert RGB to OD.
  OD = -np.log((x+1)/light_intensity)  # shape (H*W, C)
#   OD = -np.log(x/255 + 1e-8)
  
  # Remove data with OD intensity less than beta.
  # I.e. remove transparent pixels.
  # Note: This needs to be checked per channel, rather than
  # taking an average over all channels for a given pixel.
  #OD_thresh = OD[np.logical_not(np.any(OD < beta, 1)), :]
  OD_thresh = OD[np.all(OD >= beta, 1), :]  # shape (K, C)
  
  # Calculate eigenvectors.
  eigvals, eigvecs = np.linalg.eig(np.cov(OD_thresh.T))  # np.cov results in inf/nans
#   U, s, V = np.linalg.svd(OD_thresh, full_matrices=False)
  
  # Extract two largest eigenvectors.
  # Note: We swap the sign of the eigvecs here to be consistent
  # with other implementations.  Both +/- eigvecs are valid, with
  # the same eigenvalue, so this is okay.
  top_eigvecs = eigvecs[:, np.argsort(eigvals)[-2:]] * -1
#   top_eigvecs = V[0:2, :].T * -1  # shape (C, 2)
  
  # Project thresholded optical density values onto plane spanned by
  # 2 largest eigenvectors.
  proj = np.dot(OD_thresh, top_eigvecs)  # shape (K, 2)
  
  # Calculate angle of each point wrt the first plane direction.
  # Note: the parameters are `np.arctan2(y, x)`
  angles = np.arctan2(proj[:, 1], proj[:, 0])  # shape (K,)
  
  # Find robust extremes (a and 100-a percentiles) of the angle.
  min_angle = np.percentile(angles, alpha)
  max_angle = np.percentile(angles, 100-alpha)
  
  # Convert min/max vectors (extremes) back to OD space.
#   extreme_angles = np.array(
#     [np.cos(min_angle), np.cos(max_angle), np.sin(min_angle), np.sin(max_angle)]
#   ).reshape(2,2)
#   stains = np.dot(top_eigvecs, extreme_angles)  # shape (C, 2)
  min_vec = np.dot(top_eigvecs, np.array([np.cos(min_angle), np.sin(min_angle)]).reshape(2,1))
  max_vec = np.dot(top_eigvecs, np.array([np.cos(max_angle), np.sin(max_angle)]).reshape(2,1))
  
  # Merge vectors with hematoxylin first, and eosin second, as a heuristic.
  if min_vec[0] > max_vec[0]:
    stains = np.hstack((min_vec, max_vec))
  else:
    stains = np.hstack((max_vec, min_vec))

  # Calculate saturations of each stain.
  # Note: Here, we solve
  #    OD = VS
  #     S = V^{-1}OD
  # where `OD` is the matrix of optical density values of our image,
  # `V` is the matrix of stain vectors, and `S` is the matrix of stain
  # saturations.  Since this is an overdetermined system, we use the
  # least squares solver, rather than a direct solve.
  sats, _, _, _ = np.linalg.lstsq(stains, OD.T)
  
  # Normalize stain saturations.
  max_sat = np.percentile(sats, 99, axis=1, keepdims=True)
  sats = sats / max_sat * max_sat_ref
  
  # Recreate image.
  # Note: If the image is immediately converted to uint8 with `.astype(np.uint8)`, it will
  # not return the correct values due to the initital values being outside of [0,255].
  # To fix this, we round to the nearest integer, and then clip to [0,255], which is the
  # same behavior as Matlab.
  x_norm = np.exp(np.dot(-stain_ref, sats)) * light_intensity #- 1
#   x_norm = np.exp(np.dot(-stain_ref, sats)) * 255 - 1e-8
  x_norm = np.clip(np.round(x_norm), 0, 255).astype(np.uint8)
  x_norm = x_norm.T.reshape(h,w,c)
  
  # Debug.
#   print("OD shape: ", OD.shape)
#   print("OD_thresh shape: ", OD_thresh.shape)
#   print("eigvals: ", eigvals)
#   print("sorted eigvals: ", np.argsort(eigvals))
#   print("top_eigvecs shape: ", top_eigvecs.shape)
#   print("top_eigvecs: ", top_eigvecs)
#   print("top 2 eigval indices: ", np.argsort(eigvals)[-2:])
#   print("proj shape: ", proj.shape)
#   print("proj mean: ", np.mean(proj, axis=0))
#   print("angles shape: ", angles.shape)
#   print("angles mean: ", np.mean(angles))
#   print("min/max angles: ", min_angle, max_angle)
#   print("min_vec shape: ", min_vec.shape)
#   print("min_vec mean: ", np.mean(min_vec))
#   print("max_vec mean: ", np.mean(max_vec))
#   print("stains shape: ", stains.shape)
#   print("stains: ", stains)
#   print("sats shape: ", sats.shape)
#   print("sats mean: ", np.mean(sats, axis=1))
#   print("max_sat shape: ", max_sat.shape)
#   print("max_sat: ", max_sat)
#   print("x_norm shape: ", x_norm.shape)
#   print("x_norm mean: ", np.mean(x_norm, axis=(0,1)))
#   print("x_norm min: ", np.min(x_norm, axis=(0,1)))
#   print("x_norm max: ", np.max(x_norm, axis=(0,1)))
#   print(x_norm.dtype)
#   print()
# #   x = x.reshape(h,w,c).astype(np.uint8)
  
  return x_norm

Compute image channel means


In [ ]:
# tr_means = input_data.compute_channel_means(train_df.rdd, channels, size)
# val_means = input_data.compute_channel_means(val_df.rdd, channels, size)
# print(tr_means.shape)
# print(tr_means, val_means)
# # Train: [ 194.27633667  145.3067627   181.27861023]
# # Val: [ 192.92971802  142.83534241  180.18870544]

In [ ]:
def array_to_img(x, channels, size):
  x = x.reshape((channels,size,size)).transpose((1,2,0))  # shape (H,W,C)
  img = Image.fromarray(x.astype(np.uint8), 'RGB')
  return img

def img_to_array(img):
  x = np.asarray(img).astype(np.float64)  # shape (H,W,C)
  x = x.transpose(2,0,1).ravel()  # shape (C*H*W)
  return x

In [ ]:
def filter_empty(row, beta=0.15, light_intensity=240):
  x = row.sample.values
#   x = array_to_img(x, channels, size)
  x = x.reshape((channels,size,size)).transpose((1,2,0))  # shape (H,W,C)
  h, w, c = x.shape
  x = x.reshape(-1, c)  # shape (H*W, C)
  OD = -np.log((x+1)/light_intensity)  # shape (H*W, C)
  # Remove data with OD intensity less than beta.
  # I.e. remove transparent pixels.
  OD_thresh = OD[np.all(OD >= beta, 1), :]
  return OD_thresh.size > 2*c

In [ ]:
# Filter ~empty samples.
train_rdd = train_df.rdd.filter(filter_empty)
val_rdd = val_df.rdd.filter(filter_empty)

In [ ]:
# Sanity checks

# first = train_df.first()
# s = first.sample.values
# i = array_to_img(s, channels, size)
# s2 = img_to_array(i)
# assert np.allclose(s, s2)

# def assert_finite(row):
#   x = row.sample.values
#   x = x.reshape((channels,size,size)).transpose((1,2,0)) 
#   h, w, c = x.shape
#   x = x.reshape(-1, c).astype(np.float64)
#   OD = -np.log((x+1)/240)
#   OD_thresh = OD[np.all(OD >= 0.15, 1), :]
#   assert np.all(np.isfinite(OD_thresh.T))
# train_df.rdd.foreach(assert_finite)

In [ ]:
def compute_channel_means(rdd, channels, size):
  """Compute the means of each color channel across the dataset."""
  def helper(x):
    x = x.sample.values
#     x = array_to_img(x, channels, size)
    x = x.reshape((channels,size,size)).transpose((1,2,0))  # shape (H,W,C)
    x = normalize_staining(x)
    x = np.asarray(x).astype(np.float64)  # shape (H,W,C)
    mu = np.mean(x, axis=(0,1))
    return mu

  means = rdd.map(helper).collect()
  means = np.array(means)
  means = np.mean(means, axis=0)
  return means

In [ ]:
tr_means = compute_channel_means(train_rdd, channels, size)
val_means = compute_channel_means(val_rdd, channels, size)
print(tr_means.shape)
print(tr_means, val_means)
# Means: [194.27633667  145.3067627  181.27861023]
# Means with norm: train [189.54944625  152.73427159  176.89543273] val [187.45282379  150.25695602  175.23754894]
# Means with norm on updated data:
#    [ 177.27269518  136.06809866  165.07305029] [ 176.21991047  134.39199187  163.81433421]
# Means with norm on updated data v3:
#    [ 183.36777842  138.81743141  166.07406199] [ 182.41870536  137.15523608  164.81227273]

Save every image as a JPEG


In [ ]:
def helper(row, channels, size, save_dir):
  tumor_score = row.tumor_score
  sample = row.sample.values
#   img = array_to_img(sample, channels, size)
  x = sample.reshape((channels,size,size)).transpose((1,2,0))  # shape (H,W,C)
  x = normalize_staining(x)
  img = Image.fromarray(x.astype(np.uint8), 'RGB')
  filename = '{index}_{slide_num}_{hash}.jpeg'.format(
      index=row["__INDEX"], slide_num=row.slide_num, hash=np.random.randint(1e4))
  class_dir = os.path.join(save_dir, str(tumor_score))
  path = os.path.join(class_dir, filename)
  img.save(path)

In [ ]:
tr_save_dir = "images/{stage}/{p}".format(stage="train_updated_norm_v3", p=p)
val_save_dir = "images/{stage}/{p}".format(stage="val_updated_norm_v3", p=val_p)
print(tr_save_dir, val_save_dir)

In [ ]:
%%bash -s "$tr_save_dir" "$val_save_dir"
for i in 1 2 3
do
  sudo mkdir -p $1/$i
  sudo mkdir -p $2/$i
done
sudo chmod 777 -R $1
sudo chmod 777 -R $2

In [ ]:
# Note: Use this if the DataFrame doesn't have an __INDEX column yet.
# train_df = train_df.withColumn("__INDEX", F.monotonically_increasing_id())
# val_df = val_df.withColumn("__INDEX", F.monotonically_increasing_id())

In [ ]:
train_df.rdd.filter(filter_empty).foreach(lambda row: helper(row, channels, size, tr_save_dir))
val_df.rdd.filter(filter_empty).foreach(lambda row: helper(row, channels, size, val_save_dir))


In [ ]:
def show_random_image(save_dir):
  c = np.random.randint(1, 4)
  class_dir = os.path.join(save_dir, str(c))
  files = os.listdir(class_dir)
  i = np.random.randint(0, len(files))
  fname = os.path.join(class_dir, files[i])
  print(fname)
  img = Image.open(fname)
  plt.imshow(img)

In [ ]:
show_random_image(tr_save_dir)