In [18]:
# change these to try this notebook out
PROJECT = '<YOUR-GCS-BUCKET>'
BUCKET = '<YOUR-GCS-BUCKET>'
In [19]:
import os
os.environ['BUCKET'] = BUCKET
os.environ['PROJECT'] = PROJECT
os.environ['TFVERSION'] = '2.1'
In [131]:
import shutil
import pandas as pd
import tensorflow as tf
from google.cloud import bigquery
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow_hub import KerasLayer
from tensorflow.keras.layers import Dense, Input, Lambda
from tensorflow.keras.models import Model
print(tf.__version__)
%matplotlib inline
For this notebook, we'll build a text classification model using the Hacker News dataset. Each training example consists of an article title and the article source. The model will be trained to classify a given article title as belonging to either nytimes
, github
or techcrunch
.
In [25]:
DATASET_NAME = "titles_full.csv"
COLUMNS = ['title', 'source']
titles_df = pd.read_csv(DATASET_NAME, header=None, names=COLUMNS)
titles_df.head()
Out[25]:
We one-hot encode the label...
In [27]:
CLASSES = {
'github': 0,
'nytimes': 1,
'techcrunch': 2
}
N_CLASSES = len(CLASSES)
In [28]:
def encode_labels(sources):
classes = [CLASSES[source] for source in sources]
one_hots = to_categorical(classes, num_classes=N_CLASSES)
return one_hots
In [29]:
encode_labels(titles_df.source[:4])
Out[29]:
...and create a train/test split.
In [30]:
N_TRAIN = int(len(titles_df) * 0.80)
titles_train, sources_train = (
titles_df.title[:N_TRAIN], titles_df.source[:N_TRAIN])
titles_valid, sources_valid = (
titles_df.title[N_TRAIN:], titles_df.source[N_TRAIN:])
In [31]:
X_train, Y_train = titles_train.values, encode_labels(sources_train)
X_valid, Y_valid = titles_valid.values, encode_labels(sources_valid)
In [32]:
X_train[:3]
Out[32]:
We'll build a simple text classification model using a Tensorflow Hub embedding module derived from Swivel. Swivel is an algorithm that essentially factorizes word co-occurrence matrices to create the words embeddings. TF-Hub hosts the pretrained gnews-swivel-20dim-with-oov 20-dimensional Swivel module.
In [37]:
SWIVEL = "https://tfhub.dev/google/tf2-preview/gnews-swivel-20dim-with-oov/1"
swivel_module = KerasLayer(SWIVEL, output_shape=[20], input_shape=[], dtype=tf.string, trainable=True)
The build_model
function is written so that the TF Hub module can easily be exchanged with another module.
In [46]:
def build_model(hub_module, model_name):
inputs = Input(shape=[], dtype=tf.string, name="text")
module = hub_module(inputs)
h1 = Dense(16, activation='relu', name="h1")(module)
outputs = Dense(N_CLASSES, activation='softmax', name='outputs')(h1)
model = Model(inputs=inputs, outputs=[outputs], name=model_name)
model.compile(
optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy']
)
return model
In [47]:
def train_and_evaluate(train_data, val_data, model, batch_size=5000):
tf.random.set_seed(33)
X_train, Y_train = train_data
history = model.fit(
X_train, Y_train,
epochs=100,
batch_size=batch_size,
validation_data=val_data,
callbacks=[EarlyStopping()],
)
return history
In [51]:
txtcls_model = build_model(swivel_module, model_name='txtcls_swivel')
In [52]:
txtcls_model.summary()
In [43]:
# set up train and validation data
train_data = (X_train, Y_train)
val_data = (X_valid, Y_valid)
For training we'll call train_and_evaluate
on txtcls_model
.
In [45]:
txtcls_history = train_and_evaluate(train_data, val_data, txtcls_model)
In [53]:
history = txtcls_history
pd.DataFrame(history.history)[['loss', 'val_loss']].plot()
pd.DataFrame(history.history)[['accuracy', 'val_accuracy']].plot()
Out[53]:
Calling predicition from model head produces output from final dense layer. This final layer is used to compute categorical cross-entropy when training.
In [54]:
txtcls_model.predict(x=["YouTube introduces Video Chapters to make it easier to navigate longer videos"])
Out[54]:
We can save the model artifacts in the local directory called ./txtcls_swivel
.
In [55]:
tf.saved_model.save(txtcls_model, './txtcls_swivel/')
....and examine the model's serving default signature. As expected the model takes as input a text string (e.g. an article title) and retrns a 3-dimensional vector of floats (i.e. the softmax output layer).
In [57]:
!saved_model_cli show \
--tag_set serve \
--signature_def serving_default \
--dir ./txtcls_swivel/
To simplify the returned predictions, we'll modify the model signature so that the model outputs the predicted article source (either nytimes
, techcrunch
, or github
) rather than the final softmax layer. We'll also return the 'confidence' of the model's prediction. This will be the softmax value corresonding to the predicted article source.
In [59]:
@tf.function(input_signature=[tf.TensorSpec([None], dtype=tf.string)])
def source_name(text):
labels = tf.constant(['github', 'techcrunch', 'nytimes'], dtype=tf.string)
probs = txtcls_model(text, training=False)
indices = tf.argmax(probs, axis=1)
pred_source = tf.gather(params=labels, indices=indices)
pred_confidence = tf.reduce_max(probs, axis=1)
return {'source': pred_source,
'confidence': pred_confidence}
Now, we'll re-save the new Swivel model that has this updated model signature by referencing the source_name
function for the model's serving_default
.
In [60]:
shutil.rmtree('./txtcls_swivel', ignore_errors=True)
txtcls_model.save('./txtcls_swivel', signatures={'serving_default': source_name})
Examine the model signature to confirm the changes:
In [61]:
!saved_model_cli show \
--tag_set serve \
--signature_def serving_default \
--dir ./txtcls_model/
Now when we call predictions using the updated serving input function, the model will return the predicted article source as a readable string, and the model's confidence for that prediction.
In [66]:
title1 = "House Passes Sweeping Policing Bill Targeting Racial Bias and Use of Force"
title2 = "YouTube introduces Video Chapters to make it easier to navigate longer videos"
title3 = "A native Mac app wrapper for WhatsApp Web"
restored = tf.keras.models.load_model('./txtcls_swivel')
infer = restored.signatures['serving_default']
outputs = infer(text=tf.constant([title1, title2, title3]))
In [67]:
print(outputs['source'].numpy())
print(outputs['confidence'].numpy())
Once the model is trained and the assets saved, deploying the model to GCP is straightforward. After some time you should be able to see your deployed model and its version on the model page of GCP console.
In [69]:
%%bash
MODEL_NAME="txtcls"
MODEL_VERSION="swivel"
MODEL_LOCATION="./txtcls_swivel/"
gcloud ai-platform versions create ${MODEL_VERSION} \
--model ${MODEL_NAME} \
--origin ${MODEL_LOCATION} \
--staging-bucket gs://${BUCKET} \
--runtime-version=2.1
Now that the model is deployed, go to Cloud AI Platform to see the model version you've deployed and set up an evaluation job by clicking on the button called "Create Evaluation Job". You will be asked to provide some relevant information:
txtcls_eval.swivel
. If you enter a BigQuery table that doesn’t exist, one with that name will be created with the correct schema. text
.source
.confidence
.Once the evaluation job is set up, the table will be made in BigQuery to capture the online prediction requests.
In [70]:
%load_ext google.cloud.bigquery
In [99]:
%%bigquery --project $PROJECT
SELECT * FROM `txtcls_eval.swivel`
Out[99]:
Now, every time this model version receives an online prediction request, this information will be captured and stored in the BQ table. Note, this happens everytime because we set the sampling proportion to 100%.
Here are some article titles and their groundtruth sources that we can test with prediciton.
title | groundtruth |
---|---|
YouTube introduces Video Chapters to make it easier to navigate longer videos | techcrunch |
A Filmmaker Put Away for Tax Fraud Takes Us Inside a British Prison | nytimes |
A native Mac app wrapper for WhatsApp Web | github |
Astronauts Dock With Space Station After Historic SpaceX Launch | nytimes |
House Passes Sweeping Policing Bill Targeting Racial Bias and Use of Force | nytimes |
Scrollability | github |
iOS 14 lets deaf users set alerts for important sounds, among other clever accessibility perks | techcrunch |
In [100]:
%%writefile input.json
{"text": "YouTube introduces Video Chapters to make it easier to navigate longer videos"}
In [101]:
!gcloud ai-platform predict \
--model txtcls \
--json-instances input.json \
--version swivel
In [102]:
%%writefile input.json
{"text": "A Filmmaker Put Away for Tax Fraud Takes Us Inside a British Prison"}
In [103]:
!gcloud ai-platform predict \
--model txtcls \
--json-instances input.json \
--version swivel
In [104]:
%%writefile input.json
{"text": "A native Mac app wrapper for WhatsApp Web"}
In [105]:
!gcloud ai-platform predict \
--model txtcls \
--json-instances input.json \
--version swivel
In [106]:
%%writefile input.json
{"text": "Astronauts Dock With Space Station After Historic SpaceX Launch"}
In [107]:
!gcloud ai-platform predict \
--model txtcls \
--json-instances input.json \
--version swivel
In [108]:
%%writefile input.json
{"text": "House Passes Sweeping Policing Bill Targeting Racial Bias and Use of Force"}
In [109]:
!gcloud ai-platform predict \
--model txtcls \
--json-instances input.json \
--version swivel
In [110]:
%%writefile input.json
{"text": "Scrollability"}
In [111]:
!gcloud ai-platform predict \
--model txtcls \
--json-instances input.json \
--version swivel
In [112]:
%%writefile input.json
{"text": "iOS 14 lets deaf users set alerts for important sounds, among other clever accessibility perks"}
In [113]:
!gcloud ai-platform predict \
--model txtcls \
--json-instances input.json \
--version swivel
Summarizing the results from our model:
title | groundtruth | predicted |
---|---|---|
YouTube introduces Video Chapters to make it easier to navigate longer videos | techcrunch | techcrunch |
A Filmmaker Put Away for Tax Fraud Takes Us Inside a British Prison | nytimes | techcrunch |
A native Mac app wrapper for WhatsApp Web | github | techcrunch |
Astronauts Dock With Space Station After Historic SpaceX Launch | nytimes | techcrunch |
House Passes Sweeping Policing Bill Targeting Racial Bias and Use of Force | nytimes | nytimes |
Scrollability | github | techcrunch |
iOS 14 lets deaf users set alerts for important sounds, among other clever accessibility perks | techcrunch | nytimes |
In [115]:
%%bigquery --project $PROJECT
SELECT * FROM `txtcls_eval.swivel`
Out[115]:
In [117]:
%%bigquery --project $PROJECT
UPDATE `txtcls_eval.swivel`
SET
groundtruth = '{"predictions": [{"source": "techcrunch"}]}'
WHERE
raw_data = '{"instances": [{"text": "YouTube introduces Video Chapters to make it easier to navigate longer videos"}]}';
Out[117]:
In [118]:
%%bigquery --project $PROJECT
UPDATE `txtcls_eval.swivel`
SET
groundtruth = '{"predictions": [{"source": "nytimes"}]}'
WHERE
raw_data = '{"instances": [{"text": "A Filmmaker Put Away for Tax Fraud Takes Us Inside a British Prison"}]}';
Out[118]:
In [125]:
%%bigquery --project $PROJECT
UPDATE `txtcls_eval.swivel`
SET
groundtruth = '{"predictions": [{"source": "github"}]}'
WHERE
raw_data = '{"instances": [{"text": "A native Mac app wrapper for WhatsApp Web"}]}';
Out[125]:
In [119]:
%%bigquery --project $PROJECT
UPDATE `txtcls_eval.swivel`
SET
groundtruth = '{"predictions": [{"source": "nytimes"}]}'
WHERE
raw_data = '{"instances": [{"text": "Astronauts Dock With Space Station After Historic SpaceX Launch"}]}';
Out[119]:
In [120]:
%%bigquery --project $PROJECT
UPDATE `txtcls_eval.swivel`
SET
groundtruth = '{"predictions": [{"source": "nytimes"}]}'
WHERE
raw_data = '{"instances": [{"text": "House Passes Sweeping Policing Bill Targeting Racial Bias and Use of Force"}]}';
Out[120]:
In [121]:
%%bigquery --project $PROJECT
UPDATE `txtcls_eval.swivel`
SET
groundtruth = '{"predictions": [{"source": "github"}]}'
WHERE
raw_data = '{"instances": [{"text": "Scrollability"}]}';
Out[121]:
In [122]:
%%bigquery --project $PROJECT
UPDATE `txtcls_eval.swivel`
SET
groundtruth = '{"predictions": [{"source": "techcrunch"}]}'
WHERE
raw_data = '{"instances": [{"text": "iOS 14 lets deaf users set alerts for important sounds, among other clever accessibility perks"}]}';
Out[122]:
We can confirm that the ground truch has been properly added to the table.
In [126]:
%%bigquery --project $PROJECT
SELECT * FROM `txtcls_eval.swivel`
Out[126]:
In [145]:
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
from sklearn.metrics import precision_recall_fscore_support as score
from sklearn.metrics import classification_report
Using regex we can extract the model predictions, to have an easier to read format:
In [128]:
%%bigquery --project $PROJECT
SELECT
model,
model_version,
time,
REGEXP_EXTRACT(raw_data, r'.*"text": "(.*)"') AS text,
REGEXP_EXTRACT(raw_prediction, r'.*"source": "(.*?)"') AS prediction,
REGEXP_EXTRACT(raw_prediction, r'.*"confidence": (0.\d{2}).*') AS confidence,
REGEXP_EXTRACT(groundtruth, r'.*"source": "(.*?)"') AS groundtruth,
FROM
`txtcls_eval.swivel`
Out[128]:
In [132]:
query = '''
SELECT
model,
model_version,
time,
REGEXP_EXTRACT(raw_data, r'.*"text": "(.*)"') AS text,
REGEXP_EXTRACT(raw_prediction, r'.*"source": "(.*?)"') AS prediction,
REGEXP_EXTRACT(raw_prediction, r'.*"confidence": (0.\d{2}).*') AS confidence,
REGEXP_EXTRACT(groundtruth, r'.*"source": "(.*?)"') AS groundtruth,
FROM
`txtcls_eval.swivel`
'''
client = bigquery.Client()
df_results = client.query(query).to_dataframe()
In [133]:
df_results.head(20)
Out[133]:
In [134]:
prediction = list(df_results.prediction)
groundtruth = list(df_results.groundtruth)
In [135]:
precision, recall, fscore, support = score(groundtruth, prediction)
In [140]:
from tabulate import tabulate
sources = list(CLASSES.keys())
results = list(zip(sources, precision, recall, fscore, support))
print(tabulate(results, headers = ['source', 'precision', 'recall', 'fscore', 'support'],
tablefmt='orgtbl'))
Or a full classification report from the sklearn library:
In [142]:
print(classification_report(y_true=groundtruth, y_pred=prediction))
Can also examine a confusion matrix:
In [144]:
cm = confusion_matrix(groundtruth, prediction, labels=sources)
ax= plt.subplot()
sns.heatmap(cm, annot=True, ax = ax, cmap="Blues")
# labels, title and ticks
ax.set_xlabel('Predicted labels')
ax.set_ylabel('True labels')
ax.set_title('Confusion Matrix')
ax.xaxis.set_ticklabels(sources)
ax.yaxis.set_ticklabels(sources)
plt.savefig("./txtcls_cm.png")
By specifying the same evaluation table, two different model versions can be evaluated. Also, since the timestamp is captured, it is straightforward to evaluation model performance over time.
In [152]:
now = pd.Timestamp.now(tz='UTC')
one_week_ago = now - pd.DateOffset(weeks=1)
one_month_ago = now - pd.DateOffset(months=1)
In [156]:
df_prev_week = df_results[df_results.time > one_week_ago]
df_prev_month = df_results[df_results.time > one_month_ago]
In [157]:
df_prev_month
Out[157]:
Copyright 2020 Google Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License