Multilabel BERT Experiments

In this notebook we do some first experiments with BERT: we finetune a BERT model+classifier on each of our datasets separately and compute the accuracy of the resulting classifier on the test data.

For these experiments we use the pytorch_transformers package. It contains a variety of neural network architectures for transfer learning and pretrained models, including BERT and XLNET.

Two different BERT models are relevant for our experiments:

  • BERT-base-uncased: a relatively small BERT model that should already give reasonable results,
  • BERT-large-uncased: a larger model for real state-of-the-art results.

In [1]:
from multilabel import EATINGMEAT_BECAUSE_MAP, EATINGMEAT_BUT_MAP, JUNKFOOD_BECAUSE_MAP, JUNKFOOD_BUT_MAP

LABEL_MAP = JUNKFOOD_BUT_MAP
BERT_MODEL = 'bert-base-uncased'
BATCH_SIZE = 16 if "base" in BERT_MODEL else 2
GRADIENT_ACCUMULATION_STEPS = 1 if "base" in BERT_MODEL else 8
MAX_SEQ_LENGTH = 100
PREFIX = "junkfood_but"

Data

We use the same data as for all our previous experiments. Here we load the training, development and test data for a particular prompt.


In [2]:
import ndjson
import glob
from collections import Counter

train_file = f"../data/interim/{PREFIX}_train_withprompt.ndjson"
synth_files = glob.glob(f"../data/interim/{PREFIX}_train_withprompt_allsynth.ndjson")
dev_file = f"../data/interim/{PREFIX}_dev_withprompt.ndjson"
test_file = f"../data/interim/{PREFIX}_test_withprompt.ndjson"

with open(train_file) as i:
    train_data = ndjson.load(i)

synth_data = []
for f in synth_files:
    with open(f) as i:
        synth_data += ndjson.load(i)
    
with open(dev_file) as i:
    dev_data = ndjson.load(i)
    
with open(test_file) as i:
    test_data = ndjson.load(i)
    
labels = Counter([item["label"] for item in train_data])
print(labels)
print(len(synth_data))


Counter({'Schools providing healthy alternatives': 137, 'Students without choice': 46, 'Unclassified Off-Topic': 32, 'School without generating money': 26, 'Student choice': 24, 'Schools generate money': 16, 'Students can still bring/access junk food': 3})
2556

Next, we build the label vocabulary, which maps every label in the training data to an index.


In [3]:
def map_to_multilabel(items):
    return [{"text": item["text"], "label": LABEL_MAP[item["label"]]} for item in items]

train_data = map_to_multilabel(train_data)
dev_data = map_to_multilabel(dev_data)
synth_data = map_to_multilabel(synth_data)
test_data = map_to_multilabel(test_data)

In [4]:
import sys
sys.path.append('../')

from quillnlp.models.bert.preprocessing import preprocess, create_label_vocabulary

label2idx = create_label_vocabulary(train_data)
idx2label = {v:k for k,v in label2idx.items()}
target_names = [idx2label[s] for s in range(len(idx2label))]

MAX_SEQ_LENGTH = 100
train_dataloader = preprocess(train_data, BERT_MODEL, label2idx, MAX_SEQ_LENGTH, BATCH_SIZE)
dev_dataloader = preprocess(dev_data, BERT_MODEL, label2idx, MAX_SEQ_LENGTH, BATCH_SIZE)
test_dataloader = preprocess(test_data, BERT_MODEL, label2idx, MAX_SEQ_LENGTH, BATCH_SIZE, shuffle=False)

Model

We load the pretrained model and put it on a GPU if one is available. We also put the model in "training" mode, so that we can correctly update its internal parameters on the basis of our data sets.


In [ ]:
import sys
sys.path.append('../')

import torch
from quillnlp.models.bert.models import get_multilabel_bert_classifier

BERT_MODEL = 'bert-base-uncased'

device = "cuda" if torch.cuda.is_available() else "cpu"
model = get_multilabel_bert_classifier(BERT_MODEL, len(label2idx), device=device)

Training


In [ ]:
from quillnlp.models.bert.train import train

batch_size = 16 if "base" in BERT_MODEL else 2
gradient_accumulation_steps = 1 if "base" in BERT_MODEL else 8
output_model_file = train(model, train_dataloader, dev_dataloader, batch_size, gradient_accumulation_steps, device)


Epoch:   0%|          | 0/20 [00:00<?, ?it/s]

Loss history: []
Dev loss: 0.46770793199539185
Epoch:   5%|▌         | 1/20 [00:04<01:30,  4.76s/it]

Loss history: [0.46770793199539185]
Dev loss: 0.3351816892623901
Epoch:  10%|█         | 2/20 [00:09<01:24,  4.69s/it]

Loss history: [0.46770793199539185, 0.3351816892623901]
Dev loss: 0.2613398343324661
Epoch:  15%|█▌        | 3/20 [00:13<01:18,  4.65s/it]

Loss history: [0.46770793199539185, 0.3351816892623901, 0.2613398343324661]
Dev loss: 0.2466729998588562
Epoch:  20%|██        | 4/20 [00:18<01:13,  4.62s/it]

Loss history: [0.46770793199539185, 0.3351816892623901, 0.2613398343324661, 0.2466729998588562]
Dev loss: 0.21184660196304322
Epoch:  25%|██▌       | 5/20 [00:22<01:08,  4.59s/it]

Loss history: [0.46770793199539185, 0.3351816892623901, 0.2613398343324661, 0.2466729998588562, 0.21184660196304322]
Dev loss: 0.20426980555057525
Epoch:  30%|███       | 6/20 [00:27<01:04,  4.58s/it]

Loss history: [0.46770793199539185, 0.3351816892623901, 0.2613398343324661, 0.2466729998588562, 0.21184660196304322, 0.20426980555057525]
Dev loss: 0.1862332671880722
Epoch:  35%|███▌      | 7/20 [00:32<00:59,  4.57s/it]

Evaluation


In [ ]:
from quillnlp.models.bert.train import evaluate
from sklearn.metrics import precision_recall_fscore_support, classification_report

print("Loading model from", output_model_file)
device="cpu"

model = get_multilabel_bert_classifier(BERT_MODEL, len(label2idx), model_file=output_model_file, device=device)
model.eval()

_, test_correct, test_predicted = evaluate(model, test_dataloader, device)

print("Test performance:", precision_recall_fscore_support(test_correct, test_predicted, average="micro"))
print(classification_report(test_correct, test_predicted, target_names=target_names))

In [ ]:
all_correct = 0
fp, fn, tp, tn = 0, 0, 0, 0
for c, p in zip(test_correct, test_predicted):
    if sum(c == p) == len(c):
        all_correct +=1
    for ci, pi in zip(c, p):
        if pi == 1 and ci == 1:
            tp += 1
            same = 1
        elif pi == 1 and ci == 0:
            fp += 1
        elif pi == 0 and ci == 1:
            fn += 1
        else:
            tn += 1
            same =1
            
precision = tp/(tp+fp)
recall = tp/(tp+fn)
print("P:", precision)
print("R:", recall)
print("A:", all_correct/len(test_correct))

In [ ]:
for item, predicted, correct in zip(test_data, test_predicted, test_correct):
    correct_labels = [idx2label[i] for i, l in enumerate(correct) if l == 1]
    predicted_labels = [idx2label[i] for i, l in enumerate(predicted) if l == 1]
    print("{}#{}#{}".format(item["text"], ";".join(correct_labels), ";".join(predicted_labels)))

In [ ]: