In [1]:
import pandas as pd
import numpy as np
import cv2
import tifffile as tif
import matplotlib.pyplot as plt

from importlib import reload
import os, gc, time, pdb

import dilated_nets
import utils
import global_vars


Using TensorFlow backend.

In [2]:
def get_patches_test(im, label_edge, buff):
    ptch = list()
    shp = im.shape
    
    num_seg_y = shp[0] // label_edge
    num_seg_x = shp[1] // label_edge
    
    im = np.pad(im,((buff,buff), (buff, buff), (0,0)), mode='reflect')
    
    for i in range(0,num_seg_y*label_edge, label_edge):
        for j in range(0,num_seg_x*label_edge, label_edge):
            
            tmp = im[i: (i+label_edge+2*buff), j: (j+label_edge+2*buff), :]
            
            ptch.append(tmp)

    tmp = im[-(label_edge + 2*buff):, -(label_edge + 2*buff):, :]
    
    ptch.append(tmp)
            
    for i in range(0,num_seg_y*label_edge , label_edge):
        tmp = im[i:i + label_edge + 2*buff, -(label_edge + 2*buff):, :]
        ptch.append(tmp)
    
    for j in range(0,num_seg_x*label_edge, label_edge):
        tmp = im[-(label_edge + 2*buff):, j:j + label_edge + 2*buff, :]
        ptch.append(tmp)
    return ptch

In [3]:
def reconstruct_im(preds_test, preds_mask, label_edge):
    shp = preds_mask.shape
    num_seg_y = shp[0] // label_edge
    num_seg_x = shp[1] // label_edge
    
    counter = 0
    for i in range(0,num_seg_y*label_edge , label_edge):
        for j in range(0,num_seg_x*label_edge, label_edge):
            
            preds_mask[i:i + label_edge, j:j+ label_edge, :] = preds_test[counter,:,:,:]
            counter += 1
    
    preds_mask[-label_edge:, -label_edge:, :] = preds_test[counter,:,:,:]
    counter +=1
    
    for i in range(0,num_seg_y*label_edge , label_edge):
        preds_mask[i:i + label_edge, -label_edge:] = preds_test[counter, :,:,:]
        counter += 1
    
    for j in range(0,num_seg_x*label_edge, label_edge):
        preds_mask[-label_edge:, j:j + label_edge ] = preds_test[counter, :,:,:]
        counter += 1

    return preds_mask

In [4]:
ssubm = pd.read_csv(os.path.join(global_vars.DATA_DIR, 'sample_submission.csv'))
test_names = ssubm['ImageId'].unique()
train_names = utils.load_train_names()

In [5]:
dilated_nets = reload(dilated_nets)
model = dilated_nets.atr_tiny_top(72,48, 8,9)

In [6]:
model.load_weights('/media/d/ssd2/dstl/weights/196__dilated16x16_new_[0.67701916252661465, 0.010214600318620561, 0.77229527405285769, 0.24176935810189443, 0.59955639624761159, 0.92014752128027399, 0.84108460614381331, 0.0]')

In [7]:
for name in train_names:
    im_m = utils.load_m(name).astype(np.float64)
        
    im_m = ((im_m/((2.0**11)-1)) - 0.5)*2
    im_m = cv2.resize(im_m, (835, 835), interpolation=0)  
    shp = im_m.shape

    patches = get_patches_test(im_m, 48,72)
    del im_m
    
    x_test = np.array(patches).astype(np.float32)
    del patches
    
    gc.collect()
    preds = model.predict(x_test)
    gc.collect()
    
    pred_mask = np.zeros((shp[0],shp[1], 9))
    pred_mask = reconstruct_im(preds, pred_mask, 48)
    
    tif.imsave(file=os.path.join(global_vars.DATA_DIR, 'bin_masks','train_16x16', name+'.tif'), data=pred_mask)
    gc.collect()
    
    print(name)
    #break


6040_2_2
6120_2_2
6120_2_0
6090_2_0
6040_1_3
6040_1_0
6100_1_3
6010_4_2
6110_4_0
6140_3_1
6110_1_2
6100_2_3
6150_2_3
6160_2_1
6140_1_2
6110_3_1
6010_4_4
6170_2_4
6170_4_1
6170_0_4
6060_2_3
6070_2_3
6010_1_2
6040_4_4
6100_2_2

In [8]:
%%time
for name in test_names:
    im_m = utils.load_m(name).astype(np.float64)
        
    im_m = ((im_m/((2.0**11)-1)) - 0.5)*2
    im_m = cv2.resize(im_m, (835, 835), interpolation=0)  
    shp = im_m.shape

    patches = get_patches_test(im_m, 48,72)
    del im_m
    
    x_test = np.array(patches).astype(np.float32)
    del patches
    
    gc.collect()
    preds = model.predict(x_test)
    gc.collect()
    
    pred_mask = np.zeros((shp[0],shp[1], 9))
    pred_mask = reconstruct_im(preds, pred_mask, 48)
    
    tif.imsave(file=os.path.join(global_vars.DATA_DIR, 'bin_masks','test_16x16', name+'.tif'), data=pred_mask)
    gc.collect()
    
    print(name)
    #break


6120_2_4
6120_2_3
6120_2_1
6180_2_4
6180_2_1
6180_2_0
6180_2_3
6180_2_2
6180_0_3
6180_0_2
6180_0_1
6180_0_0
6180_0_4
6080_4_4
6080_4_2
6080_4_3
6080_4_0
6080_4_1
6090_4_1
6090_4_0
6090_4_3
6090_4_2
6090_4_4
6180_4_4
6180_4_3
6180_4_2
6180_4_1
6180_4_0
6160_3_2
6160_3_3
6160_3_0
6160_3_1
6160_3_4
6080_2_4
6080_2_0
6080_2_1
6080_2_2
6080_2_3
6080_0_2
6080_0_3
6080_0_0
6080_0_1
6080_0_4
6010_0_4
6010_0_1
6010_0_0
6010_0_3
6010_0_2
6010_2_3
6010_2_2
6010_2_1
6010_2_0
6010_2_4
6010_4_1
6170_3_4
6010_4_3
6170_3_1
6170_3_0
6170_3_3
6170_3_2
6170_1_3
6170_1_2
6170_1_1
6170_1_0
6170_1_4
6130_4_2
6130_4_3
6130_4_0
6130_4_1
6130_4_4
6150_3_3
6150_3_2
6150_3_1
6150_3_0
6150_3_4
6130_2_0
6130_2_1
6130_2_2
6130_2_3
6130_2_4
6130_0_4
6130_0_2
6130_0_3
6130_0_0
6130_0_1
6150_1_4
6150_1_1
6150_1_0
6150_1_3
6150_1_2
6180_3_4
6180_3_0
6180_3_1
6180_3_2
6180_3_3
6180_1_2
6180_1_3
6180_1_0
6180_1_1
6180_1_4
6010_1_4
6010_1_0
6010_1_1
6010_1_2
6010_1_3
6010_3_2
6010_3_3
6010_3_0
6010_3_1
6010_3_4
6020_0_4
6020_0_0
6020_0_1
6020_0_2
6020_0_3
6020_2_2
6020_2_3
6020_2_0
6020_2_1
6020_2_4
6020_4_0
6020_4_1
6020_4_2
6020_4_3
6020_4_4
6110_0_1
6150_0_4
6150_0_0
6150_0_1
6150_0_2
6150_0_3
6100_2_1
6100_2_0
6150_2_2
6150_2_0
6150_2_1
6150_2_4
6170_4_2
6050_3_1
6170_4_3
6150_4_0
6150_4_1
6150_4_2
6150_4_3
6150_4_4
6170_4_0
6070_3_4
6070_3_0
6070_3_1
6070_3_2
6070_3_3
6070_1_2
6070_1_3
6070_1_0
6070_1_1
6070_1_4
6120_0_3
6110_1_1
6110_1_0
6110_1_3
6110_1_4
6110_3_4
6110_3_3
6110_3_2
6110_3_0
6060_2_1
6140_3_3
6020_1_4
6020_1_1
6020_1_0
6020_1_3
6020_1_2
6020_3_3
6020_3_2
6020_3_1
6020_3_0
6020_3_4
6050_2_4
6050_2_3
6050_2_2
6050_2_1
6050_2_0
6100_3_4
6100_3_0
6100_3_1
6100_3_2
6100_3_3
6060_1_1
6060_1_0
6060_1_3
6060_1_2
6060_1_4
6060_3_4
6060_3_3
6060_3_2
6060_3_1
6060_3_0
6070_2_4
6070_2_1
6070_2_0
6070_2_3
6070_2_2
6040_2_0
6040_2_1
6040_2_3
6040_2_4
6070_0_3
6070_0_2
6070_0_1
6070_0_0
6070_0_4
6100_4_3
6040_0_4
6040_0_2
6040_0_3
6040_0_0
6040_0_1
6050_0_1
6050_0_0
6050_0_3
6050_0_2
6050_0_4
6070_4_4
6070_4_3
6070_4_2
6070_4_1
6070_4_0
6140_3_0
6140_3_2
6100_4_4
6140_3_4
6100_4_2
6100_4_1
6100_4_0
6140_0_1
6140_1_4
6140_1_3
6140_1_0
6140_1_1
6040_4_2
6040_4_3
6040_4_0
6040_4_1
6040_4_4
6010_4_0
6110_0_0
6050_4_4
6110_0_2
6110_0_3
6110_0_4
6050_4_0
6050_4_3
6050_4_2
6030_3_0
6030_3_1
6030_3_2
6030_3_3
6030_3_4
6110_2_4
6110_2_2
6110_2_3
6110_2_0
6110_2_1
6120_4_1
6120_4_0
6120_4_3
6120_4_2
6110_4_4
6110_4_1
6110_4_2
6110_4_3
6060_4_4
6060_4_0
6060_4_1
6060_4_2
6060_4_3
6100_2_4
6050_3_4
6050_3_2
6050_3_3
6050_3_0
6100_2_2
6100_0_3
6100_0_2
6100_0_1
6100_0_0
6100_0_4
6060_0_0
6060_0_1
6060_0_2
6060_0_3
6060_0_4
6160_1_0
6060_2_4
6060_2_2
6060_2_0
6160_1_1
6040_3_1
6040_3_0
6040_3_3
6040_3_2
6040_3_4
6090_0_4
6160_1_3
6090_0_1
6090_0_0
6090_0_3
6090_0_2
6090_2_3
6090_2_2
6090_2_1
6040_1_4
6040_1_2
6040_1_1
6090_2_4
6030_0_4
6030_0_3
6030_0_2
6030_0_1
6030_0_0
6050_1_0
6050_1_1
6050_1_2
6050_1_3
6050_1_4
6120_1_0
6120_1_1
6120_1_2
6120_1_3
6120_1_4
6120_0_4
6160_1_4
6130_3_0
6140_0_4
6140_0_3
6140_0_2
6160_1_2
6140_0_0
6120_3_4
6120_3_2
6120_3_3
6120_3_0
6120_3_1
6100_1_2
6100_1_0
6100_1_1
6100_1_4
6140_4_3
6140_4_2
6140_4_1
6140_4_0
6140_4_4
6160_0_4
6140_2_1
6140_2_0
6140_2_3
6140_2_2
6140_2_4
6030_4_3
6030_4_2
6030_4_1
6030_4_0
6030_4_4
6130_3_4
6030_2_1
6030_2_0
6030_2_3
6030_2_2
6030_2_4
6090_3_2
6090_3_3
6090_3_0
6090_3_1
6090_3_4
6080_3_4
6170_4_4
6080_3_1
6080_3_0
6080_3_3
6080_3_2
6160_4_1
6160_4_0
6160_4_3
6160_4_2
6160_4_4
6080_1_3
6080_1_2
6080_1_1
6080_1_0
6080_1_4
6170_2_0
6170_2_1
6170_2_2
6170_2_3
6170_0_2
6170_0_3
6170_0_0
6170_0_1
6090_1_4
6120_4_4
6090_1_0
6090_1_1
6090_1_2
6090_1_3
6120_0_1
6120_0_0
6030_1_4
6120_0_2
6030_1_2
6030_1_3
6030_1_0
6030_1_1
6130_3_1
6050_4_1
6130_3_3
6130_3_2
6160_0_1
6160_0_0
6160_0_3
6160_0_2
6160_2_3
6160_2_2
6160_2_0
6160_2_4
6130_1_4
6130_1_3
6130_1_2
6130_1_1
6130_1_0
CPU times: user 29min 59s, sys: 5min 39s, total: 35min 38s
Wall time: 35min 20s

In [ ]: