In [ ]:
import openslide as ops
import numpy as np
from PIL import Image
import os
import matplotlib.pyplot as plt
import matplotlib.patches as plp
import matplotlib.image as mpimg
import mpld3
#mpld3.enable_notebook()
import pickle
import boto3
import pandas as pd
from scipy import stats, integrate
import seaborn as sns
import cv2
import math
import heapq
import sys
import scandir
import datetime

In [ ]:
def download_dir(client, resource, amountOfFiles=False, local='../example_images', bucket='tupac.image.bucket'):
    paginator = client.get_paginator('list_objects')
    bad_ids = ['002', '045', '091', '112', '205', '242', '256', '280', '313', '329', '467']
    for result in paginator.paginate(Bucket=bucket, PaginationConfig={'MaxItems': amountOfFiles}):
        if result.get('CommonPrefixes') is not None:
            for subdir in result.get('CommonPrefixes'):
                download_dir(client, resource, subdir.get('Prefix'), amountOfFiles, local)
        if result.get('Contents') is not None:
            for file in result.get('Contents'):
                new_name = file.get('Key')[9:]
                if not new_name[:3] in bad_ids:
                    if not os.path.exists(os.path.dirname(local + os.sep + new_name[:-4] + os.sep + new_name)):
                        os.makedirs(os.path.dirname(local + os.sep + new_name[:-4] + os.sep + new_name))
                    if (file.get('Key')[:9] == 'TUPAC-TR-'):
                        resource.meta.client.download_file(bucket, file.get('Key'), local
                                                       + os.sep + new_name[:-4] + os.sep + new_name)
                        print('Downloaded file: {0}'.format(new_name))

In [ ]:
print os.listdir('../example_images')
client = boto3.client('s3')
resource = boto3.resource('s3')
#download_dir(client, resource, 10)

In [ ]:
# The eventual patch size will actually be 512, but with overlap of 256
def sumPixelsInPatch(image, stride=256):
    height, width, channels = image.shape
    new_height = int(math.ceil(float(height)/float(stride)))
    new_width = int(math.ceil(float(width)/float(stride)))
    new_image = np.ndarray(shape = (new_height, new_width, channels), dtype = int)
    for i in range(0, new_height):
        for j in range(0, new_width):
            idx = min(stride*i, height-stride)
            idy = min(stride*j, width-stride)
            new_image[i,j,:] = np.sum(image[idx:idx+stride, idy:idy+stride,:])
    return new_image

In [ ]:
# Class to hold the patch coordinates and density
class Patch(object):
    def __init__(self, x=0, y=0, density=0):
        self.x = x
        self.y = y
        self.density = density
    def __gt__(self, patch2):
        return self.density > patch2.density
    def __str__(self):
        return str([(self.x, self.y), self.density])
    def __repr__(self):
        return self.__str__()

In [ ]:
# Returns the minimum Manhattan distance between all items in the heap and the coord. Returns smallest distance and the corresponding element.
def mdist(patches, new_patch):
    min_dist = sys.maxint
    for i in range(0, len(patches)):
        patch = patches[i]
        dist = abs(patch.x-new_patch.x)+abs(patch.y-new_patch.y)
        if dist < min_dist:
            min_dist = dist
            closest_patch_id = i
    return min_dist, closest_patch_id

In [ ]:
# Store the densest patches in a min heap, check that they're at least a dist_th Manhattan distance away from all others
# Now just sorting, because heapq doesn't offer replace functionality. Should probably implement my own min heap.
def findKDensestPatches(image, k=10, dist_th = 4):
    height, width, channels = image.shape
    p = []
    pad_h = height/20 #padding to ignore patches close to borders (they often have shadows on them)
    pad_w = width/20 
    for i in range(pad_h, height-1-pad_h):
        for j in range(pad_w, width-1-pad_w):
            density = np.sum(image[i:i+1,j:j+1,:])
            patch =  Patch(i, j, density)
            if len(p)<k: # Populate the list with the first 10 patches
                p.append(patch)
                p.sort()
            elif p[0].density < density: # This patch's density is higher than the min density in the list
                dist, closest_patch_id = mdist(p, patch)
                if dist < dist_th and p[closest_patch_id].density < density: # Replace the closest patch with this patch
                    p[closest_patch_id]=patch
                    p.sort()
                elif dist > dist_th: # Push this patch and pop the one with min density
                    p[0]=patch
                    p.sort()
    return p

In [ ]:
def getROICoords(densestPatches, stride=256):
    coords = []
    for patch in densestPatches:
        coords.append([patch.x*stride, patch.y*stride])
    return coords

In [ ]:
%matplotlib inline
def plotIntermediaryResults(sum_arr, wsi_med, selected_region):
    fig = plt.figure()
    fig.set_size_inches(20,20)
    a=fig.add_subplot(1,3,1)
    imgplot = plt.imshow(sum_arr)
    a.set_title('Sum Array')
    a=fig.add_subplot(1,3,2)
    imgplot = plt.imshow(wsi_med)
    a.set_title('WSI')
    a=fig.add_subplot(1,3,3)
    imgplot = plt.imshow(selected_region)
    a.set_title('Selected region')

In [ ]:
def plotROIs(slide, coords):
    # Create figure and axes
    fig,ax = plt.subplots(1)
    fig.set_size_inches(10,10)
    # Display the image
    wsi_med = slide.read_region((0, 0), 2, slide.level_dimensions[2])
    ax.imshow(wsi_med)
    height = int(math.ceil(float(1024)/float(slide.level_downsamples[2])))
    #print height, arr_med.shape

    for coord in coords:
        #print coord
        # Create a Rectangle patch
        x  = int(math.ceil(float(coord[1])/float(slide.level_downsamples[2])))
        y = int(math.ceil(float(coord[0])/float(slide.level_downsamples[2])))
        rect = plp.Rectangle((x, y),height,height,linewidth=2,edgecolor='g',facecolor='g', fill=False)
        # Add the patch to the Axes
        ax.add_patch(rect)

    plt.show()

In [ ]:
def preprocess(from_dir, debug=False):
    print('-'*90)
    print('Preprocessing slides from {0} at {1} ...'.format(from_dir, datetime.datetime.utcnow()))
    print('-'*90)
    from_dir = from_dir if from_dir.endswith('/') else from_dir + '/'
    for subdir, _, files in scandir.walk(from_dir):
        subdir = subdir.replace('\\', '/')  # windows path fix
        subdir_split = subdir.split('/')
        for f in files:
            study_id = subdir_split[3]
            print('*'*5),
            print('Preprocessing slide {0} at {1}'.format(study_id, datetime.datetime.utcnow()))
            svs_path = os.path.join(subdir, f)
            if not svs_path.endswith('.svs'):
                print('{0} is not a svs file!'.format(svs_path))
                continue
            slide = ops.OpenSlide(svs_path)
            print('*'*10),
            print('Loaded slide {0} at {1}'.format(study_id, datetime.datetime.utcnow()))
            level = len(slide.level_dimensions)-1
            wsi = np.asarray(slide.read_region((0, 0), level,
                              slide.level_dimensions[level]))
            print('*'*10),
            print('Loaded region {0} at {1}'.format(study_id, datetime.datetime.utcnow()))
            binary_arr = np.asarray((wsi[:, :, 0:3] < 200).astype(int)) # Get rid of opacity, threshold at 200
            sum_arr = sumPixelsInPatch(binary_arr, int(math.ceil(
                        float(512)/float(slide.level_downsamples[level]))))
            densestPatches = findKDensestPatches(sum_arr, 10)
            coords = getROICoords(densestPatches, 512)
            if(debug): plotROIs(slide, coords)
            for i in range(0, len(coords)):
                coord = coords[i]
                selected_region = slide.read_region((coord[1], coord[0]), 0, (1024, 1024))
                if(debug):
                    print((coord[1], coord[0]), 0, (1024, 1024))
                    wsi_med = np.asarray(slide.read_region((0, 0), 2, slide.level_dimensions[2]))
                    plotIntermediaryResults(sum_arr, wsi_med, selected_region)
                    break
                else:
                    selected_region.save(svs_path[:-4]+'_'+str(i)+'.tiff')
                    print('*'*15),
                    print('Region saved as {0}'.format(svs_path[:-4]+'_'+str(i)+'.tiff'))
            print('*'*5),
            print('Preprocessed slide {0}'.format(study_id))
            if(debug): break
        if(debug): break
    print('-'*90)
    print('All slides in {0} have been preprocessed.'.format(from_dir))
    print('-'*90)

In [ ]:
#preprocess('../example_images/train/025', True)
preprocess('../example_images/train')
preprocess('../example_images/val')
preprocess('../example_images/test')

In [ ]: