In [0]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
This notebook illustrates how to train Neural Voxel Renderer (CVPR2020) in Tensorflow 2.
|
In [0]:
!pip install tensorflow_graphics
In [0]:
import numpy as np
import tensorflow as tf
from tensorflow_graphics.projects.neural_voxel_renderer import helpers
from tensorflow_graphics.projects.neural_voxel_renderer import models
import datetime
import matplotlib.pyplot as plt
import os
import re
import time
VOXEL_SIZE = (128, 128, 128, 4)
We store our data in TFRecords with custom protobuf messages. Each training element contains the input voxels, the voxel rendering, the light position and the target image. The data is preprocessed (eg the colored voxels have been rendered and placed accordingly). See this colab on how to generate the training/testing TFRecords.
In [0]:
# Functions for dataset generation from a set of TFRecords.
decode_proto = tf.compat.v1.io.decode_proto
def tf_image_normalize(image):
"""Normalizes the image [-1, 1]."""
return (2 * tf.cast(image, tf.float32) / 255.) - 1
def neural_voxel_plus_proto_get(element):
"""Extracts the contents from a VoxelSample proto to tensors."""
_, values = decode_proto(element,
"giotto_blender.NeuralVoxelPlusSample",
["name",
"voxel_data",
"rerendering_data",
"image_data",
"light_position"],
[tf.string,
tf.string,
tf.string,
tf.string,
tf.float32])
filename = tf.squeeze(values[0])
voxel_data = tf.squeeze(values[1])
rerendering_data = tf.squeeze(values[2])
image_data = tf.squeeze(values[3])
light_position = values[4]
voxels = tf.io.decode_raw(voxel_data, out_type=tf.uint8)
voxels = tf.cast(tf.reshape(voxels, VOXEL_SIZE), tf.float32) / 255.0
rerendering = tf.cast(tf.image.decode_image(rerendering_data, channels=3),
tf.float32)
rerendering = tf_image_normalize(rerendering)
image = tf.cast(tf.image.decode_image(image_data, channels=3), tf.float32)
image = tf_image_normalize(image)
return filename, voxels, rerendering, image, light_position
def _expand_tfrecords_pattern(tfr_pattern):
"""Helper function to expand a tfrecord patter"""
def format_shards(m):
return '{}-?????-of-{:0>5}{}'.format(*m.groups())
tfr_pattern = re.sub(r'^([^@]+)@(\d+)([^@]+)$', format_shards, tfr_pattern)
return tfr_pattern
def tfrecords_to_dataset(tfrecords_pattern,
mapping_func,
batch_size,
buffer_size=5000):
"""Generates a TF Dataset from a rio pattern."""
with tf.name_scope('Input/'):
tfrecords_pattern = _expand_tfrecords_pattern(tfrecords_pattern)
dataset = tf.data.Dataset.list_files(tfrecords_pattern, shuffle=True)
dataset = dataset.interleave(tf.data.TFRecordDataset, cycle_length=16)
dataset = dataset.shuffle(buffer_size=buffer_size)
dataset = dataset.map(mapping_func)
dataset = dataset.batch(batch_size)
return dataset
In [0]:
# Download the example data, licensed under the Apache License, Version 2.0
!rm -rf /tmp/tfrecords_dir/
!mkdir /tmp/tfrecords_dir/
!wget -P /tmp/tfrecords_dir/ https://storage.googleapis.com/tensorflow-graphics/notebooks/neural_voxel_renderer/train-00012-of-00100.tfrecord
In [0]:
tfrecords_dir = '/tmp/tfrecords_dir/'
tfrecords_pattern = os.path.join(tfrecords_dir, 'train@100.tfrecord')
batch_size = 5
mapping_function = neural_voxel_plus_proto_get
dataset = tfrecords_to_dataset(tfrecords_pattern, mapping_function, batch_size)
In [0]:
# Visualize some examples
_, ax = plt.subplots(1, 4, figsize=(10, 10))
i = 0
for a in dataset.take(4):
(filename,
voxels,
vox_render,
target,
light_position) = a
ax[i].imshow(target[0]*0.5+0.5)
ax[i].axis('off')
i += 1
plt.show()
In [0]:
# ==============================================================================
# Defining model and optimizer
LEARNING_RATE = 0.002
nvr_plus_model = models.neural_voxel_renderer_plus_tf2()
optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE)
# Saving and logging directories
checkpoint_dir = '/tmp/checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=nvr_plus_model)
log_dir="/tmp/logs/"
summary_writer = tf.summary.create_file_writer(
log_dir + "fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
# ==============================================================================
# VGG loss
VGG_LOSS_LAYER_NAMES = ['block1_conv1', 'block2_conv1']
VGG_LOSS_LAYER_WEIGHTS = [1.0, 0.1]
VGG_LOSS_WEIGHT = 0.001
def vgg_layers(layer_names):
""" Creates a vgg model that returns a list of intermediate output values."""
# Load our model. Load pretrained VGG, trained on imagenet data
vgg = tf.keras.applications.VGG19(include_top=False, weights='imagenet')
vgg.trainable = False
outputs = [vgg.get_layer(name).output for name in layer_names]
model = tf.keras.Model([vgg.input], outputs)
return model
vgg_extractor = vgg_layers(VGG_LOSS_LAYER_NAMES)
# ==============================================================================
# Total loss
def network_loss(output, target):
# L1 loss
l1_loss = tf.reduce_mean(tf.abs(target - output))
# VGG loss
vgg_output = vgg_extractor((output*0.5+0.5)*255)
vgg_target = vgg_extractor((target*0.5+0.5)*255)
vgg_loss = 0
for l in range(len(VGG_LOSS_LAYER_WEIGHTS)):
layer_loss = tf.reduce_mean(tf.square(vgg_target[l] - vgg_output[l]))
vgg_loss += VGG_LOSS_LAYER_WEIGHTS[l]*layer_loss
# Final loss
total_loss = l1_loss + VGG_LOSS_WEIGHT*vgg_loss
return l1_loss, vgg_loss, total_loss
In [0]:
@tf.function
def train_step(input_voxels, input_rendering, input_light, target, epoch):
with tf.GradientTape() as tape:
network_output = nvr_plus_model([input_voxels,
input_rendering,
input_light],
training=True)
l1_loss, vgg_loss, total_loss = network_loss(network_output, target)
network_gradients = tape.gradient(total_loss,
nvr_plus_model.trainable_variables)
optimizer.apply_gradients(zip(network_gradients,
nvr_plus_model.trainable_variables))
with summary_writer.as_default():
tf.summary.scalar('total_loss', total_loss, step=epoch)
tf.summary.scalar('l1_loss', l1_loss, step=epoch)
tf.summary.scalar('vgg_loss', vgg_loss, step=epoch)
tf.summary.image('Vox_rendering',
input_rendering*0.5+0.5,
step=epoch,
max_outputs=4)
tf.summary.image('Prediction',
network_output*0.5+0.5,
step=epoch,
max_outputs=4)
In [0]:
def training_loop(train_ds, epochs):
for epoch in range(epochs):
start = time.time()
# Train
for n, (_, voxels, vox_rendering, target, light) in train_ds.enumerate():
print('.', end='')
if (n+1) % 100 == 0:
print()
train_step(voxels, vox_rendering, light, target, epoch)
print()
# saving (checkpoint) the model every 20 epochs
if (epoch + 1) % 20 == 0:
checkpoint.save(file_prefix = checkpoint_prefix)
print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
time.time()-start))
checkpoint.save(file_prefix = checkpoint_prefix)
In [0]:
NUMBER_OF_EPOCHS = 100
training_loop(dataset, NUMBER_OF_EPOCHS)