Goal

Create a notebook that segments cat faces using SegNet


In [1]:
%matplotlib inline
import numpy as np
import pandas as pd
import theano
import sys

from skimage.io import imread
from matplotlib import pyplot as plt

import os
os.environ['KERAS_BACKEND'] = 'theano'
os.environ['THEANO_FLAGS'] = 'mode=FAST_RUN, device=gpu0, floatX=float32, optimizer=fast_compile'

from keras import models
from keras.optimizers import SGD


Using Theano backend.

Load Data

Data comes from mom and I labeling, and google scraping using bin/scrape_google_images.py

Training:

Pictures/train

Testing:

Pictures/test

Define Network

SegNet to start with: https://github.com/imlab-uiip/keras-segnet


In [2]:
path = 'Pictures/'
img_w = 256
img_h = 256
n_labels = 2

n_train = 493
n_test = 49

WITH_AUGMENTATION = 0
TRAIN_ENABLED = 1

In [3]:
from keras import models
from keras.layers.core import Activation, Reshape, Permute
from keras.layers.convolutional import Conv2D, MaxPooling2D, UpSampling2D
from keras.layers.normalization import BatchNormalization
from keras import backend as K

import json

img_w = 256
img_h = 256
n_labels = 2


if K.image_data_format() == 'channels_first':
    input_shape = (3, img_w, img_h)
else:
    input_shape = (img_w, img_h, 3)
    
kernel = 3

encoding_layers = [
    Conv2D(64, kernel, padding='same', input_shape=input_shape),
    BatchNormalization(),
    Activation('relu'),
    Conv2D(64, kernel, padding='same'),
    BatchNormalization(),
    Activation('relu'),
    MaxPooling2D(),

    Conv2D(128, kernel, padding='same'),
    BatchNormalization(),
    Activation('relu'),
    Conv2D(128, kernel, padding='same'),
    BatchNormalization(),
    Activation('relu'),
    MaxPooling2D(),

    Conv2D(256, kernel, padding='same'),
    BatchNormalization(),
    Activation('relu'),
    Conv2D(256, kernel, padding='same'),
    BatchNormalization(),
    Activation('relu'),
    Conv2D(256, kernel, padding='same'),
    BatchNormalization(),
    Activation('relu'),
    MaxPooling2D(),

    Conv2D(512, kernel, padding='same'),
    BatchNormalization(),
    Activation('relu'),
    Conv2D(512, kernel, padding='same'),
    BatchNormalization(),
    Activation('relu'),
    Conv2D(512, kernel, padding='same'),
    BatchNormalization(),
    Activation('relu'),
    MaxPooling2D(),

    Conv2D(512, kernel, padding='same'),
    BatchNormalization(),
    Activation('relu'),
    Conv2D(512, kernel, padding='same'),
    BatchNormalization(),
    Activation('relu'),
    Conv2D(512, kernel, padding='same'),
    BatchNormalization(),
    Activation('relu'),
    MaxPooling2D(),
]

autoencoder = models.Sequential()
autoencoder.encoding_layers = encoding_layers

for l in autoencoder.encoding_layers:
    autoencoder.add(l)

decoding_layers = [
    UpSampling2D(),
    Conv2D(512, kernel, padding='same'),
    BatchNormalization(),
    Activation('relu'),
    Conv2D(512, kernel, padding='same'),
    BatchNormalization(),
    Activation('relu'),
    Conv2D(512, kernel, padding='same'),
    BatchNormalization(),
    Activation('relu'),

    UpSampling2D(),
    Conv2D(512, kernel, padding='same'),
    BatchNormalization(),
    Activation('relu'),
    Conv2D(512, kernel, padding='same'),
    BatchNormalization(),
    Activation('relu'),
    Conv2D(256, kernel, padding='same'),
    BatchNormalization(),
    Activation('relu'),

    UpSampling2D(),
    Conv2D(256, kernel, padding='same'),
    BatchNormalization(),
    Activation('relu'),
    Conv2D(256, kernel, padding='same'),
    BatchNormalization(),
    Activation('relu'),
    Conv2D(128, kernel, padding='same'),
    BatchNormalization(),
    Activation('relu'),

    UpSampling2D(),
    Conv2D(128, kernel, padding='same'),
    BatchNormalization(),
    Activation('relu'),
    Conv2D(64, kernel, padding='same'),
    BatchNormalization(),
    Activation('relu'),

    UpSampling2D(),
    Conv2D(64, kernel, padding='same'),
    BatchNormalization(),
    Activation('relu'),
    Conv2D(n_labels, 1, padding='valid'),
    BatchNormalization(),
]
autoencoder.decoding_layers = decoding_layers
for l in autoencoder.decoding_layers:
    autoencoder.add(l)

In [4]:
print(autoencoder.summary())


_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_1 (Conv2D)            (None, 256, 256, 64)      1792      
_________________________________________________________________
batch_normalization_1 (Batch (None, 256, 256, 64)      256       
_________________________________________________________________
activation_1 (Activation)    (None, 256, 256, 64)      0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 256, 256, 64)      36928     
_________________________________________________________________
batch_normalization_2 (Batch (None, 256, 256, 64)      256       
_________________________________________________________________
activation_2 (Activation)    (None, 256, 256, 64)      0         
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 128, 128, 64)      0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 128, 128, 128)     73856     
_________________________________________________________________
batch_normalization_3 (Batch (None, 128, 128, 128)     512       
_________________________________________________________________
activation_3 (Activation)    (None, 128, 128, 128)     0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 128, 128, 128)     147584    
_________________________________________________________________
batch_normalization_4 (Batch (None, 128, 128, 128)     512       
_________________________________________________________________
activation_4 (Activation)    (None, 128, 128, 128)     0         
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 64, 64, 128)       0         
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 64, 64, 256)       295168    
_________________________________________________________________
batch_normalization_5 (Batch (None, 64, 64, 256)       1024      
_________________________________________________________________
activation_5 (Activation)    (None, 64, 64, 256)       0         
_________________________________________________________________
conv2d_6 (Conv2D)            (None, 64, 64, 256)       590080    
_________________________________________________________________
batch_normalization_6 (Batch (None, 64, 64, 256)       1024      
_________________________________________________________________
activation_6 (Activation)    (None, 64, 64, 256)       0         
_________________________________________________________________
conv2d_7 (Conv2D)            (None, 64, 64, 256)       590080    
_________________________________________________________________
batch_normalization_7 (Batch (None, 64, 64, 256)       1024      
_________________________________________________________________
activation_7 (Activation)    (None, 64, 64, 256)       0         
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 32, 32, 256)       0         
_________________________________________________________________
conv2d_8 (Conv2D)            (None, 32, 32, 512)       1180160   
_________________________________________________________________
batch_normalization_8 (Batch (None, 32, 32, 512)       2048      
_________________________________________________________________
activation_8 (Activation)    (None, 32, 32, 512)       0         
_________________________________________________________________
conv2d_9 (Conv2D)            (None, 32, 32, 512)       2359808   
_________________________________________________________________
batch_normalization_9 (Batch (None, 32, 32, 512)       2048      
_________________________________________________________________
activation_9 (Activation)    (None, 32, 32, 512)       0         
_________________________________________________________________
conv2d_10 (Conv2D)           (None, 32, 32, 512)       2359808   
_________________________________________________________________
batch_normalization_10 (Batc (None, 32, 32, 512)       2048      
_________________________________________________________________
activation_10 (Activation)   (None, 32, 32, 512)       0         
_________________________________________________________________
max_pooling2d_4 (MaxPooling2 (None, 16, 16, 512)       0         
_________________________________________________________________
conv2d_11 (Conv2D)           (None, 16, 16, 512)       2359808   
_________________________________________________________________
batch_normalization_11 (Batc (None, 16, 16, 512)       2048      
_________________________________________________________________
activation_11 (Activation)   (None, 16, 16, 512)       0         
_________________________________________________________________
conv2d_12 (Conv2D)           (None, 16, 16, 512)       2359808   
_________________________________________________________________
batch_normalization_12 (Batc (None, 16, 16, 512)       2048      
_________________________________________________________________
activation_12 (Activation)   (None, 16, 16, 512)       0         
_________________________________________________________________
conv2d_13 (Conv2D)           (None, 16, 16, 512)       2359808   
_________________________________________________________________
batch_normalization_13 (Batc (None, 16, 16, 512)       2048      
_________________________________________________________________
activation_13 (Activation)   (None, 16, 16, 512)       0         
_________________________________________________________________
max_pooling2d_5 (MaxPooling2 (None, 8, 8, 512)         0         
_________________________________________________________________
up_sampling2d_1 (UpSampling2 (None, 16, 16, 512)       0         
_________________________________________________________________
conv2d_14 (Conv2D)           (None, 16, 16, 512)       2359808   
_________________________________________________________________
batch_normalization_14 (Batc (None, 16, 16, 512)       2048      
_________________________________________________________________
activation_14 (Activation)   (None, 16, 16, 512)       0         
_________________________________________________________________
conv2d_15 (Conv2D)           (None, 16, 16, 512)       2359808   
_________________________________________________________________
batch_normalization_15 (Batc (None, 16, 16, 512)       2048      
_________________________________________________________________
activation_15 (Activation)   (None, 16, 16, 512)       0         
_________________________________________________________________
conv2d_16 (Conv2D)           (None, 16, 16, 512)       2359808   
_________________________________________________________________
batch_normalization_16 (Batc (None, 16, 16, 512)       2048      
_________________________________________________________________
activation_16 (Activation)   (None, 16, 16, 512)       0         
_________________________________________________________________
up_sampling2d_2 (UpSampling2 (None, 32, 32, 512)       0         
_________________________________________________________________
conv2d_17 (Conv2D)           (None, 32, 32, 512)       2359808   
_________________________________________________________________
batch_normalization_17 (Batc (None, 32, 32, 512)       2048      
_________________________________________________________________
activation_17 (Activation)   (None, 32, 32, 512)       0         
_________________________________________________________________
conv2d_18 (Conv2D)           (None, 32, 32, 512)       2359808   
_________________________________________________________________
batch_normalization_18 (Batc (None, 32, 32, 512)       2048      
_________________________________________________________________
activation_18 (Activation)   (None, 32, 32, 512)       0         
_________________________________________________________________
conv2d_19 (Conv2D)           (None, 32, 32, 256)       1179904   
_________________________________________________________________
batch_normalization_19 (Batc (None, 32, 32, 256)       1024      
_________________________________________________________________
activation_19 (Activation)   (None, 32, 32, 256)       0         
_________________________________________________________________
up_sampling2d_3 (UpSampling2 (None, 64, 64, 256)       0         
_________________________________________________________________
conv2d_20 (Conv2D)           (None, 64, 64, 256)       590080    
_________________________________________________________________
batch_normalization_20 (Batc (None, 64, 64, 256)       1024      
_________________________________________________________________
activation_20 (Activation)   (None, 64, 64, 256)       0         
_________________________________________________________________
conv2d_21 (Conv2D)           (None, 64, 64, 256)       590080    
_________________________________________________________________
batch_normalization_21 (Batc (None, 64, 64, 256)       1024      
_________________________________________________________________
activation_21 (Activation)   (None, 64, 64, 256)       0         
_________________________________________________________________
conv2d_22 (Conv2D)           (None, 64, 64, 128)       295040    
_________________________________________________________________
batch_normalization_22 (Batc (None, 64, 64, 128)       512       
_________________________________________________________________
activation_22 (Activation)   (None, 64, 64, 128)       0         
_________________________________________________________________
up_sampling2d_4 (UpSampling2 (None, 128, 128, 128)     0         
_________________________________________________________________
conv2d_23 (Conv2D)           (None, 128, 128, 128)     147584    
_________________________________________________________________
batch_normalization_23 (Batc (None, 128, 128, 128)     512       
_________________________________________________________________
activation_23 (Activation)   (None, 128, 128, 128)     0         
_________________________________________________________________
conv2d_24 (Conv2D)           (None, 128, 128, 64)      73792     
_________________________________________________________________
batch_normalization_24 (Batc (None, 128, 128, 64)      256       
_________________________________________________________________
activation_24 (Activation)   (None, 128, 128, 64)      0         
_________________________________________________________________
up_sampling2d_5 (UpSampling2 (None, 256, 256, 64)      0         
_________________________________________________________________
conv2d_25 (Conv2D)           (None, 256, 256, 64)      36928     
_________________________________________________________________
batch_normalization_25 (Batc (None, 256, 256, 64)      256       
_________________________________________________________________
activation_25 (Activation)   (None, 256, 256, 64)      0         
_________________________________________________________________
conv2d_26 (Conv2D)           (None, 256, 256, 2)       130       
_________________________________________________________________
batch_normalization_26 (Batc (None, 256, 256, 2)       8         
=================================================================
Total params: 29,459,018
Trainable params: 29,443,142
Non-trainable params: 15,876
_________________________________________________________________
None

In [5]:
autoencoder.add(Reshape((n_labels, img_h * img_w)))
autoencoder.add(Permute((2, 1)))
autoencoder.add(Activation('softmax'))

with open('model_5l.json', 'w') as outfile:
    outfile.write(json.dumps(json.loads(autoencoder.to_json()), indent=2))
print('Compiled: OK')


Compiled: OK

In [6]:
from PIL import Image

def prep_data(path, mode):
    assert mode in {'test', 'train'}, \
        'mode should be either \'test\' or \'train\''
    data = []
    label = []
    df = pd.read_csv(path + mode + '.csv')
    n = n_train if mode == 'train' else n_test
    count = 0
    for i, item in df.iterrows():
        if i >= n:
            print("broken")
            break
        img, gt = [imread(item[0])], np.clip(imread(item[1]), 0, 1)
        data.append(img)
        label.append(label_map(gt))
        sys.stdout.write('\r')
        sys.stdout.write(mode + ": [%-20s] %d%%" % ('=' * int(20. * (i + 1) / n - 1) + '>',
                                                    int(100. * (i + 1) / n)))
        sys.stdout.flush()
        count = count + 1
    sys.stdout.write('\r')
    sys.stdout.flush()
    
    data = np.array(data)
    data = data.reshape((data.shape[0] * data.shape[1], data.shape[2], data.shape[3], data.shape[4]))
    print("There are counts: ", str(count))
    label = np.array(label).reshape((n, img_h * img_w, n_labels))
    print(label.shape)

    print(mode + ': OK')
    print('\tshapes: {}, {}'.format(data.shape, label.shape))
    print('\ttypes:  {}, {}'.format(data.dtype, label.dtype))
    print('\tmemory: {}, {} MB'.format(data.nbytes / 1048576, label.nbytes / 1048576))

    return data, label

In [7]:
def label_map(labels):
    label_map = np.zeros([img_h, img_w, n_labels])    
    for r in range(img_h):
        for c in range(img_w):
            label_map[r, c, labels[r][c]] = 1
    return label_map

In [8]:
train_data, train_label = prep_data('../Pictures/', 'train')


There are counts:  493=====>] 100%
(493, 65536, 2)
train: OK
	shapes: (493, 256, 256, 3), (493, 65536, 2)
	types:  uint8, float64
	memory: 92.4375, 493.0 MB

In [9]:
test_data, test_label = prep_data('../Pictures/', 'test')


There are counts:  49=====>] 100%
(49, 65536, 2)
test: OK
	shapes: (49, 256, 256, 3), (49, 65536, 2)
	types:  uint8, float64
	memory: 9.1875, 49.0 MB

In [14]:
if 0:
  datagen = ImageDataGenerator(
        rotation_range=180,
        width_shift_range=0.2,
        height_shift_range=0.2,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True,
        fill_mode='nearest')

  img = load_img('data/train/cats/cat.0.jpg')  # this is a PIL image
  x = img_to_array(img)  # this is a Numpy array with shape (3, 150, 150)
  x = x.reshape((1,) + x.shape)  # this is a Numpy array with shape (1, 3, 150, 150)

  # the .flow() command below generates batches of randomly transformed images
  # and saves the results to the `preview/` directory
  i = 0
  for batch in datagen.flow(x, batch_size=1,
                            save_to_dir='preview', save_prefix='cat', save_format='jpeg'):
      i += 1
      if i > 20:
          break  # otherwise the generator would loop indefinitely

In [15]:
from keras.callbacks import ModelCheckpoint

if WITH_AUGMENTATION:
  # augment the input images and their masks and save them into the tmp_training folder
  print("TODO")
else:
  nb_epoch = 50
  batch_size = 18
  autoencoder.compile(optimizer=SGD(lr = 0.1, decay = 1e-6, momentum = 0.9, nesterov = False),
                      loss='categorical_crossentropy',
                      metrics=['accuracy'])
  print(train_data.shape)
  print(train_label.shape)

  checkpoint = ModelCheckpoint('.', monitor='val_acc', verbose=1, save_best_only=True, mode='max')


(493, 256, 256, 3)
(493, 65536, 2)

In [16]:
if TRAIN_ENABLED:
  history = autoencoder.fit(train_data, train_label, batch_size=batch_size, epochs=nb_epoch, verbose=1,  callbacks=[checkpoint])


Epoch 1/50
493/493 [==============================] - 3565s - loss: 0.4671 - acc: 0.7774     
/usr/local/lib/python3.6/site-packages/keras/callbacks.py:390: RuntimeWarning: Can save best model only with val_acc available, skipping.
  'skipping.' % (self.monitor), RuntimeWarning)
Epoch 2/50
493/493 [==============================] - 3576s - loss: 0.3805 - acc: 0.8194     
Epoch 3/50
493/493 [==============================] - 3577s - loss: 0.3649 - acc: 0.8315     
Epoch 4/50
493/493 [==============================] - 3576s - loss: 0.3510 - acc: 0.8374     
Epoch 5/50
493/493 [==============================] - 3576s - loss: 0.3449 - acc: 0.8378     
Epoch 6/50
493/493 [==============================] - 3578s - loss: 0.3425 - acc: 0.8382     
Epoch 7/50
493/493 [==============================] - 3577s - loss: 0.3319 - acc: 0.8447     
Epoch 8/50
493/493 [==============================] - 3578s - loss: 0.3230 - acc: 0.8495     
Epoch 9/50
493/493 [==============================] - 3577s - loss: 0.3189 - acc: 0.8532     
Epoch 10/50
493/493 [==============================] - 3579s - loss: 0.3140 - acc: 0.8557     
Epoch 11/50
493/493 [==============================] - 3584s - loss: 0.3206 - acc: 0.8517     
Epoch 12/50
493/493 [==============================] - 3591s - loss: 0.3075 - acc: 0.8592     
Epoch 13/50
493/493 [==============================] - 3590s - loss: 0.2977 - acc: 0.8644     
Epoch 14/50
493/493 [==============================] - 3589s - loss: 0.2956 - acc: 0.8663     
Epoch 15/50
493/493 [==============================] - 3590s - loss: 0.2933 - acc: 0.8681     
Epoch 16/50
493/493 [==============================] - 3593s - loss: 0.2777 - acc: 0.8758     
Epoch 17/50
493/493 [==============================] - 3585s - loss: 0.2703 - acc: 0.8788     
Epoch 18/50
493/493 [==============================] - 3585s - loss: 0.2756 - acc: 0.8771     
Epoch 19/50
493/493 [==============================] - 3587s - loss: 0.2661 - acc: 0.8812     
Epoch 20/50
493/493 [==============================] - 3585s - loss: 0.2533 - acc: 0.8891     
Epoch 21/50
493/493 [==============================] - 3661s - loss: 0.2558 - acc: 0.8878     
Epoch 22/50
493/493 [==============================] - 3666s - loss: 0.2520 - acc: 0.8886     
Epoch 23/50
493/493 [==============================] - 3595s - loss: 0.2455 - acc: 0.8918     
Epoch 24/50
493/493 [==============================] - 3594s - loss: 0.2320 - acc: 0.8989     
Epoch 25/50
493/493 [==============================] - 3594s - loss: 0.2192 - acc: 0.9043     
Epoch 26/50
493/493 [==============================] - 3594s - loss: 0.2145 - acc: 0.9070     
Epoch 27/50
493/493 [==============================] - 3595s - loss: 0.2092 - acc: 0.9084     
Epoch 28/50
493/493 [==============================] - 3595s - loss: 0.2103 - acc: 0.9088     
Epoch 29/50
493/493 [==============================] - 3595s - loss: 0.1969 - acc: 0.9148     
Epoch 30/50
493/493 [==============================] - 3594s - loss: 0.1995 - acc: 0.9142     
Epoch 31/50
493/493 [==============================] - 3596s - loss: 0.1863 - acc: 0.9197     
Epoch 32/50
493/493 [==============================] - 3596s - loss: 0.1778 - acc: 0.9231     
Epoch 33/50
493/493 [==============================] - 3598s - loss: 0.1769 - acc: 0.9243     
Epoch 34/50
493/493 [==============================] - 3594s - loss: 0.1704 - acc: 0.9268     
Epoch 35/50
493/493 [==============================] - 3595s - loss: 0.1655 - acc: 0.9289     
Epoch 36/50
493/493 [==============================] - 3597s - loss: 0.1546 - acc: 0.9336     
Epoch 37/50
493/493 [==============================] - 3594s - loss: 0.1481 - acc: 0.9365     
Epoch 38/50
493/493 [==============================] - 3595s - loss: 0.1494 - acc: 0.9363     
Epoch 39/50
493/493 [==============================] - 3597s - loss: 0.1417 - acc: 0.9391     
Epoch 40/50
493/493 [==============================] - 3596s - loss: 0.1286 - acc: 0.9453     
Epoch 41/50
493/493 [==============================] - 3598s - loss: 0.1334 - acc: 0.9434     
Epoch 42/50
493/493 [==============================] - 3599s - loss: 0.1264 - acc: 0.9459     
Epoch 43/50
493/493 [==============================] - 3599s - loss: 0.1215 - acc: 0.9485     
Epoch 44/50
493/493 [==============================] - 3598s - loss: 0.1201 - acc: 0.9489     
Epoch 45/50
493/493 [==============================] - 3613s - loss: 0.1194 - acc: 0.9491     
Epoch 46/50
493/493 [==============================] - 3754s - loss: 0.1142 - acc: 0.9513     
Epoch 47/50
493/493 [==============================] - 3662s - loss: 0.1107 - acc: 0.9532     
Epoch 48/50
493/493 [==============================] - 3600s - loss: 0.0992 - acc: 0.9576     
Epoch 49/50
493/493 [==============================] - 3601s - loss: 0.1041 - acc: 0.9561     
Epoch 50/50
493/493 [==============================] - 3601s - loss: 0.1348 - acc: 0.9425     

In [17]:
autoencoder.save_weights('model_5l_weight_ep50_384_auto.hdf5')

In [18]:
autoencoder.load_weights('model_5l_weight_ep50_384_auto.hdf5')

from keras.utils.vis_utils import plot_model

plot_model(autoencoder, to_file='model_384_auto.png', show_shapes=True)

In [19]:
autoencoder.compile(optimizer=SGD(lr = 0.1, decay = 1e-6, momentum = 0.9, nesterov = False),
                     loss='categorical_crossentropy',
                     metrics=['accuracy'])
score = autoencoder.evaluate(test_data, test_label, verbose=0)

print('Test score:',  score[0])
print('Test accuracy:', score[1])


Test score: 0.306264275191
Test accuracy: 0.880756606861

In [20]:
output = autoencoder.predict_proba(test_data, verbose=0)
output = output.reshape((output.shape[0], img_h, img_w, n_labels))

In [21]:
print(output.shape)


(49, 256, 256, 2)

In [22]:
print(output)


[[[[  9.55174088e-01   4.48259115e-02]
   [  9.47924137e-01   5.20759337e-02]
   [  9.77878332e-01   2.21216958e-02]
   ..., 
   [  7.39570975e-01   2.60429054e-01]
   [  7.30044246e-01   2.69955695e-01]
   [  7.41879702e-01   2.58120298e-01]]

  [[  7.31947780e-01   2.68052250e-01]
   [  7.42395580e-01   2.57604480e-01]
   [  7.33038962e-01   2.66961068e-01]
   ..., 
   [  5.58610976e-01   4.41389024e-01]
   [  4.41110104e-01   5.58889925e-01]
   [  4.95685726e-01   5.04314303e-01]]

  [[  9.74148393e-01   2.58516166e-02]
   [  9.68751907e-01   3.12480982e-02]
   [  9.77363825e-01   2.26362366e-02]
   ..., 
   [  6.94656432e-01   3.05343539e-01]
   [  6.82779729e-01   3.17220271e-01]
   [  6.96286082e-01   3.03713888e-01]]

  ..., 
  [[  9.99622107e-01   3.77884629e-04]
   [  9.99667883e-01   3.32111871e-04]
   [  9.99612868e-01   3.87117238e-04]
   ..., 
   [  9.98562872e-01   1.43717916e-03]
   [  9.95829284e-01   4.17071767e-03]
   [  9.95300174e-01   4.69975919e-03]]

  [[  9.54465985e-01   4.55339700e-02]
   [  9.57260847e-01   4.27390970e-02]
   [  9.79070663e-01   2.09292844e-02]
   ..., 
   [  9.99523759e-01   4.76244342e-04]
   [  9.99507666e-01   4.92310908e-04]
   [  9.99508023e-01   4.91915911e-04]]

  [[  9.99478638e-01   5.21330803e-04]
   [  9.99485493e-01   5.14525105e-04]
   [  9.99465168e-01   5.34822291e-04]
   ..., 
   [  9.96862173e-01   3.13789840e-03]
   [  9.90729511e-01   9.27052740e-03]
   [  9.89386439e-01   1.06135951e-02]]]


 [[[  7.55261898e-01   2.44738102e-01]
   [  7.37434268e-01   2.62565792e-01]
   [  8.68624270e-01   1.31375656e-01]
   ..., 
   [  9.14857686e-01   8.51423591e-02]
   [  9.23912704e-01   7.60872960e-02]
   [  9.21438575e-01   7.85614923e-02]]

  [[  9.33426321e-01   6.65737242e-02]
   [  9.31979716e-01   6.80202395e-02]
   [  9.38175261e-01   6.18247651e-02]
   ..., 
   [  9.30407405e-01   6.95925727e-02]
   [  8.83784890e-01   1.16215050e-01]
   [  8.90888810e-01   1.09111249e-01]]

  [[  8.45786691e-01   1.54213309e-01]
   [  8.27639878e-01   1.72360122e-01]
   [  8.69440258e-01   1.30559713e-01]
   ..., 
   [  9.10786510e-01   8.92134681e-02]
   [  9.14600730e-01   8.53992924e-02]
   [  9.16326106e-01   8.36739168e-02]]

  ..., 
  [[  9.99457061e-01   5.42966416e-04]
   [  9.99573290e-01   4.26675950e-04]
   [  9.99458373e-01   5.41673333e-04]
   ..., 
   [  9.99059737e-01   9.40238533e-04]
   [  9.97112870e-01   2.88716401e-03]
   [  9.96700227e-01   3.29976459e-03]]

  [[  9.94957030e-01   5.04299253e-03]
   [  9.94944155e-01   5.05584944e-03]
   [  9.97782171e-01   2.21783388e-03]
   ..., 
   [  9.99268234e-01   7.31812266e-04]
   [  9.99193370e-01   8.06659926e-04]
   [  9.99262273e-01   7.37669237e-04]]

  [[  9.99197781e-01   8.02194816e-04]
   [  9.99263108e-01   7.36865855e-04]
   [  9.99199569e-01   8.00404232e-04]
   ..., 
   [  9.97996628e-01   2.00335425e-03]
   [  9.93797958e-01   6.20202068e-03]
   [  9.92767334e-01   7.23270932e-03]]]


 [[[  9.36498880e-01   6.35011867e-02]
   [  9.26522553e-01   7.34775066e-02]
   [  9.65486050e-01   3.45139466e-02]
   ..., 
   [  9.97065365e-01   2.93466565e-03]
   [  9.97225404e-01   2.77466793e-03]
   [  9.97129142e-01   2.87086470e-03]]

  [[  9.97326016e-01   2.67397473e-03]
   [  9.97243524e-01   2.75653298e-03]
   [  9.97372508e-01   2.62751570e-03]
   ..., 
   [  9.73301947e-01   2.66980473e-02]
   [  9.56058443e-01   4.39415425e-02]
   [  9.55001295e-01   4.49987240e-02]]

  [[  9.60159242e-01   3.98407839e-02]
   [  9.52445447e-01   4.75546122e-02]
   [  9.62929308e-01   3.70707214e-02]
   ..., 
   [  9.97468948e-01   2.53108237e-03]
   [  9.97378588e-01   2.62142089e-03]
   [  9.97500837e-01   2.49917782e-03]]

  ..., 
  [[  9.99111354e-01   8.88585171e-04]
   [  9.99228716e-01   7.71274499e-04]
   [  9.99112308e-01   8.87665374e-04]
   ..., 
   [  9.96774137e-01   3.22582969e-03]
   [  9.89711523e-01   1.02884173e-02]
   [  9.89381790e-01   1.06182070e-02]]

  [[  9.69720721e-01   3.02792098e-02]
   [  9.71764147e-01   2.82358732e-02]
   [  9.87635553e-01   1.23643903e-02]
   ..., 
   [  9.98879015e-01   1.12097734e-03]
   [  9.98815656e-01   1.18436955e-03]
   [  9.98876274e-01   1.12367352e-03]]

  [[  9.98818457e-01   1.18149468e-03]
   [  9.98878539e-01   1.12141087e-03]
   [  9.98821080e-01   1.17885729e-03]
   ..., 
   [  9.93742526e-01   6.25742460e-03]
   [  9.80195701e-01   1.98042449e-02]
   [  9.79122341e-01   2.08776407e-02]]]


 ..., 
 [[[  7.57497132e-01   2.42502853e-01]
   [  7.48212993e-01   2.51787037e-01]
   [  8.70336771e-01   1.29663169e-01]
   ..., 
   [  9.96443808e-01   3.55615164e-03]
   [  9.96416092e-01   3.58393462e-03]
   [  9.96086359e-01   3.91363213e-03]]

  [[  9.95885789e-01   4.11420781e-03]
   [  9.95459974e-01   4.54002433e-03]
   [  9.95547295e-01   4.45271051e-03]
   ..., 
   [  9.35005665e-01   6.49943873e-02]
   [  8.40834558e-01   1.59165382e-01]
   [  8.51789594e-01   1.48210347e-01]]

  [[  8.49776864e-01   1.50223121e-01]
   [  8.39764118e-01   1.60235837e-01]
   [  8.72451842e-01   1.27548158e-01]
   ..., 
   [  9.96550322e-01   3.44965747e-03]
   [  9.96366501e-01   3.63348448e-03]
   [  9.96296465e-01   3.70356860e-03]]

  ..., 
  [[  9.99644041e-01   3.55982746e-04]
   [  9.99723613e-01   2.76383653e-04]
   [  9.99630451e-01   3.69517569e-04]
   ..., 
   [  9.99675393e-01   3.24589288e-04]
   [  9.99144435e-01   8.55508843e-04]
   [  9.98967528e-01   1.03246607e-03]]

  [[  9.94546771e-01   5.45322988e-03]
   [  9.94272113e-01   5.72796259e-03]
   [  9.97512698e-01   2.48734024e-03]
   ..., 
   [  9.99617696e-01   3.82312515e-04]
   [  9.99557674e-01   4.42323653e-04]
   [  9.99603570e-01   3.96475050e-04]]

  [[  9.99481738e-01   5.18268906e-04]
   [  9.99537110e-01   4.62931552e-04]
   [  9.99461472e-01   5.38485532e-04]
   ..., 
   [  9.99328613e-01   6.71377289e-04]
   [  9.98275876e-01   1.72415213e-03]
   [  9.97874498e-01   2.12547369e-03]]]


 [[[  9.20827210e-01   7.91727901e-02]
   [  9.33188915e-01   6.68110102e-02]
   [  9.71795440e-01   2.82046329e-02]
   ..., 
   [  9.96606827e-01   3.39313038e-03]
   [  9.96385217e-01   3.61482147e-03]
   [  9.96263564e-01   3.73647991e-03]]

  [[  9.95583594e-01   4.41640895e-03]
   [  9.95404124e-01   4.59587201e-03]
   [  9.95067000e-01   4.93302336e-03]
   ..., 
   [  9.05439317e-01   9.45606604e-02]
   [  7.85491645e-01   2.14508325e-01]
   [  8.02410603e-01   1.97589427e-01]]

  [[  9.62641716e-01   3.73582914e-02]
   [  9.67866600e-01   3.21333893e-02]
   [  9.76312995e-01   2.36870460e-02]
   ..., 
   [  9.97066438e-01   2.93358462e-03]
   [  9.96745586e-01   3.25439498e-03]
   [  9.96832192e-01   3.16776126e-03]]

  ..., 
  [[  9.99855757e-01   1.44163045e-04]
   [  9.99885440e-01   1.14586481e-04]
   [  9.99854445e-01   1.45541082e-04]
   ..., 
   [  9.99746382e-01   2.53582024e-04]
   [  9.99306917e-01   6.93057955e-04]
   [  9.99144316e-01   8.55624967e-04]]

  [[  9.94522929e-01   5.47708431e-03]
   [  9.93264556e-01   6.73544127e-03]
   [  9.97787714e-01   2.21225875e-03]
   ..., 
   [  9.99815524e-01   1.84484044e-04]
   [  9.99799192e-01   2.00840950e-04]
   [  9.99809086e-01   1.90903796e-04]]

  [[  9.99786079e-01   2.13947307e-04]
   [  9.99797165e-01   2.02883835e-04]
   [  9.99783814e-01   2.16242828e-04]
   ..., 
   [  9.99541402e-01   4.58669558e-04]
   [  9.98791039e-01   1.20896543e-03]
   [  9.98467624e-01   1.53236242e-03]]]


 [[[  7.69851625e-01   2.30148420e-01]
   [  7.59410262e-01   2.40589783e-01]
   [  8.75618398e-01   1.24381587e-01]
   ..., 
   [  9.63360310e-01   3.66397500e-02]
   [  9.66665506e-01   3.33345197e-02]
   [  9.63190019e-01   3.68099138e-02]]

  [[  9.65887249e-01   3.41127403e-02]
   [  9.62487280e-01   3.75127457e-02]
   [  9.65431273e-01   3.45687196e-02]
   ..., 
   [  8.07469249e-01   1.92530751e-01]
   [  6.56512380e-01   3.43487650e-01]
   [  6.79840565e-01   3.20159405e-01]]

  [[  8.49437892e-01   1.50562122e-01]
   [  8.35940301e-01   1.64059758e-01]
   [  8.67160916e-01   1.32839099e-01]
   ..., 
   [  9.62924778e-01   3.70752402e-02]
   [  9.62459922e-01   3.75401527e-02]
   [  9.62468386e-01   3.75315808e-02]]

  ..., 
  [[  9.99243140e-01   7.56894355e-04]
   [  9.99403834e-01   5.96187892e-04]
   [  9.99241710e-01   7.58351351e-04]
   ..., 
   [  9.99926448e-01   7.35397116e-05]
   [  9.99739826e-01   2.60177534e-04]
   [  9.99671340e-01   3.28645663e-04]]

  [[  9.92325187e-01   7.67477043e-03]
   [  9.92034256e-01   7.96572492e-03]
   [  9.96433735e-01   3.56625835e-03]
   ..., 
   [  9.99178708e-01   8.21271213e-04]
   [  9.99058068e-01   9.41931445e-04]
   [  9.99170780e-01   8.29272263e-04]]

  [[  9.99067962e-01   9.32022580e-04]
   [  9.99182165e-01   8.17821361e-04]
   [  9.99066293e-01   9.33775911e-04]
   ..., 
   [  9.99852180e-01   1.47838757e-04]
   [  9.99509931e-01   4.90017235e-04]
   [  9.99366343e-01   6.33643649e-04]]]]

In [23]:
from skimage import data, color, io, img_as_float

def plot_results(output):
    gt = []
    df = pd.read_csv('../Pictures/test.csv')
    for i, item in df.iterrows():
        gt.append(np.clip(imread(item[1]), 0, 1))

    plt.figure(figsize=(6, n_test*2))
    for i, item in df.iterrows():
        plt.subplot(n_test, 4, 4 * i + 1)
        plt.title('Ground Truth')
        plt.axis('off')
        gt = imread(item[1])
        plt.imshow(np.clip(gt, 0, 1))
        
        plt.subplot(n_test, 4, 4 * i + 2)
        plt.title('Prediction')
        plt.axis('off')
        labeled = np.argmax(output[i], axis=-1)
        plt.imshow(labeled)

        plt.subplot(n_test, 4, 4 * i + 3)
        plt.title('Heat map')
        plt.axis('off')
        plt.imshow(output[i][:, :, 1])

        plt.subplot(n_test, 4, 4 * i + 4)
        plt.title('Comparison')
        plt.axis('off')
#         rgb = np.empty((img_h, img_w, 3))
#         rgb[:, :, 0] = labeled
        img = imread(item[0])
        img_hsv = color.rgb2hsv(img)
        img_hsv[..., 1] = labeled
#         img_hsv[..., 1] = gt
#         rgb[:, :, 1] = imread(item[0])
#         rgb[:, :, 2] = gt
        img = color.hsv2rgb(img_hsv)
        plt.imshow(img)

    plt.savefig('result.png')
    plt.show()

In [24]:
plot_results(output)



In [ ]: