In [173]:
from __future__ import division
from pandas import *
import os, os.path
import csv
from itertools import product
import numpy as np
import matplotlib.pyplot as plt

from operator import methodcaller
from itertools import groupby
from Bio.Seq import Seq
from Bio import Motif
from StringIO import StringIO

from subprocess import check_output, check_call
from tempfile import NamedTemporaryFile as NTF
import shlex

os.chdir('/home/will/LTRtfAnalysis/')

In [3]:
def yield_motifs():
    with open('Jaspar_PWMs.txt') as handle:
        for key, lines in groupby(handle, methodcaller('startswith', '>')):
            if key:
                name = lines.next().strip().split()[-1].lower()
            else:
                tmp = ''.join(lines)
                mot = Motif.read(StringIO(tmp), 'jaspar-pfm')
                yield name, mot
                yield name+'-R', mot.reverse_complement()

            
pwm_dict = {}
for num, (name, mot) in enumerate(yield_motifs()):
    if num % 100 == 0:
        print num
    thresh = Motif.Thresholds.ScoreDistribution(mot, precision = 50).threshold_fpr(0.0001)
    pwm_dict[name] = (mot, thresh)


0
100
200

In [148]:
import pickle

with open('/home/will/HIVReportGen/Data/TrainingSequences/training_seqs.pkl') as handle:
    training_seqs = pickle.load(handle)
    
with open('/home/will/HIVReportGen/Data/TrainingSequences/training_pssm.pkl') as handle:
    pssm_data = pickle.load(handle)
    
with open('/home/will/HIVReportGen/Data/PatientFasta/seq_data.pkl') as handle:
    pat_seq_data = pickle.load(handle)

In [171]:
def fasta_reader(handle):

    name = None
    for key, lines in groupby(handle, lambda x:x.startswith('>')):
        if key:
            name = next(lines)[1:].strip()
        else:
            seq = ''.join(line.strip() for line in lines)
            yield name, seq

def fasta_writer(seq_tups, handle, alpha_only = True):

    for name, seq in seq_tups:
        if alpha_only:
            seq = ''.join(s for s in seq if s.isalpha())
        handle.write('>%s\n%s\n' % (name, seq))

def seq_number_gen(inseq):
    c = -1
    for let in inseq:
        if let != '-':
            c+=1
        yield c
            
            


def map_to_conb(conb_seq, other_seq):
    with NTF(mode = 'w') as handle:
        fasta_writer([('ConB', conb_seq), ('Other', other_seq)], handle)
        handle.flush()
        os.fsync(handle)
        with NTF(mode = 'rt') as ohandle:
            cmd = shlex.split('muscle -quiet -nocore -in %s -out %s' % (handle.name, ohandle.name))
            check_call(cmd)
            seq_dict = dict(fasta_reader(ohandle))
    
    mapping = dict(zip(seq_number_gen(seq_dict['Other']),
                       seq_number_gen(seq_dict['ConB'])))
    return mapping

In [170]:
ltr_data = training_seqs['LTR'].combine_first(pat_seq_data['LTR']).dropna()

In [167]:
with open('ltr_seqs.fasta', 'w') as handle:
    for (p,v), seq in zip(ltr_data.index, ltr_data.values):
        handle.write('>%s-%s\n%s\n' % (p,v,seq))

In [83]:
pssm_data = read_csv('/home/will/HIVReportGen/Data/TrainingSequences/pssm_data.csv', index_col = [0,1])
def decide_tropism(inval):
    if inval < -6.95:
        return True
    elif inval > -2.88:
        return False
    return np.nan
tropism_data = pssm_data['score'].map(decide_tropism).dropna()
trop_dict = {}
for (pat, visit), val in zip(tropism_data.index, tropism_data.values):
    trop_dict[pat+'-'+visit] = val
    
    
with open('/home/will/Dropbox/HIVseqs/BensTropismLabels.csv') as handle:
    reader = csv.DictReader(handle)
    for row in reader:
        trop_dict['%s-%s' % (row['Patient ID'], row['Visit'])] = row['Prediction'] == 'TRUE'

In [1]:
hxb2_ltr = """TGGAAGGGCTAATTTACTCCCAAAAAAGACAAGATATCCTTGATCTGTGGGTC
TACCACACACAAGGCTACTTCCCTGATTGGCAGAACTACACACCAGGGCCAGG
GATCAGATATCCACTGACCTTTGGATGGTGCTTCAAGCTAGTACCAGTTGAGC
CAGAGAAGGTAGAAGAGGCCAATGAAGGAGAGAACAACAGCTTGTTACACCCT
ATGAGCCTGCATGGGATGGAGGACCCGGAGAAAGAAGTGTTAGTGTGGAAGTT
TGACAGCCGCCTAGCATTTCATCACATGGCCCGAGAGCTGCATCCGGAGTACT
ACAAGGACTGCTGACATCGAGCTTTCTACAAGGGACTTTCCGCTGGGGACTTT
CCAGGGAGGCGTGGCCTGGGCGGGACTGGGGAGTGGCGAGCCCTCAGATGCTG
CATATAAGCAGCTGCTTTTTGCCTGTACTGGGTCTCTCTGGTTAGACCAGATC
TGAGCCTGGGAGCTCTCTGGCTAACTAGGGAACCCACTGCTTAAGCCTCAATA
AAGCTTGCCTTGAGTGCTTCAAGTAGTGTGTGCCCGTCTGTTGTGTGACTCTG
GTAACTAGAGATCCCTCAGACCCTTTTAGTCAGTGTGGAAAATCTCTAGCA""".replace('\n', '')

In [2]:
hxb2_ltr[329:337]


Out[2]:
'TGACATCG'

In [319]:
known_binding_pos = [('AP-1 IV', 'ap1', 104),
     ('AP-1 III','ap1', 119),
     ('AP-1 II','ap1', 154),
     #('GRE', pwm_dict['GRE'][0], 191),
     ('AP-1 I','ap1', 213),
     ('C/EBP II', 'cebpa', 280),
     #('USF-1', pwm_dict['USF-1'][0], 221),
     ('ETS-1', 'ets1', 304),
     #('Lef-1', pwm_dict['Lef-1'][0], 317),
     ('ATF/Creb', 'creb1', 329),
     ('C/EBP I', 'cebpa', 337),
     ('NFkB II', 'nf-kappab', 349),
     ('NFkB I', 'nf-kappab', 362),
     ('Sp III', 'sp1', 376),
     ('Sp II', 'sp1', 387),
     ('Sp I', 'sp1', 398),
     ('AP-1','ap1', 539),     
     ('AP-1','ap1', 571),
     #('Oct-1', pwm_dict['OCT1'][0], 440),
]

wanted_pwms = [('ap1', pwm_dict['ap1'][0]),
                ('cebpa', pwm_dict['cebpa'][0]), 
                ('ets1', pwm_dict['ets1'][0]),
                ('creb1', pwm_dict['creb1'][0]),
                ('nf-kappab', pwm_dict['nf-kappab'][0]),
                ('sp1', pwm_dict['sp1-R'][0]),]

In [95]:
seq = 'TGACATCG'
pos = hxb2_ltr.find(seq)

print pos, pos+len(seq)


329 337

In [322]:
from concurrent.futures import ProcessPoolExecutor, as_completed
from functools import partial
from itertools import islice
from itertools import imap, chain
from operator import itemgetter
from collections import defaultdict

from scipy.optimize import minimize_scalar

def unique_justseen(iterable, key=None):
    "List unique elements, preserving order. Remember only the element just seen."
    # unique_justseen('AAAABBBCCDAABBB') --> A B C D A B
    # unique_justseen('ABBCcAD', str.lower) --> A B C A D
    return imap(next, imap(itemgetter(1), groupby(iterable, key)))

def take(n, iterable):
    "Return first n items of the iterable as a list"
    return list(islice(iterable, n))


def scan_seqs(seq, pwm_tup):
    
    seq = Seq(seq)
    name, mot, thresh_fpr = pwm_tup
    thresh = Motif.Thresholds.ScoreDistribution(mot, precision = 50).threshold_fpr(thresh_fpr)
    results = []
    for loc, m in mot.search_pwm(seq, threshold=thresh):
        if loc > 0:
            results.append((name, loc, loc+len(mot)))
    
    return results

def check_seq(seq, mapping, pwms, thresh_fpr, executor):
    
    tups = [(name, pwm, thresh_fpr) for name, pwm in pwms]
    anal_fun = partial(scan_seqs, seq)
    res = executor.map(anal_fun, tups)
    for tf, start, stop in chain.from_iterable(res):
        yield tf, mapping[start], mapping[stop-1]
    
def check_all_seqs(seqs, mappings, wanted_pwms, thresh_fpr, executor):
    
    all_res = []
    for n, (seq, mapping) in enumerate(zip(seqs, mappings)):
        all_res.append((n, check_seq(seq, mapping, wanted_pwms, thresh_fpr, executor)))
    for n, res in all_res:
        for tf, start, stop in res:
            yield n, tf, start, stop
    
    
    
def obj_fun(thresh_fpr, seqs, mappings, pwms, allowed_binding_pos, correct_binding_pos, executor):
    
    correct_found = 0
    extra_found = 0
    missing = 0
    
    
    for pat, rows in groupby(check_all_seqs(seqs, mappings, pwms, thresh_fpr, executor), key = lambda x: x[0]):
        found_tfs = defaultdict(int)
        for _, tf, start, stop in rows:
            found_tfs[allowed_binding_pos.get((tf, start), None)] += 1
            #if thresh_fpr < 0.001:
            #    print tf, start, stop
        
        for binding_pos in correct_binding_pos:
            if found_tfs[binding_pos] == 0:
                missing += 1
            elif found_tfs[binding_pos] == 1:
                correct_found += 1
            else:
                correct_found += 1
                extra_found += found_tfs[binding_pos]-1
        extra_found += found_tfs[None]
    #if thresh_fpr < 0.001:
    #    raise KeyError
  
    print thresh_fpr, correct_found, missing, extra_found
    return -correct_found + extra_found
    

final_results = []
ltrseqs = ltr_data.dropna().head(n=50)


        
with ProcessPoolExecutor(max_workers = 20) as executor:
    mapping_fun = partial(map_to_conb, hxb2_ltr)
    seq_mappings = list(executor.map(mapping_fun, ltrseqs.values))
    
    for row in wanted_pwms:
        print row
        allowed_binding_pos = dict()
        for _, tf, pos in known_binding_pos:
            if tf == row[0]:
                for nudge in range(-5,6):
                    allowed_binding_pos[(tf, pos+nudge)] = '%s-%i' % (tf, pos)

        correct_binding_pos = set(allowed_binding_pos.values())
    
        res = minimize_scalar(obj_fun, bounds = [0,0.1], method = 'bounded', 
                args = (ltrseqs.values, seq_mappings, [row], allowed_binding_pos, correct_binding_pos, executor))
    
print res


('ap1', <Bio.Motif._Motif.Motif object at 0x32cf910>)
0.038196601125 1517 1477 8381
0.061803398875 1952 1048 16300
0.02360679775 1114 1868 5614
0.0109366097056 1035 1917 2541
0.0157761908947 1050 1908 3254
0.00722344801765 398 2518 1242
0.00446433639088 393 2487 845
0.00275911162677 377 2023 255
0.0017052247641 350 1822 96
0.00141465917949 350 1816 54
0.000874307455421 345 1761 22
0.000924988702411 345 1773 41
0.000540351724068 2 22 2
0.000746747716782 3 33 3
0.000825583970857 345 1761 22
0.000849945713139 345 1761 22
0.000840640355613 345 1761 22
0.000834889328384 345 1761 22
0.000831334998086 345 1761 22
('cebpa', <Bio.Motif._Motif.Motif object at 0x32cfa50>)
0.038196601125 832 164 6443
0.061803398875 890 108 10320
0.02360679775 498 494 3648
0.014589803375 433 547 1998
0.00901699437495 409 547 736
0.00557280900008 294 360 129
0.00344418537486 293 347 90
0.00411054819811 293 359 111
0.00212862362522 15 105 53
0.0029416855008 15 137 72
0.00369871332451 293 355 100
0.00325224750231 293 345 88
0.00326764053909 293 345 88
0.0032599440207 293 345 88
0.00325661063901 293 345 88
('ets1', <Bio.Motif._Motif.Motif object at 0x32ba910>)
0.038196601125 481 17 7240
0.061803398875 481 17 9414
0.02360679775 477 17 3420
0.014589803375 477 17 2691
0.0138174706229 477 17 2691
0.014203636999 477 17 2691
0.0140561345686 477 17 2691
0.0139649730533 477 17 2691
0.0139086321383
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-322-45976514e874> in <module>()
    100 
    101         res = minimize_scalar(obj_fun, bounds = [0,0.1], method = 'bounded', 
--> 102                 args = (ltrseqs.values, seq_mappings, [row], allowed_binding_pos, correct_binding_pos, executor))
    103 
    104 print res

/usr/local/lib/python2.7/dist-packages/scipy/optimize/_minimize.pyc in minimize_scalar(fun, bracket, bounds, args, method, tol, options)
    470             raise ValueError('The `bounds` parameter is mandatory for '
    471                              'method `bounded`.')
--> 472         return _minimize_scalar_bounded(fun, bounds, args, **options)
    473     elif meth == 'golden':
    474         return _minimize_scalar_golden(fun, bracket, args, **options)

/usr/local/lib/python2.7/dist-packages/scipy/optimize/optimize.pyc in _minimize_scalar_bounded(func, bounds, args, xtol, maxiter, disp, **unknown_options)
   1474         si = numpy.sign(rat) + (rat == 0)
   1475         x = xf + si*max([abs(rat), tol1])
-> 1476         fu = func(x, *args)
   1477         num += 1
   1478         fmin_data = (num, x, fu)

<ipython-input-322-45976514e874> in obj_fun(thresh_fpr, seqs, mappings, pwms, allowed_binding_pos, correct_binding_pos, executor)
     59     for pat, rows in groupby(check_all_seqs(seqs, mappings, pwms, thresh_fpr, executor), key = lambda x: x[0]):
     60         found_tfs = defaultdict(int)
---> 61         for _, tf, start, stop in rows:
     62             found_tfs[allowed_binding_pos.get((tf, start), None)] += 1
     63             #if thresh_fpr < 0.001:

<ipython-input-322-45976514e874> in check_all_seqs(seqs, mappings, wanted_pwms, thresh_fpr, executor)
     45         all_res.append((n, check_seq(seq, mapping, wanted_pwms, thresh_fpr, executor)))
     46     for n, res in all_res:
---> 47         for tf, start, stop in res:
     48             yield n, tf, start, stop
     49 

<ipython-input-322-45976514e874> in check_seq(seq, mapping, pwms, thresh_fpr, executor)
     36     anal_fun = partial(scan_seqs, seq)
     37     res = executor.map(anal_fun, tups)
---> 38     for tf, start, stop in chain.from_iterable(res):
     39         yield tf, mapping[start], mapping[stop-1]
     40 

/usr/local/lib/python2.7/dist-packages/concurrent/futures/_base.pyc in map(self, fn, *iterables, **kwargs)
    547             for future in fs:
    548                 if timeout is None:
--> 549                     yield future.result()
    550                 else:
    551                     yield future.result(end_time - time.time())

/usr/local/lib/python2.7/dist-packages/concurrent/futures/_base.pyc in result(self, timeout)
    397                 return self.__get_result()
    398 
--> 399             self._condition.wait(timeout)
    400 
    401             if self._state in [CANCELLED, CANCELLED_AND_NOTIFIED]:

/usr/lib/python2.7/threading.pyc in wait(self, timeout)
    241         try:    # restore state no matter what (e.g., KeyboardInterrupt)
    242             if timeout is None:
--> 243                 waiter.acquire()
    244                 if __debug__:
    245                     self._note("%s.wait(): got it", self)

KeyboardInterrupt: 
 477 17 2691

In [293]:
correct_binding_pos


Out[293]:
set(['nf-kappab-349',
     'sp1-387',
     'nf-kappab-362',
     'sp1-398',
     'ap1-104',
     'ap1-154',
     'ets1-304',
     'creb1-329',
     'ap1-213',
     'cebpa-337',
     'ap1-119',
     'cebpa-280',
     'sp1-376'])

In [201]:
tfdata = DataFrame(final_results, columns = ['Patient ID', 'Visit Number', 'TFName', 'Start', 'Stop'])
tfdata


Out[201]:
<class 'pandas.core.frame.DataFrame'>
Int64Index: 36513 entries, 0 to 36512
Data columns:
Patient ID      36513  non-null values
Visit Number    36513  non-null values
TFName          36513  non-null values
Start           36513  non-null values
Stop            36513  non-null values
dtypes: int64(2), object(3)

In [202]:
tf_counts = tfdata[['Patient ID', 'TFName']].groupby('TFName').count()['Patient ID']
print tf_counts


TFName
ap1               7
ap1-R             5
ar               63
ar-R              4
arid3a-R          9
arnt            543
arnt-R          543
brca1            45
brca1-R           4
cebpa             1
cebpa-R           1
creb1-R           2
ctcf             15
ctcf-R            2
ddit3::cebpa      1
...
tfap2a-R          9
tlx1::nfic       17
tlx1::nfic-R      2
tp53              1
tp53-R            9
usf1            535
usf1-R           10
zeb1             52
zeb1-R            9
zfp423            5
zfp423-R         82
zfx             845
zfx-R            14
znf143           12
znf143-R         14
Name: Patient ID, Length: 216

In [203]:
tf_grouped = tfdata.groupby(['TFName', 'Patient ID', 'Visit Number', 'Start']).first()
print tf_grouped


<class 'pandas.core.frame.DataFrame'>
MultiIndex: 36511 entries, (ap1, A0060, R00, 575) to (znf143-R, M17449, RN, 607)
Data columns:
Stop    36511  non-null values
dtypes: float64(1)

In [204]:
def crazy_iterable():
    tindex = list(ltrseqs.index.copy())
    print tindex[:10]
    tindex.sort(key = lambda x: trop_dict.get('%s-%s'%x, 'Unknown'))
    for key, inds in groupby(tindex, key = lambda x: trop_dict.get('%s-%s'%x, 'Unknown')):
        if key == True:
            key = 'R5'
        elif key == False:
            key = 'X4'
            
        for n, (p, v) in enumerate(inds):
            yield (key, n, p, v)
pat_inds = DataFrame(list(crazy_iterable()), columns = ['Tropism', 'Row', 'Patient ID', 'Visit Number'])
map_sizes = pat_inds[['Tropism', 'Row']].groupby('Tropism').max()['Row']


[('A0001', 'R00'), ('A0001', 'R01'), ('A0001', 'R04'), ('A0002', 'R00'), ('A0002', 'R01'), ('A0002', 'R03'), ('A0002', 'R04'), ('A0002', 'R05'), ('A0002', 'R06'), ('A0002', 'R09')]

In [205]:
def make_filename(inp):
    return inp.replace(' ', '-').replace(':', '-')

In [206]:
from pylab import get_cmap


order = ['R5', 'X4', 'Unknown']
cmap = get_cmap('Greys')
for tf, num in zip(tf_counts.index, tf_counts.values):
    if num > 40:
        map_dict = {
            'Unknown':np.zeros((map_sizes['Unknown']+1,700)),
            'R5':np.zeros((map_sizes['R5']+1,700)),
            'X4':np.zeros((map_sizes['X4']+1,700)),
        }
        tmp = tf_grouped.ix[tf].reset_index()
        merged = merge(tmp, pat_inds, left_on = ['Patient ID', 'Visit Number'], right_on = ['Patient ID', 'Visit Number'])
        for _, row in merged.iterrows():
            map_dict[row['Tropism']][row['Row'], row['Start']:row['Stop']] += 1
    
        fig, axes = plt.subplots(3,1, sharex = True, figsize = (10,10))
        plt.title(tf)
        for key, ax in zip(order, axes.flatten()):
            ax.imshow(map_dict[key], cmap = cmap, aspect = 'auto')
            ax.set_ylim(0, map_sizes[key]+1)
            if key != 'Unknown':
                ax.set_ylabel(key + ' TF:' + tf)
            else:
                ax.set_ylabel(key)
        plt.xlabel('LTR Position')
        
        fname = make_filename(tf)
        #plt.savefig('figures/%s.png' % fname)



In [207]:
map_sizes


Out[207]:
Tropism
R5           94
Unknown    1146
X4           18
Name: Row

In [228]:
from sklearn.cluster import MeanShift


cluster_data = DataFrame(columns = ['Patient ID', 'Visit Number', 'TFName', 'Start', 'Cluster'])

for tf, num in zip(tf_counts.index, tf_counts.values):
    
    data = tf_grouped.ix[tf].reset_index()
    data['TFName'] = tf
    clust = MeanShift(bandwidth = 10)
    res = clust.fit_predict(data[['Start']].values)
    data['Cluster'] = res
    cluster_data = concat([cluster_data, data], axis = 0, ignore_index = True)

In [229]:
res = crosstab(rows = [cluster_data['Patient ID'], cluster_data['Visit Number']], cols = [cluster_data['TFName'], cluster_data['Cluster']])

In [233]:
from sklearn.cluster import k_means, mean_shift

centroids, labels = mean_shift(res.values)

labels = Series(labels, index = res.index)
labels.sort()

plt.figure(figsize = (20,20))

plt.imshow(res.ix[labels.index].values)


/usr/local/lib/python2.7/dist-packages/sklearn/cluster/mean_shift_.py:45: NeighborsWarning: kneighbors: neighbor k+1 and neighbor k have the same distance: results will be dependent on data order.
  d, _ = nbrs.kneighbors(X, return_distance=True)
Out[233]:
<matplotlib.image.AxesImage at 0x1057fd50>

In [234]:
labels


Out[234]:
Patient ID  Visit Number
A0001       R00             0
            R01             0
            R04             0
A0002       R00             0
            R01             0
            R04             0
            R05             0
            R09             0
A0003       R00             0
A0004       R00             0
            R01             0
            R02             0
            R03             0
            R04             0
            R05             0
...
JN944909    RN               2
JN944941    RN               2
AB289587    RN               3
AB289588    RN               3
A0005       R04              4
A0461       R00              5
A0078       R00              6
A0132       R07              7
A0289       R01              7
A0371       R01              7
A0056       R00              8
JQ316126    RN               9
A0231       R00             10
A0083       R00             11
A0132       R02             12
Length: 1247

In [ ]: