In [3]:
#######################################################################################
# unet_conv_cifar10rgb_mc.py
# Convlutional Layer UNET with RGB Cifar10 dataset and Class with Keras Model approach
#######################################################################################
###########################
# AE 모델링
###########################
from keras import models, backend
from keras.layers import Input, Conv2D, MaxPooling2D, Dropout, \
UpSampling2D, BatchNormalization, Concatenate, Activation
# backend.set_image_data_format('channels_first')
class UNET(models.Model):
def __init__(self, org_shape, n_ch):
ic = 3 if backend.image_data_format() == 'channels_last' else 1
def conv(x, n_f, mp_flag=True):
x = MaxPooling2D((2, 2), padding='same')(x) if mp_flag else x
x = Conv2D(n_f, (3, 3), padding='same')(x)
x = BatchNormalization()(x)
x = Activation('tanh')(x)
x = Dropout(0.05)(x)
x = Conv2D(n_f, (3, 3), padding='same')(x)
x = BatchNormalization()(x)
x = Activation('tanh')(x)
return x
def deconv_unet(x, e, n_f):
x = UpSampling2D((2, 2))(x)
x = Concatenate(axis=ic)([x, e])
x = Conv2D(n_f, (3, 3), padding='same')(x)
x = BatchNormalization()(x)
x = Activation('tanh')(x)
x = Conv2D(n_f, (3, 3), padding='same')(x)
x = BatchNormalization()(x)
x = Activation('tanh')(x)
return x
# Input
original = Input(shape=org_shape)
# Encoding
c1 = conv(original, 16, mp_flag=False)
c2 = conv(c1, 32)
# Encoder
encoded = conv(c2, 64)
# Decoding
x = deconv_unet(encoded, c2, 32)
x = deconv_unet(x, c1, 16)
decoded = Conv2D(n_ch, (3, 3), activation='sigmoid',
padding='same')(x)
super().__init__(original, decoded)
self.compile(optimizer='adadelta', loss='binary_crossentropy')
###########################
# 데이타 불러오기
###########################
from keras import datasets, utils
class DATA():
def __init__(self, in_ch=None):
(x_train, y_train), (x_test, y_test) = datasets.cifar10.load_data()
if x_train.ndim == 4:
if backend.image_data_format() == 'channels_first':
n_ch, img_rows, img_cols = x_train.shape[1:]
else:
img_rows, img_cols, n_ch = x_train.shape[1:]
else:
img_rows, img_cols = x_train.shape[1:]
n_ch = 1
# in_ch can be 1 for changing BW to color image using UNet
in_ch = n_ch if in_ch is None else in_ch
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
def RGB2Gray(X, fmt):
if fmt == 'channels_first':
R = X[:, 0:1]
G = X[:, 1:2]
B = X[:, 2:3]
else: # "channels_last
R = X[..., 0:1]
G = X[..., 1:2]
B = X[..., 2:3]
return 0.299 * R + 0.587 * G + 0.114 * B
def RGB2RG(x_train_out, x_test_out, fmt):
if fmt == 'channels_first':
x_train_in = x_train_out[:, :2]
x_test_in = x_test_out[:, :2]
else:
x_train_in = x_train_out[..., :2]
x_test_in = x_test_out[..., :2]
return x_train_in, x_test_in
if backend.image_data_format() == 'channels_first':
x_train_out = x_train.reshape(x_train.shape[0], n_ch, img_rows, img_cols)
x_test_out = x_test.reshape(x_test.shape[0], n_ch, img_rows, img_cols)
input_shape = (in_ch, img_rows, img_cols)
else:
x_train_out = x_train.reshape(x_train.shape[0], img_rows, img_cols, n_ch)
x_test_out = x_test.reshape(x_test.shape[0], img_rows, img_cols, n_ch)
input_shape = (img_rows, img_cols, in_ch)
if in_ch == 1 and n_ch == 3:
x_train_in = RGB2Gray(x_train_out, backend.image_data_format())
x_test_in = RGB2Gray(x_test_out, backend.image_data_format())
elif in_ch == 2 and n_ch == 3:
# print(in_ch, n_ch)
x_train_in, x_test_in = RGB2RG(x_train_out, x_test_out, backend.image_data_format())
else:
x_train_in = x_train_out
x_test_in = x_test_out
self.input_shape = input_shape
self.x_train_in, self.x_train_out = x_train_in, x_train_out
self.x_test_in, self.x_test_out = x_test_in, x_test_out
self.n_ch = n_ch
self.in_ch = in_ch
###########################
# UNET 검증
###########################
from ann_mnist_cl import plot_loss
import matplotlib.pyplot as plt
###########################
# UNET 동작 확인
###########################
import numpy as np
def show_images(data, unet):
x_test_in = data.x_test_in
x_test_out = data.x_test_out
decoded_imgs = unet.predict(x_test_in)
if backend.image_data_format() == 'channels_first':
print(x_test_out.shape)
x_test_out = x_test_out.swapaxes(1, 3).swapaxes(1, 2)
print(x_test_out.shape)
decoded_imgs = decoded_imgs.swapaxes(1, 3).swapaxes(1, 2)
if data.in_ch == 1:
x_test_in = x_test_in[:, 0, ...]
elif data.in_ch == 2:
print(x_test_out.shape)
x_test_in_tmp = np.zeros_like(x_test_out)
x_test_in = x_test_in.swapaxes(1, 3).swapaxes(1, 2)
x_test_in_tmp[..., :2] = x_test_in
x_test_in = x_test_in_tmp
else:
x_test_in = x_test_in.swapaxes(1, 3).swapaxes(1, 2)
else:
# x_test_in = x_test_in[..., 0]
if data.in_ch == 1:
x_test_in = x_test_in[..., 0]
elif data.in_ch == 2:
x_test_in_tmp = np.zeros_like(x_test_out)
x_test_in_tmp[..., :2] = x_test_in
x_test_in = x_test_in_tmp
n = 10
plt.figure(figsize=(20, 6))
for i in range(n):
ax = plt.subplot(3, n, i + 1)
plt.imshow(x_test_in[i])
# plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
ax = plt.subplot(3, n, i + 1 + n)
plt.imshow(decoded_imgs[i])
# plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
ax = plt.subplot(3, n, i + 1 + n * 2)
plt.imshow(x_test_out[i])
# plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.show()
def main(in_ch, epochs=10, batch_size=128, fig=True):
###########################
# 학습 및 확인
###########################
data = DATA(in_ch=in_ch)
print(data.input_shape, data.x_train_in.shape)
unet = UNET(data.input_shape, data.n_ch)
history = unet.fit(data.x_train_in, data.x_train_out,
epochs=epochs,
batch_size=batch_size,
shuffle=True,
validation_split=0.2)
if fig:
plot_loss(history)
show_images(data, unet)
In [4]:
main(2)
In [6]:
if __name__ == '__main__':
import argparse
from distutils import util
parser = argparse.ArgumentParser(description='UNET for Cifar-10: RG to RGB')
parser.add_argument('--input_channels', type=int, default=2,
help='input channels (default: 2)')
parser.add_argument('--epochs', type=int, default=128,
help='training epochs (default: 128)')
parser.add_argument('--batch_size', type=int, default=20,
help='batch size (default: 20)')
parser.add_argument('--fig', type=lambda x: bool(util.strtobool(x)),
default=True, help='flag to show figures (default: True)')
args = parser.parse_args()
print("Aargs:", args)
main(args.input_channels, args.epochs, args.batch_size, args.fig)
In [ ]: