In [1]:
# SQLAlchemy setup
from sqlalchemy import create_engine
from sqlalchemy.sql import func, select
from connect import mysqlusername, mysqlpassword, mysqlserver, mysqldbname
from db_tables import metadata, ConditionDescription, ConditionSynonym, ConditionLookup, ConditionBrowse

# NLP
import nltk, codecs, string, random, math, cPickle as pickle, re, datetime, pandas as pd
from collections import Counter, defaultdict
from bs4 import BeautifulSoup

# scikit-learn
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import NearestNeighbors
import numpy as np
from sklearn.metrics.pairwise import linear_kernel

from __future__ import division

sent_tokenizer=nltk.data.load('tokenizers/punkt/english.pickle')
stopset = set(nltk.corpus.stopwords.words('english'))


/Users/jost/courses/clinicaltrials/env/lib/python2.7/site-packages/numpy/core/fromnumeric.py:2507: VisibleDeprecationWarning: `rank` is deprecated; use the `ndim` attribute or function instead. To find the rank of a matrix see `numpy.linalg.matrix_rank`.
  VisibleDeprecationWarning)

Load basic data


In [2]:
corrections = {"Sarcoma, Ewing's": 'Sarcoma, Ewing',
               'Beta-Thalassemia': 'beta-Thalassemia',
               'Von Willebrand Disease, Type 3': 'von Willebrand Disease, Type 3',
               'Von Willebrand Disease, Type 2': 'von Willebrand Disease, Type 2',
               'Von Willebrand Disease, Type 1': 'von Willebrand Disease, Type 1',
               'Felty''s Syndrome': 'Felty Syndrome',
               'Von Hippel-Lindau Disease': 'von Hippel-Lindau Disease',
               'Retrognathism': 'Retrognathia',
               'Regurgitation, Gastric': 'Laryngopharyngeal Reflux',
               'Persistent Hyperinsulinemia Hypoglycemia of Infancy': 'Congenital Hyperinsulinism',
               'Von Willebrand Diseases': 'von Willebrand Diseases',
               'Pontine Glioma': 'Brain Stem Neoplasms',
               'Mental Retardation': 'Intellectual Disability',
               'Overdose': 'Drug Overdose',
               'Beta-Mannosidosis': 'beta-Mannosidosis',
               'Alpha 1-Antitrypsin Deficiency': 'alpha 1-Antitrypsin Deficiency',
               'Intervertebral Disk Displacement': 'Intervertebral Disc Displacement',
               'Alpha-Thalassemia': 'alpha-Thalassemia',
               'Mycobacterium Infections, Atypical': 'Mycobacterium Infections, Nontuberculous',
               'Legg-Perthes Disease': 'Legg-Calve-Perthes Disease',
               'Intervertebral Disk Degeneration': 'Intervertebral Disc Degeneration',
               'Alpha-Mannosidosis': 'alpha-Mannosidosis',
               'Gestational Trophoblastic Disease': 'Gestational Trophoblastic Neoplasms'
               }

cond = defaultdict(set)
cond_r = defaultdict(set)
for row in codecs.open('../data/condition_browse.txt','r','utf-8').readlines():
    row_id, trial_id, mesh_term = row.strip().split('|')
    if mesh_term in corrections: mesh_term = corrections[mesh_term]
    cond[mesh_term].add(trial_id)
    cond_r[trial_id].add(mesh_term)

mesh_codes = {}
mesh_codes_r = defaultdict(set)
for row in codecs.open('../data/mesh_thesaurus.txt','r','utf-8').readlines():
    row_id, mesh_id, mesh_term = row.strip().split('|')
    mesh_codes[mesh_id] = mesh_term
    mesh_codes_r[mesh_term].add(mesh_id)

# limiting to conditions that appear in ten or more trials
top_cond = {c for c in cond if len(cond[c]) >= 10}
trials = {t for c in top_cond for t in cond[c]}

In [4]:
trial_desc = {}
for row in codecs.open('../data/clinical_study.txt','r','utf-8').readlines():
    data = row.split('|')
    brief_desc, detail_desc = (data[9].replace('<br />',' '),
                               data[10].replace('<br />',' ') if len(data[10]) > 50 else '')
    trial_desc[data[0]] = brief_desc, detail_desc

Generate model guesses

Maximum entropy model


In [ ]:
cond_text = {c: ' '.join(' '.join(trial_desc[t]) for t in cond[c])
             for c in top_cond}

In [ ]:
tfidf = TfidfVectorizer(stop_words=stopset)
train_mat = tfidf.fit_transform(cond_text.values())
apply_mat = tfidf.transform(' '.join(trial_desc[t]) for t in trial_desc.keys())

In [ ]:
model = LogisticRegression()
model.fit(train_mat,cond_text.keys())
maxent_preds = dict(zip(trial_desc.keys(),model.predict(apply_mat)))

In [ ]:
pickle.dump(maxent_preds,open('../data/mesh_maxent.pkl','wb'))

K-Nearest Neighbors model


In [ ]:
trial_text = {t: ' '.join(trial_desc[t])
              for t in trials 
              if len(trial_desc[t][0] + trial_desc[t][1]) > 50}

In [ ]:
other_text = {t: ' '.join(trial_desc[t])
              for t in trial_desc
              if t not in trials and len(trial_desc[t][0] + trial_desc[t][1]) > 50}

In [ ]:
tfidf = TfidfVectorizer(stop_words=stopset)
train_mat = tfidf.fit_transform(trial_text.values())
apply_mat = tfidf.transform(other_text.values())

In [ ]:
pickle.dump(tfidf,open('../data/tfidf_model.pkl','wb'))
pickle.dump(train_mat,open('../data/tfidf_matrix_alldesc.pkl','wb'))

In [ ]:
neigh = NearestNeighbors(n_neighbors=10,radius=5)
neigh.fit(train_mat)

In [ ]:
# knn_preds = {}
knn_preds = pickle.load(open('../data/mesh_knn.pkl','rb'))
t_keys = trial_text.keys()

In [ ]:
for i in range(len(t_keys)):
    trial_id = t_keys[i]
    if trial_id not in knn_preds:
        dist, idx = (arr.flatten() for arr in neigh.kneighbors(train_mat[i]))

        this_guess = defaultdict(float)
        for j in range(len(idx)):
            k_trial_id = t_keys[idx[j]]
            if k_trial_id != trial_id:
                for mterm in cond_r[k_trial_id]:
                    this_guess[mterm] += 1 / (10 ** dist[j])
        
        knn_preds[trial_id] = this_guess
    
    if i % 100 == 0: print i, datetime.datetime.now().time()

In [ ]:
#knn_preds = {}
# knn_preds = pickle.load(open('../data/mesh_knn.pkl','rb'))
t_keys = other_text.keys()
o_keys = trial_text.keys()

In [ ]:
for i in range(len(t_keys)):
    trial_id = t_keys[i]
    if trial_id not in knn_preds:
        dist, idx = (arr.flatten() for arr in neigh.kneighbors(apply_mat[i]))

        this_guess = defaultdict(float)
        for j in range(len(idx)):
            k_trial_id = o_keys[idx[j]]
            if k_trial_id != trial_id:
                for mterm in cond_r[k_trial_id]:
                    this_guess[mterm] += 1 / (10 ** dist[j])
        
        knn_preds[trial_id] = this_guess
    
    if i % 100 == 0: print i, datetime.datetime.now().time()

In [ ]:
pickle.dump(knn_preds,open('../data/mesh_knn.pkl','wb'))

In [ ]:
%matplotlib inline
knn_accuracy = {}
for trial_id in list(set(knn_preds.keys()) & set(cond_r.keys()) & trials):
    # initialize variables for this prediction
    accuracy = {'exact': 10000,
                'hypernym': {}}
    this_pred = knn_preds[trial_id]
    val_order = {j: i for i, j in enumerate(sorted(list(set(this_pred.values())), reverse=True))}

    hyp_pred = {m: val_order[v]
                for p, v in this_pred.items() 
                if p in mesh_codes_r 
                for m in mesh_codes_r[p]}
    
    # loop through known MeSH terms to look for greatest overlap
    for m in cond_r[trial_id]:
        if m in this_pred:
            this_rank = val_order[this_pred[m]]
            if this_rank < accuracy['exact']:
                accuracy['exact'] = this_rank
        elif m in mesh_codes_r:
            for a in mesh_codes_r[m]:
                len_a = len(a)
                for i in range(len_a,0,-4):
                    for pa in hyp_pred:
                        if pa[:i] == a[:i]:
                            if i not in accuracy['hypernym'] or hyp_pred[pa] < accuracy['hypernym'][i]:
                                accuracy['hypernym'][i] = hyp_pred[pa]

    if accuracy['exact'] == 10000: accuracy['exact'] = None
    
    knn_accuracy[trial_id] = accuracy

c = Counter([knn_accuracy[t]['exact']+1 
             if knn_accuracy[t]['exact'] is not None
             else 'No match'
             for t in knn_accuracy.keys()])
print sum(c.values()), len(knn_accuracy)

ax = pd.DataFrame(c.values(), index=c.keys()).plot(kind='bar',
                                                   figsize=(8,8),
                                                   legend=False)

ax.set_xlabel("Highest rank of nearest neighbor prediction exact match")
ax.set_ylabel("Number of trials")
ax.set_title("Evaluating KNN predictions using manually assigned MeSH terms:\nHighest ranking match of a known term")

In [ ]:
# finding good threshold for suggestions
# based on these, going to suggest anything with a distance >= 0.4 (higher is closer)
rank_accuracy = defaultdict(list)
dist_accuracy = defaultdict(list)
tr = [c for c in cond_r if len(cond_r[c]) >= 5 and c in knn_preds]
for trial_id in tr:
    # initialize variables for this prediction
    accuracy = {'exact': 10000,
                'hypernym': {}}
    this_pred = knn_preds[trial_id]
    this_rank = defaultdict(list)
    this_dist = defaultdict(list)
    val_order = {j: i for i, j in enumerate(sorted(list(set(this_pred.values())), reverse=True))}
    
    for p in this_pred:
        val_round = round(this_pred[p], 2)
        val_rank = val_order[this_pred[p]]
        this_dist[val_round].append(1 if p in cond_r[trial_id] else 0)
        this_rank[val_rank].append(1 if p in cond_r[trial_id] else 0)
    
    for v in this_dist:
        dist_accuracy[v].append(1 if sum(this_dist[v]) > 0 else 0)
    for v in this_rank:
        rank_accuracy[v].append(1 if sum(this_rank[v]) > 0 else 0)
    
r = sorted([(r, sum(rank_accuracy[r]) / len(rank_accuracy[r])) for r in rank_accuracy], key=lambda x: x[0])

ax = pd.DataFrame([t[1] for t in r], index=[t[0] for t in r]).plot(kind='bar',
                                                   figsize=(12,8),
                                                   legend=False)

d = sorted([(r, sum(dist_accuracy[r]) / len(dist_accuracy[r])) for r in dist_accuracy], key=lambda x: x[0])

ax = pd.DataFrame([t[1] for t in d], index=[t[0] for t in d]).plot(kind='bar',
                                                   figsize=(40,8),
                                                   legend=False)

Testing certain terms


In [ ]:
def print_guesses(nctid):
    these_guess = sorted(knn_preds[nctid].items(), key=lambda x: x[1], reverse=True)
    for t, c in these_guess:
        print '%s (%g)' % (t, c)

In [ ]:
term = 'Prostatic Neoplasms'

In [ ]:
g = [k for k in knn_preds 
     if sorted(knn_preds[k].items(), key=lambda x: x[1], reverse=True)[0][0] == term
        and term not in cond_r[k]]
print len(g)

In [ ]:
#t = random.choice(g)
t = 'NCT00487786'
print t
print cond_r[t]
print
print trial_text[t]
print
print_guesses(t)

In [ ]:
# good one for prostate cancer: NCT00487786

Alternate using SVD for faster neighbors lookup


In [ ]:
from sklearn.decomposition import TruncatedSVD
from sklearn.neighbors import BallTree
tfidf = TfidfVectorizer(stop_words=stopset)
train_mat = tfidf.fit_transform(trial_text.values())

In [ ]:
svd = TruncatedSVD(n_components=1000)
tsvd = svd.fit_transform(train_mat)
train_mat_b = BallTree(tsvd)
neigh = NearestNeighbors(n_neighbors=10,radius=5)
neigh.fit(train_mat_b)
knn_preds2 = {}
t_keys = trial_text.keys()
for i in range(1000):
    trial_id = t_keys[i]
    if trial_id not in knn_preds2:
        dist, idx = (arr.flatten() for arr in neigh.kneighbors(tsvd[i]))

        this_guess = defaultdict(float)
        for j in range(len(idx)):
            k_trial_id = t_keys[idx[j]]
            if k_trial_id != trial_id:
                for mterm in cond_r[k_trial_id]:
                    this_guess[mterm] += 1 / (10 ** dist[j])
        
        knn_preds2[trial_id] = this_guess
    
    if i % 100 == 0: print i, datetime.datetime.now().time()

Processing Medline topics and thesaurus information


In [ ]:
soup = BeautifulSoup(codecs.open('../data/mplus_topics_2014-11-04.xml','r','utf-8').read())

In [ ]:
# synonyms for MeSH terms (and reverse), and topic descriptions
mesh_syn = defaultdict(set)
topic_desc = {}

# loop through topics to pull out descriptions and synonyms
for t in soup.find_all("health-topic",language="English"):
    # topic summary
    topic_desc[t.attrs["title"]] = t.find("full-summary").text.replace('\n','').replace('\t','')
    
    # MeSH synonyms
    cur_mesh = t.find("mesh-heading").descriptor.text
    if cur_mesh in cond:
        mesh_syn[cur_mesh] |=  set([t.attrs["title"]] + [a.text for a in t.find_all("also-called")])

In [ ]:
# cleanup synonyms lookup dictionary
for m in mesh_syn.keys():
    cur_set = mesh_syn[m].copy()
    for s in mesh_syn[m]:
        if m.lower() == s.lower() or len(s) == 1: 
            cur_set -= set([s])
    if len(cur_set) == 0:
        del(mesh_syn[m])
    else:
        mesh_syn[m] = cur_set

for m in mesh_syn.keys():
    for s in mesh_syn[m]:
        if s in cond:
            mesh_syn[s].add(m)
            mesh_syn[s] |= mesh_syn[m]

In [ ]:
# create a single MeSH term to represent the description
king_mesh = defaultdict(set)
not_king = set()
for m in mesh_syn.keys():
    if cond[m] and m not in not_king:
        all_terms = set([m])
        all_terms |= mesh_syn[m]
        all_keys = [s for s in mesh_syn if m in mesh_syn[s]]
        for k in all_keys:
            all_terms.add(k)
            all_terms |= mesh_syn[k]
        top_term = sorted([(t, len(cond[t])) for t in all_terms], key=lambda x: x[1], reverse=True)[0][0]
        king_mesh[top_term] = all_terms
        not_king.update([t for t in all_terms if cond[t] and t != top_term])

In [ ]:
# create canonical description dictionary, linked to MeSH term
descriptions = {}
cond_lower = {c.lower(): c for c in king_mesh}
for t in topic_desc:
    
    tlow = t.lower()
    
    if tlow in cond_lower:
        descriptions[cond_lower[tlow]] = topic_desc[t]
    
    poss_match = [m for m in king_mesh if tlow in [n.lower() for n in king_mesh[m]]]
    for p in poss_match:
        descriptions[p] = topic_desc[t]

In [ ]:
# create reverse lookup to canonical MeSH term
king_mesh_r = {s: k for k in king_mesh for s in king_mesh[k]}

In [ ]:
print len(king_mesh)
print len(king_mesh_r.values())
print len(set(king_mesh_r.values()))
print len(descriptions.keys())
print len(set(king_mesh.keys()) & set(descriptions.keys()))

In [ ]:
pickle.dump(descriptions,open('../data/mesh_descriptions.pkl','wb'))
pickle.dump(king_mesh_r,open('../data/mesh_synonyms.pkl','wb'))

Loading MTI suggestions for MeSH terms


In [3]:
trial_lookup = defaultdict(set)
for row in codecs.open('../data/MTIdescriptions_output.txt','r','utf-8').readlines():
    fields = row.strip().split('|')
    if len(fields) == 9 and fields[4] == 'MH':
        nct_id = fields[0]
        mesh_term = fields[1]
        if mesh_term[0] == '*': mesh_term = mesh_term[1:]
        if mesh_term in cond and mesh_term not in cond_r[nct_id]:
            trial_lookup[nct_id].add(mesh_term)

Writing data to MySQL

Load guess and description dictionaries (from processing below)


In [5]:
maxent_preds = pickle.load(open('../data/mesh_maxent.pkl','rb'))
knn_preds = pickle.load(open('../data/mesh_knn.pkl','rb'))
descriptions = pickle.load(open('../data/mesh_descriptions.pkl','rb'))
synonyms = pickle.load(open('../data/mesh_synonyms.pkl','rb'))

In [7]:
no_include = ['Cancer in Children',
              'Diabetes in Children and Teens',
              'Heart Disease in Women',
              'HIV/AIDS in Women',
              'Living with HIV/AIDS',
              'Cancer--Living with Cancer',
              'Teen smoking',
              'Smoking and Youth',
              'Teenage drinking',
              'Underage Drinking']

In [28]:
redo = dict([(c, d) for c in cond.keys() for d in synonyms.keys() if c.lower() == d.lower() and c != d])
for c, d in redo.items():
    synonyms[c] = synonyms[d]
    del synonyms[d]

Set up MySQL connection


In [43]:
mysqlserver = 'localhost'
engine = create_engine('mysql://%s:%s@%s/%s' % (mysqlusername, mysqlpassword, mysqlserver, mysqldbname), pool_recycle=3600)
conn = engine.connect()
metadata.create_all(engine)

Upload canonical condition description and synonym data


In [34]:
# add conditions that aren't already in descriptions dictionary
for c in cond:
    if c not in descriptions:
        if c in synonyms:
            descriptions[c] = descriptions[synonyms[c]]
        else:
            descriptions[c] = ''

for s in synonyms:
    if s not in descriptions and s not in no_include:
        descriptions[s] = descriptions[synonyms[s]]

In [41]:
desc_id = {k: i for i, k in enumerate(descriptions.keys())}

In [44]:
conn.execute(ConditionDescription.insert(), [{'condition_id': desc_id[d],
                                              'mesh_term': d,
                                              'description': descriptions[d]}
                                             for d in descriptions.keys()])


Out[44]:
<sqlalchemy.engine.result.ResultProxy at 0x11902fa50>

In [46]:
# create lookup for every synonym
syn_rev = {k: set(s for s, j in synonyms.items() if j == k and s not in no_include) for k in set(synonyms.values())}
all_syn = [{'condition_id': desc_id[s],
            'synonym_id': desc_id[k]}
           for s in synonyms
           for k in syn_rev[synonyms[s]]
           if s not in no_include and s != k]

In [47]:
conn.execute(ConditionSynonym.insert(), all_syn)


Out[47]:
<sqlalchemy.engine.result.ResultProxy at 0x13a8bc410>

Insert conditions already labeled in the database


In [48]:
for k in range(0,len(cond),100):
    print k
    conn.execute(ConditionLookup.insert(), [{'condition_id': desc_id[c],
                                             'nct_id': n,
                                             'source': 'CTGOV',
                                             'syn_flag': 0}
                                            for c in cond.keys()[k:k+100]
                                            for n in cond[c]])


0
100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
1700
1800
1900
2000
2100
2200
2300
2400
2500
2600
2700
2800
2900
3000
3100
3200
3300

Insert maximum entropy predictions


In [49]:
for k in range(0,len(maxent_preds),5000):
    print k
    conn.execute(ConditionLookup.insert(), [{'condition_id': desc_id[t],
                                             'nct_id': n,
                                             'source': 'MAXENT',
                                             'syn_flag': 0}
                                            for n, t in maxent_preds.items()[k:k+5000]
                                            if t not in cond_r[n]])


0
5000
10000
15000
20000
25000
30000
35000
40000
45000
50000
55000
60000
65000
70000
75000
80000
85000
90000
95000
100000
105000
110000
115000
120000
125000
130000
135000
140000
145000
150000
155000
160000

Insert KNN predictions


In [50]:
knn_insert = []
for k in knn_preds:
    this_pred = knn_preds[k]
    val_order = {j: i for i, j in enumerate(sorted(list(set(this_pred.values())), reverse=True))}
    for m in this_pred:
        if m not in cond_r[k] and (this_pred[m] >= 0.4 or val_order[this_pred[m]] == 0):
            knn_insert.append({'condition_id': desc_id[m],
                               'nct_id': k,
                               'source': 'KNN',
                               'disp_order': val_order[this_pred[m]],
                               'syn_flag': 0})

In [51]:
for k in range(0,len(knn_insert),5000):
    print k
    conn.execute(ConditionLookup.insert(), knn_insert[k:k+5000])


0
5000
10000
15000
20000
25000
30000
35000
40000
45000
50000
55000
60000
65000
70000
75000
80000
85000
90000
95000
100000
105000

Insert MTI suggestions


In [52]:
mti_insert = [{'condition_id': desc_id[c],
                 'nct_id': t,
                 'source': 'MTI',
                 'syn_flag': 0}
                for t, n in trial_lookup.items()
                for c in n
                if c not in cond_r[t]]

In [53]:
for k in range(0,len(mti_insert),10000):
    print k
    conn.execute(ConditionLookup.insert(), mti_insert[k:k+10000])


0
10000
20000
30000
40000
50000
60000
70000
80000
90000
100000
110000
120000
130000

Insert synonyms


In [54]:
for s, k in synonyms.items():
    
    if s != k and s not in no_include:
        
        k_trials = {t[0] for t in conn.execute('''select nct_id 
                                                  from condition_lookup 
                                                  where condition_id = %d
                                                    and source = "CTGOV"''' % desc_id[k]).fetchall()}
        s_trials = {t[0] for t in conn.execute('''select nct_id 
                                                  from condition_lookup
                                                  where condition_id = %d
                                                    and source = "CTGOV"''' % desc_id[s]).fetchall()}
        
        s_insert = [{'condition_id': desc_id[s],
                     'nct_id': t,
                     'source': 'CTGOV',
                     'syn_flag': 1}
                    for t in k_trials - s_trials]
        
        if cond[s]:
            k_insert = [{'condition_id': desc_id[k],
                         'nct_id': t,
                         'source': 'CTGOV',
                         'syn_flag': 1}
                        for t in s_trials - k_trials]
        else:
            k_insert = []
        
        conn.execute(ConditionLookup.insert(), s_insert + k_insert)

In [55]:
conn.close()

In [ ]: