In [0]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
|
This Colab covers an advanced topic and hence focuses on providing a toy example to form a high level understanding of how to estimate environment lighting using Spherical Harmonics rather than providing step by step details. We refer the interested reader to to get a high level understanding of Spherical Harmonics.
Given an image of a known object (sphere) with a known reflectance function, this Colab illustrates how to perform optimization of spherical harmonics to recover the lighting environment.
In [0]:
!pip install tensorflow_graphics
Now that Tensorflow Graphics is installed, let's import everything needed to run the demo contained in this notebook.
In [0]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow_graphics.rendering.camera import orthographic
from tensorflow_graphics.geometry.representation import grid
from tensorflow_graphics.geometry.representation import ray
from tensorflow_graphics.geometry.representation import vector
from tensorflow_graphics.math import spherical_harmonics
from tensorflow_graphics.math import math_helpers as tf_math
tf.compat.v1.enable_v2_behavior()
In [0]:
def compute_intersection_normal_sphere(width, height, radius,
center, dtype):
"""Estimates sphere normal and depth for each pixel in the image."""
# Generates a 2d grid storing pixel coordinates.
pixel_grid_start = np.array((0.5, 0.5), dtype=dtype)
pixel_grid_end = np.array((width - 0.5, height - 0.5), dtype=dtype)
pixel_nb = np.array((width, height))
pixels = grid.generate(pixel_grid_start, pixel_grid_end, pixel_nb)
# Computes the ray direction of each pixel.
pixel_ray = tf.math.l2_normalize(orthographic.ray(pixels), axis=-1)
# Defines the position of pixels in world coordinates.
zero_depth = np.zeros([width, height, 1])
pixels_3d = orthographic.unproject(pixels, zero_depth)
# Computes intersections with the sphere and surface normals for each ray.
intersection, normal = ray.intersection_ray_sphere(
center, radius, pixel_ray, pixels_3d)
# Extracts data about the closest intersection.
intersection = intersection[0, ...]
normal = normal[0, ...]
# Replaces NaNs with zeros.
zeros = tf.zeros_like(pixels_3d)
intersection = tf.where(tf.math.is_nan(intersection), zeros, intersection)
normal = tf.where(tf.math.is_nan(normal), zeros, normal)
return intersection, normal
In [0]:
light_image_width = 100
light_image_height = 100
dtype = np.float64
############################################################################
# Builds the pixels grid and computes corresponding spherical coordinates. #
############################################################################
pixel_grid_start = np.array((0, 0), dtype=dtype)
pixel_grid_end = np.array((light_image_width - 1, light_image_height - 1),
dtype=dtype)
pixel_nb = np.array((light_image_width, light_image_height))
pixels = grid.generate(pixel_grid_start, pixel_grid_end, pixel_nb)
normalized_pixels = pixels / (light_image_width - 1, light_image_height - 1)
spherical_coordinates = tf_math.square_to_spherical_coordinates(
normalized_pixels)
theta = spherical_coordinates[:, :, 1]
phi = spherical_coordinates[:, :, 2]
################################################################################################
# Builds the Spherical Harmonics and sets coefficients for the light and reflectance functions. #
################################################################################################
max_band = 2
l, m = spherical_harmonics.generate_l_m_permutations(max_band)
l_broadcasted = tf.broadcast_to(l, [light_image_width, light_image_height] +
l.shape.as_list())
m_broadcasted = tf.broadcast_to(m, [light_image_width, light_image_height] +
l.shape.as_list())
theta = tf.expand_dims(theta, axis=-1)
theta_broadcasted = tf.broadcast_to(theta,
[light_image_width, light_image_height, 1])
phi = tf.expand_dims(phi, axis=-1)
phi_broadcasted = tf.broadcast_to(phi,
[light_image_width, light_image_height, 1])
sh_coefficients = spherical_harmonics.evaluate_spherical_harmonics(
l_broadcasted, m_broadcasted, theta_broadcasted, phi_broadcasted)
# The lighting and BRDF coefficients come from the first Colab demo on Spherical
# Harmonics.
light_coeffs = np.array((2.17136424e-01, -2.06274278e-01, 3.10378283e-17,
2.76236879e-01, -3.08694040e-01, -4.69862940e-17,
-1.85866463e-01, 7.05744675e-17, 9.14290771e-02))
brdf_coeffs = np.array((0.28494423, 0.33231551, 0.16889377))
# Reconstruction of the light function.
reconstructed_light_function = tf.squeeze(
vector.dot(sh_coefficients, light_coeffs))
###################################
# Setup the image, and the sphere #
###################################
# Image dimensions
image_width = 100
image_height = 80
# Sphere center and radius
sphere_radius = np.array((30,), dtype=dtype)
sphere_center = np.array((image_width / 2.0, image_height / 2.0, 100.0),
dtype=dtype)
# Builds the pixels grid and compute corresponding spherical coordinates.
pixel_grid_start = np.array((0, 0), dtype=dtype)
pixel_grid_end = np.array((image_width - 1, image_height - 1), dtype=dtype)
pixel_nb = np.array((image_width, image_height))
pixels = grid.generate(pixel_grid_start, pixel_grid_end, pixel_nb)
normalized_pixels = pixels / (image_width - 1, image_height - 1)
spherical_coordinates = tf_math.square_to_spherical_coordinates(
normalized_pixels)
################################################################################################
# For each pixel in the image, estimate the corresponding surface point and associated normal. #
################################################################################################
intersection_3d, surface_normal = compute_intersection_normal_sphere(
image_width, image_height, sphere_radius, sphere_center, dtype)
surface_normals_spherical_coordinates = tf_math.cartesian_to_spherical_coordinates(
surface_normal)
##########################################
# Estimates result using SH convolution. #
##########################################
target = spherical_harmonics.integration_product(
light_coeffs,
spherical_harmonics.rotate_zonal_harmonics(
brdf_coeffs,
tf.expand_dims(surface_normals_spherical_coordinates[:, :, 1], axis=-1),
tf.expand_dims(surface_normals_spherical_coordinates[:, :, 2],
axis=-1)),
keepdims=False)
# Sets pixels not belonging to the sphere to 0.
target = tf.where(
tf.greater(intersection_3d[:, :, 2], 0.0), target, tf.zeros_like(target))
#########################################################################################
# Optimization of the lighting coefficients by minimization of the reconstruction error #
#########################################################################################
# Initial solution.
recovered_light_coeffs = tf.Variable(
np.array((1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)))
def reconstruct_image(recovered_light_coeffs):
reconstructed_image = spherical_harmonics.integration_product(
recovered_light_coeffs,
spherical_harmonics.rotate_zonal_harmonics(
brdf_coeffs,
tf.expand_dims(
surface_normals_spherical_coordinates[:, :, 1], axis=-1),
tf.expand_dims(
surface_normals_spherical_coordinates[:, :, 2], axis=-1)),
keepdims=False)
return tf.where(
tf.greater(intersection_3d[:, :, 2], 0.0), reconstructed_image,
tf.zeros_like(target))
# Sets the optimization problem up.
def my_loss(recovered_light_coeffs):
reconstructed_image = reconstruct_image(recovered_light_coeffs)
return tf.nn.l2_loss(reconstructed_image - target) / (
image_width * image_height)
learning_rate = 0.1
with tf.name_scope("optimization"):
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate)
def gradient_loss(recovered_light_coeffs):
with tf.GradientTape() as tape:
loss_value = my_loss(recovered_light_coeffs)
return tape.gradient(loss_value, [recovered_light_coeffs])
####################
# Initial solution #
####################
target_transpose = np.transpose(target, (1, 0))
reconstructed_image = reconstruct_image(recovered_light_coeffs)
reconstructed_image = np.transpose(reconstructed_image, (1, 0))
plt.figure(figsize=(10, 20))
ax = plt.subplot("131")
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
ax.grid(False)
ax.set_title("Target")
_ = ax.imshow(target_transpose, vmin=0.0)
ax = plt.subplot("132")
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
ax.grid(False)
ax.set_title("Initial solution")
_ = ax.imshow(reconstructed_image, vmin=0.0)
################
# Optimization #
################
nb_iterations = 100
for it in range(nb_iterations):
gradients_loss = gradient_loss(recovered_light_coeffs)
optimizer.apply_gradients(zip(gradients_loss, (recovered_light_coeffs,)))
if it % 33 == 0:
reconstructed_image = reconstruct_image(recovered_light_coeffs)
reconstructed_image = np.transpose(reconstructed_image, (1, 0))
# Displays the target and prediction.
plt.figure(figsize=(10, 20))
ax = plt.subplot("131")
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
ax.grid(False)
ax.set_title("Target")
img = ax.imshow(target_transpose, vmin=0.0)
ax = plt.subplot("132")
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
ax.grid(False)
ax.set_title("Prediction iteration " + str(it))
img = ax.imshow(reconstructed_image, vmin=0.0)
# Shows the difference between groundtruth and prediction.
vmax = np.maximum(np.amax(reconstructed_image), np.amax(target_transpose))
ax = plt.subplot("133")
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
ax.grid(False)
ax.set_title("Difference iteration " + str(it))
img = ax.imshow(
np.abs(reconstructed_image - target_transpose), vmin=0.0, vmax=vmax)
# Reconstructs the groundtruth and predicted environment maps.
reconstructed_predicted_light = tf.squeeze(
vector.dot(sh_coefficients, recovered_light_coeffs))
# Displays the groundtruth and predicted environment maps.
plt.figure(figsize=(10, 20))
ax = plt.subplot("121")
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
ax.grid(False)
ax.set_title("Target light")
img = ax.imshow(reconstructed_light_function, vmin=0.0)
ax = plt.subplot("122")
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
ax.grid(False)
ax.set_title("Predicted light")
img = ax.imshow(reconstructed_predicted_light, vmin=0.0)