In [0]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Spherical Harmonics Optimization

View source on GitHub

This Colab covers an advanced topic and hence focuses on providing a toy example to form a high level understanding of how to estimate environment lighting using Spherical Harmonics rather than providing step by step details. We refer the interested reader to to get a high level understanding of Spherical Harmonics.

Given an image of a known object (sphere) with a known reflectance function, this Colab illustrates how to perform optimization of spherical harmonics to recover the lighting environment.

Setup & Imports

If Tensorflow Graphics is not installed on your system, the following cell can install the Tensorflow Graphics package for you.


In [0]:
!pip install tensorflow_graphics

Now that Tensorflow Graphics is installed, let's import everything needed to run the demo contained in this notebook.


In [0]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

from tensorflow_graphics.rendering.camera import orthographic
from tensorflow_graphics.geometry.representation import grid
from tensorflow_graphics.geometry.representation import ray
from tensorflow_graphics.geometry.representation import vector
from tensorflow_graphics.math import spherical_harmonics
from tensorflow_graphics.math import math_helpers as tf_math

tf.compat.v1.enable_v2_behavior()

In [0]:
def compute_intersection_normal_sphere(width, height, radius,
                                       center, dtype):
  """Estimates sphere normal and depth for each pixel in the image."""
  # Generates a 2d grid storing pixel coordinates.
  pixel_grid_start = np.array((0.5, 0.5), dtype=dtype)
  pixel_grid_end = np.array((width - 0.5, height - 0.5), dtype=dtype)
  pixel_nb = np.array((width, height))
  pixels = grid.generate(pixel_grid_start, pixel_grid_end, pixel_nb)
  # Computes the ray direction of each pixel.
  pixel_ray = tf.math.l2_normalize(orthographic.ray(pixels), axis=-1)
  # Defines the position of pixels in world coordinates.
  zero_depth = np.zeros([width, height, 1])
  pixels_3d = orthographic.unproject(pixels, zero_depth)
  # Computes intersections with the sphere and surface normals for each ray.
  intersection, normal = ray.intersection_ray_sphere(
      center, radius, pixel_ray, pixels_3d)
  # Extracts data about the closest intersection.
  intersection = intersection[0, ...]
  normal = normal[0, ...]
  # Replaces NaNs with zeros.
  zeros = tf.zeros_like(pixels_3d)
  intersection = tf.where(tf.math.is_nan(intersection), zeros, intersection)
  normal = tf.where(tf.math.is_nan(normal), zeros, normal)
  return intersection, normal

Spherical Harmonics optimization


In [0]:
light_image_width = 100
light_image_height = 100
dtype = np.float64

############################################################################
# Builds the pixels grid and computes corresponding spherical coordinates. #
############################################################################
pixel_grid_start = np.array((0, 0), dtype=dtype)
pixel_grid_end = np.array((light_image_width - 1, light_image_height - 1),
                          dtype=dtype)
pixel_nb = np.array((light_image_width, light_image_height))
pixels = grid.generate(pixel_grid_start, pixel_grid_end, pixel_nb)
normalized_pixels = pixels / (light_image_width - 1, light_image_height - 1)
spherical_coordinates = tf_math.square_to_spherical_coordinates(
    normalized_pixels)
theta = spherical_coordinates[:, :, 1]
phi = spherical_coordinates[:, :, 2]

################################################################################################
# Builds the Spherical Harmonics and sets coefficients for the light and reflectance functions. #
################################################################################################
max_band = 2
l, m = spherical_harmonics.generate_l_m_permutations(max_band)
l_broadcasted = tf.broadcast_to(l, [light_image_width, light_image_height] +
                                l.shape.as_list())
m_broadcasted = tf.broadcast_to(m, [light_image_width, light_image_height] +
                                l.shape.as_list())
theta = tf.expand_dims(theta, axis=-1)
theta_broadcasted = tf.broadcast_to(theta,
                                    [light_image_width, light_image_height, 1])
phi = tf.expand_dims(phi, axis=-1)
phi_broadcasted = tf.broadcast_to(phi,
                                  [light_image_width, light_image_height, 1])
sh_coefficients = spherical_harmonics.evaluate_spherical_harmonics(
    l_broadcasted, m_broadcasted, theta_broadcasted, phi_broadcasted)

# The lighting and BRDF coefficients come from the first Colab demo on Spherical
# Harmonics.
light_coeffs = np.array((2.17136424e-01, -2.06274278e-01, 3.10378283e-17,
                         2.76236879e-01, -3.08694040e-01, -4.69862940e-17,
                         -1.85866463e-01, 7.05744675e-17, 9.14290771e-02))
brdf_coeffs = np.array((0.28494423, 0.33231551, 0.16889377))

# Reconstruction of the light function.
reconstructed_light_function = tf.squeeze(
    vector.dot(sh_coefficients, light_coeffs))

###################################
# Setup the image, and the sphere #
###################################
# Image dimensions
image_width = 100
image_height = 80

# Sphere center and radius
sphere_radius = np.array((30,), dtype=dtype)
sphere_center = np.array((image_width / 2.0, image_height / 2.0, 100.0),
                         dtype=dtype)

# Builds the pixels grid and compute corresponding spherical coordinates.
pixel_grid_start = np.array((0, 0), dtype=dtype)
pixel_grid_end = np.array((image_width - 1, image_height - 1), dtype=dtype)
pixel_nb = np.array((image_width, image_height))
pixels = grid.generate(pixel_grid_start, pixel_grid_end, pixel_nb)
normalized_pixels = pixels / (image_width - 1, image_height - 1)
spherical_coordinates = tf_math.square_to_spherical_coordinates(
    normalized_pixels)

################################################################################################
# For each pixel in the image, estimate the corresponding surface point and associated normal. #
################################################################################################
intersection_3d, surface_normal = compute_intersection_normal_sphere(
    image_width, image_height, sphere_radius, sphere_center, dtype)
surface_normals_spherical_coordinates = tf_math.cartesian_to_spherical_coordinates(
    surface_normal)

##########################################
# Estimates result using SH convolution. #
##########################################
target = spherical_harmonics.integration_product(
    light_coeffs,
    spherical_harmonics.rotate_zonal_harmonics(
        brdf_coeffs,
        tf.expand_dims(surface_normals_spherical_coordinates[:, :, 1], axis=-1),
        tf.expand_dims(surface_normals_spherical_coordinates[:, :, 2],
                       axis=-1)),
    keepdims=False)
# Sets pixels not belonging to the sphere to 0.
target = tf.where(
    tf.greater(intersection_3d[:, :, 2], 0.0), target, tf.zeros_like(target))

#########################################################################################
# Optimization of the lighting coefficients by minimization of the reconstruction error #
#########################################################################################
# Initial solution.
recovered_light_coeffs = tf.Variable(
    np.array((1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)))


def reconstruct_image(recovered_light_coeffs):
  reconstructed_image = spherical_harmonics.integration_product(
      recovered_light_coeffs,
      spherical_harmonics.rotate_zonal_harmonics(
          brdf_coeffs,
          tf.expand_dims(
              surface_normals_spherical_coordinates[:, :, 1], axis=-1),
          tf.expand_dims(
              surface_normals_spherical_coordinates[:, :, 2], axis=-1)),
      keepdims=False)
  return tf.where(
      tf.greater(intersection_3d[:, :, 2], 0.0), reconstructed_image,
      tf.zeros_like(target))


# Sets the optimization problem up.
def my_loss(recovered_light_coeffs):
  reconstructed_image = reconstruct_image(recovered_light_coeffs)
  return tf.nn.l2_loss(reconstructed_image - target) / (
      image_width * image_height)


learning_rate = 0.1
with tf.name_scope("optimization"):
  optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate)


def gradient_loss(recovered_light_coeffs):
  with tf.GradientTape() as tape:
    loss_value = my_loss(recovered_light_coeffs)
  return tape.gradient(loss_value, [recovered_light_coeffs])


####################
# Initial solution #
####################
target_transpose = np.transpose(target, (1, 0))
reconstructed_image = reconstruct_image(recovered_light_coeffs)
reconstructed_image = np.transpose(reconstructed_image, (1, 0))
plt.figure(figsize=(10, 20))
ax = plt.subplot("131")
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
ax.grid(False)
ax.set_title("Target")
_ = ax.imshow(target_transpose, vmin=0.0)
ax = plt.subplot("132")
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
ax.grid(False)
ax.set_title("Initial solution")
_ = ax.imshow(reconstructed_image, vmin=0.0)

################
# Optimization #
################
nb_iterations = 100
for it in range(nb_iterations):
  gradients_loss = gradient_loss(recovered_light_coeffs)
  optimizer.apply_gradients(zip(gradients_loss, (recovered_light_coeffs,)))

  if it % 33 == 0:
    reconstructed_image = reconstruct_image(recovered_light_coeffs)
    reconstructed_image = np.transpose(reconstructed_image, (1, 0))
    # Displays the target and prediction.
    plt.figure(figsize=(10, 20))
    ax = plt.subplot("131")
    ax.axes.get_xaxis().set_visible(False)
    ax.axes.get_yaxis().set_visible(False)
    ax.grid(False)
    ax.set_title("Target")
    img = ax.imshow(target_transpose, vmin=0.0)
    ax = plt.subplot("132")
    ax.axes.get_xaxis().set_visible(False)
    ax.axes.get_yaxis().set_visible(False)
    ax.grid(False)
    ax.set_title("Prediction iteration " + str(it))
    img = ax.imshow(reconstructed_image, vmin=0.0)
    # Shows the difference between groundtruth and prediction.
    vmax = np.maximum(np.amax(reconstructed_image), np.amax(target_transpose))
    ax = plt.subplot("133")
    ax.axes.get_xaxis().set_visible(False)
    ax.axes.get_yaxis().set_visible(False)
    ax.grid(False)
    ax.set_title("Difference iteration " + str(it))
    img = ax.imshow(
        np.abs(reconstructed_image - target_transpose), vmin=0.0, vmax=vmax)

# Reconstructs the groundtruth and predicted environment maps.
reconstructed_predicted_light = tf.squeeze(
    vector.dot(sh_coefficients, recovered_light_coeffs))

# Displays the groundtruth and predicted environment maps.
plt.figure(figsize=(10, 20))
ax = plt.subplot("121")
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
ax.grid(False)
ax.set_title("Target light")
img = ax.imshow(reconstructed_light_function, vmin=0.0)
ax = plt.subplot("122")
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
ax.grid(False)
ax.set_title("Predicted light")
img = ax.imshow(reconstructed_predicted_light, vmin=0.0)