Project:: Mars Target Encyclopedia
This notebook does not explain much, however, the exaplanations are found in the original notebook(s) https://github.com/HazyResearch/snorkel/tree/master/tutorials/intro
./run.sh as described in snorkel README
In [2]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
from snorkel import SnorkelSession
import os
import numpy as np
import re
import codecs
os.environ['SNORKELDB'] = 'sqlite:///snorkel-mte.db'
In [3]:
# Open Session
session = SnorkelSession()
In [4]:
# Read input
base_dir = '/Users/thammegr/work/mte/data/newcorpus/MTE-corpus-open/'
def scan_docs(dir):
txt_filter = lambda _: re.match("^[0-9]{4}\.txt$", _)
for root, dirs, files in os.walk(dir):
for f in filter(txt_filter, files):
txt_path = os.path.join(root, f)
ann_path = txt_path.replace('.txt', '.ann')
parts = ann_path.split(os.path.sep)
parts[-2] += "-reviewed-target" # directory name
new_ann_path = os.path.sep.join(parts)
if os.path.exists(new_ann_path):
ann_path = new_ann_path
yield (txt_path, ann_path)
corpus_file = "mte-corpus.list"
with open(corpus_file, 'w') as f:
count = 0
for rec in scan_docs(base_dir):
f.write(",".join(rec))
f.write("\n")
count += 1
print("Wrote %d records to %s" %(count, corpus_file))
In [67]:
# sample 100 docs to setup whole pipeline first
!head -30 mte-corpus.list > mte-corpus-head.list
corpus_file = "mte-corpus-head.list"
!wc -l *.list
In [ ]:
from snorkel.parser import CSVPathsPreprocessor
doc_preprocessor = CSVPathsPreprocessor(path=corpus_file, column=0, delim=',')
#doc_preprocessor = CSVPathsPreprocessor("paths-sample.list")
# Corpus parser to get features
from snorkel.parser import CorpusParser
corpus_parser = CorpusParser()
%time corpus_parser.apply(doc_preprocessor)
In [68]:
from snorkel.models import Document, Sentence
print "Documents:", session.query(Document).count()
print "Sentences:", session.query(Sentence).count()
In [6]:
# Schema for Minerals
from snorkel.models import candidate_subclass
Mineral = candidate_subclass('Mineral', ['name'])
In [7]:
from snorkel.candidates import Ngrams, CandidateExtractor
from snorkel.matchers import RegexMatchEach
mineral_matcher = RegexMatchEach(attrib='pos_tags', rgx="NN.*")
ngrams = Ngrams(n_max=3)
cand_extractor = CandidateExtractor(Mineral,
[ngrams], [mineral_matcher],
symmetric_relations=False)
In [8]:
# Counts number of nouns in a sentence => could be used for filtering
def number_of_nouns(sentence):
active_sequence = False
count = 0
last_tag = ''
for tag in sentence.pos_tags:
if tag.startswith('NN') and not active_sequence:
active_sequence = True
count += 1
elif not tag.startswith('NN') and active_sequence:
active_sequence = False
return count
In [9]:
from snorkel.models import Document
# load, filter and split the sentences
docs = session.query(Document).order_by(Document.name).all()
ld = len(docs)
train_sents = set()
dev_sents = set()
test_sents = set()
splits = (0.9, 0.95)
for i,doc in enumerate(docs):
for s in doc.sentences:
if number_of_nouns(s) > 0:
if i < splits[0] * ld:
train_sents.add(s)
elif i < splits[1] * ld:
dev_sents.add(s)
else:
test_sents.add(s)
In [10]:
s1 = session.query(Sentence).all()[26]
s1.pos_tags
Out[10]:
In [11]:
cand_extractor.apply(train_sents, split=0)
In [12]:
train_cands = session.query(Mineral).filter(Mineral.split == 0).all()
print "Number of candidates:", len(train_cands)
In [13]:
# inspect the candidates using this widget
from snorkel.viewer import SentenceNgramViewer
sv = SentenceNgramViewer(train_cands[:300], session)
sv
In [14]:
# Develop and Tests
## Develop and Test
for i, sents in enumerate([dev_sents, test_sents]):
cand_extractor.apply(sents, split=i+1)
print "Number of candidates:", session.query(Mineral).filter(Mineral.split == i+1).count()
In [165]:
# Distance supervision
minerals_file = "/Users/thammegr/work/mte/git/ref/minerals.txt"
non_minerals_file = "/Users/thammegr/work/mte/git/ref/non-minerals.txt"
def load_set(path, lower=True):
with codecs.open(path, 'r', 'utf-8') as f:
lines = f.readlines()
lines = map(lambda x: x.strip(), lines)
lines = filter(lambda x: x and not x.startswith('#'), lines)
if lower:
lines = map(lambda x: x.lower(), lines)
return set(lines)
mte_minerals = load_set(minerals_file)
non_minerals = load_set(non_minerals_file)
def lf_dict_mte_minerals(c):
return 1 if c.name.get_span().lower() in mte_minerals else 0
def lf_dict_nonminerals(c):
return -1 if c.name.get_span().lower() in non_minerals else 0
# rule based
def lf_rule_ite_minerals(c):
return 1 if c.name.get_span().lower().endswith('ite') else 0
# rule based 2
ends_ite = re.compile("^[a-z]*[aeiou][a-z]*ite$")
def lf_rule_ite2_minerals(c):
# has one vowel before ite
return 1 if ends_ite.match(c.name.get_span().lower()) is not None else 0
In [159]:
In [12]:
import requests
from lxml import etree
# lxml supports XPath 1.0 which doesnt have regex match function, so extending it
ns = etree.FunctionNamespace(None)
def matches(dummy, val, patrn):
if not val:
return False
return re.match(patrn, str(val[0])) is not None
ns['matches'] = matches
all_minerals_page = "https://en.wikipedia.org/wiki/List_of_minerals"
tree = etree.HTML(requests.get(all_minerals_page).text)
minerals = tree.xpath('//h2[matches(span/@id, "^[A-Z]$")]/following-sibling::*//li/a/@title')
minerals = set(map(lambda x: x.lower().strip(), minerals)) # remove duplicates
print("Found %d minerals in %s" %(len(minerals), all_minerals_page))
minerals_kb = "wikipedia-minerals.list"
with codecs.open(minerals_kb, 'w', 'utf-8') as out:
out.write(u"\n".join(minerals))
print("Stored the mineral names at %s" % minerals_kb)
In [197]:
minerals_kb = "wikipedia-minerals.list"
minerals_set = load_set(minerals_kb)
def lf_dict_wikipedia_minerals(c):
return 1 if c.name.get_span().lower() in minerals_set else 0
# returning 0 instead of -1, because the wikipedia page may not be an exhaustive list.
# TODO: check with Kiri to confirm this
In [162]:
# Debugging label functions
from pprint import pprint
labeled = []
for c in session.query(Mineral).filter(Mineral.split == 0).all():
if lf_rule_ite2_minerals(c) != 0: # function
labeled.append(c)
print "Number labeled:", len(labeled)
In [139]:
labeled[0]
Out[139]:
In [198]:
# all labeling functions in a list
LFs = [
lf_dict_mte_minerals, lf_dict_nonminerals,
lf_dict_wikipedia_minerals,
#lf_rule_ite_minerals,
lf_rule_ite2_minerals
]
In [199]:
from snorkel.annotations import LabelAnnotator
import numpy as np
labeler = LabelAnnotator(f=LFs)
In [201]:
np.random.seed(1701)
%time L_train = labeler.apply(split=0)
L_train
Out[201]:
In [202]:
# Loading it again -- resume from here
L_train = labeler.load_matrix(session, split=0)
L_train
Out[202]:
In [170]:
L_train.get_candidate(session, 0)
Out[170]:
In [171]:
L_train.get_key(session, 0)
Out[171]:
In [203]:
L_train.lf_stats(session, )
Out[203]:
In [204]:
from snorkel.learning import GenerativeModel
gen_model = GenerativeModel()
gen_model.train(L_train, epochs=500, decay=0.95, step_size=0.1/L_train.shape[0], reg_param=1e-6)
In [191]:
train_marginals = gen_model.marginals(L_train)
# visualize
import matplotlib.pyplot as plt
plt.hist(train_marginals, bins=20)
plt.show()
In [205]:
gen_model.weights.lf_accuracy()
Out[205]:
In [206]:
L_dev = labeler.apply_existing(split=1)
In [177]:
L_dev
Out[177]:
In [7]:
dev_cands = session.query(Mineral).filter(Mineral.split == 1).all()
len(dev_cands)
In [72]:
from snorkel.viewer import SentenceNgramViewer
sv = SentenceNgramViewer(dev_cands, session)
sv
In [ ]:
from snorkel.annotations import load_gold_labels
L_gold_dev = load_gold_labels(session, annotator_name=os.environ['USER'], split=1)
L_gold_dev
In [209]:
tp, fp, tn, fn = gen_model.score(session, L_dev, L_gold_dev)
In [1]:
fn
In [163]:
L_dev.lf_stats(session, L_gold_dev, gen_model.weights.lf_accuracy())
Out[163]:
In [211]:
# Save labels
from snorkel.annotations import save_marginals
%time save_marginals(session, L_train, train_marginals)
In [212]:
# generate features
from snorkel.annotations import FeatureAnnotator
featurizer = FeatureAnnotator()
%time F_train = featurizer.apply(split=0)
F_train
Out[212]:
In [213]:
%%time
F_dev = featurizer.apply_existing(split=1)
F_test = featurizer.apply_existing(split=2)
In [229]:
from snorkel.learning import SparseLogisticRegression
from snorkel.learning.utils import MentionScorer
from snorkel.learning import RandomSearch, ListParameter, RangeParameter
# our discriminative model
disc_model = SparseLogisticRegression()
#Hyper parameters search
rate_param = RangeParameter('lr', 1e-6, 1e-2, step=1, log_base=10)
l1_param = RangeParameter('l1_penalty', 1e-6, 1e-2, step=1, log_base=10)
l2_param = RangeParameter('l2_penalty', 1e-6, 1e-2, step=1, log_base=10)
searcher = RandomSearch(session, disc_model, F_train, train_marginals, [rate_param, l1_param, l2_param], n=20)
from snorkel.annotations import load_gold_labels
L_gold_dev = load_gold_labels(session, annotator_name='gold', split=1)
# fit
np.random.seed(1701)
searcher.fit(F_dev, L_gold_dev, n_epochs=50, rebalance=0.9, print_freq=25)
Out[229]:
In [228]:
#from snorkel.annotations import load_gold_labels
#L_gold_test = load_gold_labels(session, annotator_name='gold', split=2)
#_, _, _, _ = disc_model.score(session, F_test, L_gold_test)
tp, fp, tn, fn = disc_model.score(session, F_dev, L_gold_dev)
In [226]:
vars(F_dev[0])
Out[226]:
In [ ]: