This tutorial is based off of the Gluon NLP one here https://gluon-nlp.mxnet.io/examples/sentence_embedding/bert.html
Pre-trained language representations have been shown to improve many downstream NLP tasks such as question answering, and natural language inference. To apply pre-trained representations to these tasks, there are two strategies:
While feature-based approaches such as ELMo [1] are effective in improving many downstream tasks, they require task-specific architectures. Devlin, Jacob, et al proposed BERT [2] (Bidirectional Encoder Representations from Transformers), which fine-tunes deep bidirectional representations on a wide range of tasks with minimal task-specific parameters, and obtained state- of-the-art results.
In this tutorial, we will focus on fine-tuning with the pre-trained BERT model to classify semantically equivalent sentence pairs. Specifically, we will:
To run this tutorial locally, in the example directory:
get_bert_data.sh
. lein jupyter install-kernel
. After that you can open the notebook in the project directory with lein jupyter notebook
.
In [1]:
(ns bert.bert-sentence-classification
(:require [bert.util :as bert-util]
[clojure-csv.core :as csv]
[clojure.java.shell :refer [sh]]
[clojure.string :as string]
[org.apache.clojure-mxnet.callback :as callback]
[org.apache.clojure-mxnet.context :as context]
[org.apache.clojure-mxnet.dtype :as dtype]
[org.apache.clojure-mxnet.infer :as infer]
[org.apache.clojure-mxnet.eval-metric :as eval-metric]
[org.apache.clojure-mxnet.io :as mx-io]
[org.apache.clojure-mxnet.layout :as layout]
[org.apache.clojure-mxnet.module :as m]
[org.apache.clojure-mxnet.ndarray :as ndarray]
[org.apache.clojure-mxnet.optimizer :as optimizer]
[org.apache.clojure-mxnet.symbol :as sym]))
In this tutorial we will use the pre-trained BERT model that was exported from GluonNLP via the scripts/bert/staticbert/static_export_base.py
. For convenience, the model has been downloaded for you by running the get_bert_data.sh
file in the root directory of this example.
Let’s first take a look at the BERT model architecture for sentence pair classification below:
where the model takes a pair of sequences and pools the representation of the first token in the sequence. Note that the original BERT model was trained for masked language model and next sentence prediction tasks, which includes layers for language model decoding and classification. These layers will not be used for fine-tuning sentence pair classification.
Let's load the pre-trained BERT using the module API in MXNet.
In [2]:
(def model-path-prefix "data/static_bert_base_net")
;; the vocabulary used in the model
(def vocab (bert-util/get-vocab))
;; the maximum length of the sequence
(def seq-length 128)
(def batch-size 32)
(def bert-base (m/load-checkpoint {:prefix model-path-prefix :epoch 0}))
Out[2]:
In [3]:
(defn fine-tune-model
"msymbol: the pretrained network symbol
num-classes: the number of classes for the fine-tune datasets
dropout: the dropout rate"
[msymbol {:keys [num-classes dropout]}]
(as-> msymbol data
(sym/dropout {:data data :p dropout})
(sym/fully-connected "fc-finetune" {:data data :num-hidden num-classes})
(sym/softmax-output "softmax" {:data data})))
(def model-sym (fine-tune-model (m/symbol bert-base) {:num-classes 2 :dropout 0.1}))
Out[3]:
In [4]:
(-> (sh "head" "-n" "5" "data/dev.tsv")
:out
println)
The file contains 5 columns, separated by tabs (i.e. ‘
\t ‘). The first line of the file explains each of these columns: 0. the label indicating whether the two sentences are semantically equivalent 1. the id of the first sentence in this sample 2. the id of the second sentence in this sample 3. the content of the first sentence 4. the content of the second sentence
For our task, we are interested in the 0th, 3rd and 4th columns.
In [5]:
(def raw-file
(csv/parse-csv (string/replace (slurp "data/dev.tsv") "\"" "")
:delimiter \tab
:strict true))
(def data-train-raw (->> raw-file
(mapv #(vals (select-keys % [3 4 0])))
(rest) ; drop header
(into [])))
(def sample (first data-train-raw))
(println (nth sample 0)) ;;;sentence a
(println (nth sample 1)) ;; sentence b
(println (nth sample 2)) ;; 1 means equivalent, 0 means not equivalent
To use the pre-trained BERT model, we need to preprocess the data in the same way it was trained. The following figure shows the input representation in BERT:
We will do pre-processing on the inputs to get them in the right format and to perform the following transformations:
In [6]:
(defn pre-processing
"Preprocesses the sentences in the format that BERT is expecting"
[idx->token token->idx train-item]
(let [[sentence-a sentence-b label] train-item
;;; pre-processing tokenize sentence
token-1 (bert-util/tokenize (string/lower-case sentence-a))
token-2 (bert-util/tokenize (string/lower-case sentence-b))
valid-length (+ (count token-1) (count token-2))
;;; generate token types [0000...1111...0000]
qa-embedded (into (bert-util/pad [] 0 (count token-1))
(bert-util/pad [] 1 (count token-2)))
token-types (bert-util/pad qa-embedded 0 seq-length)
;;; make BERT pre-processing standard
token-2 (conj token-2 "[SEP]")
token-1 (into [] (concat ["[CLS]"] token-1 ["[SEP]"] token-2))
tokens (bert-util/pad token-1 "[PAD]" seq-length)
;;; pre-processing - token to index translation
indexes (bert-util/tokens->idxs token->idx tokens)]
{:input-batch [indexes
token-types
[valid-length]]
:label (if (= "0" label)
[0]
[1])
:tokens tokens
:train-item train-item}))
(def idx->token (:idx->token vocab))
(def token->idx (:token->idx vocab))
(def dev (context/default-context))
(def processed-datas (mapv #(pre-processing idx->token token->idx %) data-train-raw))
(def train-count (count processed-datas))
(println "Train Count is = " train-count)
(println "[PAD] token id = " (get token->idx "[PAD]"))
(println "[CLS] token id = " (get token->idx "[CLS]"))
(println "[SEP] token id = " (get token->idx "[SEP]"))
(println "token ids = \n"(-> (first processed-datas) :input-batch first))
(println "segment ids = \n"(-> (first processed-datas) :input-batch second))
(println "valid length = \n" (-> (first processed-datas) :input-batch last))
(println "label = \n" (-> (second processed-datas) :label))
Now that we have all the input-batches for each row, we are going to slice them up column-wise and create NDArray Iterators that we can use in training
In [7]:
(defn slice-inputs-data
"Each sentence pair had to be processed as a row. This breaks all
the rows up into a column for creating a NDArray"
[processed-datas n]
(->> processed-datas
(mapv #(nth (:input-batch %) n))
(flatten)
(into [])))
(def prepared-data {:data0s (slice-inputs-data processed-datas 0)
:data1s (slice-inputs-data processed-datas 1)
:data2s (slice-inputs-data processed-datas 2)
:labels (->> (mapv :label processed-datas)
(flatten)
(into []))
:train-num (count processed-datas)})
(def train-data
(let [{:keys [data0s data1s data2s labels train-num]} prepared-data
data-desc0 (mx-io/data-desc {:name "data0"
:shape [train-num seq-length]
:dtype dtype/FLOAT32
:layout layout/NT})
data-desc1 (mx-io/data-desc {:name "data1"
:shape [train-num seq-length]
:dtype dtype/FLOAT32
:layout layout/NT})
data-desc2 (mx-io/data-desc {:name "data2"
:shape [train-num]
:dtype dtype/FLOAT32
:layout layout/N})
label-desc (mx-io/data-desc {:name "softmax_label"
:shape [train-num]
:dtype dtype/FLOAT32
:layout layout/N})]
(mx-io/ndarray-iter {data-desc0 (ndarray/array data0s [train-num seq-length]
{:ctx dev})
data-desc1 (ndarray/array data1s [train-num seq-length]
{:ctx dev})
data-desc2 (ndarray/array data2s [train-num]
{:ctx dev})}
{:label {label-desc (ndarray/array labels [train-num]
{:ctx dev})}
:data-batch-size batch-size})))
train-data
Out[7]:
In [8]:
(def num-epoch 3)
(def fine-tune-model (m/module model-sym {:contexts [dev]
:data-names ["data0" "data1" "data2"]}))
(m/fit fine-tune-model {:train-data train-data :num-epoch num-epoch
:fit-params (m/fit-params {:allow-missing true
:arg-params (m/arg-params bert-base)
:aux-params (m/aux-params bert-base)
:optimizer (optimizer/adam {:learning-rate 5e-6 :episilon 1e-9})
:batch-end-callback (callback/speedometer batch-size 1)})})
Out[8]:
Now that our model is fitted, we can use it to infer semantic equivalence of arbitrary sentence pairs. Note that for demonstration purpose we skipped the warmup learning rate schedule and validation on dev dataset used in the original implementation. This means that our model's performance will be significantly less than optimal. Please visit here for the complete fine-tuning scripts (using Python and GluonNLP).
To do inference with our model we need a predictor. It must have a batch size of 1 so we can feed the model a single sentence pair.
In [14]:
(def fine-tuned-prefix "fine-tune-sentence-bert")
(m/save-checkpoint fine-tune-model {:prefix fine-tuned-prefix :epoch 3})
(def fine-tuned-predictor
(infer/create-predictor (infer/model-factory fine-tuned-prefix
[{:name "data0" :shape [1 seq-length] :dtype dtype/FLOAT32 :layout layout/NT}
{:name "data1" :shape [1 seq-length] :dtype dtype/FLOAT32 :layout layout/NT}
{:name "data2" :shape [1] :dtype dtype/FLOAT32 :layout layout/N}])
{:epoch 3}))
Out[14]:
Now we can write a function that feeds a sentence pair to the fine-tuned model:
In [15]:
(defn predict-equivalence
[predictor sentence1 sentence2]
(let [vocab (bert.util/get-vocab)
processed-test-data (mapv #(pre-processing (:idx->token vocab)
(:token->idx vocab) %)
[[sentence1 sentence2]])
prediction (infer/predict-with-ndarray predictor
[(ndarray/array (slice-inputs-data processed-test-data 0) [1 seq-length])
(ndarray/array (slice-inputs-data processed-test-data 1) [1 seq-length])
(ndarray/array (slice-inputs-data processed-test-data 2) [1])])]
(ndarray/->vec (first prediction))))
Out[15]:
In [22]:
;; Modify an existing sentence pair to test:
;; ["1"
;; "69773"
;; "69792"
;; "Cisco pared spending to compensate for sluggish sales ."
;; "In response to sluggish sales , Cisco pared spending ."]
(predict-equivalence fine-tuned-predictor
"The company cut spending to compensate for weak sales ."
"In response to poor sales results, the company cut spending .")
Out[22]: