This is a short tutorial on how to use categorical variables (i.e. more values than binary) in Snorkel. We'll use a completely toy scenario with three sentences and two LFs just to demonstrate the mechanics. Please see the main tutorial for a more comprehensive intro!
We'll highlight in bold all parts focusing on the categorical aspect.
Viewer works in the categorical setting, but labeling Candidates in the Viewer does not.LogisticRegression and SparseLogisticRegression end models have been extended to the categorical setting, but other end models in contrib may not have been
In [ ]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import os
import numpy as np
from snorkel import SnorkelSession
session = SnorkelSession()
In [ ]:
from snorkel.parser import TSVDocPreprocessor, CorpusParser
doc_preprocessor = TSVDocPreprocessor('data/categorical_example.tsv')
corpus_parser = CorpusParser()
%time corpus_parser.apply(doc_preprocessor)
We'll define candidate relations between person mentions that now can take on one of three values:
['Married', 'Employs', False]
Note the importance of including a value for "not a relation of interest"- here we've used False, but any value could do.
Also note that None is a protected value -- denoting a labeling function abstaining -- so this cannot be used as a value.
In [ ]:
from snorkel.models import candidate_subclass
Relationship = candidate_subclass('Relationship', ['person1', 'person2'], values=['Married', 'Employs', False])
Now we extract candidates the same as in the Intro Tutorial (simplified here slightly):
In [ ]:
from snorkel.candidates import Ngrams, CandidateExtractor
from snorkel.matchers import PersonMatcher
from snorkel.models import Sentence
# Define a Person-Person candidate extractor
ngrams = Ngrams(n_max=3)
person_matcher = PersonMatcher(longest_match_only=True)
cand_extractor = CandidateExtractor(
Relationship,
[ngrams, ngrams],
[person_matcher, person_matcher],
symmetric_relations=False
)
# Apply to all (three) of the sentences for this simple example
sents = session.query(Sentence).all()
# Run the candidate extractor
%time cand_extractor.apply(sents, split=0)
In [ ]:
train_cands = session.query(Relationship).filter(Relationship.split == 0).all()
print("Number of candidates:", len(train_cands))
In [ ]:
from snorkel.viewer import SentenceNgramViewer
# NOTE: This if-then statement is only to avoid opening the viewer during automated testing of this notebook
# You should ignore this!
import os
if 'CI' not in os.environ:
sv = SentenceNgramViewer(train_cands, session)
else:
sv = None
In [ ]:
sv
The categorical labeling functions (LFs) we now write can output the following values:
None OR 0Relationship.values OR their integer indices.We'll write two simple LFs to illustrate.
Tip: we can get a random candidate (see below), or the example highlighted in the viewer above via sv.get_selected(), and then use this to test as we write the LFs!
In [ ]:
import re
from snorkel.lf_helpers import get_between_tokens
# Getting an example candidate from the Viewer
c = train_cands[0]
# Traversing the context hierarchy...
print(c.get_contexts()[0].get_parent().text)
# Using a helper function
list(get_between_tokens(c))
In [ ]:
def LF_married(c):
return 'Married' if 'married' in get_between_tokens(c) else None
WORKPLACE_RGX = r'employ|boss|company'
def LF_workplace(c):
sent = c.get_contexts()[0].get_parent()
matches = re.search(WORKPLACE_RGX, sent.text)
return 'Employs' if matches else None
LFs = [
LF_married,
LF_workplace
]
Now we apply the LFs to the candidates to produce our label matrix $L$:
In [ ]:
from snorkel.annotations import LabelAnnotator
labeler = LabelAnnotator(lfs=LFs)
%time L_train = labeler.apply(split=0)
L_train
In [ ]:
L_train.todense()
In [ ]:
from snorkel.learning import GenerativeModel
gen_model = GenerativeModel()
# Note: We pass cardinality explicitly here to be safe
# Can usually be inferred, except we have no labels with value=3
gen_model.train(L_train, cardinality=3)
In [ ]:
train_marginals = gen_model.marginals(L_train)
assert np.all(train_marginals.sum(axis=1) - np.ones(3) < 1e-10)
train_marginals
Next, we can save the training marginals:
In [ ]:
from snorkel.annotations import save_marginals, load_marginals
save_marginals(session, L_train, train_marginals)
And then reload (e.g. in another notebook):
In [ ]:
load_marginals(session, L_train)
Now we train an LSTM--note this is just to demonstrate the mechanics... since we only have three examples, don't expect anything spectacular!
In [ ]:
from snorkel.learning.pytorch import LSTM
train_kwargs = {
'n_epochs': 10,
'dropout': 0.25,
'print_freq': 1,
'seed': 1701
}
lstm = LSTM(cardinality=Relationship.cardinality, n_threads=None)
lstm.train(train_cands, train_marginals, **train_kwargs)
In [ ]:
train_labels = [1, 2, 1]
correct, incorrect = lstm.error_analysis(session, train_cands, train_labels)
In [ ]:
print("Accuracy:", lstm.score(train_cands, train_labels))
In [ ]:
test_marginals = lstm.marginals(train_cands)
test_marginals