Testing Parallel GridSearch

Note: Currently there are issues with running in a notebook where other models are run (see Issue #707), so running here in separate notebook


In [ ]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import os
os.environ['SNORKELDB'] = 'sqlite:///{0}{1}crowdsourcing.db'.format(os.getcwd(), os.sep)

from snorkel import SnorkelSession
session = SnorkelSession()

Load candidates and training marginals


In [ ]:
from snorkel.models import candidate_subclass
from snorkel.contrib.models.text import RawText
Tweet = candidate_subclass('Tweet', ['tweet'], cardinality=5)
train_tweets = session.query(Tweet).filter(Tweet.split == 0).order_by(Tweet.id).all()
len(train_tweets)

In [ ]:
from snorkel.annotations import load_marginals
train_marginals = load_marginals(session, train_tweets, split=0)
train_marginals.shape

In [ ]:
import numpy as np
test_labels = np.load('crowdsourcing_test_labels.npy')

In [ ]:
from snorkel.learning.utils import GridSearch
from snorkel.learning.tensorflow import TextRNN

test_tweets = session.query(Tweet).filter(Tweet.split == 1).order_by(Tweet.id).all()

# Searching over learning rate and dropout
param_ranges = {'lr': [1e-3, 1e-4], 'dropout': [0.0, 0.2]}
model_class_params = {'seed' : 123, 'cardinality': Tweet.cardinality, 'n_threads': 2}
model_hyperparams = {
    'dim':        100,
    'n_epochs':   20,
    'print_freq': 10
}
searcher = GridSearch(TextRNN, param_ranges, train_tweets, train_marginals,
    model_class_params=model_class_params, model_hyperparams=model_hyperparams)

# Use test set here (just for testing)
lstm, run_stats = searcher.fit(test_tweets, test_labels, n_threads=2)

In [ ]:
acc = lstm.score(test_tweets, test_labels)
print(acc)
assert acc > 0.60