Blood Vessel Segmentation using Maximum Response of Matched Filter Kernels Convolution


In [163]:
%pylab inline


Populating the interactive namespace from numpy and matplotlib

Retinal Image Class


In [164]:
# Class to instance each retinal image read from DRIVE and process over it
from scipy import ndimage
import mahotas,time
    
class RetinalImage:
    # Constructor
    def __init__(self,img_name,img_original,img_mask,manual_segm_1,manual_segm_2):
        self.__img_name = img_name
        
        self.__t_start = None
        self.__t_stop = None
        
        self.__images = {}
        self.__statistics = {}
        
        self.__images["original"] = averageFilter(img_original,5)
        self.__images["mask"] = img_mask
        self.__images["ground_truth_1"] = manual_segm_1
        self.__images["ground_truth_2"] = manual_segm_2
    
    # Method from RetinalImage class that is responsible for detection of the vessels
    def detectVessels(self,kernels):
        # Image must be within signed values so that the convolution of
        # the negative values of the kernel will produce maximum response
        img_original_signed = self.__images["original"] - 128.0
        
        # Benchmarking - Start
        self.__t_start = time.time()
        
        operational_image = zeros(self.__images["original"].shape, float)
        for k in kernels:
            operational_image = maximum(operational_image, ndimage.convolve(img_original_signed,k,mode="reflect"))
        
        self.__images["matched"] = normalizeTo(255, operational_image).astype(uint8)
        self.__images["matched"][nonzero(self.__images["mask"] == 0)] = 0
    
    # Method from RetinalImage class that is responsible for segmentation of the vessels
    def segmentVessels(self,th):
        self.__images["segmented"] = zeros(self.__images["matched"].shape, uint8)
        if th == None:
            th = mahotas.otsu(self.__images["matched"][nonzero(self.__images["mask"] != 0)])
        self.__images["segmented"][nonzero(self.__images["matched"] >= th)] = 255
        
        # If Benchmark is running
        if self.__t_start != None:
            # Benchmaking - Stop
            self.__t_stop = time.time()
    
    # Method from RetinalImage class that is responsible for gathering image statistical information
    def gatherMetrics(self):
        # Benchmarking - Gather statistics
        if ((self.__t_start != None) and (self.__t_stop != None)):
            self.__statistics["elapsed_time"] = self.__t_stop-self.__t_start
            self.__t_start = None
            self.__t_stop = None
            
        self.__statistics["ground_truth_1_stats"] = generateStats(self.__images["segmented"],\
                                                                  self.__images["ground_truth_1"],\
                                                                  self.__images["mask"])
        self.__statistics["ground_truth_2_stats"] = None
        if (self.__images["ground_truth_2"] != None):
            self.__statistics["ground_truth_2_stats"] = generateStats(self.__images["segmented"],\
                                                                      self.__images["ground_truth_2"],\
                                                                      self.__images["mask"])
    
    # Fuction that returns tuplet containing the image name, all images regarding the process and statistical data
    def getImageData(self):
        return (self.__img_name, self.__images, self.__statistics)

Retinal Image Mask Calculation


In [165]:
# Function for calculation of the retinal image mask
import mahotas
from scipy.ndimage.filters import gaussian_gradient_magnitude
from scipy.ndimage.morphology import binary_dilation,binary_erosion,binary_fill_holes

def getMask(image2D, sigma):
    img_original_signed = normalizeTo(255,image2D) - 128.0
    img_gauss = normalizeTo(255,gaussian_gradient_magnitude(img_original_signed,sigma)).astype(uint8)
    img_bin = zeros(image2D.shape,float)
    img_dil = zeros(image2D.shape,float)
    
    img_bin[nonzero(img_gauss > mahotas.otsu(img_gauss))] = 255.0
    img_dil[binary_dilation(img_bin,None,3)] = 255.0
    
    img_lbl,_ = mahotas.label((img_dil-img_bin).astype(uint8))
    sizes = mahotas.labeled.labeled_size(img_lbl)
    img_lbl_filtered = mahotas.labeled.remove_regions(img_lbl,where(sizes != sizes[2]))
    img_lbl_filtered[nonzero(img_lbl_filtered)] = 255
    img_lbl_filled = binary_fill_holes(img_lbl_filtered)
    
    return img_lbl_filled.astype(uint8)

Matched Filter Kernels Calculation


In [166]:
def isN((u,v),(sigma,L)):
    return (fabs(u)<=(3*sigma)) and (fabs(v)<=(L/2.0))

def rot(angle_deg):
    angle_rad = radians(angle_deg)
    return array([[cos(angle_rad),-sin(angle_rad)],
                  [sin(angle_rad),cos(angle_rad)]])

# Function for generating matched filter kernels
def generateMatchedFilterKernels(sigma,L,angle_step):
    # Define kernel size
    kernel_size = 2+(1+(2*int((3*round(sigma)) if ((3*round(sigma))>=(L/2)) else (L/2))))
    kernel_half_size = kernel_size/2
    
    # Generate matched filter kernels
    matched_filter_kernels = []
    current_angle = -90 + angle_step
    while (current_angle <= 90):
        kernel = zeros((kernel_size,kernel_size), float)
        rotation_matrix = rot(current_angle)
        A = 0
        for x in range(-kernel_half_size,kernel_half_size+1):
            for y in range(-kernel_half_size,kernel_half_size+1):
                p = rotation_matrix.T.dot(array([x,y]))
                if isN(p,(sigma,L)):
                    kernel[kernel_half_size+x,kernel_half_size+y] = -exp((-p[0]**2)/(2*(sigma**2)))
                    A += 1
        m = kernel.sum()/A
        kernel[nonzero(kernel != 0)] = ((kernel[nonzero(kernel != 0)] - m) * 10).round()
        matched_filter_kernels.append(kernel)
        
        current_angle += angle_step
        
    return matched_filter_kernels

Secondary Functions


In [167]:
# Support functions
def averageFilter(input_image,m):
    kernel = ones((m,m), float)
    kernel = kernel*(1/float(m*m))
    
    return ndimage.convolve(input_image,kernel,mode="reflect").astype(float)

def normalizeTo(val, image2D):
    result = zeros(image2D.shape, float)
    maxim = image2D.max()
    minim = image2D.min()
    result[:,:] = ((image2D[:,:]-minim)/float(maxim-minim))*float(val)
    
    return result.astype(float)

# Compares predicted image to Ground Truth / Gold Standard
# Returns confusion matrix, accuracy, true positive rate and false positive
def generateStats(p_img, t_img, mask):
# - Confusion Matrix:
#   [TP,FN]
#   [FP,TN]
    stats = {}
    TP = count_nonzero(p_img[nonzero(t_img==255)]==255)
    FN = count_nonzero(p_img[nonzero(t_img==255)]==0)
    FP = count_nonzero(p_img[nonzero(t_img==0)]==255)
    TN = count_nonzero(p_img[np.nonzero(t_img==0)]==0) - count_nonzero(mask==0)
    
    stats["confusion_matrix"] = array([[TP,FN],
                                       [FP,TN]])
    stats["ACC"] = (TP+TN)/float(TP+TN+FP+FN)
    stats["TPR"] = TP/float(TP+FN)
    stats["FPR"] = FP/float(FP+TN)
    
    return stats

import sklearn.metrics as metrics
# Calculates arrays of tpr and fpr at several thresholds in order to plot the ROC graph
# Returns both arrays and the ROC AUC
def calculateROC(r,t_img):
    tpr = []
    fpr = []
    for i in range(255,0,-1):
        th = i
        r.segmentVessels(th)
        r_data = r.getImageData()
        m = generateStats(r_data[1]["segmented"], t_img, r_data[1]["mask"])
        tpr.append(m["TPR"])
        fpr.append(m["FPR"])
    tpr.append(1.0)
    fpr.append(1.0)
    roc_auc = metrics.auc(fpr,tpr)
    
    return (fpr,tpr,roc_auc)

class GlobalROC():
    def __init__(self,length=256):
        self.__length = length
        self.__global_roc_tpr = [0.0] * self.__length
        self.__global_roc_fpr = [0.0] * self.__length
        self.__number_of_samples = [0.0] * self.__length
   
    def appendROC(self,fpr,tpr):
        for i in range(len(tpr)):
            self.__global_roc_tpr[i] += tpr[i]
            self.__global_roc_fpr[i] += fpr[i]
            self.__number_of_samples[i] += 1.0
   
    def getGlobalROC(self):
        self.__global_roc_tpr = divide(self.__global_roc_tpr,self.__number_of_samples)
        self.__global_roc_fpr = divide(self.__global_roc_fpr,self.__number_of_samples)
        
        global_roc_auc = metrics.auc(self.__global_roc_fpr,self.__global_roc_tpr)
        
        return (self.__global_roc_fpr,self.__global_roc_tpr,global_roc_auc)

Main Functions


In [170]:
import os
import pandas as pd

# Perform blood vessel segmentation of all images within a folder
# Returns a list of RetinalImage objects filled with statistical metrics of the segmentation
def ProcessSegmentationOfFolder(folder, mask_sigma, kernels):
    segmentation_list = []
    
    for filename in os.listdir(folder + "/images/"):
        sigma = mask_sigma
        if (filename == "15_test.tif"): sigma = 3
        if (filename == "04_test.tif"): sigma = 4.5
        if (filename == "07_test.tif"): sigma = 0.5
        if (filename == "34_training.tif"): sigma = 1.5
        
        m_index = ""
        for s in filename:
            if s.isdigit():
                m_index += s
        img_manual_seg2 = None
        if os.path.exists(folder+"/2nd_manual/"):
            img_manual_seg2 = normalizeTo(255,imread(folder+"/2nd_manual/"+m_index+"_manual2.gif")[:,:,0]).astype(uint8)
        
        # Create RetinalImage instance and process segmentation
        img_original = normalizeTo(255,imread(folder+"/images/"+filename)[:,:,1])
        r = RetinalImage(filename,\
                         img_original,\
                         getMask(img_original,sigma).astype(uint8),\
                         normalizeTo(255,imread(folder+"/1st_manual/"+m_index+"_manual1.gif")).astype(uint8),\
                         img_manual_seg2)
        r.detectVessels(kernels)
        r.segmentVessels(None)
        r.gatherMetrics()
        segmentation_list.append(r)
        
    return segmentation_list

# Displays the results and statistical metrics gathered from the segmentation process
def ProcessInformationPrint(data,kernels):
    m_global_roc = GlobalROC()
    total_accuracy = []
    total_roc_auc = []
    
    total_otsu_tpr = []
    total_otsu_fpr = []
    
    algorithm_time = []
    for (group,retinal_images) in data.items():
        
        # Display Dataset Name
        f = figure(figsize=(15,1))
        ax = f.add_subplot(1,1,1)
        ax.text(0,0,"Dataset: "+group,fontsize=30,fontweight='bold',ha='left',va='center')
        ax.set_axis_off()
        tight_layout()
        
        for r in retinal_images:
            # Gather data
            r_data = r.getImageData()
            
            img_name = r_data[0]
            img_original = r_data[1]["original"]
            img_mask = r_data[1]["mask"]
            img_matched = r_data[1]["matched"]
            img_seg = r_data[1]["segmented"]
            algorithm_time.append(r_data[2]["elapsed_time"])
            
            img_g_t_1 = r_data[1]["ground_truth_1"]
            otsu_acc_1 = r_data[2]["ground_truth_1_stats"]["ACC"]
            otsu_tpr_1 = r_data[2]["ground_truth_1_stats"]["TPR"]
            otsu_fpr_1 = r_data[2]["ground_truth_1_stats"]["FPR"]
            total_otsu_tpr.append(otsu_tpr_1)
            total_otsu_fpr.append(otsu_fpr_1)
            total_accuracy.append(otsu_acc_1)
            
            if (r_data[1]["ground_truth_2"] != None):
                img_g_t_2 = r_data[1]["ground_truth_2"]
                otsu_acc_2 = r_data[2]["ground_truth_2_stats"]["ACC"]
                otsu_tpr_2 = r_data[2]["ground_truth_2_stats"]["TPR"]
                otsu_fpr_2 = r_data[2]["ground_truth_2_stats"]["FPR"]
                total_otsu_tpr.append(otsu_tpr_2)
                total_otsu_fpr.append(otsu_fpr_2)
                total_accuracy.append(otsu_acc_2)
            
            roc_info_1 = calculateROC(r,img_g_t_1)
            total_roc_auc.append(roc_info_1[2])
            m_global_roc.appendROC(roc_info_1[0],roc_info_1[1])
            if (r_data[1]["ground_truth_2"] != None):
                roc_info_2 = calculateROC(r,img_g_t_2)
                total_roc_auc.append(roc_info_2[2])
                m_global_roc.appendROC(roc_info_2[0],roc_info_2[1])
            
            # Print data
            # Filename
            t = figure(figsize=(5,1))
            a = t.add_subplot(1,1,1)
            a.text(0,0,"File: "+img_name,fontsize=22,fontweight='bold',ha='left',va='center')
            a.set_axis_off()
            tight_layout()
            
            # First row of data
            fig = figure(figsize=(20,15))
            f1 = fig.add_subplot(3,4,1)
            f1.set_title("Original Image w/ 5x5 Average Filter")
            imshow(img_original,cm.gray)
            f1.set_axis_off()
            f2 = fig.add_subplot(3,4,2)
            f2.set_title("Generated Mask")
            imshow(img_mask,cm.gray)
            f2.set_axis_off()
            f3 = fig.add_subplot(3,4,3)
            f3.set_title("Matched Filter Result w/o Threshold")
            imshow(img_matched,cm.gray)
            f3.set_axis_off()
            
            # Second row of data
            f4 = fig.add_subplot(3,4,5)
            f4.set_title("Segmentation Result")
            imshow(img_seg,cm.gray)
            f4.set_axis_off()
            f5 = fig.add_subplot(3,4,6)
            f5.set_title("Ground Truth 1")
            imshow(img_g_t_1,cm.gray)
            f5.set_axis_off()
            f6 = fig.add_subplot(3,4,7)
            f6.plot(roc_info_1[0],roc_info_1[1])
            f6.plot([0, 1], [0, 1], 'k--')
            f6.set_xlim([0.0, 1.0])
            f6.set_ylim([0.0, 1.0])
            f6.set_xlabel('False Positive Rate')
            f6.set_ylabel('True Positive Rate')
            f6.set_title('Receiver operating characteristic')
            f7 = fig.add_subplot(3,4,8)
            f7.set_title("Statistics")
            f7.text(0,0.9,"ROC AUC: "+str(roc_info_1[2]),fontsize=18,ha='left',va='center')
            f7.text(0,0.8,"ACC (w/ otsu): "+str(otsu_acc_1),fontsize=18,ha='left',va='center')
            f7.text(0,0.7,"TPR (w/ otsu): "+str(otsu_tpr_1),fontsize=18,ha='left',va='center')
            f7.text(0,0.6,"FPR (w/ otsu): "+str(otsu_fpr_1),fontsize=18,ha='left',va='center')
            f7.set_axis_off()
            
            # If 2nd observer Ground Truth / Gold Standard is available
            if (r_data[1]["ground_truth_2"] != None):
                # Third Row of data
                f8 = fig.add_subplot(3,4,9)
                f8.set_title("Segmentation Result")
                imshow(img_seg,cm.gray)
                f8.set_axis_off()
                f9 = fig.add_subplot(3,4,10)
                f9.set_title("Ground Truth 2")
                imshow(img_g_t_2,cm.gray)
                f9.set_axis_off()
                f10 = fig.add_subplot(3,4,11)
                f10.plot(roc_info_2[0],roc_info_2[1])
                f10.plot([0, 1], [0, 1], 'k--')
                f10.set_xlim([0.0, 1.0])
                f10.set_ylim([0.0, 1.0])
                f10.set_xlabel('False Positive Rate')
                f10.set_ylabel('True Positive Rate')
                f10.set_title('Receiver operating characteristic')
                f11 = fig.add_subplot(3,4,12)
                f11.set_title("Statistics")
                f11.text(0,0.9,"ROC AUC: "+str(roc_info_2[2]),fontsize=18,ha='left',va='center')
                f11.text(0,0.8,"ACC (w/ otsu): "+str(otsu_acc_2),fontsize=18,ha='left',va='center')
                f11.text(0,0.7,"TPR (w/ otsu): "+str(otsu_tpr_2),fontsize=18,ha='left',va='center')
                f11.text(0,0.6,"FPR (w/ otsu): "+str(otsu_fpr_2),fontsize=18,ha='left',va='center')
                f11.set_axis_off()
                tight_layout()
            tight_layout()
    
    final_roc_data = m_global_roc.getGlobalROC()
    
    f_t_final = figure(figsize=(15,1))
    ax = f_t_final.add_subplot(1,1,1)
    ax.text(0,0,"Global Statistics: ",fontsize=30,fontweight='bold',ha='left',va='center')
    ax.set_axis_off()
    
    f_s_final = figure(figsize=(15,5))
    sp = f_s_final.add_subplot(1,3,1)
    tt = pd.DataFrame(zip(total_accuracy, total_roc_auc), columns=['ACC', 'AUC'])
    boxplot(tt.values);
    sp.spines['top'].set_visible(False)
    sp.spines['right'].set_visible(False)
    sp.spines['bottom'].set_visible(False)
    sp.xaxis.set_ticklabels(tt.columns)
    
    gr_graph = f_s_final.add_subplot(1,3,2)
    gr_graph.plot(final_roc_data[0],final_roc_data[1])
    gr_graph.plot([0, 1], [0, 1], 'k--')
    gr_graph.set_xlim([0.0, 1.0])
    gr_graph.set_ylim([0.0, 1.0])
    gr_graph.set_xlabel('False Positive Rate')
    gr_graph.set_ylabel('True Positive Rate')
    gr_graph.set_title('Receiver operating characteristic')
    
    t = f_s_final.add_subplot(1,3,3)
    t.set_title("Algorithm execution time mean")
    t.text(0,0.9,"Exec time: "+str(mean(algorithm_time))+" s",fontsize=18,ha='left',va='center')
    t.text(0,0.8,"Global AUROC: "+str(final_roc_data[2]),fontsize=18,ha='left',va='center')
    
    t.text(0,0.7,"TPR Mean: "+str(mean(total_otsu_tpr)),fontsize=18,ha='left',va='center')
    t.text(0,0.6,"TPR Max: "+str(max(total_otsu_tpr)),fontsize=18,ha='left',va='center')
    t.text(0,0.5,"TPR Min: "+str(min(total_otsu_tpr)),fontsize=18,ha='left',va='center')
    
    t.text(0,0.4,"FPR Mean: "+str(mean(total_otsu_fpr)),fontsize=18,ha='left',va='center')
    t.text(0,0.3,"FPR Max: "+str(max(total_otsu_fpr)),fontsize=18,ha='left',va='center')
    t.text(0,0.2,"FPR Min: "+str(min(total_otsu_fpr)),fontsize=18,ha='left',va='center')
    t.set_axis_off()
    
    tight_layout()

Main Program


In [171]:
if __name__ == "__main__":
    # Generate kernels for given sigma, L and angle parameters
    mf_kernels = generateMatchedFilterKernels(2,9,15)
    # Segmentation of blood vessels in retinal images for both test and train datasets
    # Returns list of RetinalImage objects from which the relevant information may be extracted
    img_group = {}
    img_group["test"] = ProcessSegmentationOfFolder("./images/DRIVE/test/",2.5,mf_kernels)
    img_group["training"] = ProcessSegmentationOfFolder("./images/DRIVE/training/",2.5,mf_kernels)
    # Show the information
    ProcessInformationPrint(img_group,mf_kernels)



In [ ]: