In [0]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Neural Voxel Renderer

This notebook illustrates how to train Neural Voxel Renderer (CVPR2020) in Tensorflow 2.

View source on GitHub

Setup and imports

If Tensorflow Graphics is not installed on your system, the following cell can install the Tensorflow Graphics package for you.


In [0]:
!pip install tensorflow_graphics

In [0]:
import numpy as np
import tensorflow as tf
from tensorflow_graphics.projects.neural_voxel_renderer import helpers
from tensorflow_graphics.projects.neural_voxel_renderer import models

import datetime
import matplotlib.pyplot as plt
import os
import re
import time

VOXEL_SIZE = (128, 128, 128, 4)

Dataset loading

We store our data in TFRecords with custom protobuf messages. Each training element contains the input voxels, the voxel rendering, the light position and the target image. The data is preprocessed (eg the colored voxels have been rendered and placed accordingly). See this colab on how to generate the training/testing TFRecords.


In [0]:
# Functions for dataset generation from a set of TFRecords.
decode_proto = tf.compat.v1.io.decode_proto


def tf_image_normalize(image):
  """Normalizes the image [-1, 1]."""
  return (2 * tf.cast(image, tf.float32) / 255.) - 1


def neural_voxel_plus_proto_get(element):
  """Extracts the contents from a VoxelSample proto to tensors."""
  _, values = decode_proto(element,
                           "giotto_blender.NeuralVoxelPlusSample",
                           ["name",
                            "voxel_data",
                            "rerendering_data",
                            "image_data",
                            "light_position"],
                           [tf.string,
                            tf.string,
                            tf.string,
                            tf.string,
                            tf.float32])
  filename = tf.squeeze(values[0])
  voxel_data = tf.squeeze(values[1])
  rerendering_data = tf.squeeze(values[2])
  image_data = tf.squeeze(values[3])
  light_position = values[4]
  voxels = tf.io.decode_raw(voxel_data, out_type=tf.uint8)
  voxels = tf.cast(tf.reshape(voxels, VOXEL_SIZE), tf.float32) / 255.0
  rerendering = tf.cast(tf.image.decode_image(rerendering_data, channels=3),
                        tf.float32)
  rerendering = tf_image_normalize(rerendering)
  image = tf.cast(tf.image.decode_image(image_data, channels=3), tf.float32)
  image = tf_image_normalize(image)
  return filename, voxels, rerendering, image, light_position


def _expand_tfrecords_pattern(tfr_pattern):
  """Helper function to expand a tfrecord patter"""
  def format_shards(m):
    return '{}-?????-of-{:0>5}{}'.format(*m.groups())
  tfr_pattern = re.sub(r'^([^@]+)@(\d+)([^@]+)$', format_shards, tfr_pattern)
  return tfr_pattern


def tfrecords_to_dataset(tfrecords_pattern,
                         mapping_func,
                         batch_size,
                         buffer_size=5000):
  """Generates a TF Dataset from a rio pattern."""
  with tf.name_scope('Input/'):
    tfrecords_pattern = _expand_tfrecords_pattern(tfrecords_pattern)
    dataset = tf.data.Dataset.list_files(tfrecords_pattern, shuffle=True)
    dataset = dataset.interleave(tf.data.TFRecordDataset, cycle_length=16)
    dataset = dataset.shuffle(buffer_size=buffer_size)
    dataset = dataset.map(mapping_func)
    dataset = dataset.batch(batch_size)
    return dataset

In [0]:
# Download the example data, licensed under the Apache License, Version 2.0
!rm -rf /tmp/tfrecords_dir/
!mkdir /tmp/tfrecords_dir/
!wget -P /tmp/tfrecords_dir/ https://storage.googleapis.com/tensorflow-graphics/notebooks/neural_voxel_renderer/train-00012-of-00100.tfrecord

In [0]:
tfrecords_dir = '/tmp/tfrecords_dir/'
tfrecords_pattern = os.path.join(tfrecords_dir, 'train@100.tfrecord')

batch_size = 5
mapping_function = neural_voxel_plus_proto_get
dataset = tfrecords_to_dataset(tfrecords_pattern, mapping_function, batch_size)

In [0]:
# Visualize some examples
_, ax = plt.subplots(1, 4, figsize=(10, 10))
i = 0
for a in dataset.take(4):
  (filename,
   voxels,
   vox_render,
   target,
   light_position) = a
  ax[i].imshow(target[0]*0.5+0.5)
  ax[i].axis('off')
  i += 1
plt.show()

Train the model

NVR+ is trained with Adam optimizer and L1 and perceptual VGG loss.


In [0]:
# ==============================================================================
# Defining model and optimizer
LEARNING_RATE = 0.002

nvr_plus_model = models.neural_voxel_renderer_plus_tf2()
optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE)

# Saving and logging directories
checkpoint_dir = '/tmp/checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=nvr_plus_model)
log_dir="/tmp/logs/"
summary_writer = tf.summary.create_file_writer(
  log_dir + "fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))

# ==============================================================================
# VGG loss
VGG_LOSS_LAYER_NAMES = ['block1_conv1', 'block2_conv1']
VGG_LOSS_LAYER_WEIGHTS = [1.0, 0.1]
VGG_LOSS_WEIGHT = 0.001

def vgg_layers(layer_names):
  """ Creates a vgg model that returns a list of intermediate output values."""
  # Load our model. Load pretrained VGG, trained on imagenet data
  vgg = tf.keras.applications.VGG19(include_top=False, weights='imagenet')
  vgg.trainable = False
  outputs = [vgg.get_layer(name).output for name in layer_names]
  model = tf.keras.Model([vgg.input], outputs)
  return model


vgg_extractor = vgg_layers(VGG_LOSS_LAYER_NAMES)

# ==============================================================================
# Total loss
def network_loss(output, target):
  # L1 loss
  l1_loss = tf.reduce_mean(tf.abs(target - output))
  # VGG loss
  vgg_output = vgg_extractor((output*0.5+0.5)*255)
  vgg_target = vgg_extractor((target*0.5+0.5)*255)
  vgg_loss = 0
  for l in range(len(VGG_LOSS_LAYER_WEIGHTS)):
    layer_loss = tf.reduce_mean(tf.square(vgg_target[l] - vgg_output[l]))
    vgg_loss += VGG_LOSS_LAYER_WEIGHTS[l]*layer_loss
  # Final loss
  total_loss = l1_loss + VGG_LOSS_WEIGHT*vgg_loss
  return l1_loss, vgg_loss, total_loss

In [0]:
@tf.function
def train_step(input_voxels, input_rendering, input_light, target, epoch):
  with tf.GradientTape() as tape:
    network_output = nvr_plus_model([input_voxels, 
                                     input_rendering, 
                                     input_light],
                                     training=True)
    l1_loss, vgg_loss, total_loss = network_loss(network_output, target)
  network_gradients = tape.gradient(total_loss,
                                    nvr_plus_model.trainable_variables)
  optimizer.apply_gradients(zip(network_gradients,
                                nvr_plus_model.trainable_variables))

  with summary_writer.as_default():
    tf.summary.scalar('total_loss', total_loss, step=epoch)
    tf.summary.scalar('l1_loss', l1_loss, step=epoch)
    tf.summary.scalar('vgg_loss', vgg_loss, step=epoch)
    tf.summary.image('Vox_rendering', 
                     input_rendering*0.5+0.5, 
                     step=epoch, 
                     max_outputs=4)
    tf.summary.image('Prediction', 
                     network_output*0.5+0.5, 
                     step=epoch, 
                     max_outputs=4)

In [0]:
def training_loop(train_ds, epochs):
  for epoch in range(epochs):
    start = time.time()

    # Train
    for n, (_, voxels, vox_rendering, target, light) in train_ds.enumerate():
      print('.', end='')
      if (n+1) % 100 == 0:
        print()
      train_step(voxels, vox_rendering, light, target, epoch)
    print()

    # saving (checkpoint) the model every 20 epochs
    if (epoch + 1) % 20 == 0:
      checkpoint.save(file_prefix = checkpoint_prefix)

    print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
                                                        time.time()-start))
  checkpoint.save(file_prefix = checkpoint_prefix)

In [0]:
NUMBER_OF_EPOCHS = 100
training_loop(dataset, NUMBER_OF_EPOCHS)