In this example, we'll be writing an application to extract mentions of chemical-induced-disease relationships from Pubmed abstracts, as per the BioCreative CDR Challenge. This tutorial will show off some of the more advanced features of Snorkel, so we'll assume you've followed the Intro tutorial.
Let's start by reloading from the last notebook.
In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
from snorkel import SnorkelSession
session = SnorkelSession()
In [2]:
from snorkel.models import candidate_subclass
ChemicalDisease = candidate_subclass('ChemicalDisease', ['chemical', 'disease'])
train_cands = session.query(ChemicalDisease).filter(ChemicalDisease.split == 0).all()
dev_cands = session.query(ChemicalDisease).filter(ChemicalDisease.split == 1).all()
We'll use the Comparative Toxicogenomics Database (CTD) for distant supervision. The CTD lists chemical-condition entity pairs under three categories: therapy, marker, and unspecified. Therapy means the chemical treats the condition, marker means the chemical is typically present with the condition, and unspecified is...unspecified. We can write LFs based on these categories.
In [3]:
import bz2
from six.moves.cPickle import load
with bz2.BZ2File('data/ctd.pkl.bz2', 'rb') as ctd_f:
ctd_unspecified, ctd_therapy, ctd_marker = load(ctd_f)
In [4]:
def cand_in_ctd_unspecified(c):
return 1 if c.get_cids() in ctd_unspecified else 0
def cand_in_ctd_therapy(c):
return 1 if c.get_cids() in ctd_therapy else 0
def cand_in_ctd_marker(c):
return 1 if c.get_cids() in ctd_marker else 0
In [5]:
def LF_in_ctd_unspecified(c):
return -1 * cand_in_ctd_unspecified(c)
def LF_in_ctd_therapy(c):
return -1 * cand_in_ctd_therapy(c)
def LF_in_ctd_marker(c):
return cand_in_ctd_marker(c)
In [6]:
import re
from snorkel.lf_helpers import (
get_tagged_text,
rule_regex_search_tagged_text,
rule_regex_search_btw_AB,
rule_regex_search_btw_BA,
rule_regex_search_before_A,
rule_regex_search_before_B,
)
# List to parenthetical
def ltp(x):
return '(' + '|'.join(x) + ')'
def LF_induce(c):
return 1 if re.search(r'{{A}}.{0,20}induc.{0,20}{{B}}', get_tagged_text(c), flags=re.I) else 0
causal_past = ['induced', 'caused', 'due']
def LF_d_induced_by_c(c):
return rule_regex_search_btw_BA(c, '.{0,50}' + ltp(causal_past) + '.{0,9}(by|to).{0,50}', 1)
def LF_d_induced_by_c_tight(c):
return rule_regex_search_btw_BA(c, '.{0,50}' + ltp(causal_past) + ' (by|to) ', 1)
def LF_induce_name(c):
return 1 if 'induc' in c.chemical.get_span().lower() else 0
causal = ['cause[sd]?', 'induce[sd]?', 'associated with']
def LF_c_cause_d(c):
return 1 if (
re.search(r'{{A}}.{0,50} ' + ltp(causal) + '.{0,50}{{B}}', get_tagged_text(c), re.I)
and not re.search('{{A}}.{0,50}(not|no).{0,20}' + ltp(causal) + '.{0,50}{{B}}', get_tagged_text(c), re.I)
) else 0
treat = ['treat', 'effective', 'prevent', 'resistant', 'slow', 'promise', 'therap']
def LF_d_treat_c(c):
return rule_regex_search_btw_BA(c, '.{0,50}' + ltp(treat) + '.{0,50}', -1)
def LF_c_treat_d(c):
return rule_regex_search_btw_AB(c, '.{0,50}' + ltp(treat) + '.{0,50}', -1)
def LF_treat_d(c):
return rule_regex_search_before_B(c, ltp(treat) + '.{0,50}', -1)
def LF_c_treat_d_wide(c):
return rule_regex_search_btw_AB(c, '.{0,200}' + ltp(treat) + '.{0,200}', -1)
def LF_c_d(c):
return 1 if ('{{A}} {{B}}' in get_tagged_text(c)) else 0
def LF_c_induced_d(c):
return 1 if (
('{{A}} {{B}}' in get_tagged_text(c)) and
(('-induc' in c[0].get_span().lower()) or ('-assoc' in c[0].get_span().lower()))
) else 0
def LF_improve_before_disease(c):
return rule_regex_search_before_B(c, 'improv.*', -1)
pat_terms = ['in a patient with ', 'in patients with']
def LF_in_patient_with(c):
return -1 if re.search(ltp(pat_terms) + '{{B}}', get_tagged_text(c), flags=re.I) else 0
uncertain = ['combin', 'possible', 'unlikely']
def LF_uncertain(c):
return rule_regex_search_before_A(c, ltp(uncertain) + '.*', -1)
def LF_induced_other(c):
return rule_regex_search_tagged_text(c, '{{A}}.{20,1000}-induced {{B}}', -1)
def LF_far_c_d(c):
return rule_regex_search_btw_AB(c, '.{100,5000}', -1)
def LF_far_d_c(c):
return rule_regex_search_btw_BA(c, '.{100,5000}', -1)
def LF_risk_d(c):
return rule_regex_search_before_B(c, 'risk of ', 1)
def LF_develop_d_following_c(c):
return 1 if re.search(r'develop.{0,25}{{B}}.{0,25}following.{0,25}{{A}}', get_tagged_text(c), flags=re.I) else 0
procedure, following = ['inject', 'administrat'], ['following']
def LF_d_following_c(c):
return 1 if re.search('{{B}}.{0,50}' + ltp(following) + '.{0,20}{{A}}.{0,50}' + ltp(procedure), get_tagged_text(c), flags=re.I) else 0
def LF_measure(c):
return -1 if re.search('measur.{0,75}{{A}}', get_tagged_text(c), flags=re.I) else 0
def LF_level(c):
return -1 if re.search('{{A}}.{0,25} level', get_tagged_text(c), flags=re.I) else 0
def LF_neg_d(c):
return -1 if re.search('(none|not|no) .{0,25}{{B}}', get_tagged_text(c), flags=re.I) else 0
WEAK_PHRASES = ['none', 'although', 'was carried out', 'was conducted',
'seems', 'suggests', 'risk', 'implicated',
'the aim', 'to (investigate|assess|study)']
WEAK_RGX = r'|'.join(WEAK_PHRASES)
def LF_weak_assertions(c):
return -1 if re.search(WEAK_RGX, get_tagged_text(c), flags=re.I) else 0
In [7]:
def LF_ctd_marker_c_d(c):
return LF_c_d(c) * cand_in_ctd_marker(c)
def LF_ctd_marker_induce(c):
return (LF_c_induced_d(c) or LF_d_induced_by_c_tight(c)) * cand_in_ctd_marker(c)
def LF_ctd_therapy_treat(c):
return LF_c_treat_d_wide(c) * cand_in_ctd_therapy(c)
def LF_ctd_unspecified_treat(c):
return LF_c_treat_d_wide(c) * cand_in_ctd_unspecified(c)
def LF_ctd_unspecified_induce(c):
return (LF_c_induced_d(c) or LF_d_induced_by_c_tight(c)) * cand_in_ctd_unspecified(c)
In [8]:
def LF_closer_chem(c):
# Get distance between chemical and disease
chem_start, chem_end = c.chemical.get_word_start(), c.chemical.get_word_end()
dis_start, dis_end = c.disease.get_word_start(), c.disease.get_word_end()
if dis_start < chem_start:
dist = chem_start - dis_end
else:
dist = dis_start - chem_end
# Try to find chemical closer than @dist/2 in either direction
sent = c.get_parent()
closest_other_chem = float('inf')
for i in range(dis_end, min(len(sent.words), dis_end + dist // 2)):
et, cid = sent.entity_types[i], sent.entity_cids[i]
if et == 'Chemical' and cid != sent.entity_cids[chem_start]:
return -1
for i in range(max(0, dis_start - dist // 2), dis_start):
et, cid = sent.entity_types[i], sent.entity_cids[i]
if et == 'Chemical' and cid != sent.entity_cids[chem_start]:
return -1
return 0
def LF_closer_dis(c):
# Get distance between chemical and disease
chem_start, chem_end = c.chemical.get_word_start(), c.chemical.get_word_end()
dis_start, dis_end = c.disease.get_word_start(), c.disease.get_word_end()
if dis_start < chem_start:
dist = chem_start - dis_end
else:
dist = dis_start - chem_end
# Try to find chemical disease than @dist/8 in either direction
sent = c.get_parent()
for i in range(chem_end, min(len(sent.words), chem_end + dist // 8)):
et, cid = sent.entity_types[i], sent.entity_cids[i]
if et == 'Disease' and cid != sent.entity_cids[dis_start]:
return -1
for i in range(max(0, chem_start - dist // 8), chem_start):
et, cid = sent.entity_types[i], sent.entity_cids[i]
if et == 'Disease' and cid != sent.entity_cids[dis_start]:
return -1
return 0
In [9]:
LFs = [
LF_c_cause_d,
LF_c_d,
LF_c_induced_d,
LF_c_treat_d,
LF_c_treat_d_wide,
LF_closer_chem,
LF_closer_dis,
LF_ctd_marker_c_d,
LF_ctd_marker_induce,
LF_ctd_therapy_treat,
LF_ctd_unspecified_treat,
LF_ctd_unspecified_induce,
LF_d_following_c,
LF_d_induced_by_c,
LF_d_induced_by_c_tight,
LF_d_treat_c,
LF_develop_d_following_c,
LF_far_c_d,
LF_far_d_c,
LF_improve_before_disease,
LF_in_ctd_therapy,
LF_in_ctd_marker,
LF_in_patient_with,
LF_induce,
LF_induce_name,
LF_induced_other,
LF_level,
LF_measure,
LF_neg_d,
LF_risk_d,
LF_treat_d,
LF_uncertain,
LF_weak_assertions,
]
In [10]:
from snorkel.annotations import LabelAnnotator
labeler = LabelAnnotator(lfs=LFs)
In [11]:
%time L_train = labeler.apply(split=0)
L_train
Out[11]:
In [12]:
L_train.lf_stats(session)
Out[12]:
As mentioned above, we want to include the dependencies between our LFs when training the generative model. Snorkel makes it easy to do this! DependencySelector
runs a fast structure learning algorithm over the matrix of LF outputs to identify a set of likely dependencies. We can see that these match up with our prior knowledge. For example, it identified a "reinforcing" dependency between LF_c_induced_d
and LF_ctd_marker_induce
. Recall that we constructed the latter using the former.
In [13]:
from snorkel.learning.structure import DependencySelector
ds = DependencySelector()
deps = ds.select(L_train, threshold=0.1)
len(deps)
Out[13]:
Now we'll train the generative model, using the deps
argument to account for the learned dependencies. We'll also model LF propensity here, unlike the intro tutorial. In addition to learning the accuracies of the LFs, this also learns their likelihood of labeling an example.
In [14]:
from snorkel.learning import GenerativeModel
gen_model = GenerativeModel(lf_propensity=True)
gen_model.train(
L_train, deps=deps, decay=0.95, step_size=0.1/L_train.shape[0], reg_param=0.0
)
In [15]:
train_marginals = gen_model.marginals(L_train)
In [16]:
import matplotlib.pyplot as plt
plt.hist(train_marginals, bins=20)
plt.show()
In [17]:
gen_model.learned_lf_stats()
Out[17]:
In [18]:
from snorkel.annotations import save_marginals
save_marginals(session, L_train, train_marginals)
Finally, we'll run the labeler on the development set, load in some external labels, then evaluate the LF performance. The external labels are applied via a small script for convenience. It maps the document-level relation annotations found in the CDR file to mention-level labels. Note that these will not be perfect, although they are pretty good. If we wanted to keep iterating, we could use snorkel.lf_helpers.test_LF
against the dev set, or look at some false positive and false negative candidates.
In [19]:
from load_external_annotations import load_external_labels
load_external_labels(session, ChemicalDisease, split=1, annotator='gold')
In [20]:
from snorkel.annotations import load_gold_labels
L_gold_dev = load_gold_labels(session, annotator_name='gold', split=1)
L_gold_dev
Out[20]:
In [21]:
L_dev = labeler.apply_existing(split=1)
In [22]:
_ = gen_model.error_analysis(session, L_dev, L_gold_dev)
In [23]:
L_dev.lf_stats(session, L_gold_dev, gen_model.learned_lf_stats()['Accuracy'])
Out[23]: