Building a successful machine learning (ML) system involves more than training a model. In this two-part article discusses the role of TensorFlow Data Validation (TFDV) library in performing data exploration and descriptive analytics during experimentation, as well as in validating the incoming data for training or prediction during production.
This tutorial shows you step-by-step how to use TFDV to analyze and validate data for ML on Google Cloud Platform (GCP).
The objective of this tutorial is to:
In [ ]:
!pip install -U tensorflow-data-validation==0.11.0
!pip install -U apache-beam[gcp]==2.8.0
!pip install -U google_cloud_bigquery==1.18.0
!pip install -U python-snappy==0.5.4
In [3]:
import os
import tensorflow as tf
import apache_beam as beam
import tensorflow_data_validation as tfdv
from google.cloud import bigquery
from datetime import datetime
Verify the version of the installed packages:
In [4]:
print "TF version:", tf.__version__
print "TFDV version:", tfdv.__version__
print "Beam version:", beam.__version__
print "BQ SDK version:", bigquery.__version__
To get started, set your GCP PROJECT_ID, BUCKET_NAME, and REGION to the following variables. Create a GCP Project if you don't have one. Create a regional Cloud Storage bucket if you don't have one.
In [5]:
LOCAL = True # Change to false to run on the GCP
PROJECT_ID = 'validateflow' # Set your GCP Project Id
BUCKET_NAME = 'validateflow' # Set your Bucket name
REGION = 'europe-west1' # Set the region for Dataflow jobs
ROOT = './tfdv' if LOCAL else 'gs://{}/tfdv'.format(BUCKET_NAME)
DATA_DIR = ROOT + '/data/' # Location to store data
SCHEMA_DIR = ROOT + '/schema/' # Location to store data schema
STATS_DIR = ROOT +'/stats/' # Location to store stats
STAGING_DIR = ROOT + '/job/staging/' # Dataflow staging directory on GCP
TEMP_DIR = ROOT + '/job/temp/' # Dataflow temporary directory on GCP
Cleanup working directory...
In [6]:
if tf.gfile.Exists(ROOT):
print("Removing {} contents...".format(ROOT))
tf.gfile.DeleteRecursively(ROOT)
print("Creating working directory: {}".format(ROOT))
tf.gfile.MkDir(ROOT)
In this tutorial, we will use the flights data table, which is a publically available sample data in BigQuery.
The table has more than 70 million records on internal US flights, including information on date, airlline, departure airport, arrival airport, departure schedule, actual departure time, arrival schedule, and actual arrival time.
You can use the BigQuery to explore the data, or you can run the following cell, which counts the number of flights by year.
In [ ]:
%%bigquery
SELECT
EXTRACT(YEAR FROM CAST(date as DATE)) as year,
COUNT(*) as flight_count
FROM
`bigquery-samples.airline_ontime_data.flights`
GROUP BY
year
ORDER BY
year DESC
We have data from 2002 to 2012. The dataset is ~8GB, which might be too big to store into memory for exploration. However, you can use TFDV to peform the data crunching on GCP at scale using Cloud Dataflow, to produce the statistics that can be easily loaded into memory, visualized and analzyed.
In this step, we will extract the data we want to analyze from BigQuery, convert it to TFRecord files, and store the data files to Cloud Storage (GCS). This data file in GCS will then be used by TFDV. We are going to use Apache Beam to accomplish this.
Let's say that you use this dataset to estimate the arrival delay of a particular flight using ML. Note that, in this tutorial, we are not focusing on building the model, rather we are focusing on analyzing and validating the data changes over time. We are going to use data in 2010-2011 to generate the schema, while validating data in 2012 to identify anomalies.
Note that, in more realistic scenarios, new flights data arrives on daily or weekly basis to your data warehouse, and you would validate this day-worth of data against the schema. The purpose of this example to show how this can be done at scale (using year-worth of data) to identify anomalies.
The data will be extracted with the following columns:
In [7]:
def generate_query(date_from=None, date_to=None, limit=None):
query ="""
SELECT
CAST(date AS DATE) AS flight_date,
FORMAT_DATE('%b', CAST(date AS DATE)) AS flight_month,
EXTRACT(DAY FROM CAST(date AS DATE)) AS flight_day,
FORMAT_DATE('%a', CAST(date AS DATE)) AS flight_day_of_week,
airline,
departure_airport,
arrival_airport,
CAST(SUBSTR(LPAD(CAST(departure_schedule AS STRING), 4, '0'), 0, 2) AS INT64) AS departure_schedule_hour,
CAST(SUBSTR(LPAD(CAST(departure_schedule AS STRING), 4, '0'), 3, 2) AS INT64) AS departure_schedule_minute,
CASE
WHEN departure_schedule BETWEEN 600 AND 900 THEN '[6:00am - 9:00am]'
WHEN departure_schedule BETWEEN 900 AND 1200 THEN '[9:00am - 12:pm]'
WHEN departure_schedule BETWEEN 1200 AND 1500 THEN '[12:00pm - 3:00pm]'
WHEN departure_schedule BETWEEN 1500 AND 1800 THEN '[3:00pm - 6:00pm]'
WHEN departure_schedule BETWEEN 1800 AND 2100 THEN '[6:00pm - 9:00pm]'
WHEN departure_schedule BETWEEN 2100 AND 2400 THEN '[9:00pm - 12:00am]'
ELSE '[12:00am - 6:00am]'
END AS departure_time_slot,
departure_delay,
arrival_delay
FROM
`bigquery-samples.airline_ontime_data.flights`
"""
if date_from:
query += "WHERE CAST(date as DATE) >= CAST('{}' as DATE) \n".format(date_from)
if date_to:
query += "AND CAST(date as DATE) < CAST('{}' as DATE) \n".format(date_to)
elif date_to:
query += "WHERE CAST(date as DATE) < CAST('{}' as DATE) \n".format(date_to)
if limit:
query += "LIMIT {}".format(limit)
return query
You can run the following cell to see a sample of the data to be extracted...
In [ ]:
%%bigquery
SELECT
CAST(date AS DATE) AS flight_date,
FORMAT_DATE('%b', CAST(date AS DATE)) AS flight_month,
EXTRACT(DAY FROM CAST(date AS DATE)) AS flight_day,
FORMAT_DATE('%a', CAST(date AS DATE)) AS flight_day_of_week,
airline,
departure_airport,
arrival_airport,
CAST(SUBSTR(LPAD(CAST(departure_schedule AS STRING), 4, '0'), 0, 2) AS INT64) AS departure_schedule_hour,
CAST(SUBSTR(LPAD(CAST(departure_schedule AS STRING), 4, '0'), 3, 2) AS INT64) AS departure_schedule_minute,
CASE
WHEN departure_schedule BETWEEN 600 AND 900 THEN '[6:00am - 9:00am]'
WHEN departure_schedule BETWEEN 900 AND 1200 THEN '[9:00am - 12:pm]'
WHEN departure_schedule BETWEEN 1200 AND 1500 THEN '[12:00pm - 3:00pm]'
WHEN departure_schedule BETWEEN 1500 AND 1800 THEN '[3:00pm - 6:00pm]'
WHEN departure_schedule BETWEEN 1800 AND 2100 THEN '[6:00pm - 9:00pm]'
WHEN departure_schedule BETWEEN 2100 AND 2400 THEN '[9:00pm - 12:00am]'
ELSE '[12:00am - 6:00am]'
END AS departure_time_slot,
departure_delay,
arrival_delay
FROM
`bigquery-samples.airline_ontime_data.flights`
LIMIT 5
In [8]:
def get_type_map(query):
bq_client = bigquery.Client()
query_job = bq_client.query("SELECT * FROM ({}) LIMIT 0".format(query))
results = query_job.result()
type_map = {}
for field in results.schema:
type_map[field.name] = field.field_type
return type_map
def row_to_example(instance, type_map):
feature = {}
for key, value in instance.items():
data_type = type_map[key]
if value is None:
feature[key] = tf.train.Feature()
elif data_type == 'INTEGER':
feature[key] = tf.train.Feature(
int64_list=tf.train.Int64List(value=[value]))
elif data_type == 'FLOAT':
feature[key] = tf.train.Feature(
float_list=tf.train.FloatList(value=[value]))
else:
feature[key] = tf.train.Feature(
bytes_list=tf.train.BytesList(value=[tf.compat.as_bytes(value)]))
return tf.train.Example(features=tf.train.Features(feature=feature))
In [9]:
def run_pipeline(args):
source_query = args.pop('source_query')
sink_data_location = args.pop('sink_data_location')
runner = args['runner']
pipeline_options = beam.options.pipeline_options.GoogleCloudOptions(**args)
print(pipeline_options)
with beam.Pipeline(runner, options=pipeline_options) as pipeline:
(pipeline
| "Read from BigQuery">> beam.io.Read(beam.io.BigQuerySource(query = source_query, use_standard_sql = True))
| 'Convert to tf Example' >> beam.Map(lambda instance: row_to_example(instance, type_map))
| 'Serialize to String' >> beam.Map(lambda example: example.SerializeToString(deterministic=True))
| "Write as TFRecords to GCS" >> beam.io.WriteToTFRecord(
file_path_prefix = sink_data_location+"extract",
file_name_suffix=".tfrecords")
)
In [10]:
runner = 'DirectRunner' if LOCAL else 'DataflowRunner'
job_name = 'tfdv-flights-data-extraction-{}'.format(datetime.utcnow().strftime('%y%m%d-%H%M%S'))
date_from = '2010-01-01'
date_to = '2011-12-31'
data_location = os.path.join(DATA_DIR,
"{}-{}/".format(date_from.replace('-',''), date_to.replace('-','')))
print("Data will be extracted to: {}".format(data_location))
print("Generating source query...")
limit = 100000 if LOCAL else None
source_query = generate_query(date_from, date_to, limit)
print("Retrieving data type...")
type_map = get_type_map(source_query)
args = {
'job_name': job_name,
'runner': runner,
'source_query': source_query,
'type_map': type_map,
'sink_data_location': data_location,
'project': PROJECT_ID,
'region': REGION,
'staging_location': STAGING_DIR,
'temp_location': TEMP_DIR,
'save_main_session': True,
'setup_file': './setup.py'
}
print("Pipeline args are set.")
Your notebook will freeze until the Apache Beam job finishes...
In [11]:
tf.logging.set_verbosity(tf.logging.ERROR)
print("Running data extraction pipeline...")
run_pipeline(args)
print("Pipeline is done.")
You can list the extracted data files...
In [12]:
#!gsutil ls {DATA_DIR}/*
!ls {DATA_DIR}/*
In this step, we will use TFDV to analyze the data in GCS and compute various statistics from it. This operation requires (multiple) full pass on the data to compute mean, max, min, etc., which needs to run at scale to analyze large dataset.
If we run the analysis on a sample of data, we can use TFDV to compute the statistics locally. However, we can run the TFDV process using Cloud Dataflow for scalability. The generated statistics is stored as a proto buffer to GCS.
In [13]:
job_name = 'tfdv-flights-stats-gen-{}'.format(datetime.utcnow().strftime('%y%m%d-%H%M%S'))
args['job_name'] = job_name
stats_location = os.path.join(STATS_DIR, 'stats.pb')
pipeline_options = beam.options.pipeline_options.GoogleCloudOptions(**args)
print(pipeline_options)
print("Computing statistics...")
_ = tfdv.generate_statistics_from_tfrecord(
data_location=data_location,
output_path=stats_location,
stats_options=tfdv.StatsOptions(
sample_rate=.3
),
pipeline_options = pipeline_options(**args)
)
print("Statistics are computed and saved to: {}".format(stats_location))
You can list saves statistics file...
In [14]:
!ls {stats_location}
#!gsutil ls {stats_location}
In this step, we use TFDV visualization capabilities to explore and analyze the data, using the computed statistics from the previous step, in order to identify data ranges, categorical columns vocabulary, missing values percentages, etc. This step helps to generate the expected schema of the data. TFDV uses Facets capabilities for visualization.
Using the visualization, you can identify the following properties of the features:
In [15]:
stats = tfdv.load_statistics(stats_location)
tfdv.visualize_statistics(stats)
In [16]:
schema = tfdv.infer_schema(statistics=stats)
tfdv.display_schema(schema=schema)
You can manually alter the schema before you save it. For example:
In [22]:
from tensorflow_metadata.proto.v0 import schema_pb2
# Allow no missing values
tfdv.get_feature(schema, 'airline').presence.min_fraction = 1.0
# Only allow 10% of the values to be new
tfdv.get_feature(schema, 'departure_airport').distribution_constraints.min_domain_mass = 0.9
domain = tfdv.utils.schema_util.schema_pb2.FloatDomain(
min=-60, # a flight can departure 1 hour earlier
max=480 # maximum departure delay is 8 hours, otherwise the flight is cancelled.
)
tfdv.set_domain(schema, 'departure_delay', domain)
tfdv.get_feature(schema, 'flight_month').drift_comparator.infinity_norm.threshold = 0.01
In [23]:
from tensorflow.python.lib.io import file_io
from google.protobuf import text_format
tf.gfile.MkDir(dirname=SCHEMA_DIR)
schema_location = os.path.join(SCHEMA_DIR, 'schema.pb')
tfdv.write_schema_text(schema, schema_location)
print("Schema file saved to:{}".format(schema_location))
You can list saved schema file...
In [24]:
!ls {schema_location}
#!gsuitl ls {schema_location}
In this step, we are going to extract new data from BigQuery and store it to GCS is TFRecord files. This will be flights data in 2012, however, we are going to introduce the following alternation in the data schema and content to demonstrate types of anomalies to be detected via TFDV:
In [25]:
def generate_altered_query(date_from=None, date_to=None, limit=None):
query ="""
SELECT * FROM (
SELECT
CAST(date AS DATE) AS flight_date,
FORMAT_DATE('%b', CAST(date AS DATE)) AS flight_month,
EXTRACT(DAY FROM CAST(date AS DATE)) AS flight_day,
FORMAT_DATE('%a', CAST(date AS DATE)) AS flight_day_of_week,
CASE WHEN EXTRACT(DAYOFWEEK FROM CAST(date AS DATE)) IN (1 , 7) THEN 'Yes' ELSE 'No' END AS is_weekend,
CASE WHEN airline = 'MQ' THEN NULL ELSE airline END airline,
departure_airport,
arrival_airport,
CAST(SUBSTR(LPAD(CAST(departure_schedule AS STRING), 4, '0'), 0, 2) AS INT64) AS departure_schedule_hour,
CAST(SUBSTR(LPAD(CAST(departure_schedule AS STRING), 4, '0'), 3, 2) AS INT64) AS departure_schedule_minute,
CASE
WHEN departure_schedule BETWEEN 600 AND 900 THEN '[6:00am - 9:00am]'
WHEN departure_schedule BETWEEN 900 AND 1200 THEN '[9:00am - 12:pm]'
WHEN departure_schedule BETWEEN 1200 AND 1500 THEN '[12:00pm - 3:00pm]'
WHEN departure_schedule BETWEEN 1500 AND 1800 THEN '[3:00pm - 6:00pm]'
WHEN departure_schedule BETWEEN 1800 AND 2100 THEN '[6:00pm - 9:00pm]'
WHEN departure_schedule BETWEEN 2100 AND 2400 THEN '[9:00pm - 12:00am]'
WHEN departure_schedule BETWEEN 0000 AND 300 THEN '[12:00am - 3:00am]'
ELSE '[3:00am - 6:00am]'
END AS departure_time_slot,
departure_delay * 60 AS departure_delay,
arrival_delay
FROM
`bigquery-samples.airline_ontime_data.flights`
WHERE
EXTRACT(MONTH FROM CAST(date AS DATE)) != 2
)
"""
if date_from:
query += "WHERE flight_date >= CAST('{}' as DATE) \n".format(date_from)
if date_to:
query += "AND flight_date < CAST('{}' as DATE) \n".format(date_to)
elif date_to:
query += "WHERE flight_date < CAST('{}' as DATE) \n".format(date_to)
if limit:
query += "LIMIT {}".format(limit)
return query
You can run the following cell to see a sample of the data to be extracted...
In [ ]:
%%bigquery
SELECT
CAST(date AS DATE) AS flight_date,
FORMAT_DATE('%b', CAST(date AS DATE)) AS flight_month,
EXTRACT(DAY FROM CAST(date AS DATE)) AS flight_day,
FORMAT_DATE('%a', CAST(date AS DATE)) AS flight_day_of_week,
CASE WHEN EXTRACT(DAYOFWEEK FROM CAST(date AS DATE)) IN (1 , 7) THEN 'Yes' ELSE 'No' END AS is_weekend,
CASE WHEN airline = 'MQ' THEN NULL ELSE airline END airline,
departure_airport,
arrival_airport,
CAST(SUBSTR(LPAD(CAST(departure_schedule AS STRING), 4, '0'), 0, 2) AS INT64) AS departure_schedule_hour,
CAST(SUBSTR(LPAD(CAST(departure_schedule AS STRING), 4, '0'), 3, 2) AS INT64) AS departure_schedule_minute,
CASE
WHEN departure_schedule BETWEEN 600 AND 900 THEN '[6:00am - 9:00am]'
WHEN departure_schedule BETWEEN 900 AND 1200 THEN '[9:00am - 12:pm]'
WHEN departure_schedule BETWEEN 1200 AND 1500 THEN '[12:00pm - 3:00pm]'
WHEN departure_schedule BETWEEN 1500 AND 1800 THEN '[3:00pm - 6:00pm]'
WHEN departure_schedule BETWEEN 1800 AND 2100 THEN '[6:00pm - 9:00pm]'
WHEN departure_schedule BETWEEN 2100 AND 2400 THEN '[9:00pm - 12:00am]'
WHEN departure_schedule BETWEEN 0000 AND 300 THEN '[12:00am - 3:00am]'
ELSE '[3:00am - 6:00am]'
END AS departure_time_slot,
departure_delay * 60 AS departure_delay,
arrival_delay
FROM
`bigquery-samples.airline_ontime_data.flights`
WHERE
EXTRACT(MONTH FROM CAST(date AS DATE)) != 2
LIMIT 5
In [26]:
runner = 'DirectRunner' if LOCAL else 'DataflowRunner'
job_name = 'tfdv-flights-data-extraction-{}'.format(datetime.utcnow().strftime('%y%m%d-%H%M%S'))
date_from = '2012-01-01'
date_to = '2012-12-31'
data_location = os.path.join(DATA_DIR,
"{}-{}/".format(date_from.replace('-',''), date_to.replace('-','')))
print("Data will be extracted to: {}".format(data_location))
print("Generating altered source query...")
limit = 100000 if LOCAL else None
source_query = generate_query(date_from, date_to, limit)
print("Retrieving data type...")
type_map = get_type_map(source_query)
args = {
'job_name': job_name,
'runner': runner,
'source_query': source_query,
'type_map': type_map,
'sink_data_location': data_location,
'project': PROJECT_ID,
'region': REGION,
'staging_location': STAGING_DIR,
'temp_location': TEMP_DIR,
'save_main_session': True,
'setup_file': './setup.py'
}
print("Pipeline args are set.")
In [27]:
print("Running data extraction pipeline...")
run_pipeline(args)
print("Pipeline is done.")
You can list the extracted data files...
In [28]:
#!gsutil ls {DATA_DIR}/*
!ls {DATA_DIR}/*
In [30]:
job_name = 'tfdv-flights-stats-gen-{}'.format(datetime.utcnow().strftime('%y%m%d-%H%M%S'))
args['job_name'] = job_name
new_stats_location = os.path.join(STATS_DIR, 'new_stats.pb')
pipeline_options = beam.options.pipeline_options.GoogleCloudOptions(**args)
print(pipeline_options)
print("Computing statistics...")
_ = tfdv.generate_statistics_from_tfrecord(
data_location=data_location,
output_path=new_stats_location,
stats_options=tfdv.StatsOptions(
sample_rate=.5
),
pipeline_options = pipeline_options
)
print("Statistics are computed and saved to: {}".format(new_stats_location))
In [31]:
schema = tfdv.utils.schema_util.load_schema_text(schema_location)
In [32]:
stats = tfdv.load_statistics(stats_location)
new_stats = tfdv.load_statistics(new_stats_location)
In [33]:
anomalies = tfdv.validate_statistics(
statistics=new_stats,
schema=schema,
previous_statistics=stats
)
In [34]:
tfdv.display_anomalies(anomalies)
To handling these anomalies depends on the type of each anomaly:
Authors: Khalid Salama and Eric Evn der Knaap
Disclaimer: This is not an official Google product. The sample code provided for an educational purpose.
Copyright 2019 Google LLC
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0.
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
In [ ]: