In [0]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Closed Form Matting Energy

View source on GitHub

Matting is an important task in image editing where a novel background is combined with a given foreground to produce a new composite image. To achieve a plausible result, the foreground needs to be carefully extracted from a given image, i.e. preserving all the thin structures, before being inpainted over the new background. In image matting, the input image $I$ is assumed to be a linear combination of a foreground image $F$ and a background image $B$. For a pixel $j$ of $I$, the color of the pixel can therefore be expressed as $I_j = \alpha_j F_j +(1-\alpha_j)B_j$, where $\alpha_j$ is the foreground opacity for the pixel $j$. The opacity image made of all the $\alpha_j$ pixels is called a matte.

Using a trimap (white for foreground, black for background, and gray for unknown pixels)

or a set of scribbles (user strokes), an optimization problem can be formulated to retrieve the unknown pixel opacities. This colab demonstrates how to use the image matting loss implemented in TensorFlow Graphics to precisely segment out objects from images and have the ability to paste them on top of new backgrounds. This matting loss is derived from the paper titled "A Closed Form Solution to Natural Image Matting" from Levin et al. The loss was "tensorized" inspired by "Deep-Energy: Unsupervised Training of Deep Neural Networks" from Golts et al.

Setup & Imports

If TensorFlow Graphics is not installed on your system, the following cell can install the TensorFlow Graphics package for you.


In [0]:
!pip install tensorflow_graphics

Now that TensorFlow Graphics is installed, let's import everything needed to run the demos contained in this notebook.


In [0]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow_graphics.image import matting
from tqdm import tqdm

Import the image and trimap

Download the image and trimap from alphamatting.com.


In [0]:
# Download dataset from alphamatting.com
!rm -rf input_training_lowres
!rm -rf trimap_training_lowres
!rm -rf gt_training_lowres

!wget -q http://www.alphamatting.com/datasets/zip/input_training_lowres.zip
!wget -q http://www.alphamatting.com/datasets/zip/trimap_training_lowres.zip
!wget -q http://www.alphamatting.com/datasets/zip/gt_training_lowres.zip

!unzip -q input_training_lowres.zip -d input_training_lowres
!unzip -q trimap_training_lowres.zip -d trimap_training_lowres
!unzip -q gt_training_lowres.zip -d gt_training_lowres

In [0]:
# Read and decode images
source = tf.io.read_file('input_training_lowres/GT07.png')
source = tf.cast(tf.io.decode_png(source), tf.float64) / 255.0
source = tf.expand_dims(source, axis=0)
trimap = tf.io.read_file('trimap_training_lowres/Trimap1/GT07.png')
trimap = tf.cast(tf.io.decode_png(trimap), tf.float64) / 255.0
trimap = tf.reduce_mean(trimap, axis=-1, keepdims=True)
trimap = tf.expand_dims(trimap, axis=0)
gt_matte = tf.io.read_file('gt_training_lowres/GT07.png')
gt_matte = tf.cast(tf.io.decode_png(gt_matte), tf.float64) / 255.0
gt_matte = tf.reduce_mean(gt_matte, axis=-1, keepdims=True)
gt_matte = tf.expand_dims(gt_matte, axis=0)

# Resize images to improve performance
source = tf.image.resize(
    source,
    tf.shape(source)[1:3] // 2,
    method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
trimap = tf.image.resize(
    trimap,
    tf.shape(trimap)[1:3] // 2,
    method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
gt_matte = tf.image.resize(
    gt_matte,
    tf.shape(gt_matte)[1:3] // 2,
    method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

# Show images
figure = plt.figure(figsize=(22, 18))
axes = figure.add_subplot(1, 3, 1)
axes.grid(False)
axes.set_title('Input image', fontsize=14)
_= plt.imshow(source[0, ...].numpy())
axes = figure.add_subplot(1, 3, 2)
axes.grid(False)
axes.set_title('Input trimap', fontsize=14)
_= plt.imshow(trimap[0, ..., 0].numpy(), cmap='gray', vmin=0, vmax=1)
axes = figure.add_subplot(1, 3, 3)
axes.grid(False)
axes.set_title('GT matte', fontsize=14)
_= plt.imshow(gt_matte[0, ..., 0].numpy(), cmap='gray', vmin=0, vmax=1)

Extract the foreground and background constraints from the trimap image


In [0]:
# Extract the foreground and background constraints from the trimap image
foreground = tf.cast(tf.equal(trimap, 1.0), tf.float64)
background = tf.cast(tf.equal(trimap, 0.0), tf.float64)

# Show foreground and background constraints
figure = plt.figure(figsize=(22, 18))
axes = figure.add_subplot(1, 2, 1)
axes.grid(False)
axes.set_title('Foreground constraints', fontsize=14)
_= plt.imshow(foreground[0, ..., 0].numpy(), cmap='gray', vmin=0, vmax=1)
axes = figure.add_subplot(1, 2, 2)
axes.grid(False)
axes.set_title('Background constraints', fontsize=14)
_= plt.imshow(background[0, ..., 0].numpy(), cmap='gray', vmin=0, vmax=1)

Setup & run the optimization

Setup the matting loss function using TensorFlow Graphics and run the Adam optimizer for 400 iterations.


In [0]:
# Initialize the matte with random values
matte_shape = tf.concat((tf.shape(source)[:-1], (1,)), axis=-1)
matte = tf.Variable(
    tf.random.uniform(
        shape=matte_shape, minval=0.0, maxval=1.0, dtype=tf.float64))
# Create the closed form matting Laplacian
laplacian, _ = matting.build_matrices(source)

# Function computing the loss and applying the gradient
@tf.function
def optimize(optimizer):
  with tf.GradientTape() as tape:
    tape.watch(matte)
    # Compute a loss enforcing the trimap constraints
    constraints = tf.reduce_mean((foreground + background) *
                                 tf.math.squared_difference(matte, foreground))
    # Compute the matting loss
    smoothness = matting.loss(matte, laplacian)
    # Sum up the constraint and matting losses
    total_loss = 100 * constraints + smoothness
  # Compute and apply the gradient to the matte
  gradient = tape.gradient(total_loss, [matte])
  optimizer.apply_gradients(zip(gradient, (matte,)))

# Run the Adam optimizer for 400 iterations
optimizer = tf.optimizers.Adam(learning_rate=1.0)
nb_iterations = 400
for it in tqdm(range(nb_iterations)):
  optimize(optimizer)

# Clip the matte value between 0 and 1
matte = tf.clip_by_value(matte, 0.0, 1.0)

# Display the results
figure = plt.figure(figsize=(22, 18))
axes = figure.add_subplot(1, 3, 1)
axes.grid(False)
axes.set_title('Input image', fontsize=14)
plt.imshow(source[0, ...].numpy())
axes = figure.add_subplot(1, 3, 2)
axes.grid(False)
axes.set_title('Input trimap', fontsize=14)
_= plt.imshow(trimap[0, ..., 0].numpy(), cmap='gray', vmin=0, vmax=1)
axes = figure.add_subplot(1, 3, 3)
axes.grid(False)
axes.set_title('Matte', fontsize=14)
_= plt.imshow(matte[0, ..., 0].numpy(), cmap='gray', vmin=0, vmax=1)

Compositing

Let's now composite our extracted object on top of a new background!


In [0]:
!wget -q https://p2.piqsels.com/preview/861/934/460/concrete-texture-background-backdrop.jpg
background = tf.io.read_file('concrete-texture-background-backdrop.jpg')
background = tf.cast(tf.io.decode_jpeg(background), tf.float64) / 255.0
background = tf.expand_dims(background, axis=0)

# Resize images to improve performance
background = tf.image.resize(
    background,
    tf.shape(source)[1:3],
    method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

# Inpaint the foreground over a new background
inpainted_black = matte * source
inpainted_concrete = matte * source + (1.0 - matte) * background

# Display the results
figure = plt.figure(figsize=(22, 18))
axes = figure.add_subplot(1, 2, 1)
axes.grid(False)
axes.set_title('Inpainted black', fontsize=14)
_= plt.imshow(inpainted_black[0, ...].numpy())
axes = figure.add_subplot(1, 2, 2)
axes.grid(False)
axes.set_title('Inpainted concrete', fontsize=14)
_= plt.imshow(inpainted_concrete[0, ...].numpy())

Note that the inpainting is approximate as we did not recover the real foreground $F_j = \frac{I_j - (1−\alpha_j)B_j}{\alpha_j } $, which also necessitates an estimation of the background color.