Chemical-Disease Relation (CDR) Tutorial

In this example, we'll be writing an application to extract mentions of chemical-induced-disease relationships from Pubmed abstracts, as per the BioCreative CDR Challenge. This tutorial will show off some of the more advanced features of Snorkel, so we'll assume you've followed the Intro tutorial.

Let's start by reloading from the last notebook.


In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

from snorkel import SnorkelSession

session = SnorkelSession()

In [2]:
from snorkel.models import candidate_subclass

ChemicalDisease = candidate_subclass('ChemicalDisease', ['chemical', 'disease'])

train_cands = session.query(ChemicalDisease).filter(ChemicalDisease.split == 0).all()
dev_cands = session.query(ChemicalDisease).filter(ChemicalDisease.split == 1).all()

Part III: Writing LFs

This tutorial features some more advanced LFs than the intro tutorial, with more focus on distant supervision and dependencies between LFs.

Distant supervision approaches

We'll use the Comparative Toxicogenomics Database (CTD) for distant supervision. The CTD lists chemical-condition entity pairs under three categories: therapy, marker, and unspecified. Therapy means the chemical treats the condition, marker means the chemical is typically present with the condition, and unspecified is...unspecified. We can write LFs based on these categories.


In [3]:
import bz2
from six.moves.cPickle import load

with bz2.BZ2File('data/ctd.pkl.bz2', 'rb') as ctd_f:
    ctd_unspecified, ctd_therapy, ctd_marker = load(ctd_f)

In [4]:
def cand_in_ctd_unspecified(c):
    return 1 if c.get_cids() in ctd_unspecified else 0

def cand_in_ctd_therapy(c):
    return 1 if c.get_cids() in ctd_therapy else 0

def cand_in_ctd_marker(c):
    return 1 if c.get_cids() in ctd_marker else 0

In [5]:
def LF_in_ctd_unspecified(c):
    return -1 * cand_in_ctd_unspecified(c)

def LF_in_ctd_therapy(c):
    return -1 * cand_in_ctd_therapy(c)

def LF_in_ctd_marker(c):
    return cand_in_ctd_marker(c)

Text pattern approaches

Now we'll use some LF helpers to create LFs based on indicative text patterns. We came up with these rules by using the viewer to examine training candidates and noting frequent patterns.


In [6]:
import re
from snorkel.lf_helpers import (
    get_tagged_text,
    rule_regex_search_tagged_text,
    rule_regex_search_btw_AB,
    rule_regex_search_btw_BA,
    rule_regex_search_before_A,
    rule_regex_search_before_B,
)

# List to parenthetical
def ltp(x):
    return '(' + '|'.join(x) + ')'

def LF_induce(c):
    return 1 if re.search(r'{{A}}.{0,20}induc.{0,20}{{B}}', get_tagged_text(c), flags=re.I) else 0

causal_past = ['induced', 'caused', 'due']
def LF_d_induced_by_c(c):
    return rule_regex_search_btw_BA(c, '.{0,50}' + ltp(causal_past) + '.{0,9}(by|to).{0,50}', 1)
def LF_d_induced_by_c_tight(c):
    return rule_regex_search_btw_BA(c, '.{0,50}' + ltp(causal_past) + ' (by|to) ', 1)

def LF_induce_name(c):
    return 1 if 'induc' in c.chemical.get_span().lower() else 0     

causal = ['cause[sd]?', 'induce[sd]?', 'associated with']
def LF_c_cause_d(c):
    return 1 if (
        re.search(r'{{A}}.{0,50} ' + ltp(causal) + '.{0,50}{{B}}', get_tagged_text(c), re.I)
        and not re.search('{{A}}.{0,50}(not|no).{0,20}' + ltp(causal) + '.{0,50}{{B}}', get_tagged_text(c), re.I)
    ) else 0

treat = ['treat', 'effective', 'prevent', 'resistant', 'slow', 'promise', 'therap']
def LF_d_treat_c(c):
    return rule_regex_search_btw_BA(c, '.{0,50}' + ltp(treat) + '.{0,50}', -1)
def LF_c_treat_d(c):
    return rule_regex_search_btw_AB(c, '.{0,50}' + ltp(treat) + '.{0,50}', -1)
def LF_treat_d(c):
    return rule_regex_search_before_B(c, ltp(treat) + '.{0,50}', -1)
def LF_c_treat_d_wide(c):
    return rule_regex_search_btw_AB(c, '.{0,200}' + ltp(treat) + '.{0,200}', -1)

def LF_c_d(c):
    return 1 if ('{{A}} {{B}}' in get_tagged_text(c)) else 0

def LF_c_induced_d(c):
    return 1 if (
        ('{{A}} {{B}}' in get_tagged_text(c)) and 
        (('-induc' in c[0].get_span().lower()) or ('-assoc' in c[0].get_span().lower()))
        ) else 0

def LF_improve_before_disease(c):
    return rule_regex_search_before_B(c, 'improv.*', -1)

pat_terms = ['in a patient with ', 'in patients with']
def LF_in_patient_with(c):
    return -1 if re.search(ltp(pat_terms) + '{{B}}', get_tagged_text(c), flags=re.I) else 0

uncertain = ['combin', 'possible', 'unlikely']
def LF_uncertain(c):
    return rule_regex_search_before_A(c, ltp(uncertain) + '.*', -1)

def LF_induced_other(c):
    return rule_regex_search_tagged_text(c, '{{A}}.{20,1000}-induced {{B}}', -1)

def LF_far_c_d(c):
    return rule_regex_search_btw_AB(c, '.{100,5000}', -1)

def LF_far_d_c(c):
    return rule_regex_search_btw_BA(c, '.{100,5000}', -1)

def LF_risk_d(c):
    return rule_regex_search_before_B(c, 'risk of ', 1)

def LF_develop_d_following_c(c):
    return 1 if re.search(r'develop.{0,25}{{B}}.{0,25}following.{0,25}{{A}}', get_tagged_text(c), flags=re.I) else 0

procedure, following = ['inject', 'administrat'], ['following']
def LF_d_following_c(c):
    return 1 if re.search('{{B}}.{0,50}' + ltp(following) + '.{0,20}{{A}}.{0,50}' + ltp(procedure), get_tagged_text(c), flags=re.I) else 0

def LF_measure(c):
    return -1 if re.search('measur.{0,75}{{A}}', get_tagged_text(c), flags=re.I) else 0

def LF_level(c):
    return -1 if re.search('{{A}}.{0,25} level', get_tagged_text(c), flags=re.I) else 0

def LF_neg_d(c):
    return -1 if re.search('(none|not|no) .{0,25}{{B}}', get_tagged_text(c), flags=re.I) else 0

WEAK_PHRASES = ['none', 'although', 'was carried out', 'was conducted',
                'seems', 'suggests', 'risk', 'implicated',
               'the aim', 'to (investigate|assess|study)']

WEAK_RGX = r'|'.join(WEAK_PHRASES)

def LF_weak_assertions(c):
    return -1 if re.search(WEAK_RGX, get_tagged_text(c), flags=re.I) else 0

Composite LFs

The following LFs take some of the strongest distant supervision and text pattern LFs, and combine them to form more specific LFs. These LFs introduce some obvious dependencies within the LF set, which we will model later.


In [7]:
def LF_ctd_marker_c_d(c):
    return LF_c_d(c) * cand_in_ctd_marker(c)

def LF_ctd_marker_induce(c):
    return (LF_c_induced_d(c) or LF_d_induced_by_c_tight(c)) * cand_in_ctd_marker(c)

def LF_ctd_therapy_treat(c):
    return LF_c_treat_d_wide(c) * cand_in_ctd_therapy(c)

def LF_ctd_unspecified_treat(c):
    return LF_c_treat_d_wide(c) * cand_in_ctd_unspecified(c)

def LF_ctd_unspecified_induce(c):
    return (LF_c_induced_d(c) or LF_d_induced_by_c_tight(c)) * cand_in_ctd_unspecified(c)

Rules based on context hierarchy

These last two rules will make use of the context hierarchy. The first checks if there is a chemical mention much closer to the candidate's disease mention than the candidate's chemical mention. The second does the analog for diseases.


In [8]:
def LF_closer_chem(c):
    # Get distance between chemical and disease
    chem_start, chem_end = c.chemical.get_word_start(), c.chemical.get_word_end()
    dis_start, dis_end = c.disease.get_word_start(), c.disease.get_word_end()
    if dis_start < chem_start:
        dist = chem_start - dis_end
    else:
        dist = dis_start - chem_end
    # Try to find chemical closer than @dist/2 in either direction
    sent = c.get_parent()
    closest_other_chem = float('inf')
    for i in range(dis_end, min(len(sent.words), dis_end + dist // 2)):
        et, cid = sent.entity_types[i], sent.entity_cids[i]
        if et == 'Chemical' and cid != sent.entity_cids[chem_start]:
            return -1
    for i in range(max(0, dis_start - dist // 2), dis_start):
        et, cid = sent.entity_types[i], sent.entity_cids[i]
        if et == 'Chemical' and cid != sent.entity_cids[chem_start]:
            return -1
    return 0

def LF_closer_dis(c):
    # Get distance between chemical and disease
    chem_start, chem_end = c.chemical.get_word_start(), c.chemical.get_word_end()
    dis_start, dis_end = c.disease.get_word_start(), c.disease.get_word_end()
    if dis_start < chem_start:
        dist = chem_start - dis_end
    else:
        dist = dis_start - chem_end
    # Try to find chemical disease than @dist/8 in either direction
    sent = c.get_parent()
    for i in range(chem_end, min(len(sent.words), chem_end + dist // 8)):
        et, cid = sent.entity_types[i], sent.entity_cids[i]
        if et == 'Disease' and cid != sent.entity_cids[dis_start]:
            return -1
    for i in range(max(0, chem_start - dist // 8), chem_start):
        et, cid = sent.entity_types[i], sent.entity_cids[i]
        if et == 'Disease' and cid != sent.entity_cids[dis_start]:
            return -1
    return 0

Running the LFs on the training set


In [9]:
LFs = [
    LF_c_cause_d,
    LF_c_d,
    LF_c_induced_d,
    LF_c_treat_d,
    LF_c_treat_d_wide,
    LF_closer_chem,
    LF_closer_dis,
    LF_ctd_marker_c_d,
    LF_ctd_marker_induce,
    LF_ctd_therapy_treat,
    LF_ctd_unspecified_treat,
    LF_ctd_unspecified_induce,
    LF_d_following_c,
    LF_d_induced_by_c,
    LF_d_induced_by_c_tight,
    LF_d_treat_c,
    LF_develop_d_following_c,
    LF_far_c_d,
    LF_far_d_c,
    LF_improve_before_disease,
    LF_in_ctd_therapy,
    LF_in_ctd_marker,
    LF_in_patient_with,
    LF_induce,
    LF_induce_name,
    LF_induced_other,
    LF_level,
    LF_measure,
    LF_neg_d,
    LF_risk_d,
    LF_treat_d,
    LF_uncertain,
    LF_weak_assertions,
]

In [10]:
from snorkel.annotations import LabelAnnotator
labeler = LabelAnnotator(lfs=LFs)

In [11]:
%time L_train = labeler.apply(split=0)
L_train


Clearing existing...
Running UDF...
[========================================] 100%

CPU times: user 1min 4s, sys: 434 ms, total: 1min 4s
Wall time: 1min 4s
Out[11]:
<8272x33 sparse matrix of type '<type 'numpy.int64'>'
	with 20079 stored elements in Compressed Sparse Row format>

In [12]:
L_train.lf_stats(session)


Out[12]:
j Coverage Overlaps Conflicts
LF_c_cause_d 0 0.032519 0.029014 0.012935
LF_c_d 1 0.092602 0.088733 0.026838
LF_c_induced_d 2 0.070358 0.070358 0.021518
LF_c_treat_d 3 0.045938 0.045938 0.019221
LF_c_treat_d_wide 4 0.086315 0.085469 0.037476
LF_closer_chem 5 0.193303 0.175290 0.099613
LF_closer_dis 6 0.018133 0.017287 0.012935
LF_ctd_marker_c_d 7 0.085348 0.085348 0.026233
LF_ctd_marker_induce 8 0.084260 0.084260 0.027805
LF_ctd_therapy_treat 9 0.046301 0.046301 0.017771
LF_ctd_unspecified_treat 10 0.055005 0.055005 0.027200
LF_ctd_unspecified_induce 11 0.068424 0.068424 0.024420
LF_d_following_c 12 0.000484 0.000484 0.000000
LF_d_induced_by_c 13 0.037234 0.034574 0.015716
LF_d_induced_by_c_tight 14 0.017892 0.017892 0.006407
LF_d_treat_c 15 0.028288 0.024903 0.015474
LF_develop_d_following_c 16 0.000846 0.000846 0.000484
LF_far_c_d 17 0.108317 0.096591 0.054279
LF_far_d_c 18 0.080513 0.069995 0.044246
LF_improve_before_disease 19 0.001572 0.001451 0.000725
LF_in_ctd_therapy 20 0.296905 0.259067 0.174565
LF_in_ctd_marker 21 0.610977 0.455754 0.344778
LF_in_patient_with 22 0.001572 0.001088 0.000725
LF_induce 23 0.020672 0.020068 0.010759
LF_induce_name 24 0.119923 0.114241 0.053675
LF_induced_other 25 0.041465 0.040982 0.019584
LF_level 26 0.007253 0.005440 0.003022
LF_measure 27 0.003627 0.003022 0.002660
LF_neg_d 28 0.018133 0.015353 0.011485
LF_risk_d 29 0.004836 0.004836 0.004836
LF_treat_d 30 0.021881 0.019826 0.011122
LF_uncertain 31 0.018375 0.017045 0.008100
LF_weak_assertions 32 0.108075 0.094898 0.066248

Part IV: Training the generative model

As mentioned above, we want to include the dependencies between our LFs when training the generative model. Snorkel makes it easy to do this! DependencySelector runs a fast structure learning algorithm over the matrix of LF outputs to identify a set of likely dependencies. We can see that these match up with our prior knowledge. For example, it identified a "reinforcing" dependency between LF_c_induced_d and LF_ctd_marker_induce. Recall that we constructed the latter using the former.


In [13]:
from snorkel.learning.structure import DependencySelector
ds = DependencySelector()
deps = ds.select(L_train, threshold=0.1)
len(deps)


Out[13]:
215

Now we'll train the generative model, using the deps argument to account for the learned dependencies. We'll also model LF propensity here, unlike the intro tutorial. In addition to learning the accuracies of the LFs, this also learns their likelihood of labeling an example.


In [14]:
from snorkel.learning import GenerativeModel

gen_model = GenerativeModel(lf_propensity=True)
gen_model.train(
    L_train, deps=deps, decay=0.95, step_size=0.1/L_train.shape[0], reg_param=0.0
)


Inferred cardinality: 2

In [15]:
train_marginals = gen_model.marginals(L_train)

In [16]:
import matplotlib.pyplot as plt
plt.hist(train_marginals, bins=20)
plt.show()



In [17]:
gen_model.learned_lf_stats()


Out[17]:
Accuracy Coverage Precision Recall
0 0.375000 0.0048 0.454545 0.002006
1 0.566524 0.0233 0.565574 0.013839
2 0.584906 0.0212 0.627451 0.012836
3 0.560000 0.0025 0.300000 0.000602
4 0.571429 0.0231 0.538462 0.011231
5 0.596539 0.1098 0.602941 0.065784
6 0.500000 0.0016 0.600000 0.000602
7 0.551724 0.0232 0.513043 0.011833
8 0.714286 0.0042 0.583333 0.002808
9 0.444444 0.0027 0.368421 0.001404
10 0.409091 0.0066 0.394737 0.003008
11 0.419355 0.0031 0.562500 0.001805
12 0.000000 0.0002 0.000000 0.000000
13 0.500000 0.0002 1.000000 0.000201
14 NaN 0.0000 NaN 0.000000
15 0.560000 0.0050 0.583333 0.002808
16 0.400000 0.0005 0.000000 0.000000
17 0.516854 0.0089 0.630435 0.005816
18 0.638655 0.0119 0.661538 0.008624
19 0.333333 0.0006 0.000000 0.000000
20 0.594558 0.1470 0.600000 0.090854
21 0.723925 0.6625 0.724148 0.477537
22 0.444444 0.0009 0.250000 0.000201
23 0.541667 0.0024 0.500000 0.001604
24 0.566586 0.0413 0.579208 0.023466
25 0.486486 0.0037 0.500000 0.001404
26 0.500000 0.0006 0.250000 0.000201
27 0.500000 0.0008 0.000000 0.000000
28 0.304348 0.0023 0.300000 0.000602
29 1.000000 0.0003 1.000000 0.000401
30 0.583333 0.0012 0.600000 0.000602
31 0.600000 0.0010 0.500000 0.000401
32 0.528302 0.0212 0.470000 0.009426

In [18]:
from snorkel.annotations import save_marginals
save_marginals(session, L_train, train_marginals)


Saved 8272 marginals

Checking performance against development set labels

Finally, we'll run the labeler on the development set, load in some external labels, then evaluate the LF performance. The external labels are applied via a small script for convenience. It maps the document-level relation annotations found in the CDR file to mention-level labels. Note that these will not be perfect, although they are pretty good. If we wanted to keep iterating, we could use snorkel.lf_helpers.test_LF against the dev set, or look at some false positive and false negative candidates.


In [19]:
from load_external_annotations import load_external_labels
load_external_labels(session, ChemicalDisease, split=1, annotator='gold')


AnnotatorLabels created: 888

In [20]:
from snorkel.annotations import load_gold_labels
L_gold_dev = load_gold_labels(session, annotator_name='gold', split=1)
L_gold_dev


Out[20]:
<888x1 sparse matrix of type '<type 'numpy.int64'>'
	with 888 stored elements in Compressed Sparse Row format>

In [21]:
L_dev = labeler.apply_existing(split=1)


Clearing existing...
Running UDF...
[========================================] 100%


In [22]:
_ = gen_model.error_analysis(session, L_dev, L_gold_dev)


========================================
Scores (Un-adjusted)
========================================
Pos. class accuracy: 0.932
Neg. class accuracy: 0.584
Precision            0.529
Recall               0.932
F1                   0.675
----------------------------------------
TP: 276 | FP: 246 | TN: 346 | FN: 20
========================================


In [23]:
L_dev.lf_stats(session, L_gold_dev, gen_model.learned_lf_stats()['Accuracy'])


Out[23]:
j Coverage Overlaps Conflicts TP FP FN TN Empirical Acc. Learned Acc.
LF_c_cause_d 0 0.036036 0.034910 0.012387 22 10 0 0 0.687500 0.596154
LF_c_d 1 0.082207 0.078829 0.024775 47 26 0 0 0.643836 0.556738
LF_c_induced_d 2 0.059685 0.059685 0.021396 36 17 0 0 0.679245 0.585774
LF_c_treat_d 3 0.031532 0.031532 0.018018 0 0 9 19 0.678571 0.629630
LF_c_treat_d_wide 4 0.076577 0.075450 0.042793 0 0 17 51 0.750000 0.573394
LF_closer_chem 5 0.218468 0.201577 0.128378 0 0 57 137 0.706186 0.603246
LF_closer_dis 6 0.011261 0.011261 0.006757 0 0 2 8 0.800000 0.692308
LF_ctd_marker_c_d 7 0.072072 0.072072 0.023649 47 17 0 0 0.734375 0.583333
LF_ctd_marker_induce 8 0.067568 0.067568 0.027027 45 15 0 0 0.750000 0.760870
LF_ctd_therapy_treat 9 0.041667 0.041667 0.028153 0 0 8 29 0.783784 0.615385
LF_ctd_unspecified_treat 10 0.057432 0.057432 0.033784 0 0 16 35 0.686275 0.513158
LF_ctd_unspecified_induce 11 0.061937 0.061937 0.025901 39 16 0 0 0.709091 0.551020
LF_d_following_c 12 0.000000 0.000000 0.000000 0 0 0 0 NaN 0.800000
LF_d_induced_by_c 13 0.033784 0.033784 0.019144 21 9 0 0 0.700000 0.500000
LF_d_induced_by_c_tight 14 0.015766 0.015766 0.006757 9 5 0 0 0.642857 0.500000
LF_d_treat_c 15 0.045045 0.037162 0.027027 0 0 16 24 0.600000 0.394737
LF_develop_d_following_c 16 0.000000 0.000000 0.000000 0 0 0 0 NaN 0.600000
LF_far_c_d 17 0.147523 0.132883 0.084459 0 0 28 103 0.786260 0.567901
LF_far_d_c 18 0.077703 0.072072 0.046171 0 0 26 43 0.623188 0.564103
LF_improve_before_disease 19 0.007883 0.007883 0.004505 0 0 1 6 0.857143 0.500000
LF_in_ctd_therapy 20 0.316441 0.274775 0.209459 0 0 93 188 0.669039 0.601782
LF_in_ctd_marker 21 0.599099 0.480856 0.378378 294 238 0 0 0.552632 0.716942
LF_in_patient_with 22 0.000000 0.000000 0.000000 0 0 0 0 NaN 1.000000
LF_induce 23 0.027027 0.027027 0.011261 16 8 0 0 0.666667 0.625000
LF_induce_name 24 0.091216 0.087838 0.037162 48 33 0 0 0.592593 0.599490
LF_induced_other 25 0.034910 0.034910 0.016892 0 0 9 22 0.709677 0.564103
LF_level 26 0.020270 0.015766 0.011261 0 0 6 12 0.666667 0.857143
LF_measure 27 0.002252 0.002252 0.001126 0 0 0 2 1.000000 0.800000
LF_neg_d 28 0.010135 0.009009 0.004505 0 0 2 7 0.777778 0.444444
LF_risk_d 29 0.001126 0.001126 0.001126 1 0 0 0 1.000000 0.666667
LF_treat_d 30 0.013514 0.013514 0.003378 0 0 1 11 0.916667 0.600000
LF_uncertain 31 0.027027 0.025901 0.014640 0 0 2 22 0.916667 0.461538
LF_weak_assertions 32 0.129505 0.113739 0.068694 0 0 34 81 0.704348 0.563107