In [0]:
# Copyright 2018 The TensorFlow Hub Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
This notebook shows how to use TPUEstimator to build a simple classification model. The model can train, evaluate, and generate predictions using Cloud TPUs. It uses the iris dataset to predict the species of the flower and also shows how to use your own data instead of using pre-loaded data. This model uses 4 input features (SepalLength, SepalWidth, PetalLength, PetalWidth) to determine one of these flower species (Setosa, Versicolor, Virginica).
The model trains for 20 epochs and completes in approximately 2 minutes.
This notebook is hosted on GitHub. To view it in its original repository, after opening the notebook, select File > View on GitHub.
Create a Cloud Storage bucket for your TensorBoard logs at http://console.cloud.google.com/storage and fill in the bucket name in the "Resolve TPU Address and authenticate GCS bucket" cell below.
On the main menu, click Runtime and select Change runtime type. Set "TPU" as the hardware accelerator.
TPUs are located in Google Cloud, for optimal performance, they read data directly from Google Cloud Storage (GCS)
In [0]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import os
import pandas as pd
import pprint
import tensorflow as tf
import time
In [0]:
use_tpu = True #@param {type:"boolean"}
bucket = '' #@param {type:"string"}
assert bucket, 'Must specify an existing GCS bucket name'
print('Using bucket: {}'.format(bucket))
if use_tpu:
assert 'COLAB_TPU_ADDR' in os.environ, 'Missing TPU; did you request a TPU in Notebook Settings?'
MODEL_DIR = 'gs://{}/{}'.format(bucket, time.strftime('tpuestimator-dnn/%Y-%m-%d-%H-%M-%S'))
print('Using model dir: {}'.format(MODEL_DIR))
from google.colab import auth
auth.authenticate_user()
if 'COLAB_TPU_ADDR' in os.environ:
TF_MASTER = 'grpc://{}'.format(os.environ['COLAB_TPU_ADDR'])
# Upload credentials to TPU.
with tf.Session(TF_MASTER) as sess:
with open('/content/adc.json', 'r') as f:
auth_info = json.load(f)
tf.contrib.cloud.configure_gcs(sess, credentials=auth_info)
# Now credentials are set for all future sessions on this TPU.
else:
TF_MASTER=''
with tf.Session(TF_MASTER) as session:
print ('List of devices:')
pprint.pprint(session.list_devices())
In [0]:
# Model specific parameters
# TPU address
tpu_address = TF_MASTER
# Estimators model_dir
model_dir = MODEL_DIR
# This is the global batch size, not the per-shard batch.
batch_size = 128
# Total number of training steps.
train_steps = 1000
# Total number of evaluation steps. If '0', evaluation after training is skipped
eval_steps = 4
# Number of iterations per TPU training loop
iterations = 500
In [0]:
TRAIN_URL = "http://download.tensorflow.org/data/iris_training.csv"
TEST_URL = "http://download.tensorflow.org/data/iris_test.csv"
CSV_COLUMN_NAMES = ['SepalLength', 'SepalWidth',
'PetalLength', 'PetalWidth', 'Species']
SPECIES = ['Setosa', 'Versicolor', 'Virginica']
PREDICTION_INPUT_DATA = {
'SepalLength': [6.9, 5.1, 5.9],
'SepalWidth': [3.1, 3.3, 3.0],
'PetalLength': [5.4, 1.7, 4.2],
'PetalWidth': [2.1, 0.5, 1.5],
}
PREDICTION_OUTPUT_DATA = ['Virginica', 'Setosa', 'Versicolor']
def maybe_download():
train_path = tf.keras.utils.get_file(TRAIN_URL.split('/')[-1], TRAIN_URL)
test_path = tf.keras.utils.get_file(TEST_URL.split('/')[-1], TEST_URL)
return train_path, test_path
def load_data(y_name='Species'):
"""Returns the iris dataset as (train_x, train_y), (test_x, test_y)."""
train_path, test_path = maybe_download()
train = pd.read_csv(train_path, names=CSV_COLUMN_NAMES, header=0, dtype={'SepalLength': pd.np.float32,
'SepalWidth': pd.np.float32, 'PetalLength': pd.np.float32, 'PetalWidth': pd.np.float32, 'Species': pd.np.int32})
train_x, train_y = train, train.pop(y_name)
test = pd.read_csv(test_path, names=CSV_COLUMN_NAMES, header=0, dtype={'SepalLength': pd.np.float32,
'SepalWidth': pd.np.float32, 'PetalLength': pd.np.float32, 'PetalWidth': pd.np.float32, 'Species': pd.np.int32})
test_x, test_y = test, test.pop(y_name)
return (train_x, train_y), (test_x, test_y)
def train_input_fn(features, labels, batch_size):
"""An input function for training"""
# Convert the inputs to a Dataset.
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
# Shuffle, repeat, and batch the examples.
dataset = dataset.shuffle(1000).repeat()
dataset = dataset.apply(
tf.contrib.data.batch_and_drop_remainder(batch_size))
# Return the dataset.
return dataset
def eval_input_fn(features, labels, batch_size):
"""An input function for evaluation"""
features=dict(features)
inputs = (features, labels)
# Convert the inputs to a Dataset.
dataset = tf.data.Dataset.from_tensor_slices(inputs)
dataset = dataset.shuffle(1000).repeat()
dataset = dataset.apply(
tf.contrib.data.batch_and_drop_remainder(batch_size))
# Return the dataset.
return dataset
def predict_input_fn(features, batch_size):
"""An input function for prediction"""
dataset = tf.data.Dataset.from_tensor_slices(features)
dataset = dataset.batch(batch_size)
return dataset
In [0]:
def metric_fn(labels, logits):
"""Function to return metrics for evaluation"""
predicted_classes = tf.argmax(logits, 1)
accuracy = tf.metrics.accuracy(labels=labels,
predictions=predicted_classes,
name='acc_op')
return {'accuracy': accuracy}
def my_model(features, labels, mode, params):
"""DNN with three hidden layers, and dropout of 0.1 probability."""
# Create three fully connected layers each layer having a dropout
# probability of 0.1.
net = tf.feature_column.input_layer(features, params['feature_columns'])
for units in params['hidden_units']:
net = tf.layers.dense(net, units=units, activation=tf.nn.relu)
# Compute logits (1 per class).
logits = tf.layers.dense(net, params['n_classes'], activation=None)
# Compute predictions.
predicted_classes = tf.argmax(logits, 1)
if mode == tf.estimator.ModeKeys.PREDICT:
predictions = {
'class_ids': predicted_classes[:, tf.newaxis],
'probabilities': tf.nn.softmax(logits),
'logits': logits,
}
return tf.contrib.tpu.TPUEstimatorSpec(mode, predictions=predictions)
# Compute loss.
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels,
logits=logits)
if mode == tf.estimator.ModeKeys.EVAL:
return tf.contrib.tpu.TPUEstimatorSpec(
mode=mode, loss=loss, eval_metrics=(metric_fn, [labels, logits]))
# Create training op.
if mode == tf.estimator.ModeKeys.TRAIN:
optimizer = tf.train.AdagradOptimizer(learning_rate=0.1)
if use_tpu:
optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
return tf.contrib.tpu.TPUEstimatorSpec(mode, loss=loss, train_op=train_op)
In [0]:
def main():
# Fetch the data
(train_x, train_y), (test_x, test_y) = load_data()
# Feature columns describe how to use the input.
my_feature_columns = []
for key in train_x.keys():
my_feature_columns.append(tf.feature_column.numeric_column(key=key))
# Resolve TPU cluster and runconfig for this.
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
tpu_address)
run_config = tf.contrib.tpu.RunConfig(
model_dir=model_dir,
cluster=tpu_cluster_resolver,
session_config=tf.ConfigProto(
allow_soft_placement=True, log_device_placement=True),
tpu_config=tf.contrib.tpu.TPUConfig(iterations),
)
# Build 2 hidden layer DNN with 10, 10 units respectively.
classifier = tf.contrib.tpu.TPUEstimator(
model_fn=my_model,
use_tpu=use_tpu,
train_batch_size=batch_size,
eval_batch_size=batch_size,
predict_batch_size=batch_size,
config=run_config,
params={
'feature_columns': my_feature_columns,
# Two hidden layers of 10 nodes each.
'hidden_units': [10, 10],
# The model must choose between 3 classes.
'n_classes': 3,
'use_tpu': use_tpu,
})
# Train the Model.
classifier.train(
input_fn = lambda params: train_input_fn(
train_x, train_y, params["batch_size"]),
max_steps=train_steps)
# Evaluate the model.
eval_result = classifier.evaluate(
input_fn = lambda params: eval_input_fn(
test_x, test_y, params["batch_size"]),
steps=eval_steps)
print('\nTest set accuracy: {accuracy:0.3f}\n'.format(**eval_result))
# Generate predictions from the model
predictions = classifier.predict(
input_fn = lambda params: predict_input_fn(
PREDICTION_INPUT_DATA, params["batch_size"]))
for pred_dict, expec in zip(predictions, PREDICTION_OUTPUT_DATA):
template = ('\nPrediction is "{}" ({:.1f}%), expected "{}"')
class_id = pred_dict['class_ids'][0]
probability = pred_dict['probabilities'][class_id]
print(template.format(SPECIES[class_id],
100 * probability, expec))
In [0]:
main()
On Google Cloud Platform, in addition to GPUs and TPUs available on pre-configured deep learning VMs, you will find AutoML(beta) for training custom models without writing code and Cloud ML Engine which will allows you to run parallel trainings and hyperparameter tuning of your custom models on powerful distributed hardware.