Assignment 1 BMI260

Segmnetation of lung CT scans

Author: Yusuf Roohani (yroohani@stanford)

This is the primary test script

Read in the data


In [1]:
import dicom
import numpy as np
import cv2

import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
%matplotlib inline

from skimage.filters import threshold_otsu
from skimage.measure import label, regionprops
from skimage.color import label2rgb

from os import listdir
from os.path import join

from mpl_toolkits.mplot3d import Axes3D

In [9]:
# Read in filenames
#path_dcms = '../kaggle_data...'

# Print out all the .dcm's
#list_dcms = listdir(path_dcms)

# let's just read in the first one.
name_dcm = '../Sample_Imge.png'#list_dcms[10]
img=cv2.imread(name_dcm)
#path_dcm = join(path_dcms, name_dcm)
#dcm = dicom.read_file(path_dcm)

# Let's now display the image.
#img = dcm.pixel_array
plt.rcParams['figure.figsize'] = (10, 10)
plt.imshow(img)


Out[9]:
<matplotlib.image.AxesImage at 0x11b155cd0>

In [12]:
import mahotas
a = mahotas.features.haralick(img)

In [17]:
a[2].shape


Out[17]:
(13,)

In [4]:
dcm.pixel_array_houns = (dcm.RescaleSlope * dcm.pixel_array) + dcm.RescaleIntercept
dcm.pixel_array_houns[dcm.pixel_array_houns > -400] = -400

In [5]:
# Let's now display the image.
img = dcm.pixel_array_houns
plt.rcParams['figure.figsize'] = (8, 8)
plt.imshow(img)


Out[5]:
<matplotlib.image.AxesImage at 0x1191da590>

Sort the data into an ordered 3D volume

For a given patient, order the slices and store them in a 3D tensor

Normalize the image values to lie between 0 and 1 for easier processing


In [6]:
# Create variables to hold a 3D volume
width = 512; height = 512
depth = len(list_dcms)
lung_vol = np.zeros([width,height,depth])
lung_vol_houns = np.zeros([width,height,depth])

# Create a dictionary to link slice location with image
dcm_slices = {}
dcm_slices_houns = {}

In [7]:
# Stack up all the slices in order
for d in range(depth):
    dcm_file = dicom.read_file(path_dcms + list_dcms[d])
    
    # This is to club bone and other organs into a single broad non-lung class
    dcm_file.pixel_array_houns = (dcm_file.RescaleSlope * dcm_file.pixel_array) + dcm_file.RescaleIntercept
    dcm_file.pixel_array_houns[dcm_file.pixel_array_houns > -450] = -450
    dcm_slices_houns[dcm_file.SliceLocation] = dcm_file.pixel_array_houns
    
    dcm_slices[dcm_file.SliceLocation] = dcm_file.pixel_array
    
sorted_slices = np.sort(dcm_slices.keys())
sorted_slices_houns = np.sort(dcm_slices_houns.keys())

# Now let's make a single volume
for idx,s in enumerate(sorted_slices):
    lung_vol[:,:,idx] = dcm_slices[s]
    lung_vol_houns[:,:,idx] = dcm_slices_houns[s]

In [13]:
# let's see what these look like in order
fig,ax = plt.subplots(10,10, figsize = (15,15))
d = 0
for i in range(10):
    for j in range(10):
        img = lung_vol[:,:,d]
        ax[i,j].imshow(img)
        ax[i,j].axis('off')
        d = d+1
        
# Good, the image sequence makes sense



In [28]:
# let's see what these look like in order after clearing out higher hounsfield values
fig,ax = plt.subplots(10,10, figsize = (15,15))
d = 0
for i in range(10):
    for j in range(10):
        img = lung_vol_houns[:,:,d]
        ax[i,j].imshow(img)
        ax[i,j].axis('off')
        d = d+1
        
# Good, the image sequence makes sense



In [9]:
# Let's max min normalize these images
#lung_vol_norm = (lung_vol - np.min(lung_vol))/(np.max(lung_vol) - np.min(lung_vol))
lung_vol_norm = (lung_vol_houns - np.min(lung_vol_houns))/(np.max(lung_vol_houns) - np.min(lung_vol_houns))

# Let's check if the normalization worked
print np.min(lung_vol_norm),np.max(lung_vol_norm)


0.0 1.0

In [8]:
# Good, now that everything is in between 0 and 1, let's also check if the images look the same
fig,ax = plt.subplots(10,10, figsize = (15,15))
d = 0
for i in range(10):
    for j in range(10):
        img = lung_vol_norm[:,:,d]
        ax[i,j].imshow(img)
        ax[i,j].axis('off')
        d = d+1
        
# Yup!


Segmentation

Now, we'll try to separate out the lung from the rest of the slide using

  • Otsu thresholding

In [10]:
# Visualize the threshold for the 3d image
flat_volume = lung_vol_norm.flatten()
plt.figure(figsize = (7,7))
_ = plt.hist(lung_vol_norm.flatten(), bins = 100)



In [11]:
# There appears to be a clear bimodal distribution, as well as a lot of blank space/zero values
# We can remove those negligible values and perform an Otsu thresholding to get a good split

# By choosing an extremely low cut off (i.e 0.05), we can justify this arbitrary selection of a noise threshold
# since it is very unlikely to affect the signal but would produce a significantly more realistic
# prediction of an Otsu treshold
noise_thresh = 0.05
flat_volume_clean = flat_volume[flat_volume > noise_thresh]

In [12]:
# Here's the Otsu threshold for the cleaned set and it looks like a great estimate
plt.figure(figsize = (7,7))
_ = plt.hist(lung_vol_norm.flatten(), bins = 100)
otsu_thresh = threshold_otsu(flat_volume_clean)
plt.axvline(otsu_thresh, color='r')


Out[12]:
<matplotlib.lines.Line2D at 0x11b23f8d0>

In [13]:
# Use the previous calculated Otsu threshold to create a binary image
threshed_image = np.zeros(lung_vol_norm[:,:,100].shape)
threshed_image[lung_vol_norm[:,:,100]>= otsu_thresh]=1

In [14]:
# Let's take a look.
plt.figure(figsize = (7,7))
plt.imshow(threshed_image)

# Looks like a good segmentation of the lungs. Although the outer regions are also getting classfied
# under a similar class, but there's a clear separation between the inner and the outer class
# so it should be straightforward to separate


Out[14]:
<matplotlib.image.AxesImage at 0x1a6e91810>

In [26]:
# Let's close up the contours first. We choose a kernel size that's big enough to capture one lung
# but not too big such that it begins to merge the lungs together, or even merge the lung with the outer region

kernel = np.ones((9,9))
closed = cv2.morphologyEx(threshed_image, cv2.MORPH_OPEN, kernel)
plt.imshow(closed)


Out[26]:
<matplotlib.image.AxesImage at 0x11a003f90>

In [27]:
# Label the different regions using simple connectivity maps
labelled, no_regions = label(closed, background = 1, return_num='TRUE')
print no_regions

# Keep track of the pixels assigned to each region
regions = {}
region_areas = [None]*(no_regions+1)
for r in range(no_regions+1):
    regions[r] = np.where(labelled == r) # Track the pixels themselves
    region_areas[r] = len(regions[r][0]) # As well as the area of the region


2

In [28]:
# These are all the labels that had a presence along the edges
edges = np.concatenate([labelled[1,:], labelled[-1,:], labelled[:,1], labelled[:,-1]])
edge_regions = np.unique(edges)

# Add the background region '0' here too
edge_regions = np.append(edge_regions,0)

In [29]:
# Let's sort the region areas and then remove the edge regions
large_areas = np.argsort(region_areas)[::-1]
inner_areas = [i for i in large_areas if i not in edge_regions]

# Make sure there are enough remaining areas before referencing the list
inner_areas_num = np.min([len(inner_areas),2])
lung_idx = inner_areas[0:inner_areas_num]

In [30]:
# Now create an image with only the selected regions
if len(lung_idx) >= 1:
    labelled[labelled == lung_idx[0]] = 50
if len(lung_idx) == 2:
    labelled[labelled == lung_idx[1]] = 50
labelled[labelled != 50] = 0

In [31]:
plt.imshow(labelled)


Out[31]:
<matplotlib.image.AxesImage at 0x119a0a4d0>

In [18]:
# Now let's functionalize this whole procedure

# This is a function that accepts a lung CT image and segments out the lungs from the background
# It requires a binary treshold value
def lung_segment(image, thresh):
    
    # Use the previous calculated global Otsu threshold to create a binary image
    threshed_image = np.zeros(image.shape)
    threshed_image[image >= thresh]=1
    
    # Let's close up the contours
    kernel = np.ones((4,4))
    closed = cv2.morphologyEx(threshed_image, cv2.MORPH_OPEN, kernel)
    
    # Label the different regions using simple connectivity maps
    labelled, no_regions = label(closed, background = 1, return_num='TRUE')

    # Keep track of the pixels assigned to each region
    regions = {}
    for r in range(no_regions+1):
        regions[r] = np.where(labelled == r) # Track the pixels themselves
        
    # Fine region labels that have a presence along the edges
    edges = np.concatenate([labelled[1,:], labelled[-1,:], labelled[:,1], labelled[:,-1]])
    edge_regions = np.unique(edges)

    # Add the background region '0' here too
    edge_regions = np.append(edge_regions,0)
    
    # Remove the edge regions and background
    select_regions = [i for i in range(no_regions+1) if i not in edge_regions]
    
    # Now create an image with only the selected regions
    for r in select_regions:
        labelled[labelled == r] = 50
    labelled[labelled != 50] = 0
    
    return labelled

Add a volume/area threshold to get rid of noise


In [32]:
# Let's create a volume of the segmented slices
segmented_slices = np.zeros(lung_vol_norm.shape)
for d in range(lung_vol_norm.shape[2]):
        segmented_slices[:,:,d] = lung_segment(lung_vol_norm[:,:,d], otsu_thresh)/50

In [15]:
# Let's visualize all the segmented slices
fig,ax = plt.subplots(12,10, figsize = (15,15))
d = 0
for i in range(11):
    for j in range(10):
        img = segmented_slices[:,:,d]
        ax[i,j].imshow(img)
        ax[i,j].axis('off')
        d = d+1



In [39]:
# It looks like there's still some noise to clean up. Using the volume instead of the areas may be a good approach
labelled_vols, no_regions = label(segmented_slices, return_num = 'TRUE')
print no_regions


98

In [40]:
regions = {}
volume_thresh = 100000
region_vols = np.zeros(no_regions+1)
for r in range(no_regions+1):
    regions[r] = np.where(labelled_vols == r) # Track the pixels themselves
    region_vols[r] = len(regions[r][0]) # As well as the volume of the region

print region_vols


[  3.76221520e+07   7.99131200e+06   6.00691000e+05   1.00000000e+00
   2.86820000e+04   8.36410000e+04   1.00000000e+00   2.07650000e+04
   2.00000000e+00   2.89600000e+03   6.00000000e+00   1.53200000e+04
   3.70000000e+01   1.87600000e+03   9.29700000e+03   3.00000000e+00
   1.34000000e+02   5.89000000e+02   1.54700000e+03   1.28500000e+03
   2.00000000e+00   1.00000000e+00   9.00000000e+00   1.00000000e+00
   1.30000000e+01   1.20000000e+01   2.00000000e+00   1.50000000e+01
   3.00000000e+00   1.20000000e+01   3.00000000e+00   5.00000000e+00
   5.00000000e+00   2.00000000e+00   1.00000000e+00   3.00000000e+00
   8.00000000e+00   5.00000000e+00   3.00000000e+00   1.00000000e+00
   2.00000000e+00   2.00000000e+01   6.00000000e+00   7.00000000e+00
   2.00000000e+00   1.00000000e+00   3.00000000e+00   2.51000000e+02
   3.10000000e+02   1.04000000e+02   7.45000000e+02   9.00000000e+01
   4.36000000e+02   2.94000000e+02   3.20000000e+01   1.83000000e+02
   1.68000000e+02   4.71000000e+02   6.07000000e+02   1.05000000e+02
   4.53500000e+03   1.33300000e+03   9.70000000e+01   1.24000000e+02
   8.60000000e+01   3.30000000e+01   9.30000000e+01   4.97000000e+02
   1.01000000e+03   2.42000000e+02   1.98000000e+02   9.60000000e+01
   2.43000000e+03   8.10000000e+01   9.69000000e+02   2.15800000e+03
   2.00000000e+00   9.00000000e+00   1.90000000e+01   2.50000000e+02
   1.31000000e+02   1.05000000e+02   9.10000000e+01   4.50000000e+01
   6.00000000e+00   1.27000000e+02   1.25000000e+02   1.25000000e+02
   8.00000000e+00   4.00000000e+00   2.20000000e+02   1.00000000e+00
   1.00000000e+00   1.00000000e+00   8.00000000e+00   1.50000000e+01
   1.00000000e+00   3.00000000e+01   2.00000000e+00]

In [47]:
select_vols = np.where(region_vols > volume_thresh)[0][1:] # Remove the first element becuase it's background
select_vols = [1]
for r in select_vols:
    print r
    print sum(labelled_vols[labelled_vols == r])
    labelled_vols[labelled_vols == r] = -1
labelled_vols[labelled_vols != -1] = 0
labelled_vols[labelled_vols == -1] = 1


1
7991312

3D visualizations


In [19]:
# Create a 3D visualization
fig = plt.figure()
ax = fig.gca(projection='3d')
ax.set_zlim([0,120])
X = np.tile(np.arange(512),[512,1])
Y = np.transpose(np.tile(np.arange(512),[512,1]))

for i in range(labelled_vols.shape[2]):
    ax.contour(X, Y, labelled_vols[:,:,i]*(i+1), levels = [ i])



In [48]:
# Create a 3D visualization
fig = plt.figure()
ax = fig.gca(projection='3d')
ax.set_zlim([0,120])
X = np.tile(np.arange(512),[512,1])
Y = np.transpose(np.tile(np.arange(512),[512,1]))

for i in range(labelled_vols.shape[2]):
    ax.contour(X, Y, labelled_vols[:,:,i]*(i+1), levels = [ i])


Marching cubes for 3D visualization


In [97]:
from skimage.measure import marching_cubes_lewiner
from mpl_toolkits.mplot3d.art3d import Poly3DCollection

# Use marching cubes to obtain the surface mesh of these ellipsoids
verts, faces, normals, values = marching_cubes_lewiner(labelled_vols, 0, step_size=1)

In [98]:
faces.shape


Out[98]:
(2050050, 3)

In [99]:
# Display resulting triangular mesh using Matplotlib. This can also be done
# with mayavi (see skimage.measure.marching_cubes docstring).
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection='3d')

# Fancy indexing: `verts[faces]` to generate a collection of triangles
mesh = Poly3DCollection(verts[faces])
mesh.set_edgecolor('k')
ax.add_collection3d(mesh)

ax.set_xlabel("x-axis")
ax.set_ylabel("y-axis")
ax.set_zlabel("z-axis")

#ax.set_xlim(0, 24)  # a = 6 (times two for 2nd ellipsoid)
#ax.set_ylim(0, 20)  # b = 10
#ax.set_zlim(0, 32)  # c = 16

plt.tight_layout()
plt.show()


---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-99-34a170db2149> in <module>()
     18 
     19 plt.tight_layout()
---> 20 plt.show()

//anaconda/lib/python2.7/site-packages/matplotlib/pyplot.pyc in show(*args, **kw)
    251     """
    252     global _show
--> 253     return _show(*args, **kw)
    254 
    255 

//anaconda/lib/python2.7/site-packages/ipykernel/pylab/backend_inline.pyc in show(close, block)
     34     try:
     35         for figure_manager in Gcf.get_all_fig_managers():
---> 36             display(figure_manager.canvas.figure)
     37     finally:
     38         show._to_draw = []

//anaconda/lib/python2.7/site-packages/IPython/core/display.pyc in display(*objs, **kwargs)
    156             publish_display_data(data=obj, metadata=metadata)
    157         else:
--> 158             format_dict, md_dict = format(obj, include=include, exclude=exclude)
    159             if not format_dict:
    160                 # nothing to display (e.g. _ipython_display_ took over)

//anaconda/lib/python2.7/site-packages/IPython/core/formatters.pyc in format(self, obj, include, exclude)
    175             md = None
    176             try:
--> 177                 data = formatter(obj)
    178             except:
    179                 # FIXME: log the exception

<decorator-gen-9> in __call__(self, obj)

//anaconda/lib/python2.7/site-packages/IPython/core/formatters.pyc in catch_format_error(method, self, *args, **kwargs)
    220     """show traceback on failed format call"""
    221     try:
--> 222         r = method(self, *args, **kwargs)
    223     except NotImplementedError:
    224         # don't warn on NotImplementedErrors

//anaconda/lib/python2.7/site-packages/IPython/core/formatters.pyc in __call__(self, obj)
    337                 pass
    338             else:
--> 339                 return printer(obj)
    340             # Finally look for special method names
    341             method = _safe_get_formatter_method(obj, self.print_method)

//anaconda/lib/python2.7/site-packages/IPython/core/pylabtools.pyc in <lambda>(fig)
    226 
    227     if 'png' in formats:
--> 228         png_formatter.for_type(Figure, lambda fig: print_figure(fig, 'png', **kwargs))
    229     if 'retina' in formats or 'png2x' in formats:
    230         png_formatter.for_type(Figure, lambda fig: retina_figure(fig, **kwargs))

//anaconda/lib/python2.7/site-packages/IPython/core/pylabtools.pyc in print_figure(fig, fmt, bbox_inches, **kwargs)
    117 
    118     bytes_io = BytesIO()
--> 119     fig.canvas.print_figure(bytes_io, **kw)
    120     data = bytes_io.getvalue()
    121     if fmt == 'svg':

//anaconda/lib/python2.7/site-packages/matplotlib/backend_bases.pyc in print_figure(self, filename, dpi, facecolor, edgecolor, orientation, format, **kwargs)
   2242                 orientation=orientation,
   2243                 bbox_inches_restore=_bbox_inches_restore,
-> 2244                 **kwargs)
   2245         finally:
   2246             if bbox_inches and restore_bbox:

//anaconda/lib/python2.7/site-packages/matplotlib/backends/backend_agg.pyc in print_png(self, filename_or_obj, *args, **kwargs)
    543 
    544     def print_png(self, filename_or_obj, *args, **kwargs):
--> 545         FigureCanvasAgg.draw(self)
    546         renderer = self.get_renderer()
    547         original_dpi = renderer.dpi

//anaconda/lib/python2.7/site-packages/matplotlib/backends/backend_agg.pyc in draw(self)
    462 
    463         try:
--> 464             self.figure.draw(self.renderer)
    465         finally:
    466             RendererAgg.lock.release()

//anaconda/lib/python2.7/site-packages/matplotlib/artist.pyc in draw_wrapper(artist, renderer, *args, **kwargs)
     61     def draw_wrapper(artist, renderer, *args, **kwargs):
     62         before(artist, renderer)
---> 63         draw(artist, renderer, *args, **kwargs)
     64         after(artist, renderer)
     65 

//anaconda/lib/python2.7/site-packages/matplotlib/figure.pyc in draw(self, renderer)
   1141 
   1142             mimage._draw_list_compositing_images(
-> 1143                 renderer, self, dsu, self.suppressComposite)
   1144 
   1145             renderer.close_group('figure')

//anaconda/lib/python2.7/site-packages/matplotlib/image.pyc in _draw_list_compositing_images(renderer, parent, dsu, suppress_composite)
    137     if not_composite or not has_images:
    138         for zorder, a in dsu:
--> 139             a.draw(renderer)
    140     else:
    141         # Composite any adjacent images together

//anaconda/lib/python2.7/site-packages/mpl_toolkits/mplot3d/axes3d.pyc in draw(self, renderer)
    269         # Calculate projection of collections and zorder them
    270         zlist = [(col.do_3d_projection(renderer), col) \
--> 271                  for col in self.collections]
    272         zlist.sort(key=itemgetter(0), reverse=True)
    273         for i, (z, col) in enumerate(zlist):

//anaconda/lib/python2.7/site-packages/mpl_toolkits/mplot3d/art3d.pyc in do_3d_projection(self, renderer)
    640             PolyCollection.set_verts_and_codes(self, segments_2d, codes)
    641         else:
--> 642             PolyCollection.set_verts(self, segments_2d)
    643 
    644         self._facecolors2d = [fc for z, s, fc, ec, idx in z_segments_2d]

//anaconda/lib/python2.7/site-packages/matplotlib/collections.pyc in set_verts(self, verts, closed)
    933             # This is much faster than having Path do it one at a time.
    934         if closed:
--> 935             self._paths = []
    936             for xy in verts:
    937                 if len(xy):

KeyboardInterrupt: 

Saving images for Fiji visualization


In [173]:
from os import mkdir
from os.path import join,isdir
import scipy.misc

# Path of the saved images.  CHANGE THIS FOR YOUR COMPUTER.
path_save = './lung_3d/'
if not isdir(path_save):
    mkdir(path_save)

# Loop over all images to save.
for i in range(labelled_vols.shape[2]):
    path_save_img = path_save + str(i)+'.png'
    scipy.misc.imsave(path_save_img, labelled_vols[:,:,i])