Testing TFNoiseAwareModel

We'll start by testing the textRNN model on a categorical problem from tutorials/crowdsourcing. In particular we'll test for (a) basic performance and (b) proper construction / re-construction of the TF computation graph both after (i) repeated notebook calls, and (ii) with GridSearch in particular.


In [ ]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import os
os.environ['SNORKELDB'] = 'sqlite:///{0}{1}crowdsourcing.db'.format(os.getcwd(), os.sep)

from snorkel import SnorkelSession
session = SnorkelSession()

Load candidates and training marginals


In [ ]:
from snorkel.models import candidate_subclass
from snorkel.contrib.models.text import RawText
Tweet = candidate_subclass('Tweet', ['tweet'], cardinality=5)
train_tweets = session.query(Tweet).filter(Tweet.split == 0).order_by(Tweet.id).all()
len(train_tweets)

In [ ]:
from snorkel.annotations import load_marginals
train_marginals = load_marginals(session, train_tweets, split=0)
train_marginals.shape

Train LogisticRegression


In [ ]:
# Simple unigram featurizer
def get_unigram_tweet_features(c):
    for w in c.tweet.text.split():
        yield w, 1

# Construct feature matrix
from snorkel.annotations import FeatureAnnotator
featurizer = FeatureAnnotator(f=get_unigram_tweet_features)

%time F_train = featurizer.apply(split=0)
F_train

In [ ]:
%time F_test = featurizer.apply_existing(split=1)
F_test

In [ ]:
from snorkel.learning.tensorflow import LogisticRegression

model = LogisticRegression(cardinality=Tweet.cardinality)
model.train(F_train.todense(), train_marginals)

Train SparseLogisticRegression

Note: Testing doesn't currently work with LogisticRegression above, but no real reason to use that over this...


In [ ]:
from snorkel.learning.tensorflow import SparseLogisticRegression

model = SparseLogisticRegression(cardinality=Tweet.cardinality)
model.train(F_train, train_marginals, n_epochs=50, print_freq=10)

In [ ]:
import numpy as np
test_labels = np.load('crowdsourcing_test_labels.npy')
acc = model.score(F_test, test_labels)
print(acc)
assert acc > 0.6

In [ ]:
# Test with batch size s.t. N % batch_size == 1...
model.score(F_test, test_labels, batch_size=9)

Train basic LSTM

With dev set scoring during execution (note we use test set here to be simple)


In [ ]:
from snorkel.learning.tensorflow import TextRNN
test_tweets = session.query(Tweet).filter(Tweet.split == 1).order_by(Tweet.id).all()

train_kwargs = {
    'dim':        100,
    'lr':         0.001,
    'n_epochs':   25,
    'dropout':    0.2,
    'print_freq': 5
}
lstm = TextRNN(seed=123, cardinality=Tweet.cardinality)
lstm.train(train_tweets, train_marginals, X_dev=test_tweets, Y_dev=test_labels, **train_kwargs)

In [ ]:
acc = lstm.score(test_tweets, test_labels)
print(acc)
assert acc > 0.60

In [ ]:
# Test with batch size s.t. N % batch_size == 1...
lstm.score(test_tweets, test_labels, batch_size=9)

Run GridSearch


In [ ]:
from snorkel.learning.utils import GridSearch

# Searching over learning rate
param_ranges = {'lr': [1e-3, 1e-4], 'dim': [50, 100]}
model_class_params = {'seed' : 123, 'cardinality': Tweet.cardinality}
model_hyperparams = {
    'dim':        100,
    'n_epochs':   20,
    'dropout':    0.1,
    'print_freq': 10
}
searcher = GridSearch(TextRNN, param_ranges, train_tweets, train_marginals,
                     model_class_params=model_class_params,
                     model_hyperparams=model_hyperparams)

# Use test set here (just for testing)
lstm, run_stats = searcher.fit(test_tweets, test_labels)

In [ ]:
acc = lstm.score(test_tweets, test_labels)
print(acc)
assert acc > 0.60

Reload saved model outside of GridSearch


In [ ]:
lstm = TextRNN(seed=123, cardinality=Tweet.cardinality)
lstm.load('TextRNN_best', save_dir='checkpoints/grid_search')
acc = lstm.score(test_tweets, test_labels)
print(acc)
assert acc > 0.60

Reload a model with different structure


In [ ]:
lstm.load('TextRNN_0', save_dir='checkpoints/grid_search')
acc = lstm.score(test_tweets, test_labels)
print(acc)
assert acc < 0.60

Testing GenerativeModel

Testing GridSearch on crowdsourcing data


In [ ]:
from snorkel.annotations import load_label_matrix
import numpy as np

L_train = load_label_matrix(session, split=0)
train_labels = np.load('crowdsourcing_train_labels.npy')

In [ ]:
from snorkel.learning import GenerativeModel

# Searching over learning rate
searcher = GridSearch(GenerativeModel, {'epochs': [0, 10, 30]}, L_train)

# Use training set labels here (just for testing)
gen_model, run_stats = searcher.fit(L_train, train_labels)

In [ ]:
acc = gen_model.score(L_train, train_labels)
print(acc)
assert acc > 0.97