In [ ]:
# %% load required libraies
import tensorflow as tf
import math
import matplotlib.pyplot as plt
import numpy as np
import scipy.io
from skimage.transform import resize
# %%%%%%%%%%%%%%%%%%%%%%%input%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# list of images for training
# bool variable to check if the data is normalized or not
# test image for autoencoder
# dimesions of the network
# number of iterations
# batch size
# %%%%%%%%%%%%%%%%%%%%%%%output%%%%%%%%%%%%%%%%%%%%%%%%%%
# the decoded image using the autoencoder
# Some of the codes are taken from :
# https://github.com/pkmital/tensorflow_tutorials
def find_patches(training, batch_size=1):
length = len(training)
res = [ training[i*length // batch_size: (i+1)*length // batch_size]
for i in range(batch_size) ]
return res
def basic_auto_encoder(imgs_list, test_img, normalized=False, data_dimensions=[780, 512, 256, 64], n_epochs=50, batch_size=50):
global img_w
img_w = 50
global img_h
img_h = 50
if normalized:
train_imgs_list_normalized = imgs_list
test_img_normalized = test_img
else:
train_imgs_list_normalized = []
#prepare training images
for i, img in enumerate(imgs_list):
n_channels = len(img.shape)
if (n_channels == 3):
img_norm = img[:,:,0]
else:
img_norm = img
img_norm = img_norm / 255
train_imgs_list_normalized.append(img_norm)
# prepare testing images
n_channels = len(test_img.shape)
if (n_channels == 3):
test_img = test_img[:,:,0]
test_img_normalized = test_img / 255
# get the mean image for the training samples
global train_mean_img
train_mean_img = np.mean(train_imgs_list_normalized, axis=0)
global auto_encoder_architecture
auto_encoder_architecture = autoencoder(dimensions= data_dimensions)
# run the graph using the test image
learning_rate = 0.001
optimizer = tf.train.AdamOptimizer(learning_rate).minimize(auto_encoder_architecture['cost'])
global sess
sess = tf.Session()
sess.run(tf.initialize_all_variables())
# LEarning from training Samples
print('########## Learning Encoders and Decoders ####################')
for epoch_i in range (n_epochs):
np.random.shuffle(train_imgs_list_normalized)
idx = np.random.choice(len(train_imgs_list_normalized), batch_size)
batch_xs = find_patches(train_imgs_list_normalized, batch_size)
for img_i,batch in enumerate(batch_xs):
pre_processed_batch = np.array([img - train_mean_img for img in batch])
current_size = len(pre_processed_batch)
pre_processed_batch = np.reshape(pre_processed_batch, (current_size,img_w * img_h ))
sess.run(optimizer, feed_dict={
auto_encoder_architecture['x']: pre_processed_batch
})
current_cost = sess.run(auto_encoder_architecture['cost'], feed_dict={
auto_encoder_architecture['x'] : pre_processed_batch
})
print ('Iteration num: ', epoch_i, ' cost: ', current_cost)
# print('########## Decoding Test Image ###################')
test_img_normalized_mean = test_img_normalized-train_mean_img
test_img_normalized_mean = np.reshape(test_img_normalized, (1,img_w * img_h))
test_reconstructed = sess.run(auto_encoder_architecture['y'], feed_dict={
auto_encoder_architecture['x']: test_img_normalized_mean
})
test_reconstructed = np.reshape(test_reconstructed, (img_w, img_h))
reconstructed_img = test_reconstructed + train_mean_img
#plots
fig, axs = plt.subplots(3, 1)
axs[0].imshow(test_img, cmap='gray', interpolation='nearest')
axs[1].imshow(train_mean_img, cmap='gray', interpolation='nearest')
axs[2].imshow(reconstructed_img, cmap='gray', interpolation='nearest')
plt.draw()
plt.show()
print('Done!')
return reconstructed_img
def reconstruct_img(test_img_normalized):
test_img_normalized_mean = test_img_normalized-train_mean_img
test_img_normalized_mean1 = 1 - test_img_normalized_mean
test_img_normalized_mean = np.reshape(test_img_normalized_mean1, (1,img_w * img_h))
test_reconstructed = sess.run(auto_encoder_architecture['y'], feed_dict={
auto_encoder_architecture['x']: test_img_normalized_mean
})
test_reconstructed = np.reshape(test_reconstructed, (img_w, img_h))
reconstructed_img = test_reconstructed + train_mean_img
#reconstructed_img = test_reconstructed
#plots
fig, axs = plt.subplots(3, 1)
axs[0].imshow(test_img_normalized, cmap='gray', interpolation='nearest')
axs[1].imshow(train_mean_img, cmap='gray', interpolation='nearest')
axs[2].imshow(reconstructed_img, cmap='gray', interpolation='nearest')
plt.draw()
plt.show()
print('Done!')
return reconstructed_img
# basic autoencoder
def autoencoder (dimensions=[780, 512, 256, 64]):
# define the netowrk architecture
x = tf.placeholder(tf.float32, [None, dimensions[0]], name='x_input')
current_input = x
# build the encoder using the train_nromalized datasets
encoded_dataset = []
for layer_i, n_output in enumerate(dimensions[1:]):
n_input = int(current_input.get_shape()[1])
W = tf.Variable(
tf.random_uniform([n_input, n_output],
-1.0 / math.sqrt(n_input),
1.0 / math.sqrt(n_input)))
b = tf.Variable(tf.zeros([n_output]))
encoded_dataset.append(W)
output = tf.nn.tanh(tf.matmul(current_input, W) + b)
current_input = output
# latent representation
z = current_input
#build the decoder
encoded_dataset.reverse()
dimensions.reverse()
for layer_i, n_output in enumerate(dimensions[1:]):
W = tf.transpose(encoded_dataset[layer_i])
b = tf.Variable(tf.zeros([n_output]))
output = tf.nn.tanh(tf.matmul(current_input, W) + b)
current_input = output
# Decoded image
y = current_input
# cost function
cost = tf.reduce_sum(tf.square(y - x))
# return output
return {'x': x, 'z': z, 'y': y, 'cost': cost}
from skimage.transform import resize
import skimage.io
import matplotlib.pyplot as plt
data = skimage.io.imread_collection('path/your/training/file/*.png')
test_sample = skimage.io.imread('path/to/your/test/image')
train_samples = data[0:212]
train_data = []
for i, img in enumerate(datack):
train_data.append(img)
plt.imshow(test_sample, cmap='gray')
plt.show()
train_samples_resized = []
for i, img in enumerate(train_data):
img = resize(img, (50,50))
train_samples_resized.append(img)
test_sample = resize(test_sample, (50,50))
dimensions = [2500, 2300, 2000, 1700, 1400, 1100, 800, 500, 300]
res = basic_auto_encoder(train_samples_resized, test_img = test_sample, normalized = False, data_dimensions = dimensions,n_epochs= 10, batch_size=70 )