For these experiments we use the pytorch_transformers package. It contains a variety of neural network architectures for transfer learning and pretrained models, including BERT and XLNET.
Two different BERT models are relevant for our experiments:
In [1]:
from multilabel import EATINGMEAT_BECAUSE_MAP, EATINGMEAT_BUT_MAP, JUNKFOOD_BECAUSE_MAP, JUNKFOOD_BUT_MAP
LABEL_MAP = JUNKFOOD_BUT_MAP
BERT_MODEL = 'bert-base-uncased'
BATCH_SIZE = 16 if "base" in BERT_MODEL else 2
GRADIENT_ACCUMULATION_STEPS = 1 if "base" in BERT_MODEL else 8
MAX_SEQ_LENGTH = 100
PREFIX = "junkfood_but"
In [2]:
import ndjson
import glob
from collections import Counter
train_file = f"../data/interim/{PREFIX}_train_withprompt.ndjson"
synth_files = glob.glob(f"../data/interim/{PREFIX}_train_withprompt_allsynth.ndjson")
dev_file = f"../data/interim/{PREFIX}_dev_withprompt.ndjson"
test_file = f"../data/interim/{PREFIX}_test_withprompt.ndjson"
with open(train_file) as i:
train_data = ndjson.load(i)
synth_data = []
for f in synth_files:
with open(f) as i:
synth_data += ndjson.load(i)
with open(dev_file) as i:
dev_data = ndjson.load(i)
with open(test_file) as i:
test_data = ndjson.load(i)
labels = Counter([item["label"] for item in train_data])
print(labels)
print(len(synth_data))
Next, we build the label vocabulary, which maps every label in the training data to an index.
In [3]:
def map_to_multilabel(items):
return [{"text": item["text"], "label": LABEL_MAP[item["label"]]} for item in items]
train_data = map_to_multilabel(train_data)
dev_data = map_to_multilabel(dev_data)
synth_data = map_to_multilabel(synth_data)
test_data = map_to_multilabel(test_data)
In [4]:
import sys
sys.path.append('../')
from quillnlp.models.bert.preprocessing import preprocess, create_label_vocabulary
label2idx = create_label_vocabulary(train_data)
idx2label = {v:k for k,v in label2idx.items()}
target_names = [idx2label[s] for s in range(len(idx2label))]
MAX_SEQ_LENGTH = 100
train_dataloader = preprocess(train_data, BERT_MODEL, label2idx, MAX_SEQ_LENGTH, BATCH_SIZE)
dev_dataloader = preprocess(dev_data, BERT_MODEL, label2idx, MAX_SEQ_LENGTH, BATCH_SIZE)
test_dataloader = preprocess(test_data, BERT_MODEL, label2idx, MAX_SEQ_LENGTH, BATCH_SIZE, shuffle=False)
In [ ]:
import sys
sys.path.append('../')
import torch
from quillnlp.models.bert.models import get_multilabel_bert_classifier
BERT_MODEL = 'bert-base-uncased'
device = "cuda" if torch.cuda.is_available() else "cpu"
model = get_multilabel_bert_classifier(BERT_MODEL, len(label2idx), device=device)
In [ ]:
from quillnlp.models.bert.train import train
batch_size = 16 if "base" in BERT_MODEL else 2
gradient_accumulation_steps = 1 if "base" in BERT_MODEL else 8
output_model_file = train(model, train_dataloader, dev_dataloader, batch_size, gradient_accumulation_steps, device)
In [ ]:
from quillnlp.models.bert.train import evaluate
from sklearn.metrics import precision_recall_fscore_support, classification_report
print("Loading model from", output_model_file)
device="cpu"
model = get_multilabel_bert_classifier(BERT_MODEL, len(label2idx), model_file=output_model_file, device=device)
model.eval()
_, test_correct, test_predicted = evaluate(model, test_dataloader, device)
print("Test performance:", precision_recall_fscore_support(test_correct, test_predicted, average="micro"))
print(classification_report(test_correct, test_predicted, target_names=target_names))
In [ ]:
all_correct = 0
fp, fn, tp, tn = 0, 0, 0, 0
for c, p in zip(test_correct, test_predicted):
if sum(c == p) == len(c):
all_correct +=1
for ci, pi in zip(c, p):
if pi == 1 and ci == 1:
tp += 1
same = 1
elif pi == 1 and ci == 0:
fp += 1
elif pi == 0 and ci == 1:
fn += 1
else:
tn += 1
same =1
precision = tp/(tp+fp)
recall = tp/(tp+fn)
print("P:", precision)
print("R:", recall)
print("A:", all_correct/len(test_correct))
In [ ]:
for item, predicted, correct in zip(test_data, test_predicted, test_correct):
correct_labels = [idx2label[i] for i, l in enumerate(correct) if l == 1]
predicted_labels = [idx2label[i] for i, l in enumerate(predicted) if l == 1]
print("{}#{}#{}".format(item["text"], ";".join(correct_labels), ";".join(predicted_labels)))
In [ ]: