Snorkel Workshop: Extracting Spouse Relations
from the News

Part 3: Training the Generative Model

Now, we'll train a model of the LFs to estimate their accuracies. Once the model is trained, we can combine the outputs of the LFs into a single, noise-aware training label set for our extractor. Intuitively, we'll model the LFs by observing how they overlap and conflict with each other.


In [ ]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import os
import re
import numpy as np

# Connect to the database backend and initalize a Snorkel session
from lib.init import *
from snorkel.models import candidate_subclass
from snorkel.annotations import load_gold_labels

from snorkel.lf_helpers import (
    get_left_tokens, get_right_tokens, get_between_tokens,
    get_text_between, get_tagged_text,
)

# initialize our candidate type definition
Spouse = candidate_subclass('Spouse', ['person1', 'person2'])

# gold (human-labeled) development set labels
L_gold_dev = load_gold_labels(session, annotator_name='gold', split=1)

I. Loading Labeling Matricies

First we'll load our label matrices from notebook 2


In [ ]:
from snorkel.annotations import LabelAnnotator

labeler = LabelAnnotator()
L_train = labeler.load_matrix(session, split=0)
L_dev   = labeler.load_matrix(session, split=1)

Now we set up and run the hyperparameter search, training our model with different hyperparamters and picking the best model configuration to keep. We'll set the random seed to maintain reproducibility.

Note that we are fitting our model's parameters to the training set generated by our labeling functions, while we are picking hyperparamters with respect to score over the development set labels which we created by hand.

II: Unifying supervision

A. Majority Vote

The most simple way to unify the output of all your LFs is by computed the unweighted majority vote.


In [ ]:
from lib.scoring import *

majority_vote_score(L_dev, L_gold_dev)

B. Generative Model

In data programming, we use a more sophisitcated model to unify our labeling functions. We know that these labeling functions will not be perfect, and some may be quite low-quality, so we will model their accuracies with a generative model, which Snorkel will help us easily apply.

This will ultimately produce a single set of noise-aware training labels, which we will then use to train an end extraction model in the next notebook. For more technical details of this overall approach, see our NIPS 2016 paper.

1. Training the Model

When training the generative model, we'll tune our hyperparamters using a simple grid search.

Parameter Definitions

epochs     A single pass through all the data in your training set
step_size  The factor by which we update model weights after computing the gradient
decay      The rate our update factor dimishes (decay) over time.

In [ ]:
from snorkel.learning import GenerativeModel
from snorkel.learning import RandomSearch, ListParameter, RangeParameter

# use grid search to optimize the generative model
step_size_param     = ListParameter('step_size', [0.1 / L_train.shape[0], 1e-5])
decay_param         = ListParameter('decay', [0.9, 0.95])
epochs_param        = ListParameter('epochs', [10, 50])
reg_param           = ListParameter('reg_param', [1e-3, 1e-6])
prior_param         = ListParameter('LF_acc_prior_weight_default', [1.0, 0.9, 0.8])

# search for the best model
param_grid = [step_size_param, decay_param, epochs_param, reg_param, prior_param]
searcher = RandomSearch(GenerativeModel, param_grid, L_train, n=10, lf_propensity=False)
%time gen_model, run_stats = searcher.fit(L_dev, L_gold_dev, deps=set())

run_stats

2. Model Accuracies

These are the weights learned for each LF


In [ ]:
L_dev.lf_stats(session, L_gold_dev, gen_model.learned_lf_stats()['Accuracy'])

In [ ]:
train_marginals = gen_model.marginals(L_train)

3. Plotting Marginal Probabilities

One immediate santity check you can peform using the generative model is to visually examine the distribution of predicted training marginals. Ideally, there should get a bimodal distribution with large seperation between each peaks, as shown below by the far right image. The corresponds to good signal for true and positive class labels. For your first Snorkel application, you'll probably see marginals closer to the far left or middle images. With all mass centered around p=0.5, you probably need to write more LFs got get more overall coverage. In the middle image, you have good negative coverage, but not enough positive LFs


In [ ]:
import matplotlib.pyplot as plt
plt.hist(train_marginals, bins=20, range=(0.0, 1.0))
plt.show()

4. Generative Model Metrics


In [ ]:
dev_marginals = gen_model.marginals(L_dev)
_, _, _, _ = gen_model.error_analysis(session, L_dev, L_gold_dev)

5. Saving our training labels

Finally, we'll save the training_marginals, which are our "noise-aware training labels", so that we can use them in the next tutorial to train our end extraction model:


In [ ]:
from snorkel.annotations import save_marginals
%time save_marginals(session, L_train, train_marginals)

III. Advanced Generative Model Features

A. Structure Learning

We may also want to include the dependencies between our LFs when training the generative model. Snorkel makes it easy to do this! DependencySelector runs a fast structure learning algorithm over the matrix of LF outputs to identify a set of likely dependencies.


In [ ]:
from snorkel.learning.structure import DependencySelector

MAX_DEPS = 5

ds = DependencySelector()
deps = ds.select(L_train, threshold=0.1)
deps = set(list(deps)[0:min(len(deps), MAX_DEPS)])

print "Using {} dependencies".format(len(deps))

Now train the generative model with dependencies, we just pass in the above set as the deps argument to our model train function.

searcher = RandomSearch(GenerativeModel, param_grid, L_train, n=4, lf_propensity=False)
gen_model, run_stats = searcher.fit(L_dev, L_gold_dev, deps=deps)
run_stats