GridSearch
Note: Currently there are issues with running in a notebook where other models are run (see Issue #707), so running here in separate notebook
In [ ]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import os
os.environ['SNORKELDB'] = 'sqlite:///{0}{1}crowdsourcing.db'.format(os.getcwd(), os.sep)
from snorkel import SnorkelSession
session = SnorkelSession()
In [ ]:
from snorkel.models import candidate_subclass
from snorkel.contrib.models.text import RawText
Tweet = candidate_subclass('Tweet', ['tweet'], cardinality=5)
train_tweets = session.query(Tweet).filter(Tweet.split == 0).order_by(Tweet.id).all()
len(train_tweets)
In [ ]:
from snorkel.annotations import load_marginals
train_marginals = load_marginals(session, train_tweets, split=0)
train_marginals.shape
In [ ]:
import numpy as np
test_labels = np.load('crowdsourcing_test_labels.npy')
In [ ]:
from snorkel.learning.utils import GridSearch
from snorkel.learning.tensorflow import TextRNN
test_tweets = session.query(Tweet).filter(Tweet.split == 1).order_by(Tweet.id).all()
# Searching over learning rate and dropout
param_ranges = {'lr': [1e-3, 1e-4], 'dropout': [0.0, 0.2]}
model_class_params = {'seed' : 123, 'cardinality': Tweet.cardinality, 'n_threads': 2}
model_hyperparams = {
'dim': 100,
'n_epochs': 20,
'print_freq': 10
}
searcher = GridSearch(TextRNN, param_ranges, train_tweets, train_marginals,
model_class_params=model_class_params, model_hyperparams=model_hyperparams)
# Use test set here (just for testing)
lstm, run_stats = searcher.fit(test_tweets, test_labels, n_threads=2)
In [ ]:
acc = lstm.score(test_tweets, test_labels)
print(acc)
assert acc > 0.60