In this example, we'll be writing an application to extract mentions of chemical-induced-disease relationships from Pubmed abstracts, as per the BioCreative CDR Challenge. This tutorial will show off some of the more advanced features of Snorkel, so we'll assume you've followed the Intro tutorial.
Let's start by reloading from the last notebook.
In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import numpy as np
from snorkel import SnorkelSession
session = SnorkelSession()
In [2]:
from snorkel.models import candidate_subclass
ChemicalDisease = candidate_subclass('ChemicalDisease', ['chemical', 'disease'])
train = session.query(ChemicalDisease).filter(ChemicalDisease.split == 0).all()
dev = session.query(ChemicalDisease).filter(ChemicalDisease.split == 1).all()
test = session.query(ChemicalDisease).filter(ChemicalDisease.split == 2).all()
print('Training set:\t{0} candidates'.format(len(train)))
print('Dev set:\t{0} candidates'.format(len(dev)))
print('Test set:\t{0} candidates'.format(len(test)))
In the intro tutorial, we automatically featurized the candidates and trained a linear model over these features. Here, we'll train a more complicated model for relation extraction: an LSTM network. You can read more about LSTMs here or here. An LSTM is a type of recurrent neural network and automatically generates a numerical representation for the candidate based on the sentence text, so no need for featurizing explicitly as in the intro tutorial. LSTMs take longer to train, and Snorkel doesn't currently support hyperparameter searches for them. We'll train a single model here, but feel free to try out other parameter sets. Just make sure to use the development set - and not the test set - for model selection.
Note: Again, training for more epochs than below will greatly improve performance- try it out!
In [3]:
from snorkel.annotations import load_marginals
train_marginals = load_marginals(session, split=0)
In [4]:
from snorkel.annotations import load_gold_labels
L_gold_dev = load_gold_labels(session, annotator_name='gold', split=1)
In [5]:
from snorkel.learning.pytorch import LSTM
train_kwargs = {
'lr': 0.01,
'embedding_dim': 100,
'hidden_dim': 100,
'n_epochs': 20,
'dropout': 0.5,
'rebalance': 0.25,
'print_freq': 5,
'seed': 1701
}
lstm = LSTM(n_threads=None)
lstm.train(train, train_marginals, X_dev=dev, Y_dev=L_gold_dev, **train_kwargs)
In [6]:
from load_external_annotations import load_external_labels
load_external_labels(session, ChemicalDisease, split=2, annotator='gold')
L_gold_test = load_gold_labels(session, annotator_name='gold', split=2)
L_gold_test
Out[6]:
In [7]:
lstm.score(test, L_gold_test)
Out[7]: