In [ ]:
# Copyright 2019 The TensorFlow Hub Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
This colab explores how to train autoencoders on a TPU device.
For this colab, consider the following scenario: you have an image classification model that you want to improve by adding some additional features. The features that you can add to the model could be image embeddings that can be separately trained on a TPU.
This example uses a fully-connected one layer model as the model that you want to make better with additional features trained on a TPU.
In this Colab, you will learn how to
In [ ]:
import os
assert os.environ['COLAB_TPU_ADDR'], 'Make sure to select TPU from Edit > Notebook settings > Hardware accelerator'
In [ ]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from absl import logging
logging.set_verbosity(logging.ERROR)
# Initialize TPU Strategy.
resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.experimental.TPUStrategy(resolver)
In [ ]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
y_train, y_test = y_train.astype(np.int32), y_test.astype(np.int32)
In [ ]:
def show_img(img):
plt.figure()
plt.imshow(img)
plt.grid(False)
plt.show()
img = 0
In [ ]:
show_img(x_test[img].reshape(28, 28))
In [ ]:
NUM_CLASSES = 10
# input image dimensions
IMG_ROWS, IMG_COLS = 28, 28
x_train = x_train.reshape(x_train.shape[0], IMG_ROWS, IMG_COLS, 1)
x_test = x_test.reshape(x_test.shape[0], IMG_ROWS, IMG_COLS, 1)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train = x_train / 255.0
x_test = x_test / 255.0
Here is a contrived example where the training happens only on the corners of the MNIST image.
Suppose that your original model, the fully-connected one layer network, was too computationally heavy, in terms of resources, and thus you could only afford to train on parts of the images. Instead of training on 28 by 28 pixels (784 pixels), you train on 14 by 14 pixels (196 pixels). This colab will later show that just by adding 49 more pixels to each training example, the size of each embedding, accuracy can be significantly increased.
This way you introduce minimal changes to an original model while gaining benefits from a heavy computational task that you can be offload to a TPU.
In [ ]:
x_train_corners = x_train[:, :14, :14, :]
x_test_corners = x_test[:, :14, :14, :]
In [ ]:
show_img(x_test_corners[img].reshape(14, 14))
In [ ]:
def get_model(input_shape):
ip = tf.keras.layers.Input(shape=input_shape)
x = tf.keras.layers.Flatten()(ip)
x = tf.keras.layers.Dense(NUM_CLASSES, activation='sigmoid')(x)
model = tf.keras.models.Model(ip, x)
return model
In [ ]:
with strategy.scope():
model0 = get_model(x_train_corners[0].shape)
model0.compile(
optimizer=tf.optimizers.SGD(learning_rate=0.05),
loss=tf.losses.SparseCategoricalCrossentropy(),
metrics=[tf.metrics.SparseCategoricalAccuracy()])
In [ ]:
model0.fit(x_train_corners, y_train, epochs=3, batch_size=128)
model0.evaluate(x_test_corners, y_test)
In [ ]:
def get_autoencoder_and_encoder(input_shape):
ip = tf.keras.layers.Input(shape=input_shape)
x = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', padding='same')(ip)
x = tf.keras.layers.MaxPooling2D((2, 2), padding='same')(x)
x = tf.keras.layers.Conv2D(1, (3, 3), activation='relu', padding='same')(x)
encoded = tf.keras.layers.MaxPooling2D((2, 2), padding='same')(x)
x = tf.keras.layers.Conv2DTranspose(1, (3, 3), activation='relu', strides=2, padding='same')(encoded)
x = tf.keras.layers.Conv2DTranspose(32, (3, 3), activation='relu', strides=2, padding='same')(x)
decoded = tf.keras.layers.Conv2DTranspose(1, (3, 3), activation='sigmoid', padding='same')(x)
autoencoder = tf.keras.models.Model(ip, outputs=decoded)
encoder = tf.keras.models.Model(ip, encoded)
return autoencoder, encoder
In [ ]:
tf.keras.backend.clear_session()
with strategy.scope():
autoencoder, encoder = get_autoencoder_and_encoder(x_train[0].shape)
autoencoder.compile(
optimizer=tf.optimizers.Adam(),
loss=tf.keras.losses.BinaryCrossentropy(),
metrics=[tf.metrics.BinaryAccuracy()])
autoencoder.fit(
x_train,
x_train,
batch_size=128,
epochs=3,
steps_per_epoch=468,
validation_data=(x_test, x_test))
In [ ]:
x_train_embeddings = encoder.predict(x_train)
x_test_embeddings = encoder.predict(x_test)
In [ ]:
x_test_hat = autoencoder.predict(x_test[:8])
In [ ]:
show_img(x_test_hat[img].reshape(28, 28))
In [ ]:
show_img(x_test[0].reshape(28, 28))
In [ ]:
show_img(x_test_embeddings[0].reshape(7, 7))
In [ ]:
x_train_augmented = np.concatenate([x_train_corners.reshape(60000, 14*14, 1), x_train_embeddings.reshape(60000, 7*7, 1)], axis=1)
x_test_augmented = np.concatenate([x_test_corners.reshape(10000, 14*14, 1), x_test_embeddings.reshape(10000, 7*7, 1)], axis=1)
In [ ]:
with strategy.scope():
model1 = get_model(x_train_augmented[0].shape)
model1.compile(
optimizer=tf.optimizers.SGD(learning_rate=0.06),
loss=tf.losses.SparseCategoricalCrossentropy(),
metrics=['accuracy'])
model1.fit(x_train_augmented, y_train, epochs=3, batch_size=128)
model1.evaluate(x_test_augmented, y_test)
On Google Cloud Platform, in addition to GPUs and TPUs available on pre-configured deep learning VMs, you will find AutoML(beta) for training custom models without writing code and Cloud ML Engine which will allows you to run parallel trainings and hyperparameter tuning of your custom models on powerful distributed hardware.