Bert Experiments: One Model

In this notebook, we continue our BERT experiments. We try to finetune one BERT model on several of our data sets. This makes it easier to deploy our solution in production.

As a first test, we'll just train a BERT model that takes as input a response from any of several data sets, and outputs probabilities for all labels in all data sets. This is slightly suboptimal (after all, we don't need probabilities for labels that are not relevant to a specific prompt), but as long as we're not working with thousands of different labels, I don't think this is very problematic.

The setup and preprocessing procedure is very similar to that in the first "Bert experiments" notebook. I'll highlight the areas where it is different.


In [1]:
import torch

from pytorch_transformers.tokenization_bert import BertTokenizer
from pytorch_transformers.modeling_bert import BertForSequenceClassification

BERT_MODEL = 'bert-large-uncased'
BATCH_SIZE = 16 if "base" in BERT_MODEL else 2
GRADIENT_ACCUMULATION_STEPS = 1 if "base" in BERT_MODEL else 8


tokenizer = BertTokenizer.from_pretrained(BERT_MODEL)

Data

As we build one "big" model, we combine the data from all of our input files. We keep the test files separate, because we want to be able to evaluate on every prompt separately.

In addition, we also remember which labels are relevant for every prompt, because in the prediction phase, we will only look at the probabilities of the relevant labels.


In [2]:
import ndjson
import glob

file_prefixes = ["eatingmeat_but_large", "eatingmeat_because_large",
                 "junkfood_but", "junkfood_because"]

train_data = []
dev_data = []
test_data = {}
label2idx = {}
target_names = {}

for prefix in file_prefixes:
    
    train_files = glob.glob(f"../data/interim/{prefix}_train_withprompt*.ndjson")
    dev_file = f"../data/interim/{prefix}_dev_withprompt.ndjson"
    test_file = f"../data/interim/{prefix}_test_withprompt.ndjson"

    target_names[prefix] = []
    for train_file in train_files:
        with open(train_file) as i:
            new_train_data = ndjson.load(i)
            for item in new_train_data:
                if item["label"] not in label2idx:
                    target_names[prefix].append(item["label"])
                    label2idx[item["label"]] = len(label2idx)
            train_data += new_train_data
                
    with open(dev_file) as i:
        dev_data += ndjson.load(i)

    with open(test_file) as i:
        test_data[prefix] = ndjson.load(i)

In [3]:
print(label2idx)
print(target_names)


{'Change without mentioning consumption': 0, 'Less meat consumption could harm economy and cut jobs': 1, 'The meat industry is important/thriving and/or exports/demand increasing': 2, 'Eating meat is necessary for good nutrition': 3, 'Eating meat is part of culture/tradition': 4, 'Meat creates jobs and benefits economy': 5, "Outside of article's scope": 6, 'People will or should still eat meat': 7, 'Flexitarian w/o connection to environment or jobs': 8, 'Flexitarians benefit environment': 9, 'Meat consumption harms environment': 10, 'Meat industry produces greenhouse gases and/or uses water - general': 11, 'Meat industry produces greenhouse gases and/or uses water - specific numbers': 12, 'Because as preposition': 13, 'Meat industry harms environment/uses resources w/o mentioning greenhouse gases or water': 14, 'Irrelevant fact from article': 15, 'Meat industry harms animals': 16, 'Unclassified Off-Topic': 17, 'School without generating money': 18, 'Schools providing healthy alternatives': 19, 'Student choice': 20, 'Students without choice': 21, 'Schools generate money': 22, 'Students can still bring/access junk food': 23, 'Unhealthy without Diabetes and Risk Factors': 24, 'Diabetes and Risk Factors': 25, 'Nutritional value without Diabetes and Risk Factors': 26, 'Obesity without Diabetes': 27}
{'eatingmeat_but_large': ['Change without mentioning consumption', 'Less meat consumption could harm economy and cut jobs', 'The meat industry is important/thriving and/or exports/demand increasing', 'Eating meat is necessary for good nutrition', 'Eating meat is part of culture/tradition', 'Meat creates jobs and benefits economy', "Outside of article's scope", 'People will or should still eat meat', 'Flexitarian w/o connection to environment or jobs', 'Flexitarians benefit environment', 'Meat consumption harms environment'], 'eatingmeat_because_large': ['Meat industry produces greenhouse gases and/or uses water - general', 'Meat industry produces greenhouse gases and/or uses water - specific numbers', 'Because as preposition', 'Meat industry harms environment/uses resources w/o mentioning greenhouse gases or water', 'Irrelevant fact from article', 'Meat industry harms animals'], 'junkfood_but': ['Unclassified Off-Topic', 'School without generating money', 'Schools providing healthy alternatives', 'Student choice', 'Students without choice', 'Schools generate money', 'Students can still bring/access junk food'], 'junkfood_because': ['Unhealthy without Diabetes and Risk Factors', 'Diabetes and Risk Factors', 'Nutritional value without Diabetes and Risk Factors', 'Obesity without Diabetes']}

Model


In [4]:
model = BertForSequenceClassification.from_pretrained(BERT_MODEL, num_labels=len(label2idx))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.train()


Out[4]:
BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 1024, padding_idx=0)
      (position_embeddings): Embedding(512, 1024)
      (token_type_embeddings): Embedding(2, 1024)
      (LayerNorm): BertLayerNorm()
      (dropout): Dropout(p=0.1)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): BertLayerNorm()
              (dropout): Dropout(p=0.1)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=1024, out_features=4096, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=4096, out_features=1024, bias=True)
            (LayerNorm): BertLayerNorm()
            (dropout): Dropout(p=0.1)
          )
        )
        (1): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): BertLayerNorm()
              (dropout): Dropout(p=0.1)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=1024, out_features=4096, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=4096, out_features=1024, bias=True)
            (LayerNorm): BertLayerNorm()
            (dropout): Dropout(p=0.1)
          )
        )
        (2): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): BertLayerNorm()
              (dropout): Dropout(p=0.1)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=1024, out_features=4096, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=4096, out_features=1024, bias=True)
            (LayerNorm): BertLayerNorm()
            (dropout): Dropout(p=0.1)
          )
        )
        (3): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): BertLayerNorm()
              (dropout): Dropout(p=0.1)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=1024, out_features=4096, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=4096, out_features=1024, bias=True)
            (LayerNorm): BertLayerNorm()
            (dropout): Dropout(p=0.1)
          )
        )
        (4): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): BertLayerNorm()
              (dropout): Dropout(p=0.1)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=1024, out_features=4096, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=4096, out_features=1024, bias=True)
            (LayerNorm): BertLayerNorm()
            (dropout): Dropout(p=0.1)
          )
        )
        (5): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): BertLayerNorm()
              (dropout): Dropout(p=0.1)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=1024, out_features=4096, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=4096, out_features=1024, bias=True)
            (LayerNorm): BertLayerNorm()
            (dropout): Dropout(p=0.1)
          )
        )
        (6): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): BertLayerNorm()
              (dropout): Dropout(p=0.1)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=1024, out_features=4096, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=4096, out_features=1024, bias=True)
            (LayerNorm): BertLayerNorm()
            (dropout): Dropout(p=0.1)
          )
        )
        (7): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): BertLayerNorm()
              (dropout): Dropout(p=0.1)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=1024, out_features=4096, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=4096, out_features=1024, bias=True)
            (LayerNorm): BertLayerNorm()
            (dropout): Dropout(p=0.1)
          )
        )
        (8): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): BertLayerNorm()
              (dropout): Dropout(p=0.1)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=1024, out_features=4096, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=4096, out_features=1024, bias=True)
            (LayerNorm): BertLayerNorm()
            (dropout): Dropout(p=0.1)
          )
        )
        (9): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): BertLayerNorm()
              (dropout): Dropout(p=0.1)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=1024, out_features=4096, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=4096, out_features=1024, bias=True)
            (LayerNorm): BertLayerNorm()
            (dropout): Dropout(p=0.1)
          )
        )
        (10): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): BertLayerNorm()
              (dropout): Dropout(p=0.1)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=1024, out_features=4096, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=4096, out_features=1024, bias=True)
            (LayerNorm): BertLayerNorm()
            (dropout): Dropout(p=0.1)
          )
        )
        (11): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): BertLayerNorm()
              (dropout): Dropout(p=0.1)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=1024, out_features=4096, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=4096, out_features=1024, bias=True)
            (LayerNorm): BertLayerNorm()
            (dropout): Dropout(p=0.1)
          )
        )
        (12): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): BertLayerNorm()
              (dropout): Dropout(p=0.1)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=1024, out_features=4096, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=4096, out_features=1024, bias=True)
            (LayerNorm): BertLayerNorm()
            (dropout): Dropout(p=0.1)
          )
        )
        (13): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): BertLayerNorm()
              (dropout): Dropout(p=0.1)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=1024, out_features=4096, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=4096, out_features=1024, bias=True)
            (LayerNorm): BertLayerNorm()
            (dropout): Dropout(p=0.1)
          )
        )
        (14): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): BertLayerNorm()
              (dropout): Dropout(p=0.1)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=1024, out_features=4096, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=4096, out_features=1024, bias=True)
            (LayerNorm): BertLayerNorm()
            (dropout): Dropout(p=0.1)
          )
        )
        (15): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): BertLayerNorm()
              (dropout): Dropout(p=0.1)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=1024, out_features=4096, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=4096, out_features=1024, bias=True)
            (LayerNorm): BertLayerNorm()
            (dropout): Dropout(p=0.1)
          )
        )
        (16): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): BertLayerNorm()
              (dropout): Dropout(p=0.1)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=1024, out_features=4096, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=4096, out_features=1024, bias=True)
            (LayerNorm): BertLayerNorm()
            (dropout): Dropout(p=0.1)
          )
        )
        (17): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): BertLayerNorm()
              (dropout): Dropout(p=0.1)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=1024, out_features=4096, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=4096, out_features=1024, bias=True)
            (LayerNorm): BertLayerNorm()
            (dropout): Dropout(p=0.1)
          )
        )
        (18): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): BertLayerNorm()
              (dropout): Dropout(p=0.1)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=1024, out_features=4096, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=4096, out_features=1024, bias=True)
            (LayerNorm): BertLayerNorm()
            (dropout): Dropout(p=0.1)
          )
        )
        (19): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): BertLayerNorm()
              (dropout): Dropout(p=0.1)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=1024, out_features=4096, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=4096, out_features=1024, bias=True)
            (LayerNorm): BertLayerNorm()
            (dropout): Dropout(p=0.1)
          )
        )
        (20): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): BertLayerNorm()
              (dropout): Dropout(p=0.1)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=1024, out_features=4096, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=4096, out_features=1024, bias=True)
            (LayerNorm): BertLayerNorm()
            (dropout): Dropout(p=0.1)
          )
        )
        (21): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): BertLayerNorm()
              (dropout): Dropout(p=0.1)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=1024, out_features=4096, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=4096, out_features=1024, bias=True)
            (LayerNorm): BertLayerNorm()
            (dropout): Dropout(p=0.1)
          )
        )
        (22): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): BertLayerNorm()
              (dropout): Dropout(p=0.1)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=1024, out_features=4096, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=4096, out_features=1024, bias=True)
            (LayerNorm): BertLayerNorm()
            (dropout): Dropout(p=0.1)
          )
        )
        (23): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): BertLayerNorm()
              (dropout): Dropout(p=0.1)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=1024, out_features=4096, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=4096, out_features=1024, bias=True)
            (LayerNorm): BertLayerNorm()
            (dropout): Dropout(p=0.1)
          )
        )
      )
    )
    (pooler): BertPooler(
      (dense): Linear(in_features=1024, out_features=1024, bias=True)
      (activation): Tanh()
    )
  )
  (dropout): Dropout(p=0.1)
  (classifier): Linear(in_features=1024, out_features=28, bias=True)
)

Preprocessing


In [5]:
import logging
import numpy as np

logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt = '%m/%d/%Y %H:%M:%S',
                    level = logging.INFO)
logger = logging.getLogger(__name__)

MAX_SEQ_LENGTH=100

class InputFeatures(object):
    """A single set of features of data."""

    def __init__(self, input_ids, input_mask, segment_ids, label_id):
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.label_id = label_id
        
        
def convert_examples_to_features(examples, label2idx, max_seq_length, tokenizer, verbose=0):
    """Loads a data file into a list of `InputBatch`s."""
    
    features = []
    for (ex_index, ex) in enumerate(examples):
        
        # TODO: should deal better with sentences > max tok length
        input_ids = tokenizer.encode("[CLS] " + ex["text"] + " [SEP]")
        segment_ids = [0] * len(input_ids)
            
        # The mask has 1 for real tokens and 0 for padding tokens. Only real
        # tokens are attended to.
        input_mask = [1] * len(input_ids)

        # Zero-pad up to the sequence length.
        padding = [0] * (max_seq_length - len(input_ids))
        input_ids += padding
        input_mask += padding
        segment_ids += padding

        assert len(input_ids) == max_seq_length
        assert len(input_mask) == max_seq_length
        assert len(segment_ids) == max_seq_length

        label_id = label2idx[ex["label"]]
        if verbose and ex_index == 0:
            logger.info("*** Example ***")
            logger.info("text: %s" % ex["text"])
            logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
            logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
            logger.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
            logger.info("label:" + str(ex["label"]) + " id: " + str(label_id))

        features.append(
                InputFeatures(input_ids=input_ids,
                              input_mask=input_mask,
                              segment_ids=segment_ids,
                              label_id=label_id))
    return features

train_features = convert_examples_to_features(train_data, label2idx, MAX_SEQ_LENGTH, tokenizer, verbose=0)
dev_features = convert_examples_to_features(dev_data, label2idx, MAX_SEQ_LENGTH, tokenizer)

test_features = {}
for prefix in test_data:
    test_features[prefix] = convert_examples_to_features(test_data[prefix], label2idx, MAX_SEQ_LENGTH, tokenizer, verbose=1)


08/01/2019 17:57:57 - INFO - __main__ -   *** Example ***
08/01/2019 17:57:57 - INFO - __main__ -   text: Large amounts of meat consumption are harming the environment, but decreasing meat consumption could harm meat industry, the economy, and decrease jobs.
08/01/2019 17:57:57 - INFO - __main__ -   input_ids: 101 2312 8310 1997 6240 8381 2024 7386 2075 1996 4044 1010 2021 16922 6240 8381 2071 7386 6240 3068 1010 1996 4610 1010 1998 9885 5841 1012 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
08/01/2019 17:57:57 - INFO - __main__ -   input_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
08/01/2019 17:57:57 - INFO - __main__ -   segment_ids: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
08/01/2019 17:57:57 - INFO - __main__ -   label:Less meat consumption could harm economy and cut jobs id: 1
08/01/2019 17:57:57 - INFO - __main__ -   *** Example ***
08/01/2019 17:57:57 - INFO - __main__ -   text: Large amounts of meat consumption are harming the environment, because it creates one-fifth of the earth's greenhouse gases.
08/01/2019 17:57:57 - INFO - __main__ -   input_ids: 101 2312 8310 1997 6240 8381 2024 7386 2075 1996 4044 1010 2138 2009 9005 2028 1011 3587 1997 1996 3011 1005 1055 16635 15865 1012 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
08/01/2019 17:57:57 - INFO - __main__ -   input_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
08/01/2019 17:57:57 - INFO - __main__ -   segment_ids: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
08/01/2019 17:57:57 - INFO - __main__ -   label:Meat industry produces greenhouse gases and/or uses water - specific numbers id: 12
08/01/2019 17:57:57 - INFO - __main__ -   *** Example ***
08/01/2019 17:57:57 - INFO - __main__ -   text: Schools should not allow junk food to be sold on campus but kids will still bring in unhealthy food
08/01/2019 17:57:57 - INFO - __main__ -   input_ids: 101 2816 2323 2025 3499 18015 2833 2000 2022 2853 2006 3721 2021 4268 2097 2145 3288 1999 4895 20192 24658 2100 2833 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
08/01/2019 17:57:57 - INFO - __main__ -   input_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
08/01/2019 17:57:57 - INFO - __main__ -   segment_ids: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
08/01/2019 17:57:57 - INFO - __main__ -   label:Students without choice id: 21
08/01/2019 17:57:57 - INFO - __main__ -   *** Example ***
08/01/2019 17:57:57 - INFO - __main__ -   text: Schools should not allow junk food to be sold on campus because it causes major health problems at a developmental stage in kid's lives.
08/01/2019 17:57:57 - INFO - __main__ -   input_ids: 101 2816 2323 2025 3499 18015 2833 2000 2022 2853 2006 3721 2138 2009 5320 2350 2740 3471 2012 1037 13908 2754 1999 4845 1005 1055 3268 1012 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
08/01/2019 17:57:57 - INFO - __main__ -   input_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
08/01/2019 17:57:57 - INFO - __main__ -   segment_ids: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
08/01/2019 17:57:57 - INFO - __main__ -   label:Unhealthy without Diabetes and Risk Factors id: 24

In [6]:
import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler

def get_data_loader(features, max_seq_length, batch_size): 

    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
    all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long)
    data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
    sampler = RandomSampler(data)
    dataloader = DataLoader(data, sampler=sampler, batch_size=batch_size)
    return dataloader

train_dataloader = get_data_loader(train_features, MAX_SEQ_LENGTH, BATCH_SIZE)
dev_dataloader = get_data_loader(dev_features, MAX_SEQ_LENGTH, BATCH_SIZE)
test_dataloaders = {} 
for prefix in test_features:
    test_dataloaders[prefix] = get_data_loader(test_features[prefix], MAX_SEQ_LENGTH, BATCH_SIZE)

Evaluation


In [7]:
def evaluate(model, dataloader):

    eval_loss = 0
    nb_eval_steps = 0
    predicted_labels, correct_labels = [], []

    for step, batch in enumerate(tqdm(dataloader, desc="Evaluation iteration")):
        batch = tuple(t.to(device) for t in batch)
        input_ids, input_mask, segment_ids, label_ids = batch

        with torch.no_grad():
            tmp_eval_loss, logits = model(input_ids, segment_ids, input_mask, label_ids)

        outputs = np.argmax(logits.to('cpu'), axis=1)
        label_ids = label_ids.to('cpu').numpy()
        
        predicted_labels += list(outputs)
        correct_labels += list(label_ids)
        
        eval_loss += tmp_eval_loss.mean().item()
        nb_eval_steps += 1

    eval_loss = eval_loss / nb_eval_steps
    
    correct_labels = np.array(correct_labels)
    predicted_labels = np.array(predicted_labels)
        
    return eval_loss, correct_labels, predicted_labels

Training


In [8]:
from pytorch_transformers.optimization import AdamW, WarmupLinearSchedule

NUM_TRAIN_EPOCHS = 100
LEARNING_RATE = 1e-5
WARMUP_PROPORTION = 0.1

def warmup_linear(x, warmup=0.002):
    if x < warmup:
        return x/warmup
    return 1.0 - x

num_train_steps = int(len(train_data) / BATCH_SIZE / GRADIENT_ACCUMULATION_STEPS * NUM_TRAIN_EPOCHS)

param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]

optimizer = AdamW(optimizer_grouped_parameters, lr=LEARNING_RATE, correct_bias=False)
scheduler = WarmupLinearSchedule(optimizer, warmup_steps=100, t_total=num_train_steps)

In [9]:
import os

OUTPUT_DIR = "/tmp/"
MODEL_FILE_NAME = "pytorch_model.bin"
output_model_file = os.path.join(OUTPUT_DIR, MODEL_FILE_NAME)

In [10]:
from tqdm import trange
from tqdm import tqdm_notebook as tqdm
from sklearn.metrics import classification_report, precision_recall_fscore_support


PATIENCE = 5

global_step = 0
model.train()
loss_history = []
best_epoch = 0
for epoch in trange(int(NUM_TRAIN_EPOCHS), desc="Epoch"):
    tr_loss = 0
    nb_tr_examples, nb_tr_steps = 0, 0
    for step, batch in enumerate(tqdm(train_dataloader, desc="Training iteration")):
        batch = tuple(t.to(device) for t in batch)
        input_ids, input_mask, segment_ids, label_ids = batch
        outputs = model(input_ids, segment_ids, input_mask, label_ids)
        loss = outputs[0]
        
        if GRADIENT_ACCUMULATION_STEPS > 1:
            loss = loss / GRADIENT_ACCUMULATION_STEPS

        loss.backward()

        tr_loss += loss.item()
        nb_tr_examples += input_ids.size(0)
        nb_tr_steps += 1
        if (step + 1) % GRADIENT_ACCUMULATION_STEPS == 0:
            lr_this_step = LEARNING_RATE * warmup_linear(global_step/num_train_steps, WARMUP_PROPORTION)
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr_this_step
            optimizer.step()
            optimizer.zero_grad()
            global_step += 1

    dev_loss, _, _ = evaluate(model, dev_dataloader)
    
    print("Loss history:", loss_history)
    print("Dev loss:", dev_loss)
    
    if len(loss_history) == 0 or dev_loss < min(loss_history):
        model_to_save = model.module if hasattr(model, 'module') else model
        torch.save(model_to_save.state_dict(), output_model_file)
        best_epoch = epoch
    
    if epoch-best_epoch >= PATIENCE: 
        print("No improvement on development set. Finish training.")
        break
    
    loss_history.append(dev_loss)


Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Loss history: []
Dev loss: 2.661243469064886
Epoch:   1%|          | 1/100 [05:42<9:25:18, 342.61s/it]

Loss history: [2.661243469064886]
Dev loss: 2.1294797619183856
Epoch:   2%|▏         | 2/100 [11:25<9:19:56, 342.83s/it]

Loss history: [2.661243469064886, 2.1294797619183856]
Dev loss: 1.6750584876898564
Epoch:   3%|▎         | 3/100 [17:09<9:14:28, 342.98s/it]

Loss history: [2.661243469064886, 2.1294797619183856, 1.6750584876898564]
Dev loss: 1.3086172049695795
Epoch:   4%|▍         | 4/100 [22:52<9:08:50, 343.03s/it]

Loss history: [2.661243469064886, 2.1294797619183856, 1.6750584876898564, 1.3086172049695795]
Dev loss: 1.0326541987332432
Epoch:   5%|▌         | 5/100 [28:36<9:03:26, 343.23s/it]

Loss history: [2.661243469064886, 2.1294797619183856, 1.6750584876898564, 1.3086172049695795, 1.0326541987332432]
Dev loss: 0.8468931219794533
Epoch:   6%|▌         | 6/100 [34:19<8:57:40, 343.20s/it]

Loss history: [2.661243469064886, 2.1294797619183856, 1.6750584876898564, 1.3086172049695795, 1.0326541987332432, 0.8468931219794533]
Dev loss: 0.6602746053175492
Epoch:   7%|▋         | 7/100 [40:02<8:51:54, 343.16s/it]

Loss history: [2.661243469064886, 2.1294797619183856, 1.6750584876898564, 1.3086172049695795, 1.0326541987332432, 0.8468931219794533, 0.6602746053175492]
Dev loss: 0.5387974966656078
Epoch:   8%|▊         | 8/100 [45:45<8:46:09, 343.14s/it]

Epoch:   9%|▉         | 9/100 [51:27<8:40:03, 342.89s/it]
Loss history: [2.661243469064886, 2.1294797619183856, 1.6750584876898564, 1.3086172049695795, 1.0326541987332432, 0.8468931219794533, 0.6602746053175492, 0.5387974966656078]
Dev loss: 0.583350766066349

Loss history: [2.661243469064886, 2.1294797619183856, 1.6750584876898564, 1.3086172049695795, 1.0326541987332432, 0.8468931219794533, 0.6602746053175492, 0.5387974966656078, 0.583350766066349]
Dev loss: 0.5258630470796065
Epoch:  10%|█         | 10/100 [57:11<8:34:45, 343.17s/it]

Loss history: [2.661243469064886, 2.1294797619183856, 1.6750584876898564, 1.3086172049695795, 1.0326541987332432, 0.8468931219794533, 0.6602746053175492, 0.5387974966656078, 0.583350766066349, 0.5258630470796065]
Dev loss: 0.48672709826267124
Epoch:  11%|█         | 11/100 [1:02:54<8:29:07, 343.23s/it]

Epoch:  12%|█▏        | 12/100 [1:08:36<8:22:46, 342.80s/it]
Loss history: [2.661243469064886, 2.1294797619183856, 1.6750584876898564, 1.3086172049695795, 1.0326541987332432, 0.8468931219794533, 0.6602746053175492, 0.5387974966656078, 0.583350766066349, 0.5258630470796065, 0.48672709826267124]
Dev loss: 0.5388190912477898

Loss history: [2.661243469064886, 2.1294797619183856, 1.6750584876898564, 1.3086172049695795, 1.0326541987332432, 0.8468931219794533, 0.6602746053175492, 0.5387974966656078, 0.583350766066349, 0.5258630470796065, 0.48672709826267124, 0.5388190912477898]
Dev loss: 0.47566417130556976
Epoch:  13%|█▎        | 13/100 [1:14:19<8:16:53, 342.68s/it]

Epoch:  14%|█▍        | 14/100 [1:20:00<8:10:35, 342.27s/it]
Loss history: [2.661243469064886, 2.1294797619183856, 1.6750584876898564, 1.3086172049695795, 1.0326541987332432, 0.8468931219794533, 0.6602746053175492, 0.5387974966656078, 0.583350766066349, 0.5258630470796065, 0.48672709826267124, 0.5388190912477898, 0.47566417130556976]
Dev loss: 0.48471841523141573

Loss history: [2.661243469064886, 2.1294797619183856, 1.6750584876898564, 1.3086172049695795, 1.0326541987332432, 0.8468931219794533, 0.6602746053175492, 0.5387974966656078, 0.583350766066349, 0.5258630470796065, 0.48672709826267124, 0.5388190912477898, 0.47566417130556976, 0.48471841523141573]
Dev loss: 0.47057445844014484
Epoch:  15%|█▌        | 15/100 [1:25:44<8:05:33, 342.75s/it]

Epoch:  16%|█▌        | 16/100 [1:31:26<7:59:39, 342.61s/it]
Loss history: [2.661243469064886, 2.1294797619183856, 1.6750584876898564, 1.3086172049695795, 1.0326541987332432, 0.8468931219794533, 0.6602746053175492, 0.5387974966656078, 0.583350766066349, 0.5258630470796065, 0.48672709826267124, 0.5388190912477898, 0.47566417130556976, 0.48471841523141573, 0.47057445844014484]
Dev loss: 0.5064784953088471

Epoch:  17%|█▋        | 17/100 [1:37:08<7:53:41, 342.43s/it]
Loss history: [2.661243469064886, 2.1294797619183856, 1.6750584876898564, 1.3086172049695795, 1.0326541987332432, 0.8468931219794533, 0.6602746053175492, 0.5387974966656078, 0.583350766066349, 0.5258630470796065, 0.48672709826267124, 0.5388190912477898, 0.47566417130556976, 0.48471841523141573, 0.47057445844014484, 0.5064784953088471]
Dev loss: 0.524652259277575

Epoch:  18%|█▊        | 18/100 [1:42:50<7:47:45, 342.26s/it]
Loss history: [2.661243469064886, 2.1294797619183856, 1.6750584876898564, 1.3086172049695795, 1.0326541987332432, 0.8468931219794533, 0.6602746053175492, 0.5387974966656078, 0.583350766066349, 0.5258630470796065, 0.48672709826267124, 0.5388190912477898, 0.47566417130556976, 0.48471841523141573, 0.47057445844014484, 0.5064784953088471, 0.524652259277575]
Dev loss: 0.5291007135853623

Epoch:  19%|█▉        | 19/100 [1:48:32<7:42:03, 342.26s/it]
Loss history: [2.661243469064886, 2.1294797619183856, 1.6750584876898564, 1.3086172049695795, 1.0326541987332432, 0.8468931219794533, 0.6602746053175492, 0.5387974966656078, 0.583350766066349, 0.5258630470796065, 0.48672709826267124, 0.5388190912477898, 0.47566417130556976, 0.48471841523141573, 0.47057445844014484, 0.5064784953088471, 0.524652259277575, 0.5291007135853623]
Dev loss: 0.5433029131455855

Loss history: [2.661243469064886, 2.1294797619183856, 1.6750584876898564, 1.3086172049695795, 1.0326541987332432, 0.8468931219794533, 0.6602746053175492, 0.5387974966656078, 0.583350766066349, 0.5258630470796065, 0.48672709826267124, 0.5388190912477898, 0.47566417130556976, 0.48471841523141573, 0.47057445844014484, 0.5064784953088471, 0.524652259277575, 0.5291007135853623, 0.5433029131455855]
Dev loss: 0.519724946311026
No improvement on development set. Finish training.

Results


In [11]:
from tqdm import tqdm_notebook as tqdm
from sklearn.metrics import classification_report, precision_recall_fscore_support

device="cpu"
print("Loading model from", output_model_file)

model_state_dict = torch.load(output_model_file, map_location=lambda storage, loc: storage)
model = BertForSequenceClassification.from_pretrained(BERT_MODEL, state_dict=model_state_dict, num_labels=len(label2idx))
model.to(device)

model.eval()

#_, train_correct, train_predicted = evaluate(model, train_dataloader)
#_, dev_correct, dev_predicted = evaluate(model, dev_dataloader)

#print("Training performance:", precision_recall_fscore_support(train_correct, train_predicted, average="micro"))
#print("Development performance:", precision_recall_fscore_support(dev_correct, dev_predicted, average="micro"))

for prefix in test_dataloaders:
    print(prefix)
    _, test_correct, test_predicted = evaluate(model, test_dataloaders[prefix])

    print("Test performance:", precision_recall_fscore_support(test_correct, test_predicted, average="micro"))

    print(classification_report(test_correct, test_predicted, target_names=target_names[prefix]))


Loading model from /tmp/pytorch_model.bin
08/01/2019 19:52:13 - INFO - pytorch_transformers.modeling_utils -   loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json from cache at /home/yves/.cache/torch/pytorch_transformers/6dfaed860471b03ab5b9acb6153bea82b6632fb9bbe514d3fff050fe1319ee6d.4c88e2dec8f8b017f319f6db2b157fee632c0860d9422e4851bd0d6999f9ce38
08/01/2019 19:52:13 - INFO - pytorch_transformers.modeling_utils -   Model config {
  "attention_probs_dropout_prob": 0.1,
  "finetuning_task": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 1024,
  "initializer_range": 0.02,
  "intermediate_size": 4096,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "num_attention_heads": 16,
  "num_hidden_layers": 24,
  "num_labels": 28,
  "output_attentions": false,
  "output_hidden_states": false,
  "torchscript": false,
  "type_vocab_size": 2,
  "vocab_size": 30522
}

08/01/2019 19:52:14 - INFO - pytorch_transformers.modeling_utils -   loading weights file https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin from cache at /home/yves/.cache/torch/pytorch_transformers/54da47087cc86ce75324e4dc9bbb5f66c6e83a7c6bd23baea8b489acc8d09aa4.4d5343a4b979c4beeaadef17a0453d1bb183dd9b084f58b84c7cc781df343ae6
eatingmeat_but_large
Test performance: (0.9024390243902439, 0.9024390243902439, 0.9024390243902439, None)
                                                                          precision    recall  f1-score   support

                                   Change without mentioning consumption       0.00      0.00      0.00         1
                   Less meat consumption could harm economy and cut jobs       1.00      1.00      1.00        42
The meat industry is important/thriving and/or exports/demand increasing       0.94      0.94      0.94        18
                             Eating meat is necessary for good nutrition       1.00      0.33      0.50         6
                                Eating meat is part of culture/tradition       0.86      0.95      0.90        19
                                  Meat creates jobs and benefits economy       1.00      0.97      0.99        40
                                              Outside of article's scope       0.77      0.91      0.83        11
                                    People will or should still eat meat       0.40      0.40      0.40         5
                       Flexitarian w/o connection to environment or jobs       0.80      0.92      0.86        13
                                        Flexitarians benefit environment       0.86      0.75      0.80         8
                                      Meat consumption harms environment       0.00      0.00      0.00         1

                                                             avg / total       0.91      0.90      0.90       164

eatingmeat_because_large
Test performance: (0.8958333333333334, 0.8958333333333334, 0.8958333333333334, None)
                                                                                         precision    recall  f1-score   support

                    Meat industry produces greenhouse gases and/or uses water - general       1.00      0.33      0.50         3
           Meat industry produces greenhouse gases and/or uses water - specific numbers       0.85      0.92      0.88        49
                                                                 Because as preposition       0.96      0.95      0.95        56
Meat industry harms environment/uses resources w/o mentioning greenhouse gases or water       0.96      0.96      0.96        23
                                                           Irrelevant fact from article       0.56      0.56      0.56         9
                                                            Meat industry harms animals       1.00      1.00      1.00         3

                                                                            avg / total       0.89      0.90      0.89       144

junkfood_but
/opt/anaconda3/lib/python3.7/site-packages/sklearn/metrics/classification.py:1428: UserWarning: labels size, 7, does not match size of target_names, 6
  .format(len(labels), len(target_names))
/opt/anaconda3/lib/python3.7/site-packages/sklearn/metrics/classification.py:1135: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples.
  'precision', 'predicted', average, warn_for)
Test performance: (0.8496732026143791, 0.8496732026143791, 0.8496732026143791, None)
                                           precision    recall  f1-score   support

                   Unclassified Off-Topic       0.67      0.73      0.70        11
          School without generating money       1.00      0.44      0.61        16
   Schools providing healthy alternatives       0.94      0.99      0.96        75
                           Student choice       0.47      1.00      0.64         7
                  Students without choice       0.86      0.76      0.81        33
                   Schools generate money       0.89      1.00      0.94         8
Students can still bring/access junk food       0.50      0.33      0.40         3

                              avg / total       0.88      0.85      0.84       153

junkfood_because
Test performance: (0.9661016949152542, 0.9661016949152542, 0.9661016949152542, None)
                                                     precision    recall  f1-score   support

        Unhealthy without Diabetes and Risk Factors       1.00      0.95      0.97        75
                          Diabetes and Risk Factors       1.00      1.00      1.00        29
Nutritional value without Diabetes and Risk Factors       0.78      1.00      0.88         7
                           Obesity without Diabetes       0.78      1.00      0.88         7

                                        avg / total       0.97      0.97      0.97       118


In [ ]: