In [1]:
import sys
import os
sys.path.append(os.path.join(os.getcwd(), '../Code/'))
from LadickyDataset import *
In [11]:
import tensorflow as tf
from keras.models import Model, load_model
from keras.applications.vgg16 import VGG16
from keras.layers import Input , Flatten, Dense, Reshape, Lambda
from keras.layers.convolutional import Conv2D
In [3]:
from math import ceil
In [4]:
from PIL import Image
In [5]:
def show_image(npimg):
return Image.fromarray(npimg.astype(np.uint8))
In [6]:
def show_normals(npnorms):
return Image.fromarray(((npnorms+1)/2*255).astype(np.uint8))
In [17]:
file = '../Data/LadickyDataset.mat'
In [18]:
dataset = LadickyDataset(file,list(range(32)))
In [20]:
batchSize = 32
epochs = 1
totalBatches = epochs * ceil(dataset.size/batchSize)
In [22]:
def mean_dot_product(y_true, y_pred):
dot = tf.einsum('ijkl,ijkl->ijk', y_true, y_pred) # Dot product
n = tf.cast(tf.count_nonzero(dot),tf.float32)
mean = tf.reduce_sum(dot) / n
return -1 * mean
In [14]:
imgs, norms = dataset.get_batch(3)
In [28]:
norms.shape
Out[28]:
In [59]:
rnd = np.random.randn(3,240,320,3) # Weights should be initialized by random samples from normal distribution
rnd[2] = norms[2] # one third is valid data
In [60]:
sess = tf.Session()
tmp_true = tf.constant(norms)
tmp_pred = tf.nn.l2_normalize(tf.constant(rnd, dtype=tf.float32), 3)
mean_dot_product(tmp_true,tmp_pred).eval(session=sess)
Out[60]:
In [23]:
def vgg16_model():
# create model
input_tensor = Input(shape=(240, 320, 3))
base_model = VGG16(input_tensor=input_tensor,weights='imagenet', include_top=False)
x = base_model.output
x = Flatten()(x)
x = Dense(4096, activation='relu', name='fc1')(x)
x = Dense(80*60*3, activation='relu', name='fc2')(x)
x = Reshape((60,80,3))(x)
x = Lambda(lambda x: tf.image.resize_bilinear(x , [240,320]) )(x)
pred = Lambda(lambda x: tf.nn.l2_normalize(x, 3) )(x)
model = Model(inputs=base_model.input, outputs=pred)
# Compile model
model.compile(loss= mean_dot_product, optimizer='sgd')
return model
In [26]:
batchSize = 4
epochs = 1
totalBatches = epochs * ceil(dataset.size/batchSize)
model = vgg16_model()
In [29]:
for batch in range(totalBatches):
print('Batch:'+str(batch+1)+' of '+str(totalBatches))
imgs, norms = dataset.get_batch(batchSize)
model.train_on_batch(imgs, norms)
In [13]:
model.save('../Data/model.h5')
In [13]:
model = load_model('../Data/model.h5', custom_objects={'mean_dot_product': mean_dot_product, 'tf':tf})
In [15]:
imgs, norms = dataset.get_batch(1)
pred = model.predict(imgs, batch_size=1)
In [16]:
np.unique(np.linalg.norm(pred[0], axis=2))
Out[16]:
In [2]:
%%writefile ../Code/Experiments/Training.py
# Imports
import tensorflow as tf
from math import ceil
from PIL import Image
import time
# Utility functions
def show_image(npimg):
return Image.fromarray(npimg.astype(np.uint8))
def show_normals(npnorms):
return Image.fromarray(((npnorms+1)/2*255).astype(np.uint8))
# Loss function
def mean_dot_product(y_true, y_pred):
'''
Arguments shape: (batchSize, height, width, components)
'''
dot = tf.einsum('ijkl,ijkl->ijk', y_true, y_pred) # Dot product
n = tf.cast(tf.count_nonzero(dot),tf.float32)
mean = tf.reduce_sum(dot) / n
return -1 * mean
# Training
def Train(ID, Dataset, model, loss, optimizer, batchSize, epochs):
# Load data set
print('Loading the data set...')
dataset = Dataset()
# Build model
print('Building the model...')
model = model()
if loss == 'mean_dot_product':
loss = mean_dot_product
model.compile(optimizer, loss)
# Parameter
totalBatches = ceil(dataset.size/batchSize)
# Training Loop
print('Training '+ID+'...')
for epoch in range(epochs):
print('------------------------------------------')
start = time.perf_counter()
for batch in range(totalBatches):
print('*** Epoch: '+str(epoch+1)+'/'+str(epochs) +' *** Batch: '+str(batch+1)+'/'+str(totalBatches)+' ***')
imgs, norms = dataset.get_batch(batchSize)
loss = model.train_on_batch(imgs, norms)
print('Loss: ' + str(loss))
if( (epoch+1) % 5 == 0):
# Saving the model
print('Saving the model...')
model.save('Experiments/Outputs/'+ ID + '.h5')
# Estimating the remaining time
end = time.perf_counter()
rem = divmod((epochs-epoch-1)*(end-start),60)
print('Remaining time: '+str(round(rem[0]))+' minute(s) and '+ str(round(rem[1]))+ ' seconds')
In [ ]: