Categorical Variables in Snorkel

This is a short tutorial on how to use categorical variables (i.e. more values than binary) in Snorkel. We'll use a completely toy scenario with three sentences and two LFs just to demonstrate the mechanics. Please see the main tutorial for a more comprehensive intro!

We'll highlight in bold all parts focusing on the categorical aspect.

Notes on Current Categorical Support:

  • The Viewer works in the categorical setting, but labeling Candidates in the Viewer does not.
    • Instead can import test / dev set labels from e.g. BRAT
  • The LogisticRegression and SparseLogisticRegression end models have been extended to the categorical setting, but other end models in contrib may not have been
    • Note: It's simple to make this change, so feel free to post an issue with requests for other end models!

In [ ]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import os
import numpy as np

from snorkel import SnorkelSession
session = SnorkelSession()

Step 1: Preprocessing the data


In [ ]:
from snorkel.parser import TSVDocPreprocessor, CorpusParser

doc_preprocessor = TSVDocPreprocessor('data/categorical_example.tsv') 
corpus_parser = CorpusParser()
%time corpus_parser.apply(doc_preprocessor)

Step 2: Defining candidates

We'll define candidate relations between person mentions that now can take on one of three values:

['Married', 'Employs', False]

Note the importance of including a value for "not a relation of interest"- here we've used False, but any value could do. Also note that None is a protected value -- denoting a labeling function abstaining -- so this cannot be used as a value.


In [ ]:
from snorkel.models import candidate_subclass
Relationship = candidate_subclass('Relationship', ['person1', 'person2'], values=['Married', 'Employs', False])

Now we extract candidates the same as in the Intro Tutorial (simplified here slightly):


In [ ]:
from snorkel.candidates import Ngrams, CandidateExtractor
from snorkel.matchers import PersonMatcher
from snorkel.models import Sentence

# Define a Person-Person candidate extractor
ngrams = Ngrams(n_max=3)
person_matcher = PersonMatcher(longest_match_only=True)
cand_extractor = CandidateExtractor(
    Relationship, 
    [ngrams, ngrams],
    [person_matcher, person_matcher],
    symmetric_relations=False
)

# Apply to all (three) of the sentences for this simple example
sents = session.query(Sentence).all()

# Run the candidate extractor
%time cand_extractor.apply(sents, split=0)

In [ ]:
train_cands = session.query(Relationship).filter(Relationship.split == 0).all()
print("Number of candidates:", len(train_cands))

In [ ]:
from snorkel.viewer import SentenceNgramViewer

# NOTE: This if-then statement is only to avoid opening the viewer during automated testing of this notebook
# You should ignore this!
import os
if 'CI' not in os.environ:
    sv = SentenceNgramViewer(train_cands, session)
else:
    sv = None

In [ ]:
sv

Step 3: Writing Labeling Functions

The categorical labeling functions (LFs) we now write can output the following values:

  • Abstain: None OR 0
  • Categorical values: The literal values in Relationship.values OR their integer indices.

We'll write two simple LFs to illustrate.

Tip: we can get a random candidate (see below), or the example highlighted in the viewer above via sv.get_selected(), and then use this to test as we write the LFs!


In [ ]:
import re
from snorkel.lf_helpers import get_between_tokens

# Getting an example candidate from the Viewer
c = train_cands[0]

# Traversing the context hierarchy...
print(c.get_contexts()[0].get_parent().text)

# Using a helper function
list(get_between_tokens(c))

In [ ]:
def LF_married(c):
    return 'Married' if 'married' in get_between_tokens(c) else None

WORKPLACE_RGX = r'employ|boss|company'
def LF_workplace(c):
    sent = c.get_contexts()[0].get_parent()
    matches = re.search(WORKPLACE_RGX, sent.text)
    return 'Employs' if matches else None

LFs = [
    LF_married,
    LF_workplace
]

Now we apply the LFs to the candidates to produce our label matrix $L$:


In [ ]:
from snorkel.annotations import LabelAnnotator

labeler = LabelAnnotator(lfs=LFs)
%time L_train = labeler.apply(split=0)
L_train

In [ ]:
L_train.todense()

Step 4: Training the Generative Model


In [ ]:
from snorkel.learning import GenerativeModel

gen_model = GenerativeModel()

# Note: We pass cardinality explicitly here to be safe
# Can usually be inferred, except we have no labels with value=3
gen_model.train(L_train, cardinality=3)

In [ ]:
train_marginals = gen_model.marginals(L_train)

assert np.all(train_marginals.sum(axis=1) - np.ones(3) < 1e-10)
train_marginals

Next, we can save the training marginals:


In [ ]:
from snorkel.annotations import save_marginals, load_marginals

save_marginals(session, L_train, train_marginals)

And then reload (e.g. in another notebook):


In [ ]:
load_marginals(session, L_train)

Step 5: Training the End Model

Now we train an LSTM--note this is just to demonstrate the mechanics... since we only have three examples, don't expect anything spectacular!


In [ ]:
from snorkel.learning.pytorch import LSTM

train_kwargs = {
    'n_epochs':   10,
    'dropout':    0.25,
    'print_freq': 1,
    'seed':       1701
}

lstm = LSTM(cardinality=Relationship.cardinality, n_threads=None)
lstm.train(train_cands, train_marginals, **train_kwargs)

In [ ]:
train_labels = [1, 2, 1]
correct, incorrect = lstm.error_analysis(session, train_cands, train_labels)

In [ ]:
print("Accuracy:", lstm.score(train_cands, train_labels))

In [ ]:
test_marginals = lstm.marginals(train_cands)
test_marginals