Domain Adaptation with GRL

In this lab, we will implement the paper Unsupervised Domain Adaptation by Backpropagation which adapts between two domains using a Gradient Reversal Layer.

Our source domain will be MNIST, while our target domain will be MNIST-M. It is a modified version of MNIST with various color.

Loading data

We will use only a subset of the datasets to be faster, but feel free to use all the datasets if you have a GPU.

In [ ]:
import numpy as np


def get_subset(x, y):
    if not USE_SUBSET:
        return x, y

    subset_index = 10000
    indexes = np.random.permutation(len(x))[:subset_index]
    x, y = x[indexes], y[indexes]

    return x, y

Loading source dataset MNIST:

In [ ]:
from tensorflow.keras.datasets import mnist
from skimage.color import gray2rgb
from skimage.transform import resize
from sklearn.model_selection import train_test_split

(x_source_train, y_source_train), (x_source_test, y_source_test) = mnist.load_data()

def process_mnist(x):
    x = np.moveaxis(x, 0, -1)
    x = resize(x, (32, 32), anti_aliasing=True, mode='constant')
    x = np.moveaxis(x, -1, 0)
    return gray2rgb(x).astype("float32")

x_source_train = process_mnist(x_source_train)
x_source_test = process_mnist(x_source_test)

x_source_train, y_source_train = get_subset(x_source_train, y_source_train)
#x_source_test, y_source_test = get_subset(x_source_test, y_source_test)

x_source_train, x_source_val, y_source_train, y_source_val = train_test_split(
    x_source_train, y_source_train,
    test_size=int(0.1 * len(x_source_train))

x_source_train.shape, x_source_val.shape, x_source_test.shape

In [ ]:
%matplotlib inline
import matplotlib.pyplot as plt

plt.figure(figsize=(20, 15))
for i, digit in enumerate(np.unique(y_source_train), start=1):
    index = np.where(y_source_train == digit)[0][0]
    ax = plt.subplot(1, 10, i)

Loading target dataset MNIST-M:

In [ ]:
import pickle as pkl

with open("mnistm_data.pkl", "rb") as f:
    mnist_m = pkl.load(f)
x_target_train, y_target_train = get_subset(mnist_m["x_train"], mnist_m["y_train"])
x_target_test, y_target_test = mnist_m["x_test"], mnist_m["y_test"]

x_target_train = resize(x_target_train, (x_target_train.shape[0], 32, 32, 3), anti_aliasing=True, mode='edge').astype("float32")
x_target_test = resize(x_target_test, (x_target_test.shape[0], 32, 32, 3), anti_aliasing=True, mode='edge').astype("float32")

x_target_train.shape, x_target_test.shape

In [ ]:
plt.figure(figsize=(20, 15))
for i, digit in enumerate(np.unique(y_target_train), start=1):
    index = np.where(y_target_train == digit)[0][0]
    ax = plt.subplot(1, 10, i)

Naive model

In the first step, we will build a naive model, depicted in the image below. Implement it as shown:

In [ ]:
from tensorflow.keras.layers import MaxPool2D, Conv2D, Dense, Dropout, Flatten, Input
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import SGD
import tensorflow as tf

def get_network(input_shape=x_source_train.shape[1:]):
    # TODO
    return Model(inputs=inputs, outputs=digits_classifier)

model = get_network()

    optimizer=SGD(lr=0.1, momentum=0.9, nesterov=True),


In [ ]:
# %load solutions/

In [ ]:
    x_source_train, y_source_train,
    validation_data=(x_source_val, y_source_val),

After training on our source dataset MNIST, we evaluate our model performance on both the source (MNIST) and the target dataset MNIST-M:

In [ ]:
print("Loss & Accuracy on MNIST test set:")
model.evaluate(x_source_test, y_source_test, verbose=0)

In [ ]:
print("Loss & Accuracy on MNIST-M test set:")
model.evaluate(x_target_test, y_target_test, verbose=0)

Note that the two datasets are too different. The model didn't generalize on the target set.

Model with Gradient Reversal Layer

Let us first define a Gradient Rerversal Layer where we want to inverse the gradient:

In [ ]:
def grad_reverse(x):
    y = tf.identity(x)
    def custom_grad(dy):
        return None # TODO
    return y, custom_grad

class GradReverse(tf.keras.layers.Layer):
    def __init__(self):

    def call(self, x):
        return grad_reverse(x)

In [ ]:
# %load solutions/

Then define the whole model: convnet + classification branch + domain branch

In [ ]:
def get_adaptable_network(input_shape=x_source_train.shape[1:]):
    # TODO
    return Model(inputs=inputs, outputs=None)

model = get_adaptable_network()

In [ ]:
# %load solutions/

We define our generators. Note that we also add the domain labels. We choose arbitrarily to set the source domain to 1, and the target domain to 0.

In [ ]:
batch_size = 128
epochs = 10

d_source_train = np.ones_like(y_source_train)
d_source_val = np.ones_like(y_source_val)

source_train_generator =
    (x_source_train, y_source_train, d_source_train)).batch(batch_size)

d_target_train = np.zeros_like(y_target_train)

target_train_generator =
    (x_target_train, d_target_train)

We want to train alternatively on the source and target dataset. Fill the following block.

Note that to work properly we set a low factor of 0.2 to the domain losses.

See the documentation for more information on how to use GradientTape: doc.

In [ ]:
from tensorflow.keras.losses import SparseCategoricalCrossentropy, BinaryCrossentropy
from tensorflow.keras.metrics import Mean, Accuracy

optimizer = SGD(lr=0.01, momentum=0.9, nesterov=True)

cce = SparseCategoricalCrossentropy()
bce = BinaryCrossentropy()

    loss=[cce, bce],
    metrics=["accuracy", "accuracy"]

def train_epoch(source_train_generator, target_train_generator):
    global lambda_factor, global_step

    # Keras provide helpful classes to monitor various metrics:
    epoch_source_digits = tf.keras.metrics.Mean()
    epoch_source_domains = tf.keras.metrics.Mean()
    epoch_target_domains = tf.keras.metrics.Mean()
    epoch_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()

    # Fetch all trainable variables but those used uniquely for the digits classification:
    variables_but_classifier = list(filter(lambda x: "digits" not in, model.trainable_variables))
    loss_record = collections.defaultdict(list)
    for i, data in enumerate(zip(source_train_generator, target_train_generator)):
        source_data, target_data = data
        # Training digits classifier & domain classifier on source:
        x_source, y_source, d_source = source_data

        # Remember that you can do forward likewise:
        #   outputs = model(inputs)
        with tf.GradientTape() as tape:
            # TODO

        gradients = tape.gradient(# TODO, # TODO)
        optimizer.apply_gradients(zip(# TODO, # TODO))

        # Training domain classifier on target:
        x_target, d_target = target_data
        with tf.GradientTape() as tape:
            # TODO

        gradients = tape.gradient(# TODO, # TODO)
        optimizer.apply_gradients(zip(# TODO, # TODO))

        # Log the various losses and accuracy
        epoch_accuracy(y_source, digits_prob)

    print("Source digits loss={}, Source Accuracy={}, Source domain loss={}, Target domain loss={}".format(
        epoch_source_digits.result(), epoch_accuracy.result(), 
        epoch_source_domains.result(), epoch_target_domains.result()))

for epoch in range(epochs):
    print("Epoch: {}".format(epoch), end=" ")
    loss_record = train_epoch(source_train_generator, target_train_generator)

This new model has more metrics & losses than the previous one. To know what they are we can display the metrics_name:

In [ ]:

Evaluate the performance on both source and target dataset:

In [ ]:
print("Loss & Accuracy on MNIST test set:")
model.evaluate(x_source_test, [y_source_test, np.ones_like(y_source_test)], verbose=0)

In [ ]:
print("Loss & Accuracy on MNIST-M test set:")
model.evaluate(x_target_test, [y_target_test, np.zeros_like(y_target_test)], verbose=0)

The model is still not as good on the target dataset (MNIST-M) than on the source dataset (MNIST), but the performance are much better! Without using target labels we improve our performance from 40% of accuracy to more than 60% of accuracy.


  • Train on the whole dataset
  • Train for more epochs, use callbacks such as EarlyStopping to know when to stop
  • Try to improve model by scheduling the learning rate as they do in the paper
  • Try to improve model by scheduling the domain loss weight
  • Try others domains, like SVHN -> MNIST

In [ ]: