Load modules


In [1]:
# basic NLP
import nltk, codecs, string, random, math, cPickle as pickle, re, datetime
from collections import Counter

# scikit-learn
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
import numpy as np
from sklearn.metrics.pairwise import linear_kernel

from __future__ import division

sent_tokenizer=nltk.data.load('tokenizers/punkt/english.pickle')
stopset = set(nltk.corpus.stopwords.words('english'))

Load data


In [2]:
corrections = {"Sarcoma, Ewing's": 'Sarcoma, Ewing',
               'Beta-Thalassemia': 'beta-Thalassemia',
               'Von Willebrand Disease, Type 3': 'von Willebrand Disease, Type 3',
               'Von Willebrand Disease, Type 2': 'von Willebrand Disease, Type 2',
               'Von Willebrand Disease, Type 1': 'von Willebrand Disease, Type 1',
               'Felty''s Syndrome': 'Felty Syndrome',
               'Von Hippel-Lindau Disease': 'von Hippel-Lindau Disease',
               'Retrognathism': 'Retrognathia',
               'Regurgitation, Gastric': 'Laryngopharyngeal Reflux',
               'Persistent Hyperinsulinemia Hypoglycemia of Infancy': 'Congenital Hyperinsulinism',
               'Von Willebrand Diseases': 'von Willebrand Diseases',
               'Pontine Glioma': 'Brain Stem Neoplasms',
               'Mental Retardation': 'Intellectual Disability',
               'Overdose': 'Drug Overdose',
               'Beta-Mannosidosis': 'beta-Mannosidosis',
               'Alpha 1-Antitrypsin Deficiency': 'alpha 1-Antitrypsin Deficiency',
               'Intervertebral Disk Displacement': 'Intervertebral Disc Displacement',
               'Alpha-Thalassemia': 'alpha-Thalassemia',
               'Mycobacterium Infections, Atypical': 'Mycobacterium Infections, Nontuberculous',
               'Legg-Perthes Disease': 'Legg-Calve-Perthes Disease',
               'Intervertebral Disk Degeneration': 'Intervertebral Disc Degeneration',
               'Alpha-Mannosidosis': 'alpha-Mannosidosis',
               'Gestational Trophoblastic Disease': 'Gestational Trophoblastic Neoplasms'
               }
cond = {}
cond_r = {}
for row in codecs.open('../data/condition_browse.txt','r','utf-8').readlines():
    row_id, trial_id, mesh_term = row.strip().split('|')
    if mesh_term in corrections: mesh_term = corrections[mesh_term]
    if mesh_term not in cond: cond[mesh_term] = []
    cond[mesh_term].append(trial_id)
    if trial_id not in cond_r: cond_r[trial_id] = []
    cond_r[trial_id].append(mesh_term)

mesh_codes = {}
mesh_codes_r = {}
for row in codecs.open('../data/mesh_thesaurus.txt','r','utf-8').readlines():
    row_id, mesh_id, mesh_term = row.strip().split('|')
    mesh_codes[mesh_id] = mesh_term
    if mesh_term not in mesh_codes_r: mesh_codes_r[mesh_term] = []
    mesh_codes_r[mesh_term].append(mesh_id)

# limiting to conditions that appear in ten or more trials
top_cond = {c for c in cond if len(cond[c]) >= 10}
trials = {t for c in top_cond for t in cond[c]}

In [ ]:
trial_desc = {}
for row in codecs.open('../data/clinical_study.txt','r','utf-8').readlines():
    data = row.split('|')
    brief_desc, detail_desc = (data[9],
                               data[10] if len(data[10]) > 50 else '')
    trial_desc[data[0]] = brief_desc, detail_desc

to_classify = [t for t in trial_desc if t not in trials]

In [ ]:
pickle.dump(trial_desc,open('../data/trial_desc.pkl','wb'))

In [3]:
trial_desc = pickle.load(open('../data/trial_desc.pkl','rb'))
to_classify = set([t for t in trial_desc if t not in trials] + random.sample(list(trials), int(len(trials) / 10)))

In [4]:
pickle.dump(to_classify,open('../data/to_classify_holdout.pkl','wb'))

In [5]:
trial_desc = pickle.load(open('../data/trial_desc.pkl','rb'))
to_classify = pickle.load(open('../data/to_classify_holdout.pkl','rb'))

Analyze data


In [ ]:
print 'Total MeSH terms: %d' % len(cond)
print 'Total MeSH terms (level 1): %d' % len([mesh_codes[m] for m in set([mr[:3] for c in cond  if c in mesh_codes_r for mr in mesh_codes_r[c]])])
print 'Total MeSH terms (level 2): %d' % len([mesh_codes[m] for m in set([mr[:7] for c in cond  if c in mesh_codes_r for mr in mesh_codes_r[c]])])

Create trial lookup for MeSH term hypernyms in the second level of the hierarchy


In [6]:
cond_l2 = {}
for m in cond.keys():
    if m in mesh_codes_r:
        m_l2 = set([mr[:7] for mr in mesh_codes_r[m]])
        for l2 in m_l2:
            if l2 not in cond_l2: cond_l2[l2] = set()
            cond_l2[l2] |= (set(cond[m]) - to_classify)

Process text


In [6]:
def process_text(text):
    return [word.lower() 
            for sent in sent_tokenizer.tokenize(text) 
            for word in nltk.word_tokenize(sent)
            if word.lower() not in stopset and
            sum(1 for char in word if char not in string.punctuation) > 0]

In [7]:
cond_text = {cond: Counter([word
                            for trial_id in cond_l2[cond] 
                            for desc in trial_desc[trial_id]
                            if len(desc) > 0
                            for word in process_text(desc)])
             for cond in cond_l2.keys()}

In [8]:
total_text = sum(cond_text.values(),Counter())

In [9]:
pickle.dump(cond_text,open('../data/mesh_level2_textcount_holdout.pkl','wb'))
pickle.dump(total_text,open('../data/mesh_level2_alltextcount_holdout.pkl','wb'))

Building series of individual level-2 MeSH classifiers


In [10]:
# initializing values
mesh_models = {}

total_text_keys, total_text_values = zip(*[(k, v)
                                           for k, v in total_text.items() 
                                           if len(k) > 2 and sum([1 
                                                                  for char in k 
                                                                  if char not in '1234567890']) > 0])

other_text_len = sum(total_text_values)

In [ ]:
i = len(mesh_models) + 1

for c in cond_text.keys():
    if c not in mesh_models and len(c) > 3:
        # get total number of words for that term and for everything else that isn't that term
        cond_text_len = sum([v 
                             for k, v in cond_text[c].items() 
                             if len(k) > 2 and sum([1 
                                                    for char in k 
                                                    if char not in '1234567890']) > 0])
        cur_other_text_len = other_text_len - cond_text_len
        
        # create set of tuples (term % of target MeSH descriptor text, term % of other MeSH descriptor text)
        if cond_text_len > 0:
            vecs = [(cond_text[c][t] / cond_text_len, (total_text[t] - cond_text[c][t]) / cur_other_text_len)
                    for t in total_text.keys()
                    if len(t) > 2 and sum([1
                                           for char in t
                                           if char not in '1234567890']) > 0]

            # fit logistic model
            model = LogisticRegression()
            mesh_models[c] = model.fit(zip(*vecs),[1,0])

        print '%-3d %s (%s)' % (i, c, mesh_codes[c])
        i += 1

In [12]:
pickle.dump(mesh_models,open('../data/mesh_models_series_holdout.pkl','wb'))

Applying models to each unclassified trial


In [13]:
classify_text = {trial_id: Counter([word
                                    for desc in trial_desc[trial_id]
                                    if len(desc) > 0
                                    for word in process_text(desc)])
                 for trial_id in to_classify}

In [72]:
guesses = {}
total_text_keys, total_text_values = zip(*[(k, v)
                                           for k, v in total_text.items() 
                                           if len(k) > 2 and sum([1 
                                                                  for char in k 
                                                                  if char not in '1234567890']) > 0])

other_text_len = sum(total_text_values)

In [ ]:
i = len(guesses) + 1

for c in classify_text.keys():
    if c not in guesses:
        text_len = sum([v
                        for k, v in classify_text[c].items()
                        if len(k) > 2 and sum([1
                                               for char in k
                                               if char not in '1234567890']) > 0])
        
        if text_len > 0:
            # create set of tuples (term % of target descriptor text, term % of other MeSH descriptor text)
            vecs = [classify_text[c][t] / text_len
                    for t in total_text.keys()
                    if len(t) > 2 and sum([1
                                           for char in t
                                           if char not in '1234567890']) > 0]

            # predict logistic models
            predictions = {}
            for term, model in mesh_models.items():
                predictions[term] = model.predict_proba(vecs)[0][1]

            guesses[c] = predictions

        i += 1
        if i % 10 == 0: print i

In [128]:
pickle.dump(guesses,open('../data/mesh_guesses.pkl','wb'))

Single-prediction maxent classifier


In [14]:
cond_text = {c: ' '.join(' '.join(trial_desc[t]) for t in cond[c] if t not in to_classify)
             for c in top_cond}

In [20]:
tfidf = TfidfVectorizer(stop_words=stopset)
train_mat = tfidf.fit_transform(cond_text.values())
apply_mat = tfidf.transform(' '.join(trial_desc[t]) for t in to_classify)

In [21]:
model = LogisticRegression()
model.fit(train_mat,cond_text.keys())
single_preds = dict(zip(to_classify,model.predict(apply_mat)))

In [22]:
pickle.dump(single_preds,open('../data/mesh_guesses_maxent_holdout.pkl','wb'))

K Nearest Neighbors suggestions


In [23]:
trial_text = {t: ' '.join(trial_desc[t])
              for t in trials 
              if len(trial_desc[t][0] + trial_desc[t][1]) > 50
              and t not in to_classify}
trial_text_other = {t: ' '.join(trial_desc[t]) 
                    for t in to_classify
                    if len(trial_desc[t][0] + trial_desc[t][1]) > 50}

In [26]:
tfidf = TfidfVectorizer(stop_words=stopset)
train_mat = tfidf.fit_transform(trial_text.values())
apply_mat = tfidf.transform(trial_text_other.values())

In [27]:
from sklearn.neighbors import NearestNeighbors
neigh = NearestNeighbors(n_neighbors=10,radius=5)
neigh.fit(train_mat)


Out[27]:
NearestNeighbors(algorithm='auto', leaf_size=30, metric='minkowski',
         n_neighbors=10, radius=5)

In [28]:
knn_guesses = {}

In [31]:
for i in range(len(trial_text_other.keys())):
    trial_id = trial_text_other.keys()[i]
    if trial_id not in knn_guesses:
        dist, idx = (arr.flatten() for arr in neigh.kneighbors(apply_mat[i]))

        this_guess = {}
        for j in range(len(idx)):
            k_trial_id = trial_text.keys()[idx[j]]
            for mterm in cond_r[k_trial_id]:
                if mterm not in this_guess: this_guess[mterm] = []
                this_guess[mterm].append(dist[j])

        knn_guesses[trial_id] = this_guess
    if i % 100 == 0: print i, datetime.datetime.now().time()


0 05:27:18.711195
100 05:27:18.966047
200 05:27:19.215825
300 05:27:19.499565
400 05:27:19.754160
500 05:27:20.009247
600 05:27:20.256201
700 05:27:20.500066
800 05:27:20.749301
900 05:27:20.995688
1000 05:27:21.239368
1100 05:27:21.511133
1200 05:27:21.773008
1300 05:27:22.016315
1400 05:27:22.272214
1500 05:27:22.533980
1600 05:27:22.770701
1700 05:27:23.008513
1800 05:27:23.259398
1900 05:27:23.504050
2000 05:27:23.751124
2100 05:27:24.013013
2200 05:27:24.268502
2300 05:27:24.526612
2400 05:27:24.772590
2500 05:27:25.020393
2600 05:27:25.266119
2700 05:27:25.509122
2800 05:27:25.754282
2900 05:27:26.004016
3000 05:27:26.247388
3100 05:27:26.494267
3200 05:27:26.740995
3300 05:27:26.986402
3400 05:27:27.232738
3500 05:27:27.482695
3600 05:27:27.733016
3700 05:27:27.978853
3800 05:27:28.225843
3900 05:27:28.470189
4000 05:27:28.714924
4100 05:27:28.961223
4200 05:27:29.205330
4300 05:27:29.451466
4400 05:27:29.694215
4500 05:27:29.938433
4600 05:27:30.183147
4700 05:27:30.424045
4800 05:27:30.673396
4900 05:27:30.921584
5000 05:27:31.163341
5100 05:27:31.413067
5200 05:27:31.654799
5300 05:27:31.899744
5400 05:27:32.143521
5500 05:27:32.393813
5600 05:27:32.639612
5700 05:27:32.879737
5800 05:27:33.127413
5900 05:27:33.392298
6000 05:27:33.656081
6100 05:27:33.916482
6200 05:27:34.179866
6300 05:27:34.439077
6400 05:27:34.689786
6500 05:27:34.932139
6600 05:27:35.184941
6700 05:27:35.428691
6800 05:27:35.678166
6900 05:27:35.920819
7000 05:27:36.160831
7100 05:27:36.402266
7200 05:27:36.650364
7300 05:27:36.895916
7400 05:27:37.149524
7500 05:27:37.398694
7600 05:27:37.648045
7700 05:27:37.895797
7800 05:27:38.145107
7900 05:27:38.389890
8000 05:27:38.631912
8100 05:27:38.881371
8200 05:27:39.125964
8300 05:27:39.372792
8400 05:27:39.621777
8500 05:27:39.859385
8600 05:27:40.107876
8700 05:27:40.355762
8800 05:27:40.599450
8900 05:27:40.844036
9000 05:27:41.093755
9100 05:27:41.333526
9200 05:27:41.572202
9300 05:27:41.821783
9400 05:27:42.070956
9500 05:27:42.316778
9600 05:27:42.558809
9700 05:27:42.801014
9800 05:27:43.044367
9900 05:27:43.292681
10000 05:27:43.534545
10100 05:27:43.777531
10200 05:27:44.026056
10300 05:27:44.265046
10400 05:27:44.504318
10500 05:27:44.745199
10600 05:27:44.987004
10700 05:27:45.248214
10800 05:27:45.509959
10900 05:27:45.774124
11000 05:27:46.025622
11100 05:27:46.271037
11200 05:27:46.515397
11300 05:27:46.766083
11400 05:27:47.011850
11500 05:27:47.255060
11600 05:27:47.506464
11700 05:27:47.755658
11800 05:27:48.003335
11900 05:27:48.251713
12000 05:27:48.499846
12100 05:27:48.751302
12200 05:27:49.000248
12300 05:27:49.247699
12400 05:27:49.493352
12500 05:27:49.737488
12600 05:27:50.005100
12700 05:27:50.263207
12800 05:27:50.505860
12900 05:27:50.755362
13000 05:27:50.995800
13100 05:27:51.244792
13200 05:27:51.487862
13300 05:27:51.732649
13400 05:27:51.977635
13500 05:27:52.221974
13600 05:27:52.468331
13700 05:27:52.720295
13800 05:27:52.964626
13900 05:27:53.198264
14000 05:27:53.440351
14100 05:27:53.685263
14200 05:27:53.934725
14300 05:27:54.178219
14400 05:27:54.421882
14500 05:27:54.669551
14600 05:27:54.915684
14700 05:27:55.157147
14800 05:27:55.400177
14900 05:27:55.663291
15000 05:27:55.926952
15100 05:27:56.184938
15200 05:27:56.426501
15300 05:27:56.668974
15400 05:27:56.915577
15500 05:27:57.158558
15600 05:27:57.404645
15700 05:27:57.652247
15800 05:27:57.907114
15900 05:27:58.169760
16000 05:27:58.436906
16100 05:27:58.683777
16200 05:27:58.921562
16300 05:27:59.163547
16400 05:27:59.409529
16500 05:27:59.687910
16600 05:27:59.946639
16700 05:28:00.199081
16800 05:28:00.445671
16900 05:28:00.695833
17000 05:28:00.944703
17100 05:28:01.189896
17200 05:28:01.439242
17300 05:28:01.705971
17400 05:28:01.959028
17500 05:28:02.203006
17600 05:28:02.454794
17700 05:28:02.705227
17800 05:28:02.955539
17900 05:28:03.193774
18000 05:28:03.444833
18100 05:28:03.693232
18200 05:28:03.946123
18300 05:28:04.194691
18400 05:28:04.443330
18500 05:28:04.692771
18600 05:28:04.932343
18700 05:28:05.182879
18800 05:28:05.431103
18900 05:28:05.684437
19000 05:28:05.932202
19100 05:28:06.182846
19200 05:28:06.432126
19300 05:28:06.681156
19400 05:28:06.930695
19500 05:28:07.181984
19600 05:28:07.429499
19700 05:28:07.681597
19800 05:28:07.932905
19900 05:28:08.184812
20000 05:28:08.431189
20100 05:28:08.678306
20200 05:28:08.926331
20300 05:28:09.178777
20400 05:28:09.433382
20500 05:28:09.683441
20600 05:28:09.929735
20700 05:28:10.176096
20800 05:28:10.411647
20900 05:28:10.665150
21000 05:28:10.925969
21100 05:28:11.187667
21200 05:28:11.431138
21300 05:28:11.678191
21400 05:28:11.923430
21500 05:28:12.160124
21600 05:28:12.401423
21700 05:28:12.641449
21800 05:28:12.879405
21900 05:28:13.124590
22000 05:28:13.366333
22100 05:28:13.611392
22200 05:28:13.856799
22300 05:28:14.105343
22400 05:28:14.348185
22500 05:28:14.583596
22600 05:28:14.823623
22700 05:28:15.074253
22800 05:28:15.328680
22900 05:28:15.588679
23000 05:28:15.835379
23100 05:28:16.078265
23200 05:28:16.317990
23300 05:28:16.558804
23400 05:28:16.803671
23500 05:28:17.045883
23600 05:28:17.296817
23700 05:28:17.551332
23800 05:28:17.816931
23900 05:28:18.057675
24000 05:28:18.302728
24100 05:28:18.543288
24200 05:28:18.786188
24300 05:28:19.029950
24400 05:28:19.271092
24500 05:28:19.513474
24600 05:28:19.756668
24700 05:28:19.999712
24800 05:28:20.243238
24900 05:28:20.492558
25000 05:28:20.736634
25100 05:28:20.985612
25200 05:28:21.229570
25300 05:28:21.475906
25400 05:28:21.728573
25500 05:28:21.987236
25600 05:28:22.244460
25700 05:28:22.490614
25800 05:28:22.736014
25900 05:28:22.978040
26000 05:28:23.224748
26100 05:28:23.471708
26200 05:28:23.714565
26300 05:28:23.952289
26400 05:28:24.190568
26500 05:28:24.467453
26600 05:28:24.726417
26700 05:28:24.980492
26800 05:28:25.240309
26900 05:28:25.491187
27000 05:28:25.754998
27100 05:28:26.010998
27200 05:28:26.256312
27300 05:28:26.511035
27400 05:28:26.766607
27500 05:28:27.022591
27600 05:28:27.266280
27700 05:28:27.520224
27800 05:28:27.768151
27900 05:28:28.029194
28000 05:28:28.280594
28100 05:28:28.529840
28200 05:28:28.780836
28300 05:28:29.033910
28400 05:28:29.276701
28500 05:28:29.520951
28600 05:28:29.766553
28700 05:28:30.017249
28800 05:28:30.258047
28900 05:28:30.503939
29000 05:28:30.752852
29100 05:28:31.004825
29200 05:28:31.255619
29300 05:28:31.509076
29400 05:28:31.761871
29500 05:28:32.008453
29600 05:29:39.840256
29700 05:31:23.457453
29800 05:32:59.615974
29900 05:35:00.719232
30000 05:37:08.678325
30100 05:39:16.402851
30200 05:41:18.870398
30300 05:43:20.173787
30400 05:45:21.294188
30500 05:47:21.698030
30600 05:49:21.888629
30700 05:51:22.272438
30800 05:53:22.809469
30900 05:55:22.949318
31000 05:57:23.555572
31100 05:59:23.842150
31200 06:01:24.357208
31300 06:03:24.856782
31400 06:05:25.131586
31500 06:07:25.368981
31600 06:09:35.205367
31700 06:11:45.301905
31800 06:13:49.891444
31900 06:15:50.173998
32000 06:17:50.691016
32100 06:19:51.113673
32200 06:21:51.620186
32300 06:23:52.121137
32400 06:25:52.718593
32500 06:27:53.269575
32600 06:29:53.828208
32700 06:31:54.479532
32800 06:33:54.950709
32900 06:35:55.570695
33000 06:37:56.342716
33100 06:39:57.088941
33200 06:41:57.770876
33300 06:43:58.439155
33400 06:45:59.184664
33500 06:47:59.693417
33600 06:50:00.211767
33700 06:52:00.185505
33800 06:54:00.816011
33900 06:56:01.776048
34000 06:58:02.450724
34100 07:00:03.059486
34200 07:02:03.387822
34300 07:04:03.896142
34400 07:06:04.730589
34500 07:08:08.874657
34600 07:10:19.021178
34700 07:12:41.601831
34800 07:15:25.132876
34900 07:17:47.065158
35000 07:20:23.237031
35100 07:22:32.225620
35200 07:24:42.280717
35300 07:27:02.120535
35400 07:29:18.237835
35500 07:31:23.751726
35600 07:33:39.843326
35700 07:35:57.804049
35800 07:38:11.493097
35900 07:41:53.785491
36000 07:44:00.145633
36100 07:46:08.780063
36200 07:48:16.114503
36300 07:50:26.650637
36400 07:52:37.340642
36500 07:54:50.159140
36600 07:56:52.198009
36700 07:58:56.892800
36800 08:01:04.098475
36900 08:03:06.265316
37000 08:05:15.184396
37100 08:07:21.629068
37200 08:09:23.989239
37300 08:11:25.206728
37400 08:14:59.155788
37500 08:17:02.673677

In [32]:
pickle.dump(knn_guesses,open('../data/mesh_guesses_knn_holdout.pkl','wb'))

In [33]:
len(knn_guesses)


Out[33]:
37527

In [ ]: