In [0]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Spherical Harmonics Approximation

View source on GitHub

This Colab covers an advanced topic and hence focuses on providing controllable toy examples to form a high level understanding of Spherical Harmonics and their use for lighting rather than providing step by step details. For those interested, these details are nevertheless available in the code. A great resource to form a good understanding of Spherical Harmonics and their use for lighting is Spherical Harmonics Lighting: the Gritty Details.

This Colab demonstrates how to approximate functions defined over a sphere using Spherical Harmonics. These can be used to approximate lighting and , leading to very efficient rendering.

In more details, the following cells demonstrate:

  • Approximation of lighting environments with Spherical Harmonics (SH)
  • Approximation of the Lambertian BRDF with Zonal Harmonics (ZH)
  • Rotation of Zonal Harmonics
  • Rendering via Spherical Harmonics convolution of the SH lighting and ZH BRDF

Setup & Imports

If Tensorflow Graphics is not installed on your system, the following cell can install the Tensorflow Graphics package for you.


In [0]:
!pip install tensorflow_graphics

Now that Tensorflow Graphics is installed, let's import everything needed to run the demo contained in this notebook.


In [0]:
###########
# Imports #
###########
import math

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

from tensorflow_graphics.geometry.representation import grid
from tensorflow_graphics.geometry.representation import ray
from tensorflow_graphics.geometry.representation import vector
from tensorflow_graphics.rendering.camera import orthographic
from tensorflow_graphics.math import spherical_harmonics
from tensorflow_graphics.math import math_helpers as tf_math

tf.compat.v1.enable_eager_execution()

Approximation of lighting with Spherical Harmonics


In [0]:
#@title Controls { vertical-output: false, run: "auto" }
max_band = 2  #@param { type: "slider", min: 0, max: 10 , step: 1 }

#########################################################################
# This cell creates a lighting function which we approximate with an SH #
#########################################################################

def image_to_spherical_coordinates(image_width, image_height):
  pixel_grid_start = np.array((0, 0), dtype=type)
  pixel_grid_end = np.array((image_width - 1, image_height - 1), dtype=type)
  pixel_nb = np.array((image_width, image_height))
  pixels = grid.generate(pixel_grid_start, pixel_grid_end, pixel_nb)
  normalized_pixels = pixels / (image_width - 1, image_height - 1)
  spherical_coordinates = tf_math.square_to_spherical_coordinates(
      normalized_pixels)
  return spherical_coordinates


def light_function(theta, phi):
  theta = tf.convert_to_tensor(theta)
  phi = tf.convert_to_tensor(phi)
  zero = tf.zeros_like(theta)
  return tf.maximum(zero,
                    -4.0 * tf.sin(theta - np.pi) * tf.cos(phi - 2.5) - 3.0)


light_image_width = 30
light_image_height = 30
type = np.float64

# Builds the pixels grid and compute corresponding spherical coordinates.
spherical_coordinates = image_to_spherical_coordinates(light_image_width,
                                                       light_image_height)
theta = spherical_coordinates[:, :, 1]
phi = spherical_coordinates[:, :, 2]

# Samples the light function.
sampled_light_function = light_function(theta, phi)
ones_normal = tf.ones_like(theta)
spherical_coordinates_3d = tf.stack((ones_normal, theta, phi), axis=-1)
samples_direction_to_light = tf_math.spherical_to_cartesian_coordinates(
    spherical_coordinates_3d)

# Samples the SH.
l, m = spherical_harmonics.generate_l_m_permutations(max_band)
l = tf.convert_to_tensor(l)
m = tf.convert_to_tensor(m)
l_broadcasted = tf.broadcast_to(l, [light_image_width, light_image_height] +
                                l.shape.as_list())
m_broadcasted = tf.broadcast_to(m, [light_image_width, light_image_height] +
                                l.shape.as_list())
theta = tf.expand_dims(theta, axis=-1)
theta_broadcasted = tf.broadcast_to(
    theta, [light_image_width, light_image_height, 1])
phi = tf.expand_dims(phi, axis=-1)
phi_broadcasted = tf.broadcast_to(phi, [light_image_width, light_image_height, 1])
sh_coefficients = spherical_harmonics.evaluate_spherical_harmonics(
    l_broadcasted, m_broadcasted, theta_broadcasted, phi_broadcasted)
sampled_light_function_broadcasted = tf.expand_dims(
    sampled_light_function, axis=-1)
sampled_light_function_broadcasted = tf.broadcast_to(
    sampled_light_function_broadcasted,
    [light_image_width, light_image_height] + l.shape.as_list())

# Integrates the light function times SH over the sphere.
projection = sh_coefficients * sampled_light_function_broadcasted * 4.0 * math.pi / (
    light_image_width * light_image_height)
light_coeffs = tf.reduce_sum(projection, (0, 1))

# Reconstructs the image.
reconstructed_light_function = tf.squeeze(
    vector.dot(sh_coefficients, light_coeffs))

print(
    "average l2 reconstruction error ",
    np.linalg.norm(sampled_light_function - reconstructed_light_function) /
    (light_image_width * light_image_height))

vmin = np.minimum(
    np.amin(np.minimum(sampled_light_function, reconstructed_light_function)),
    0.0)
vmax = np.maximum(
    np.amax(np.maximum(sampled_light_function, reconstructed_light_function)),
    1.0)
# Plots results.
plt.figure(figsize=(10, 10))
ax = plt.subplot("131")
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
ax.grid(False)
ax.set_title("Original lighting function")
_ = ax.imshow(sampled_light_function, vmin=vmin, vmax=vmax)
ax = plt.subplot("132")
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
ax.grid(False)
ax.set_title("Spherical Harmonics approximation")
_ = ax.imshow(reconstructed_light_function, vmin=vmin, vmax=vmax)
ax = plt.subplot("133")
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
ax.grid(False)
ax.set_title("Difference")
_ = ax.imshow(
    np.abs(reconstructed_light_function - sampled_light_function),
    vmin=vmin,
    vmax=vmax)

Approximates the Lambertian BRDF with Zonal Harmonics


In [0]:
#################################################################
# This cell creates an SH that approximates the Lambertian BRDF #
#################################################################

# The image dimensions control how many uniform samples we draw from the BRDF.
brdf_image_width = 30
brdf_image_height = 30
type = np.float64

# Builds the pixels grid and compute corresponding spherical coordinates.
spherical_coordinates = image_to_spherical_coordinates(brdf_image_width,
                                                       brdf_image_height)

# Samples the BRDF function.
cos_theta = tf.cos(spherical_coordinates[:, :, 1])
sampled_brdf = tf.maximum(tf.zeros_like(cos_theta), cos_theta / np.pi)

# Samples the zonal SH.
l, m = spherical_harmonics.generate_l_m_zonal(max_band)
l_broadcasted = tf.broadcast_to(l, [brdf_image_width, brdf_image_height] +
                                l.shape.as_list())
m_broadcasted = tf.broadcast_to(m, [brdf_image_width, brdf_image_height] +
                                l.shape.as_list())
theta = tf.expand_dims(spherical_coordinates[:, :, 1], axis=-1)
theta_broadcasted = tf.broadcast_to(
    theta, [brdf_image_width, brdf_image_height, 1])
phi = tf.expand_dims(spherical_coordinates[:, :, 2], axis=-1)
phi_broadcasted = tf.broadcast_to(phi, [brdf_image_width, brdf_image_height, 1])
sh_coefficients = spherical_harmonics.evaluate_spherical_harmonics(
    l_broadcasted, m_broadcasted, theta_broadcasted, phi_broadcasted)
sampled_brdf_broadcasted = tf.expand_dims(sampled_brdf, axis=-1)
sampled_brdf_broadcasted = tf.broadcast_to(
    sampled_brdf_broadcasted,
    [brdf_image_width, brdf_image_height] + l.shape.as_list())

# Integrates the BRDF function times SH over the sphere.
projection = sh_coefficients * sampled_brdf_broadcasted * 4.0 * math.pi / (
    brdf_image_width * brdf_image_height)
brdf_coeffs = tf.reduce_sum(projection, (0, 1))

# Reconstructs the image.
reconstructed_brdf = tf.squeeze(vector.dot(sh_coefficients, brdf_coeffs))

print(
    "average l2 reconstruction error ",
    np.linalg.norm(sampled_brdf - reconstructed_brdf) /
    (brdf_image_width * brdf_image_height))

vmin = np.minimum(np.amin(np.minimum(sampled_brdf, reconstructed_brdf)), 0.0)
vmax = np.maximum(
    np.amax(np.maximum(sampled_brdf, reconstructed_brdf)), 1.0 / np.pi)
# Plots results.
plt.figure(figsize=(10, 10))
ax = plt.subplot("131")
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
ax.grid(False)
ax.set_title("Original reflectance function")
_ = ax.imshow(sampled_brdf, vmin=vmin, vmax=vmax)
ax = plt.subplot("132")
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
ax.grid(False)
ax.set_title("Zonal Harmonics approximation")
_ = ax.imshow(reconstructed_brdf, vmin=vmin, vmax=vmax)
ax = plt.subplot("133")
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
ax.grid(False)
ax.set_title("Difference")
_ = ax.imshow(np.abs(sampled_brdf - reconstructed_brdf), vmin=vmin, vmax=vmax)

plt.figure(figsize=(10, 5))
plt.plot(
    spherical_coordinates[:, 0, 1],
    sampled_brdf[:, 0],
    label="max(0,cos(x) / pi)")
plt.plot(
    spherical_coordinates[:, 0, 1],
    reconstructed_brdf[:, 0],
    label="SH approximation")
plt.title("Approximation quality")
plt.legend()
plt.show()

Rotation of Zonal Harmonics


In [0]:
###############################
# Rotation of zonal harmonics #
###############################

r_theta = tf.constant(np.pi / 2, shape=(1,), dtype=brdf_coeffs.dtype)
r_phi = tf.constant(0.0, shape=(1,), dtype=brdf_coeffs.dtype)
rotated_zonal_coefficients = spherical_harmonics.rotate_zonal_harmonics(
    brdf_coeffs, r_theta, r_phi)

# Builds the pixels grid and compute corresponding spherical coordinates.
pixel_grid_start = np.array((0, 0), dtype=type)
pixel_grid_end = np.array((brdf_image_width - 1, brdf_image_height - 1),
                          dtype=type)
pixel_nb = np.array((brdf_image_width, brdf_image_height))
pixels = grid.generate(pixel_grid_start, pixel_grid_end, pixel_nb)
normalized_pixels = pixels / (brdf_image_width - 1, brdf_image_height - 1)
spherical_coordinates = tf_math.square_to_spherical_coordinates(
    normalized_pixels)

# reconstruction.
l, m = spherical_harmonics.generate_l_m_permutations(max_band)
l_broadcasted = tf.broadcast_to(
    l, [light_image_width, light_image_height] + l.shape.as_list())
m_broadcasted = tf.broadcast_to(
    m, [light_image_width, light_image_height] + l.shape.as_list())
theta = tf.expand_dims(spherical_coordinates[:, :, 1], axis=-1)
theta_broadcasted = tf.broadcast_to(
    theta, [light_image_width, light_image_height, 1])
phi = tf.expand_dims(spherical_coordinates[:, :, 2], axis=-1)
phi_broadcasted = tf.broadcast_to(
    phi, [light_image_width, light_image_height, 1])
sh_coefficients = spherical_harmonics.evaluate_spherical_harmonics(
    l_broadcasted, m_broadcasted, theta_broadcasted, phi_broadcasted)

reconstructed_rotated_brdf_function = tf.squeeze(
    vector.dot(sh_coefficients, rotated_zonal_coefficients))

plt.figure(figsize=(10, 10))
ax = plt.subplot("121")
ax.set_title("Zonal SH")
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
ax.grid(False)
_ = ax.imshow(reconstructed_brdf)
ax = plt.subplot("122")
ax.set_title("Rotated version")
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
ax.grid(False)
_ = ax.imshow(reconstructed_rotated_brdf_function)

Reconstruction via Spherical Harmonics convolution of the SH lighting and ZH BRDF


In [0]:
############################################################################################
# Helper function allowing to estimate sphere normal and depth for each pixel in the image #
############################################################################################
def compute_intersection_normal_sphere(image_width, image_height, sphere_radius,
                                       sphere_center, type):
  pixel_grid_start = np.array((0.5, 0.5), dtype=type)
  pixel_grid_end = np.array((image_width - 0.5, image_height - 0.5), dtype=type)
  pixel_nb = np.array((image_width, image_height))
  pixels = grid.generate(pixel_grid_start, pixel_grid_end, pixel_nb)

  pixel_ray = tf.math.l2_normalize(orthographic.ray(pixels), axis=-1)
  zero_depth = np.zeros([image_width, image_height, 1])
  pixels_3d = orthographic.unproject(pixels, zero_depth)

  intersections_points, normals = ray.intersection_ray_sphere(
      sphere_center, sphere_radius, pixel_ray, pixels_3d)
  return intersections_points[0, :, :, :], normals[0, :, :, :]


###############################
# Setup the image, and sphere #
###############################
# Image dimensions
image_width = 100
image_height = 80

# Sphere center and radius
sphere_radius = np.array((30.0,), dtype=type)
sphere_center = np.array((image_width / 2.0, image_height / 2.0, 100.0),
                         dtype=type)

# Builds the pixels grid and compute corresponding spherical coordinates.
pixel_grid_start = np.array((0, 0), dtype=type)
pixel_grid_end = np.array((image_width - 1, image_height - 1), dtype=type)
pixel_nb = np.array((image_width, image_height))
pixels = grid.generate(pixel_grid_start, pixel_grid_end, pixel_nb)
normalized_pixels = pixels / (image_width - 1, image_height - 1)
spherical_coordinates = tf_math.square_to_spherical_coordinates(
    normalized_pixels)

################################################################################################
# For each pixel in the image, estimate the corresponding surface point and associated normal. #
################################################################################################
intersection_3d, surface_normal = compute_intersection_normal_sphere(
    image_width, image_height, sphere_radius, sphere_center, type)

surface_normals_spherical_coordinates = tf_math.cartesian_to_spherical_coordinates(
    surface_normal)

# SH
l, m = spherical_harmonics.generate_l_m_permutations(
    max_band)  # recomputed => optimize
l = tf.convert_to_tensor(l)
m = tf.convert_to_tensor(m)
l_broadcasted = tf.broadcast_to(l,
                                [image_width, image_height] + l.shape.as_list())
m_broadcasted = tf.broadcast_to(m,
                                [image_width, image_height] + l.shape.as_list())

#################################################
# Estimates result using SH convolution - cheap #
#################################################

sh_integration = spherical_harmonics.integration_product(
    light_coeffs,
    spherical_harmonics.rotate_zonal_harmonics(
        brdf_coeffs,
        tf.expand_dims(surface_normals_spherical_coordinates[:, :, 1], axis=-1),
        tf.expand_dims(surface_normals_spherical_coordinates[:, :, 2],
                       axis=-1)),
    keepdims=False)

# Sets pixels not belonging to the sphere to 0.
sh_integration = tf.where(
    tf.greater(intersection_3d[:, :, 2], 0.0), sh_integration,
    tf.zeros_like(sh_integration))
# Sets pixels with negative light to 0.
sh_integration = tf.where(
    tf.greater(sh_integration, 0.0), sh_integration,
    tf.zeros_like(sh_integration))

###########################################
# 'Brute force' solution - very expensive #
###########################################

factor = 4.0 * np.pi / (light_image_width * light_image_height)
gt = tf.einsum(
    "hwn,uvn->hwuv", surface_normal,
    samples_direction_to_light *
    tf.expand_dims(sampled_light_function, axis=-1))
gt = tf.maximum(gt, 0.0)  # removes negative dot products
gt = tf.reduce_sum(gt, axis=(2, 3))
# Sets pixels not belonging to the sphere to 0.
gt = tf.where(tf.greater(intersection_3d[:, :, 2], 0.0), gt, tf.zeros_like(gt))
gt *= factor

# TODO(b/124463095): gt and sh_integration differ by a factor of pi.
sh_integration = np.transpose(sh_integration, (1, 0))
gt = np.transpose(gt, (1, 0))

plt.figure(figsize=(10, 20))
ax = plt.subplot("121")
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
ax.grid(False)
ax.set_title("SH light and SH BRDF")
_ = ax.imshow(sh_integration, vmin=0.0)
ax = plt.subplot("122")
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
ax.grid(False)
ax.set_title("GT light and GT BRDF")
_ = ax.imshow(gt, vmin=0.0)