In [1]:
%matplotlib inline
import numpy as np
import pandas as pd
import theano
import sys
from skimage.io import imread
from matplotlib import pyplot as plt
import os
os.environ['KERAS_BACKEND'] = 'theano'
os.environ['THEANO_FLAGS'] = 'mode=FAST_RUN, device=gpu0, floatX=float32, optimizer=fast_compile'
from keras import models
from keras.optimizers import SGD
SegNet to start with: https://github.com/imlab-uiip/keras-segnet
In [2]:
path = 'Pictures/'
img_w = 256
img_h = 256
n_labels = 2
n_train = 493
n_test = 49
WITH_AUGMENTATION = 0
TRAIN_ENABLED = 1
In [3]:
from keras import models
from keras.layers.core import Activation, Reshape, Permute
from keras.layers.convolutional import Conv2D, MaxPooling2D, UpSampling2D
from keras.layers.normalization import BatchNormalization
from keras import backend as K
import json
img_w = 256
img_h = 256
n_labels = 2
if K.image_data_format() == 'channels_first':
input_shape = (3, img_w, img_h)
else:
input_shape = (img_w, img_h, 3)
kernel = 3
encoding_layers = [
Conv2D(64, kernel, padding='same', input_shape=input_shape),
BatchNormalization(),
Activation('relu'),
Conv2D(64, kernel, padding='same'),
BatchNormalization(),
Activation('relu'),
MaxPooling2D(),
Conv2D(128, kernel, padding='same'),
BatchNormalization(),
Activation('relu'),
Conv2D(128, kernel, padding='same'),
BatchNormalization(),
Activation('relu'),
MaxPooling2D(),
Conv2D(256, kernel, padding='same'),
BatchNormalization(),
Activation('relu'),
Conv2D(256, kernel, padding='same'),
BatchNormalization(),
Activation('relu'),
Conv2D(256, kernel, padding='same'),
BatchNormalization(),
Activation('relu'),
MaxPooling2D(),
Conv2D(512, kernel, padding='same'),
BatchNormalization(),
Activation('relu'),
Conv2D(512, kernel, padding='same'),
BatchNormalization(),
Activation('relu'),
Conv2D(512, kernel, padding='same'),
BatchNormalization(),
Activation('relu'),
MaxPooling2D(),
Conv2D(512, kernel, padding='same'),
BatchNormalization(),
Activation('relu'),
Conv2D(512, kernel, padding='same'),
BatchNormalization(),
Activation('relu'),
Conv2D(512, kernel, padding='same'),
BatchNormalization(),
Activation('relu'),
MaxPooling2D(),
]
autoencoder = models.Sequential()
autoencoder.encoding_layers = encoding_layers
for l in autoencoder.encoding_layers:
autoencoder.add(l)
decoding_layers = [
UpSampling2D(),
Conv2D(512, kernel, padding='same'),
BatchNormalization(),
Activation('relu'),
Conv2D(512, kernel, padding='same'),
BatchNormalization(),
Activation('relu'),
Conv2D(512, kernel, padding='same'),
BatchNormalization(),
Activation('relu'),
UpSampling2D(),
Conv2D(512, kernel, padding='same'),
BatchNormalization(),
Activation('relu'),
Conv2D(512, kernel, padding='same'),
BatchNormalization(),
Activation('relu'),
Conv2D(256, kernel, padding='same'),
BatchNormalization(),
Activation('relu'),
UpSampling2D(),
Conv2D(256, kernel, padding='same'),
BatchNormalization(),
Activation('relu'),
Conv2D(256, kernel, padding='same'),
BatchNormalization(),
Activation('relu'),
Conv2D(128, kernel, padding='same'),
BatchNormalization(),
Activation('relu'),
UpSampling2D(),
Conv2D(128, kernel, padding='same'),
BatchNormalization(),
Activation('relu'),
Conv2D(64, kernel, padding='same'),
BatchNormalization(),
Activation('relu'),
UpSampling2D(),
Conv2D(64, kernel, padding='same'),
BatchNormalization(),
Activation('relu'),
Conv2D(n_labels, 1, padding='valid'),
BatchNormalization(),
]
autoencoder.decoding_layers = decoding_layers
for l in autoencoder.decoding_layers:
autoencoder.add(l)
In [4]:
print(autoencoder.summary())
In [5]:
autoencoder.add(Reshape((n_labels, img_h * img_w)))
autoencoder.add(Permute((2, 1)))
autoencoder.add(Activation('softmax'))
with open('model_5l.json', 'w') as outfile:
outfile.write(json.dumps(json.loads(autoencoder.to_json()), indent=2))
print('Compiled: OK')
In [6]:
from PIL import Image
def prep_data(path, mode):
assert mode in {'test', 'train'}, \
'mode should be either \'test\' or \'train\''
data = []
label = []
df = pd.read_csv(path + mode + '.csv')
n = n_train if mode == 'train' else n_test
count = 0
for i, item in df.iterrows():
if i >= n:
print("broken")
break
img, gt = [imread(item[0])], np.clip(imread(item[1]), 0, 1)
data.append(img)
label.append(label_map(gt))
sys.stdout.write('\r')
sys.stdout.write(mode + ": [%-20s] %d%%" % ('=' * int(20. * (i + 1) / n - 1) + '>',
int(100. * (i + 1) / n)))
sys.stdout.flush()
count = count + 1
sys.stdout.write('\r')
sys.stdout.flush()
data = np.array(data)
data = data.reshape((data.shape[0] * data.shape[1], data.shape[2], data.shape[3], data.shape[4]))
print("There are counts: ", str(count))
label = np.array(label).reshape((n, img_h * img_w, n_labels))
print(label.shape)
print(mode + ': OK')
print('\tshapes: {}, {}'.format(data.shape, label.shape))
print('\ttypes: {}, {}'.format(data.dtype, label.dtype))
print('\tmemory: {}, {} MB'.format(data.nbytes / 1048576, label.nbytes / 1048576))
return data, label
In [7]:
def label_map(labels):
label_map = np.zeros([img_h, img_w, n_labels])
for r in range(img_h):
for c in range(img_w):
label_map[r, c, labels[r][c]] = 1
return label_map
In [8]:
train_data, train_label = prep_data('../Pictures/', 'train')
In [9]:
test_data, test_label = prep_data('../Pictures/', 'test')
In [14]:
if 0:
datagen = ImageDataGenerator(
rotation_range=180,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
fill_mode='nearest')
img = load_img('data/train/cats/cat.0.jpg') # this is a PIL image
x = img_to_array(img) # this is a Numpy array with shape (3, 150, 150)
x = x.reshape((1,) + x.shape) # this is a Numpy array with shape (1, 3, 150, 150)
# the .flow() command below generates batches of randomly transformed images
# and saves the results to the `preview/` directory
i = 0
for batch in datagen.flow(x, batch_size=1,
save_to_dir='preview', save_prefix='cat', save_format='jpeg'):
i += 1
if i > 20:
break # otherwise the generator would loop indefinitely
In [15]:
from keras.callbacks import ModelCheckpoint
if WITH_AUGMENTATION:
# augment the input images and their masks and save them into the tmp_training folder
print("TODO")
else:
nb_epoch = 50
batch_size = 18
autoencoder.compile(optimizer=SGD(lr = 0.1, decay = 1e-6, momentum = 0.9, nesterov = False),
loss='categorical_crossentropy',
metrics=['accuracy'])
print(train_data.shape)
print(train_label.shape)
checkpoint = ModelCheckpoint('.', monitor='val_acc', verbose=1, save_best_only=True, mode='max')
In [16]:
if TRAIN_ENABLED:
history = autoencoder.fit(train_data, train_label, batch_size=batch_size, epochs=nb_epoch, verbose=1, callbacks=[checkpoint])
In [17]:
autoencoder.save_weights('model_5l_weight_ep50_384_auto.hdf5')
In [18]:
autoencoder.load_weights('model_5l_weight_ep50_384_auto.hdf5')
from keras.utils.vis_utils import plot_model
plot_model(autoencoder, to_file='model_384_auto.png', show_shapes=True)
In [19]:
autoencoder.compile(optimizer=SGD(lr = 0.1, decay = 1e-6, momentum = 0.9, nesterov = False),
loss='categorical_crossentropy',
metrics=['accuracy'])
score = autoencoder.evaluate(test_data, test_label, verbose=0)
print('Test score:', score[0])
print('Test accuracy:', score[1])
In [20]:
output = autoencoder.predict_proba(test_data, verbose=0)
output = output.reshape((output.shape[0], img_h, img_w, n_labels))
In [21]:
print(output.shape)
In [22]:
print(output)
In [23]:
from skimage import data, color, io, img_as_float
def plot_results(output):
gt = []
df = pd.read_csv('../Pictures/test.csv')
for i, item in df.iterrows():
gt.append(np.clip(imread(item[1]), 0, 1))
plt.figure(figsize=(6, n_test*2))
for i, item in df.iterrows():
plt.subplot(n_test, 4, 4 * i + 1)
plt.title('Ground Truth')
plt.axis('off')
gt = imread(item[1])
plt.imshow(np.clip(gt, 0, 1))
plt.subplot(n_test, 4, 4 * i + 2)
plt.title('Prediction')
plt.axis('off')
labeled = np.argmax(output[i], axis=-1)
plt.imshow(labeled)
plt.subplot(n_test, 4, 4 * i + 3)
plt.title('Heat map')
plt.axis('off')
plt.imshow(output[i][:, :, 1])
plt.subplot(n_test, 4, 4 * i + 4)
plt.title('Comparison')
plt.axis('off')
# rgb = np.empty((img_h, img_w, 3))
# rgb[:, :, 0] = labeled
img = imread(item[0])
img_hsv = color.rgb2hsv(img)
img_hsv[..., 1] = labeled
# img_hsv[..., 1] = gt
# rgb[:, :, 1] = imread(item[0])
# rgb[:, :, 2] = gt
img = color.hsv2rgb(img_hsv)
plt.imshow(img)
plt.savefig('result.png')
plt.show()
In [24]:
plot_results(output)
In [ ]: