In [ ]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
The TensorFlow Models NLP library is a collection of tools for building and training modern high performance natural language models.
The TransformEncoder is the core of this library, and lots of new network architectures are proposed to improve the encoder. In this Colab notebook, we will learn how to customize the encoder to employ new network architectures.
In [ ]:
!pip install -q tf-nightly
!pip install -q tf-models-nightly
In [ ]:
import numpy as np
import tensorflow as tf
from official.modeling import activations
from official.nlp import modeling
from official.nlp.modeling import layers, losses, models, networks
In [ ]:
cfg = {
"vocab_size": 100,
"hidden_size": 32,
"num_layers": 3,
"num_attention_heads": 4,
"intermediate_size": 64,
"activation": activations.gelu,
"dropout_rate": 0.1,
"attention_dropout_rate": 0.1,
"sequence_length": 16,
"type_vocab_size": 2,
"initializer": tf.keras.initializers.TruncatedNormal(stddev=0.02),
}
bert_encoder = modeling.networks.TransformerEncoder(**cfg)
def build_classifier(bert_encoder):
return modeling.models.BertClassifier(bert_encoder, num_classes=2)
canonical_classifier_model = build_classifier(bert_encoder)
canonical_classifier_model
can be trained using the training data. For details about how to train the model, please see the colab fine_tuning_bert.ipynb. We skip the code that trains the model here.
After training, we can apply the model to do prediction.
In [ ]:
def predict(model):
batch_size = 3
np.random.seed(0)
word_ids = np.random.randint(
cfg["vocab_size"], size=(batch_size, cfg["sequence_length"]))
mask = np.random.randint(2, size=(batch_size, cfg["sequence_length"]))
type_ids = np.random.randint(
cfg["type_vocab_size"], size=(batch_size, cfg["sequence_length"]))
print(model([word_ids, mask, type_ids], training=False))
predict(canonical_classifier_model)
We provide easy ways to customize each of those components via (1) EncoderScaffold and (2) TransformerScaffold.
In [ ]:
default_hidden_cfg = dict(
num_attention_heads=cfg["num_attention_heads"],
intermediate_size=cfg["intermediate_size"],
intermediate_activation=activations.gelu,
dropout_rate=cfg["dropout_rate"],
attention_dropout_rate=cfg["attention_dropout_rate"],
kernel_initializer=tf.keras.initializers.TruncatedNormal(0.02),
)
default_embedding_cfg = dict(
vocab_size=cfg["vocab_size"],
type_vocab_size=cfg["type_vocab_size"],
hidden_size=cfg["hidden_size"],
seq_length=cfg["sequence_length"],
initializer=tf.keras.initializers.TruncatedNormal(0.02),
dropout_rate=cfg["dropout_rate"],
max_seq_length=cfg["sequence_length"],
)
default_kwargs = dict(
hidden_cfg=default_hidden_cfg,
embedding_cfg=default_embedding_cfg,
num_hidden_instances=cfg["num_layers"],
pooled_output_dim=cfg["hidden_size"],
return_all_layer_outputs=True,
pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(0.02),
)
encoder_scaffold = modeling.networks.EncoderScaffold(**default_kwargs)
classifier_model_from_encoder_scaffold = build_classifier(encoder_scaffold)
classifier_model_from_encoder_scaffold.set_weights(
canonical_classifier_model.get_weights())
predict(classifier_model_from_encoder_scaffold)
In [ ]:
word_ids = tf.keras.layers.Input(
shape=(cfg['sequence_length'],), dtype=tf.int32, name="input_word_ids")
mask = tf.keras.layers.Input(
shape=(cfg['sequence_length'],), dtype=tf.int32, name="input_mask")
embedding_layer = modeling.layers.OnDeviceEmbedding(
vocab_size=cfg['vocab_size'],
embedding_width=cfg['hidden_size'],
initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
name="word_embeddings")
word_embeddings = embedding_layer(word_ids)
attention_mask = layers.SelfAttentionMask()([word_embeddings, mask])
new_embedding_network = tf.keras.Model([word_ids, mask],
[word_embeddings, attention_mask])
Inspecting new_embedding_network
, we can see it takes two inputs:
input_word_ids
and input_mask
.
In [ ]:
tf.keras.utils.plot_model(new_embedding_network, show_shapes=True, dpi=48)
We then can build a new encoder using the above new_embedding_network
.
In [ ]:
kwargs = dict(default_kwargs)
# Use new embedding network.
kwargs['embedding_cls'] = new_embedding_network
kwargs['embedding_data'] = embedding_layer.embeddings
encoder_with_customized_embedding = modeling.networks.EncoderScaffold(**kwargs)
classifier_model = build_classifier(encoder_with_customized_embedding)
# ... Train the model ...
print(classifier_model.inputs)
# Assert that there are only two inputs.
assert len(classifier_model.inputs) == 2
User can also override the hidden_cls argument in EncoderScaffold
's constructor to employ a customized Transformer layer.
See ReZeroTransformer for how to implement a customized Transformer layer.
Following is an example of using ReZeroTransformer
:
In [ ]:
kwargs = dict(default_kwargs)
# Use ReZeroTransformer.
kwargs['hidden_cls'] = modeling.layers.ReZeroTransformer
encoder_with_rezero_transformer = modeling.networks.EncoderScaffold(**kwargs)
classifier_model = build_classifier(encoder_with_rezero_transformer)
# ... Train the model ...
predict(classifier_model)
# Assert that the variable `rezero_alpha` from ReZeroTransformer exists.
assert 'rezero_alpha' in ''.join([x.name for x in classifier_model.trainable_weights])
The above method of customizing Transformer
requires rewriting the whole Transformer
layer, while sometimes you may only want to customize either attention layer or feedforward block. In this case, TransformerScaffold can be used.
User can also override the attention_cls argument in TransformerScaffold
's constructor to employ a customized Attention layer.
See TalkingHeadsAttention for how to implement a customized Attention
layer.
Following is an example of using TalkingHeadsAttention:
In [ ]:
# Use TalkingHeadsAttention
hidden_cfg = dict(default_hidden_cfg)
hidden_cfg['attention_cls'] = modeling.layers.TalkingHeadsAttention
kwargs = dict(default_kwargs)
kwargs['hidden_cls'] = modeling.layers.TransformerScaffold
kwargs['hidden_cfg'] = hidden_cfg
encoder = modeling.networks.EncoderScaffold(**kwargs)
classifier_model = build_classifier(encoder)
# ... Train the model ...
predict(classifier_model)
# Assert that the variable `pre_softmax_weight` from TalkingHeadsAttention exists.
assert 'pre_softmax_weight' in ''.join([x.name for x in classifier_model.trainable_weights])
Similiarly, one could also customize the feedforward layer.
See GatedFeedforward for how to implement a customized feedforward layer.
Following is an example of using GatedFeedforward.
In [ ]:
# Use TalkingHeadsAttention
hidden_cfg = dict(default_hidden_cfg)
hidden_cfg['feedforward_cls'] = modeling.layers.GatedFeedforward
kwargs = dict(default_kwargs)
kwargs['hidden_cls'] = modeling.layers.TransformerScaffold
kwargs['hidden_cfg'] = hidden_cfg
encoder_with_gated_feedforward = modeling.networks.EncoderScaffold(**kwargs)
classifier_model = build_classifier(encoder_with_gated_feedforward)
# ... Train the model ...
predict(classifier_model)
# Assert that the variable `gate` from GatedFeedforward exists.
assert 'gate' in ''.join([x.name for x in classifier_model.trainable_weights])
Finally, you could also build a new encoder using building blocks in the modeling library.
See AlbertTransformerEncoder as an example:
In [ ]:
albert_encoder = modeling.networks.AlbertTransformerEncoder(**cfg)
classifier_model = build_classifier(albert_encoder)
# ... Train the model ...
predict(classifier_model)
Inspecting the albert_encoder
, we see it stacks the same Transformer
layer multiple times.
In [ ]:
tf.keras.utils.plot_model(albert_encoder, show_shapes=True, dpi=48)