In this part of the tutorial, we will write labeling functions which express various heuristics, patterns, and weak supervision strategies to label our data.
In most real-world settings, hand-labeled training data is prohibitively expensive and slow to collect. A common scenario, though, is to have access to tons of unlabeled training data, and have some idea of how to label it programmatically. For example:
Our labeling functions will capture these types of strategies. We know that these labeling functions will not be perfect, and some may be quite low-quality, so we will model their accuracies with a generative model, which Snorkel will help us easily apply.
This will ultimately produce a single set of noise-aware training labels, which we will then use to train an end extraction model in the next notebook. For more technical details of this overall approach, see our NIPS 2016 paper.
In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import os
# TO USE A DATABASE OTHER THAN SQLITE, USE THIS LINE
# Note that this is necessary for parallel execution amongst other things...
# os.environ['SNORKELDB'] = 'postgres:///snorkel-intro'
import numpy as np
from snorkel import SnorkelSession
session = SnorkelSession()
We repeat our definition of the Spouse
Candidate
subclass from Parts II and III.
In [2]:
from snorkel.models import candidate_subclass
Spouse = candidate_subclass('Spouse', ['person1', 'person2'])
In our setting here, we will use the phrase "development set" to refer to a small set of examples (here, a subset of our training set) which we label by hand and use to help us develop and refine labeling functions. Unlike the test set, which we do not look at and use for final evaluation, we can inspect the development set while writing labeling functions.
In our case, we already loaded existing labels for a development set (split
1), so we can load them again now:
In [3]:
from snorkel.annotations import load_gold_labels
L_gold_dev = load_gold_labels(session, annotator_name='gold', split=1)
Our biggest step in the data programming pipeline is the creation - and modeling - of a noisy training set. We'll approach this in three main steps:
Creating labeling functions (LFs): This is where most of our development time would actually go into if this were a real application. Labeling functions encode our heuristics and weak supervision signals to generate (noisy) labels for our training candidates.
Applying the LFs: Here, we actually use them to label our candidates!
Training a generative model of our training set: Here we learn a model over our LFs, learning their respective accuracies automatically. This will allow us to combine them into a single, higher-quality label set.
We'll also add some detail on how to go about developing labeling functions and then debugging our model of them to improve performance.
In Snorkel, our primary interface through which we provide training signal to the end extraction model we are training is by writing labeling functions (LFs) (as opposed to hand-labeling massive training sets). We'll go through some examples for our spouse extraction task below.
A labeling function is just a Python function that accepts a Candidate
and returns 1
to mark the Candidate
as true, -1
to mark the Candidate
as false, and 0
to abstain from labeling the Candidate
(note that the non-binary classification setting is covered in the advanced tutorials!).
In the next stages of the Snorkel pipeline, we'll train a model to learn the accuracies of the labeling functions and reweight them accordingly, and then use them to train a downstream model. It turns out by doing this, we can get high-quality models even with lower-quality labeling functions. So they don't need to be perfect! Now on to writing some:
In [4]:
import re
from snorkel.lf_helpers import (
get_left_tokens, get_right_tokens, get_between_tokens,
get_text_between, get_tagged_text,
)
These LFs express some common sense text patterns which indicate that a person pair might be married. For example, LF_husband_wife
looks for words in spouses
between the person mentions, and LF_same_last_name
checks to see if the two people have the same last name (but aren't the same whole name).
In [5]:
spouses = {'spouse', 'wife', 'husband', 'ex-wife', 'ex-husband'}
family = {'father', 'mother', 'sister', 'brother', 'son', 'daughter',
'grandfather', 'grandmother', 'uncle', 'aunt', 'cousin'}
family = family | {f + '-in-law' for f in family}
other = {'boyfriend', 'girlfriend' 'boss', 'employee', 'secretary', 'co-worker'}
# Helper function to get last name
def last_name(s):
name_parts = s.split(' ')
return name_parts[-1] if len(name_parts) > 1 else None
def LF_husband_wife(c):
return 1 if len(spouses.intersection(get_between_tokens(c))) > 0 else 0
def LF_husband_wife_left_window(c):
if len(spouses.intersection(get_left_tokens(c[0], window=2))) > 0:
return 1
elif len(spouses.intersection(get_left_tokens(c[1], window=2))) > 0:
return 1
else:
return 0
def LF_same_last_name(c):
p1_last_name = last_name(c.person1.get_span())
p2_last_name = last_name(c.person2.get_span())
if p1_last_name and p2_last_name and p1_last_name == p2_last_name:
if c.person1.get_span() != c.person2.get_span():
return 1
return 0
def LF_no_spouse_in_sentence(c):
return -1 if np.random.rand() < 0.75 and len(spouses.intersection(c.get_parent().words)) == 0 else 0
def LF_and_married(c):
return 1 if 'and' in get_between_tokens(c) and 'married' in get_right_tokens(c) else 0
def LF_familial_relationship(c):
return -1 if len(family.intersection(get_between_tokens(c))) > 0 else 0
def LF_family_left_window(c):
if len(family.intersection(get_left_tokens(c[0], window=2))) > 0:
return -1
elif len(family.intersection(get_left_tokens(c[1], window=2))) > 0:
return -1
else:
return 0
def LF_other_relationship(c):
return -1 if len(other.intersection(get_between_tokens(c))) > 0 else 0
In addition to writing labeling functions that describe text pattern-based heuristics for labeling training examples, we can also write labeling functions that distantly supervise examples. Here, we'll load in a list of known spouse pairs and check to see if the candidate pair matches one of these.
In [6]:
import bz2
# Function to remove special characters from text
def strip_special(s):
return ''.join(c for c in s if ord(c) < 128)
# Read in known spouse pairs and save as set of tuples
with bz2.BZ2File('data/spouses_dbpedia.csv.bz2', 'rb') as f:
known_spouses = set(
tuple(strip_special(x.decode('utf-8')).strip().split(',')) for x in f.readlines()
)
# Last name pairs for known spouses
last_names = set([(last_name(x), last_name(y)) for x, y in known_spouses if last_name(x) and last_name(y)])
def LF_distant_supervision(c):
p1, p2 = c.person1.get_span(), c.person2.get_span()
return 1 if (p1, p2) in known_spouses or (p2, p1) in known_spouses else 0
def LF_distant_supervision_last_names(c):
p1, p2 = c.person1.get_span(), c.person2.get_span()
p1n, p2n = last_name(p1), last_name(p2)
return 1 if (p1 != p2) and ((p1n, p2n) in last_names or (p2n, p1n) in last_names) else 0
For later convenience we group the labeling functions into a list.
In [7]:
LFs = [
LF_distant_supervision, LF_distant_supervision_last_names,
LF_husband_wife, LF_husband_wife_left_window, LF_same_last_name,
LF_no_spouse_in_sentence, LF_and_married, LF_familial_relationship,
LF_family_left_window, LF_other_relationship
]
Above, we've written a bunch of labeling functions already, which should give you some sense about how to go about it. While writing them, we probably want to check to make sure that they at least work as intended before adding to our set. Suppose we're thinking about writing a simple LF:
In [8]:
def LF_wife_in_sentence(c):
"""A simple example of a labeling function"""
return 1 if 'wife' in c.get_parent().words else 0
One simple thing we can do is quickly test it on our development set (or any other set), without saving it to the database. This is simple to do. For example, we can easily get every candidate that this LF labels as true:
In [9]:
labeled = []
for c in session.query(Spouse).filter(Spouse.split == 1).all():
if LF_wife_in_sentence(c) != 0:
labeled.append(c)
print("Number labeled:", len(labeled))
We can then easily put this into the Viewer as usual (try it out!):
SentenceNgramViewer(labeled, session)
We also have a simple helper function for getting the empirical accuracy of a single LF with respect to the development set labels for example. This function also returns the evaluation buckets of the candidates (true positive, false positive, true negative, false negative):
In [10]:
from snorkel.lf_helpers import test_LF
tp, fp, tn, fn = test_LF(session, LF_wife_in_sentence, split=1, annotator_name='gold')
Next, we need to actually run the LFs over all of our training candidates, producing a set of Labels
and LabelKeys
(just the names of the LFs) in the database. We'll do this using the LabelAnnotator
class, a UDF which we will again run with UDFRunner
. Note that this will delete any existing Labels
and LabelKeys
for this candidate set. We start by setting up the class:
In [11]:
from snorkel.annotations import LabelAnnotator
labeler = LabelAnnotator(lfs=LFs)
Finally, we run the labeler
. Note that we set a random seed for reproducibility, since some of the LFs involve random number generators. Again, this can be run in parallel, given an appropriate database like Postgres is being used:
In [12]:
np.random.seed(1701)
%time L_train = labeler.apply(split=0)
L_train
Out[12]:
If we've already created the labels (saved in the database), we can load them in as a sparse matrix here too:
In [13]:
%time L_train = labeler.load_matrix(session, split=0)
L_train
Out[13]:
Note that the returned matrix is a special subclass of the scipy.sparse.csr_matrix
class, with some special features which we demonstrate below:
In [14]:
L_train.get_candidate(session, 0)
Out[14]:
In [15]:
L_train.get_key(session, 0)
Out[15]:
We can also view statistics about the resulting label matrix.
In [16]:
L_train.lf_stats(session)
Out[16]:
Now, we'll train a model of the LFs to estimate their accuracies. Once the model is trained, we can combine the outputs of the LFs into a single, noise-aware training label set for our extractor. Intuitively, we'll model the LFs by observing how they overlap and conflict with each other.
In [17]:
from snorkel.learning import GenerativeModel
gen_model = GenerativeModel()
gen_model.train(L_train, epochs=100, decay=0.95, step_size=0.1 / L_train.shape[0], reg_param=1e-6)
In [18]:
gen_model.weights.lf_accuracy
Out[18]:
We now apply the generative model to the training candidates to get the noise-aware training label set. We'll refer to these as the training marginals:
In [19]:
train_marginals = gen_model.marginals(L_train)
We'll look at the distribution of the training marginals:
In [20]:
import matplotlib.pyplot as plt
plt.hist(train_marginals, bins=20)
plt.show()
We can view the learned accuracy parameters, and other statistics about the LFs learned by the generative model:
In [21]:
gen_model.learned_lf_stats()
Out[21]:
In [22]:
L_dev = labeler.apply_existing(split=1)
And finally, we get the score of the generative model:
In [23]:
tp, fp, tn, fn = gen_model.error_analysis(session, L_dev, L_gold_dev)
At this point, we should be getting an F1 score of around 0.4 to 0.5 on the development set, which is pretty good! However, we should be very careful in interpreting this. Since we developed our labeling functions using this development set as a guide, and our generative model is composed of these labeling functions, we expect it to score very well here!
In fact, it is probably somewhat overfit to this set. However this is fine, since in the next tutorial, we'll train a more powerful end extraction model which will generalize beyond the development set, and which we will evaluate on a blind test set (i.e. one we never looked at during development).
In [24]:
from snorkel.viewer import SentenceNgramViewer
# NOTE: This if-then statement is only to avoid opening the viewer during automated testing of this notebook
# You should ignore this!
import os
if 'CI' not in os.environ:
sv = SentenceNgramViewer(fn, session)
else:
sv = None
In [25]:
sv
In [26]:
c = sv.get_selected() if sv else list(fp.union(fn))[0]
c
Out[26]:
We can easily see the labels that the LFs gave to this candidate using simple ORM-enabled syntax:
In [27]:
c.labels
Out[27]:
We can also now explore some of the additional functionalities of the lf_stats
method for our dev set LF labels, L_dev
: we can plug in the gold labels that we have, and the accuracies that our generative model has learned:
In [28]:
L_dev.lf_stats(session, L_gold_dev, gen_model.learned_lf_stats()['Accuracy'])
Out[28]:
Note that for labeling functions with low coverage, our learned accuracies are closer to our prior of 70% accuracy.
In [29]:
from snorkel.annotations import save_marginals
%time save_marginals(session, L_train, train_marginals)
Next, in Part III, we'll use these probabilistic training labels to train a deep neural network.