In [1]:
from __future__ import division
from pandas import DataFrame, Series, merge, read_csv, MultiIndex, Index, concat
from subprocess import check_call
from tempfile import NamedTemporaryFile as NTF
import os, os.path
import numpy as np
from scipy.stats import ttest_ind
from itertools import groupby,combinations, islice
from operator import itemgetter
from Bio import Phylo
import networkx
import sys

from random import shuffle
import csv, shlex, shutil

os.chdir('/home/will/HIVTropism//')
sys.path.append('/home/will/HIVReportGen/AnalysisCode/')
sys.path.append('/home/will/PySeqUtils/')

In [2]:
from SeqProcessTools import read_pat_seq_data, load_training_seq_data, align_seq_data_frame
from TreeingTools import make_mrbayes_trees, run_bats, get_pairwise_distances, check_distance_pvals
from GeneralSeqTools import fasta_reader, WebPSSM_V3_fasta, yield_chunks
from Bio.Seq import Seq
from Bio.Alphabet import generic_dna
import glob
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
from itertools import chain


/usr/local/lib/python2.7/dist-packages/futures/__init__.py:24: DeprecationWarning: The futures package has been deprecated. Use the concurrent.futures package instead.
  DeprecationWarning)

In [12]:
def simple_translate(inseq):
    seq = Seq(inseq, alphabet=generic_dna)
    return seq.translate().tostring()


seq_files = glob.glob('LANLdata/*.fasta')
seq_data = []
for f in seq_files:
    parts = f.split(os.sep)[-1].split('-')
    prot = parts[1]
    subtype = parts[0]
    with open(f) as handle:
        for name, seq in fasta_reader(handle):
            nseq = ''.join(l for l in seq if l.isalpha())
            if prot != 'LTR':
                nseq = simple_translate(nseq)
            seq_data.append((subtype, name, prot, nseq))
            
pat_ltr_seqs = glob.glob('/home/will/HIVReportGen/Data/PatientFasta/*LTR.fasta')
print len(pat_ltr_seqs)
for f in pat_ltr_seqs:
    with open(f) as handle:
        for name, seq in fasta_reader(handle):
            seq_data.append(('SubB', name, 'LTR', seq))
            
pat_v3_seqs = glob.glob('/home/will/HIVReportGen/Data/PatientFasta/*V3.fasta')
print len(pat_v3_seqs)
for f in pat_v3_seqs:
    with open(f) as handle:
        for name, seq in fasta_reader(handle):
            seq_data.append(('SubB', name, 'V3', seq))

            
seq_df = DataFrame(seq_data, columns=['Subtype', 'Name', 'Prot', 'Seq'])


1025
110

In [13]:
v3_seqs = [(name, seq) for subtype, name, prot, seq in seq_data if prot == 'V3']

have_data = []
need_data = set([name for name, _ in v3_seqs])
count = 0
while need_data and count < 5:
    count += 1
    print len(need_data)
    gen = ((name, seq) for name, seq in v3_seqs if name in need_data)
    chunks = yield_chunks(gen, 50)
    with ThreadPoolExecutor(max_workers = 10) as e:
        res = e.map(WebPSSM_V3_fasta, chunks)
        have_data += list(chain.from_iterable(res))
    need_data -= set(row[0] for row in have_data)


21488
2900
2750
2650
2650

In [14]:
import numpy as np
def safe_float(inval):
    try:
        return float(inval)
    except ValueError:
        return np.nan

fields = ['name','score','pred','x4.pct','r5.pct','geno','pos.chg','net.chg','percentile']
pssm_data = DataFrame(have_data, columns=fields)
pssm_data['pred'] = pssm_data['pred'] == '1'

float_cols = [1, 3, 4, 6, 7, 8]
for col in float_cols:
    pssm_data[fields[col]] = pssm_data[fields[col]].map(safe_float)
    
valid_pssm = pssm_data[pssm_data['percentile']<0.95]

In [15]:
trop_scores = valid_pssm.groupby('name')[['score']].mean()
#trop_scores.to_excel('NewPSSMScores.xls')

In [16]:
tmp = dict(zip(seq_df['Name'], seq_df['Subtype']))
grouped_seq_df = seq_df.pivot(index = 'Name', columns='Prot', values='Seq').reset_index()
grouped_seq_df = merge(grouped_seq_df, trop_scores, 
                        left_on = 'Name', right_index = True)
grouped_seq_df['Subtype'] = grouped_seq_df['Name'].map(lambda x: tmp[x])
print grouped_seq_df


<class 'pandas.core.frame.DataFrame'>
Int64Index: 17844 entries, 0 to 24495
Data columns:
Name       17844  non-null values
LTR        144  non-null values
Nef        986  non-null values
V3         17844  non-null values
gp120      2762  non-null values
gp41       2425  non-null values
score      17844  non-null values
Subtype    17844  non-null values
dtypes: float64(1), object(7)

In [17]:
from collections import defaultdict
from random import shuffle
def calculate_mutual_info(signal1, signal2, shuf = False, batch=False, **kwargs):
    """Caluculates the Mutual Information shared by two signals.

    Arguements:
    signal1 -- An iterable indicating the first signal
    signal2 -- An iterable indicating the second signal

    Signals MUST be the same length! Items must be hashable!

    Returns:
    Mutual Information -- float"""


    if shuf:
        shuffle(signal1)
        shuffle(signal2)
    if batch:
        res = []
        extra = {}
        for _ in xrange(batch):
            r, extra = calculate_mutual_info(signal1, signal2, shuf=True, want_extra=True, **extra)
            res.append(r)
        return res

    overlap_prob = signal2prob(zip(signal1, signal2))
    signal1_prob = kwargs.get('S1prob',signal2prob(signal1))
    signal2_prob = kwargs.get('S2prob',signal2prob(signal2))

    num_items = len(signal1)
    mut_info = float()


    for (s1, s2), count in overlap_prob.items():
        mut_info += overlap_prob[(s1, s2)]*log(overlap_prob[(s1, s2)]/(signal1_prob[s1]*signal2_prob[s2]))

    if kwargs.get('want_extra', False):
        return mut_info, {'S1prob':signal1_prob, 'S2prob':signal2_prob}
    else:
        return mut_info
    
def signal2prob(signal):
    counts = signal2count(signal)
    num = len(signal)
    for key, val in counts.items():
        counts[key] = val/num
    return counts

def signal2count(signal):
    counts = defaultdict(int)
    for s in signal:
        counts[s] += 1
    return counts

In [18]:
wanted_seq_data = grouped_seq_df.dropna(subset = ['score'])
align_data = align_seq_data_frame(wanted_seq_data,  '/home/will/HIVReportGen/Data/BlastDB/ConBseqs.txt')

In [21]:
def decide_tropism(inval):
    if inval < -6.95:
        return 'R5'
    elif inval > -2.88:
        return 'X4'
    return np.nan
cols = ['gp120-seq-align',
        'gp41-seq-align',
        'Nef-seq-align',
        'LTR-seq-align']
align_data['Tropism'] = align_data['score'].map(decide_tropism)
#ask = align_data['Tropism'] == 'X4'
#mp_data = align_data[[col, 'V3-seq-align']][mask].dropna()
#print tmp_data
align_data


Out[21]:
<class 'pandas.core.frame.DataFrame'>
Int64Index: 17844 entries, 0 to 24495
Data columns:
Name               17844  non-null values
LTR                144  non-null values
Nef                986  non-null values
V3                 17844  non-null values
gp120              2762  non-null values
gp41               2425  non-null values
score              17844  non-null values
Subtype            17844  non-null values
LTR-bin-align      144  non-null values
LTR-seq-align      144  non-null values
Nef-bin-align      986  non-null values
Nef-seq-align      986  non-null values
V3-bin-align       17844  non-null values
V3-seq-align       17844  non-null values
gp120-bin-align    2762  non-null values
gp120-seq-align    2762  non-null values
gp41-bin-align     2425  non-null values
gp41-seq-align     2425  non-null values
Tropism            14960  non-null values
dtypes: float64(1), object(18)

In [225]:
t = align_data.dropna(subset = ['gp120', 'gp41', 'Nef', 'LTR'], how = 'all')
t['Subtype'].value_counts()


Out[225]:
SubB    2010
SubC     937

In [227]:
tmp = pivot_table(align_data, rows = ['Subtype', 'Tropism'], values = ['gp120', 'gp41', 'LTR', 'Nef'], aggfunc = 'count')
seq_counts = tmp.drop(['Subtype', 'Tropism'], axis = 1)
seq_counts

#crosstab(align_data['Subtype'][v3mask], )
#align_data.dropna(subset = ['V3'])


Out[227]:
LTR Nef gp120 gp41
Subtype Tropism
SubB R5 91 616 1660 1414
X4 15 27 70 59
SubC R5 11 177 524 485
X4 1 8 32 33

In [220]:
pivot_table?

In [39]:
def process_tups(tups):
    ocols, tupA, tupB = tups
    batchsize = 100
    posA, colA = tupA
    posB, colB = tupB
    subtype, col, name = ocols
    
    
    tv = calculate_mutual_info(list(colA), list(colB))
    wrongs = []
    numwrongs = 0
    while (numwrongs < 10) and batchsize < 5000:
        wrongs += calculate_mutual_info(list(colA), list(colB), batch = batchsize)
        numwrongs = (np.array(wrongs) > tv).sum()
        batchsize *= 2
    wrongs = np.array(wrongs)
    pval = (wrongs > tv).mean()
    return subtype, col, name, posA, posB, tv, wrongs.mean(), len(wrongs), pval

In [40]:
from itertools import product, imap
from itertools import dropwhile
from functools import partial

cols = ['gp120-seq-align',
        'gp41-seq-align',
        'Nef-seq-align',
        'LTR-seq-align']
masks = [('All', align_data['V3'].map(len)>0),
          ('R5', align_data['Tropism'] == 'R5'),
          ('X4', align_data['Tropism'] == 'X4')]
subtypes = [('SubBC', (align_data['Subtype'] == 'SubB') | (align_data['Subtype'] == 'SubC')),  
                ('SubC', align_data['Subtype'] == 'SubC'),
                ('SubB', align_data['Subtype'] == 'SubB'), 
            ]

def yield_cols():
    for (subname, submask), col, (trop_name, trop_mask) in product(subtypes, cols, masks):
    
        tmp_data = align_data[[col, 'V3-seq-align']][submask & trop_mask].dropna()
        if len(tmp_data.index) < 20:
            continue
            
        ovals = zip(*tmp_data[col].values)
        v3vals = zip(*tmp_data['V3-seq-align'].values)
        args = product(enumerate(ovals), enumerate(v3vals))
        for tupA, tupB in args:
            yield (subname, col, trop_name), tupA, tupB
    
def check_item(last_res, row):
    
    cols = ['gp120-seq-align',
        'gp41-seq-align',
        'Nef-seq-align',
        'LTR-seq-align']
    names = ['All', 'R5', 'X4']
    
    last_col_num = [num for num, val in enumerate(cols) if val == last_res[1]][0]
    last_mask_num = [num for num, val in enumerate(names) if val == last_res[0]][0]
    last_oval = int(last_res[2])
    last_v3val = int(last_res[3])
    
    (cur_col, cur_name), (posA, _), (posB, _) = row
    #print cur_col, cur_name
    cur_col_num = [num for num, val in enumerate(cols) if val == cur_col][0]
    cur_mask_num = [num for num, val in enumerate(names) if val == cur_name][0]
    
    checks = [(last_col_num, cur_col_num),
                (last_mask_num, cur_mask_num),
                (last_oval, posA),
                (last_v3val, posB)]
    #print checks
    
    for last, cur in checks:
        if last > cur:
            return True

    return False

In [41]:
results = []#list(csv.reader(open('quick_linkage.csv')))[:-2]
writer = csv.writer(open('subtype_linkage.csv', 'w'))
process_iter = yield_cols()

with ProcessPoolExecutor(max_workers = 30) as e:
    
    res_iter = enumerate(e.map(process_tups, process_iter))
    for num, row in res_iter:
        if (num == 10) or (num == 100) or (num == 1000) or (num % 5000 == 0):
            print num, row
        
        writer.writerow(row)
        results.append(row)


0 ('SubBC', 'gp120-seq-align', 'All', 0, 0, 0.0075236991975863206, 0.0063536565230194308, 100, 0.20000000000000001)
10 ('SubBC', 'gp120-seq-align', 'All', 0, 10, 0.03564762173353566, 0.020985485394383663, 6300, 0.0)
100 ('SubBC', 'gp120-seq-align', 'All', 3, 4, 0.037207694500780285, 0.023932139668861312, 6300, 0.0)
1000 ('SubBC', 'gp120-seq-align', 'All', 31, 8, 0.030265592317767865, 0.026162365388082195, 100, 0.10000000000000001)
5000 ('SubBC', 'gp120-seq-align', 'All', 156, 8, 0.035661704722320831, 0.034061222019237052, 100, 0.26000000000000001)
10000 ('SubBC', 'gp120-seq-align', 'All', 312, 16, 0.033479277729287139, 0.022081553942630439, 6300, 0.0)
15000 ('SubBC', 'gp120-seq-align', 'R5', 6, 24, 0.061841850519445504, 0.039441107144451669, 6300, 0.0)
20000 ('SubBC', 'gp120-seq-align', 'R5', 163, 0, 0.0086997133423464146, 0.011966773257781288, 100, 0.96999999999999997)
25000 ('SubBC', 'gp120-seq-align', 'R5', 319, 8, 0.039747785202006773, 0.034812975330859322, 300, 0.076666666666666661)
30000 ('SubBC', 'gp120-seq-align', 'X4', 13, 16, 0.22896494016645516, 0.236958655983938, 100, 0.56000000000000005)
35000 ('SubBC', 'gp120-seq-align', 'X4', 169, 24, 0.42123303204591178, 0.42951094515053334, 100, 0.55000000000000004)
40000 ('SubBC', 'gp120-seq-align', 'X4', 326, 0, 0.015422383485079962, 0.035868320367297611, 100, 0.71999999999999997)
45000 ('SubBC', 'gp41-seq-align', 'All', 20, 8, 0.0056472368457109456, 0.0027176415093186799, 300, 0.053333333333333337)
50000 ('SubBC', 'gp41-seq-align', 'All', 176, 16, 0.018324638501450864, 0.013962228023673428, 300, 0.053333333333333337)
55000 ('SubBC', 'gp41-seq-align', 'R5', 19, 24, 0.012432164703782984, 0.0091829782472780397, 300, 0.093333333333333338)
60000 ('SubBC', 'gp41-seq-align', 'R5', 176, 0, 0.0058621397440392379, 0.0045000410843927235, 100, 0.19)
65000 ('SubBC', 'gp41-seq-align', 'X4', 19, 8, 0.06189934751240167, 0.069112974693351065, 100, 0.60999999999999999)
70000 ('SubBC', 'gp41-seq-align', 'X4', 175, 16, 0.17906057048334081, 0.13561377794696511, 100, 0.10000000000000001)
75000 ('SubBC', 'Nef-seq-align', 'All', 18, 24, 0.035385974446638685, 0.010939648709231654, 6300, 0.0)
80000 ('SubBC', 'Nef-seq-align', 'All', 175, 0, 0.0081015088382062062, 0.0062062579680080528, 100, 0.22)
85000 ('SubBC', 'Nef-seq-align', 'R5', 126, 8, 0.0012387351488129781, 0.0045213004123766795, 100, 0.89000000000000001)
90000 ('SubBC', 'Nef-seq-align', 'X4', 77, 16, 0.0, 0.0, 6300, 0.0)
95000 ('SubBC', 'LTR-seq-align', 'All', 28, 24, 0.092298677299764501, 0.063456746817598955, 300, 0.053333333333333337)
100000 ('SubBC', 'LTR-seq-align', 'All', 185, 0, 0.041758670670721684, 0.0067888725765745828, 6300, 0.0015873015873015873)
105000 ('SubBC', 'LTR-seq-align', 'All', 341, 8, 0.059360533389386107, 0.042121396922692504, 100, 0.14000000000000001)
110000 ('SubBC', 'LTR-seq-align', 'All', 497, 16, 0.019411034884470487, 0.040551996630514037, 100, 0.89000000000000001)
115000 ('SubBC', 'LTR-seq-align', 'R5', 23, 24, 0.068927602131829815, 0.06426653620789996, 100, 0.34999999999999998)
120000 ('SubBC', 'LTR-seq-align', 'R5', 180, 0, 0.017440740617340616, 0.0097876128844944094, 100, 0.12)
125000 ('SubBC', 'LTR-seq-align', 'R5', 336, 8, 0.024100861061428825, 0.015466119974937241, 100, 0.16)
130000 ('SubBC', 'LTR-seq-align', 'R5', 492, 16, 0.028622356390210848, 0.033122458681399965, 100, 0.46000000000000002)
135000 ('SubC', 'gp120-seq-align', 'All', 18, 24, 0.048324322000990197, 0.051982533509611469, 100, 0.67000000000000004)
140000 ('SubC', 'gp120-seq-align', 'All', 175, 0, 0.020349872177393188, 0.01345199629136937, 700, 0.029999999999999999)
145000 ('SubC', 'gp120-seq-align', 'All', 331, 8, 0.06564225145474202, 0.053319537047654984, 700, 0.027142857142857142)
150000 ('SubC', 'gp120-seq-align', 'R5', 25, 16, 0.0095633517171682704, 0.0040500428292512139, 1500, 0.014666666666666666)
155000 ('SubC', 'gp120-seq-align', 'R5', 181, 24, 0.058829867815639751, 0.05267541125329453, 100, 0.19)
160000 ('SubC', 'gp120-seq-align', 'R5', 338, 0, 0.01497848466332817, 0.01082502610163398, 100, 0.11)
165000 ('SubC', 'gp120-seq-align', 'X4', 32, 8, 0.51553370127548348, 0.51796024192624845, 100, 0.5)
170000 ('SubC', 'gp120-seq-align', 'X4', 188, 16, 0.037573604794753884, 0.10178822944118066, 100, 0.69999999999999996)
175000 ('SubC', 'gp120-seq-align', 'X4', 344, 24, 0.74105677170463002, 0.71996511848717704, 100, 0.34000000000000002)
180000 ('SubC', 'gp41-seq-align', 'All', 39, 0, 0.0022982680116013089, 0.0024292516972657468, 100, 0.23999999999999999)
185000 ('SubC', 'gp41-seq-align', 'All', 195, 8, 0.020301242108578747, 0.016932908689463906, 100, 0.16)
190000 ('SubC', 'gp41-seq-align', 'R5', 38, 16, 0.00017303137496480617, 0.00062498249497290376, 300, 0.083333333333333329)
195000 ('SubC', 'gp41-seq-align', 'R5', 194, 24, 0.046101076250819287, 0.048459147057332376, 100, 0.57999999999999996)
200000 ('SubC', 'gp41-seq-align', 'X4', 38, 0, 0.00094712388639037518, 0.0042219755902596417, 700, 0.024285714285714285)
205000 ('SubC', 'gp41-seq-align', 'X4', 194, 8, 0.50290086344787344, 0.5324637688558822, 100, 0.63)
210000 ('SubC', 'Nef-seq-align', 'All', 37, 16, 0.0089335524110161404, 0.0053743662997004227, 100, 0.20000000000000001)
215000 ('SubC', 'Nef-seq-align', 'All', 193, 24, 0.0033843285909430781, 0.011016855040819918, 100, 0.63)
220000 ('SubC', 'Nef-seq-align', 'R5', 145, 0, 9.6854935823293483e-05, 0.00051539974312178594, 700, 0.014285714285714285)
225000 ('SubC', 'LTR-seq-align', 'All', 96, 8, 0.0, 0.0, 6300, 0.0)
230000 ('SubC', 'LTR-seq-align', 'All', 252, 16, 0.0, 0.0, 6300, 0.0)
235000 ('SubC', 'LTR-seq-align', 'All', 408, 24, 0.005412003687107942, 0.033090134704864146, 300, 0.14333333333333334)
240000 ('SubC', 'LTR-seq-align', 'All', 565, 0, 0.0, 0.0, 6300, 0.0)
245000 ('SubB', 'gp120-seq-align', 'All', 91, 8, 0.027561654091975404, 0.021662853080893038, 300, 0.073333333333333334)
250000 ('SubB', 'gp120-seq-align', 'All', 247, 16, 0.062179825796121746, 0.039114141464039041, 6300, 0.0)
255000 ('SubB', 'gp120-seq-align', 'All', 403, 24, 0.059185178719689945, 0.04015245290893589, 6300, 0.0)
260000 ('SubB', 'gp120-seq-align', 'R5', 98, 0, 0.0093181598583692172, 0.0085722742001129778, 100, 0.37)
265000 ('SubB', 'gp120-seq-align', 'R5', 254, 8, 0.052078547671836888, 0.04063506405840546, 3100, 0.0038709677419354839)
270000 ('SubB', 'gp120-seq-align', 'R5', 410, 16, 0.056352537838851316, 0.044505935457677483, 3100, 0.004193548387096774)
275000 ('SubB', 'gp120-seq-align', 'X4', 104, 24, 0.48666298552892445, 0.54097606745780336, 100, 0.83999999999999997)
280000 ('SubB', 'gp120-seq-align', 'X4', 261, 0, 0.030026132337956271, 0.032016908208812478, 100, 0.23999999999999999)
285000 ('SubB', 'gp120-seq-align', 'X4', 417, 8, 0.14205589002903435, 0.1693205216804933, 100, 0.77000000000000002)
290000 ('SubB', 'gp41-seq-align', 'All', 111, 16, 0.013659837476267644, 0.0062442887306795973, 3100, 0.004193548387096774)
295000 ('SubB', 'gp41-seq-align', 'All', 267, 24, 0.10480562592481799, 0.043443547105319717, 6300, 0.0)
300000 ('SubB', 'gp41-seq-align', 'R5', 111, 0, 0.0058891074848446377, 0.0014355265945984346, 700, 0.015714285714285715)
305000 ('SubB', 'gp41-seq-align', 'R5', 267, 8, 0.041692980772599614, 0.039710296741770977, 100, 0.31)
310000 ('SubB', 'gp41-seq-align', 'X4', 110, 16, 0.20850381604363116, 0.15090267406000488, 300, 0.073333333333333334)
315000 ('SubB', 'gp41-seq-align', 'X4', 266, 24, 0.2719760551746494, 0.28636809169812205, 100, 0.59999999999999998)
320000 ('SubB', 'Nef-seq-align', 'All', 110, 0, 0.00058105987228330037, 0.0016513249745187897, 100, 0.32000000000000001)
325000 ('SubB', 'Nef-seq-align', 'R5', 61, 8, 0.0032982688867368641, 0.010599003552834102, 100, 0.93000000000000005)
330000 ('SubB', 'Nef-seq-align', 'X4', 12, 16, 0.0, 0.0, 6300, 0.0)
335000 ('SubB', 'Nef-seq-align', 'X4', 168, 24, 0.33040279636726444, 0.17547482309188644, 1500, 0.015333333333333332)
340000 ('SubB', 'LTR-seq-align', 'All', 120, 0, 0.013015318209610311, 0.0075152028946387757, 100, 0.12)
345000 ('SubB', 'LTR-seq-align', 'All', 276, 8, 0.0026874289688624143, 0.0065471013390024637, 100, 0.27000000000000002)
350000 ('SubB', 'LTR-seq-align', 'All', 432, 16, 0.024244141376697954, 0.030877375806585533, 100, 0.54000000000000004)
355000 ('SubB', 'LTR-seq-align', 'All', 588, 24, 0.031951132312543901, 0.048719540831257714, 100, 0.79000000000000004)
360000 ('SubB', 'LTR-seq-align', 'R5', 115, 0, 0.029783352908309399, 0.016567162457042615, 300, 0.043333333333333335)
365000 ('SubB', 'LTR-seq-align', 'R5', 271, 8, 0.0064158320286882712, 0.022027997175770246, 100, 0.71999999999999997)
370000 ('SubB', 'LTR-seq-align', 'R5', 427, 16, 0.026592210799642076, 0.016899527638579132, 100, 0.17999999999999999)
375000 ('SubB', 'LTR-seq-align', 'R5', 583, 24, 0.040966328844945006, 0.049426991567078984, 100, 0.55000000000000004)

In [52]:
def fix_fun(row):
    trow = [row[0], row[1], int(row[2]), int(row[3]), 
            float(row[4]), float(row[5]), float(row[6]), float(row[7])]
    return trow

tdf = DataFrame(results, columns = ['Subtype', 'Prot', 'Group', 'TPos', 'Vpos', 'MI', 'nMI', 'Count', 'Pval'])
print tdf.head()


  Subtype             Prot Group  TPos  Vpos        MI       nMI  Count  Pval
0   SubBC  gp120-seq-align   All     0     0  0.007524  0.006354    100  0.20
1   SubBC  gp120-seq-align   All     0     1  0.044746  0.019317   6300  0.00
2   SubBC  gp120-seq-align   All     0     2  0.007105  0.007187    100  0.51
3   SubBC  gp120-seq-align   All     0     3  0.010729  0.010631    100  0.47
4   SubBC  gp120-seq-align   All     0     4  0.050342  0.017498   6300  0.00

In [71]:
len(align_data['LTR-seq-align'].dropna().values[0])


Out[71]:
630

In [108]:
tregions = [('gp120', 462),
            ('gp41', 313),
            ('Nef', 205),
            ('LTR', 630)]
nregions = []
mapper = {}
endpos = 0
for prot, size in tregions:
    mapper[prot] = endpos
    nregions.append((prot, endpos, endpos+size, 0, size, 1, 'product'))
    endpos += size + 50

nregions.append(('V3', 267, 302, 0, 32, 2, 'ROI'))
mapper['V3'] = 267
regions = DataFrame(nregions, columns = ['Region_name', 'Genome_Start', 'Genome_Stop', 'Gene_AA_Start', 'Gene_AA_Stop', 'Frame', 'RegionType'])
regions = regions.set_index('Region_name')
print regions, endpos


<class 'pandas.core.frame.DataFrame'>
Index: 5 entries, gp120 to V3
Data columns:
Genome_Start     5  non-null values
Genome_Stop      5  non-null values
Gene_AA_Start    5  non-null values
Gene_AA_Stop     5  non-null values
Frame            5  non-null values
RegionType       5  non-null values
dtypes: int64(5), object(1) 1810

In [185]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.path import Path
from math import pi, sin, cos
from random import randint, shuffle, seed
import pylab
from itertools import cycle
import colorsys

class PyCircos(object):
    def __init__(self, regions, genome_size, figsize = (10,10), fig = None, ax = None):
        #a panda object keyed by the region name
        #it must have the following columns:
        # genome_start, genome_stop, frame
        # gene_name
        # aa_start, aa_stop
        # optional: color
        self.regions = regions 
        if ax:
            self.ax = ax
        if fig:
            self.figure = fig
        if fig is None and ax is None:
            self.figure = figure(figsize = figsize)
            self.ax = plt.gca()
            print 'made defaults'
            

        self.genome_size = genome_size
        self.radius = 10
        self.frame_offset = 0.5
        self._prep_circle()
    
    def _prep_circle(self):
        radius_size = self.radius
        circ=pylab.Circle((0,0),radius=radius_size, fc = 'w')
        lim = radius_size + 4.5*self.frame_offset
        self.ax.set_xlim(-lim,lim)
        self.ax.set_ylim(-lim,lim)
        self.ax.add_patch(circ)
    
    def _get_colors(self, num_colors):
        colors=[]
        seed(50) #so we get a consitent set of random colors!
        for i in np.arange(0., 360., 360. / num_colors):
            hue = i/360.
            lightness = (50 + np.random.rand() * 10)/100.
            saturation = (90 + np.random.rand() * 10)/100.
            colors.append(colorsys.hls_to_rgb(hue, lightness, saturation))
        return colors
    
    def add_labels(self, skip_labels = set(), with_text = True):

        radius_size = self.radius
        frame_offset = self.frame_offset
        #print self.genome_size
        roi_types = ['product', 'ROI']
        colors = self._get_colors(len(self.regions))
        shuffle(colors)
        colors = iter(colors)
        for ty in roi_types:
            tyregions = self.regions[self.regions['RegionType'] == ty]
            for region, row in tyregions.iterrows():
                if region == 'gp160':
                    continue
                if region in skip_labels:
                    _ = colors.next()
                    continue
                nrad = row['Frame']*frame_offset + radius_size
                theta1 = 360*(row['Genome_Start']/self.genome_size)
                theta2 = 360*(row['Genome_Stop']/self.genome_size)
                #print region, theta1, theta2, row
                arc_patch = patches.Arc((0,0), 2*nrad, 2*nrad, 
                                    theta1 = theta1, theta2 = theta2,
                                    edgecolor = colors.next(),
                                    linewidth = 10)
                self.ax.add_patch(arc_patch)
                midangle = (theta1 + abs(theta1-theta2)/2)*(2*pi/360)
                point = [(nrad+0.5)*cos(midangle), (nrad+0.5)*sin(midangle)]
                if with_text:
                    self.ax.text(point[0], point[1], region, fontsize = 18)
                
            
    
    def _make_path(self, source_item, target_item):
        rads_per_item = 2*pi/self.genome_size
        srad = source_item*rads_per_item
        trad = target_item*rads_per_item
        
        svert = [self.radius*cos(srad), self.radius*sin(srad)]
        tvert = [self.radius*cos(trad), self.radius*sin(trad)]
        
        verts = [svert, [0,0], tvert]
        codes = [Path.MOVETO,
                 Path.CURVE3,
                 Path.CURVE3]
        return verts, codes
    
    def add_raw_path(self, genome_start, genome_end, **kwargs):
        
        if 'facecolor' not in kwargs:
            kwargs['facecolor'] = 'none'
        if 'lw' not in kwargs:
            kwargs['lw'] = 2
            
        verts, codes = self._make_path(genome_start, genome_end)
        pt = Path(verts, codes)
        linepatch = patches.PathPatch(pt, **kwargs)
        self.ax.add_patch(linepatch)
    
    def add_raw_with_check(self, genome_start, region_start, genome_end, region_end, **kwargs):
        
        reg = self.regions
        startTF = genome_start >= reg['Genome_Start'][region_start] and genome_start <= reg['Genome_Stop'][region_start]
        endTF = genome_end >= reg['Genome_Start'][region_end] and genome_end <= reg['Genome_Stop'][region_end]
        if startTF and endTF:
            self.add_raw_path(genome_start, genome_end, **kwargs)
        
    
    def add_region_path(self, source_region_name, source_pos, target_region_name, target_pos):
        
        regions = self.regions
        sreg = regions.ix[source_region_name]
        treg = regions.ix[target_region_name]
        source_genome_start = sreg['Genome_Start']#+3*(source_pos-sreg['Gene_AA_Start'])
        target_genome_start = treg['Genome_Start']#+3*(target_pos-treg['Gene_AA_Start'])
        
        self.add_raw_path(source_genome_start, target_genome_start)

In [249]:
#['Subtype', 'Prot', 'Group', 'TPos', 'Vpos', 'MI', 'nMI', 'Count', 'Pval']

subtypes = [ 'SubBC', 'SubB', 'SubC']
prots = ['gp120', 'gp41', 'LTR', 'Nef']
color_map = {'All':'k', 'R5':'r', 'X4':'b'}
cuts = {'All':10, 'R5': 10, 'X4': 5}
ncuts = tdf['Group'].map(lambda x: cuts[x])
groups = ['All', 'R5', 'X4', 'Combined']

all_plots = []
wanted_rows = (tdf['Pval'] == 0) #& (tdf['MI'] > ncuts*tdf['nMI'])
for with_text in [False, True]:
    fig, axs = plt.subplots(3,4, figsize = (15*(5/4),15))
    ax_iter = iter(axs.flatten())
    for subtype in subtypes:
        plotted_args = []
        for group in groups:
            ax = ax_iter.next()
            circ_obj = PyCircos(regions, 1810, ax = ax, fig = fig)
    
            if group != 'Combined':
                gmask = (tdf['Group'] == group) & wanted_rows
                gmask &= tdf['Subtype'] == subtype
                tdata = tdf[gmask].sort('deltaMI').dropna()
                skip_labels = set()
                for prot in prots:
                    pdata = tdata[tdata['Prot'] == prot+'-seq-align']
                    pmask = pdata['deltaMI'].rank() > (len(pdata)-20)
                    if pmask.sum() == 0:
                        skip_labels.add(prot)
                    for _, row in pdata[pmask].iterrows():
                        Cpos = row['TPos'] + mapper[prot]
                        Tpos = row['Vpos'] + mapper['V3']
                        if (prot == 'gp120') and (Cpos > 265) and (Cpos < 305):
                            continue
                        edgecolor = color_map[row['Group']]
                        circ_obj.add_raw_path(Tpos, Cpos, edgecolor = edgecolor, alpha = 0.2)
                        plotted_args.append((Tpos, Cpos, edgecolor, group, subtype, prot))
                        all_plots.append((row['Vpos'], row['TPos'], edgecolor, group, subtype, prot))
            else:
                skip_labels = set()
                for Tpos, Cpos, edgecolor, group, subtype, prot in plotted_args:
                    circ_obj.add_raw_path(Tpos, Cpos, edgecolor = edgecolor, alpha = 0.2)
                #all_plots += plotted_args
                
            circ_obj.add_labels(skip_labels=skip_labels, with_text = with_text)
            ax.set_title(group + ' ' + subtype)
            ax.axis('off')
    fname = 'new_tropism_MI_links'
    if with_text:
        fname += '_with_text'
    plt.savefig(fname + '.png', dpi = 200)
    
plt.close()



In [251]:
rolling_count?

In [266]:
from pandas import rolling_sum
def get_group_count(indf):
    tregions = dict([('gp120', 462),
            ('gp41', 313),
            ('Nef', 205),
            ('LTR', 630)])
    nser = Series([0]*tregions[indf['Prot'].values[0]])
    #print nser
    nser[indf['Cpos']] = 1
    res = rolling_sum(nser, 20, min_periods=1)
    out = DataFrame({'RollingCount':res, 'Pos':range(tregions[indf['Prot'].values[0]])})
    out['Prot'] = indf['Prot'].values[0]
    out['Group'] = indf['Group'].values[0]
    out['Subtype'] = indf['Subtype'].values[0]
    return out
    


plotted_df = DataFrame(all_plots, columns = ['V3pos', 'Cpos', 'edgecolor', 'Group', 'Subtype', 'Prot']).sort(['Group', 'Subtype', 'Prot', 'Cpos'])
res = plotted_df.groupby(['Group', 'Subtype', 'Prot'], as_index = False).apply(get_group_count)

out = pivot_table(res, rows = ['Prot', 'Pos'], cols = ['Subtype', 'Group'], values = 'RollingCount').dropna(how = 'all')
wanted = (out>5).any(axis = 1)
out[wanted].to_csv('MIgroupings.csv')
#for key, group in :
#    print key, rolling_count(group['Cpos'], 5)

In [115]:
tdf['ProtName'] = tdf['Prot'].map(lambda x: x.split('-')[0])
res = pivot_table(tdf[wanted_rows], rows = ['ProtName', 'TPos'], cols = 'Group', values = 'Pval', aggfunc = 'min')
(res == 0).to_excel('mutual_info.xlsx')

In [150]:
from scipy.stats import fisher_exact
cols = ['gp120-bin-align',
        'gp41-bin-align',
        'Nef-bin-align',
        'LTR-bin-align']
masks = [('R5', align_data['Tropism'] == 'R5'),
          ('X4', align_data['Tropism'] == 'X4')]
subtypes = [('SubBC', (align_data['Subtype'] == 'SubB') | (align_data['Subtype'] == 'SubC')),  
                ('SubC', align_data['Subtype'] == 'SubC'),
                ('SubB', align_data['Subtype'] == 'SubB'), 
            ]

def yield_cols_for_fishers():
    for (subname, submask), col in product(subtypes, cols):
    
        tmp_data = align_data[[col, 'Tropism']][submask].dropna()
        if len(tmp_data.index) < 20:
            continue
        r5_trops = tmp_data['Tropism'] == 'R5'
        ovals = zip(*tmp_data[col].values)
        for pos, tup in enumerate(ovals):
            yield (subname, col.split('-')[0], pos), Series(tup, index = tmp_data.index), r5_trops
            
def process_fishers(intup):
    
    subtype, col, pos = intup[0]
    muts = intup[1]
    trops = intup[2]
    #print muts
    #print trops
    
    table = [[(muts & trops).sum(), (muts & ~trops).sum()],
             [(~muts & trops).sum(), (~muts & ~trops).sum()]]
    odds, pval = fisher_exact(table)
    return subtype, col, pos, odds, pval, trops.sum(), (~trops).sum()

with ProcessPoolExecutor(max_workers = 30) as e:
    fisher_res = list(e.map(process_fishers, yield_cols_for_fishers()))
fisher_df = DataFrame(fisher_res, columns = ['Subtype', 'Prot', 'Pos', 'Odds', 'Pval', 'R5Count', 'X4Count'])

In [134]:
fisher_df = fisher_df.dropna()

In [210]:
subtypes = ['SubB', 'SubC', 'SubBC', 'Combined']
prots = ['gp120', 'gp41', 'LTR', 'Nef']
color_map = {'SubB':'g', 'SubC':'r', 'SubBC':'b'}
fig, axs = plt.subplots(2,2, figsize = (10,10))
wanted_rows = (fisher_df['Pval'] <= 0.01)
for with_text in [False, True]:
    for subtype, ax in zip(subtypes, axs.flatten()):
        circ_obj = PyCircos(regions, 1810, ax = ax, fig = fig)
        if subtype != 'Combined':
            gmask = (fisher_df['Subtype'] == subtype) & wanted_rows
            skipm = (fisher_df['Subtype'] == subtype) & (fisher_df['X4Count'].fillna(0) > 0)
            skips = set(prots) - set(fisher_df['Prot'][skipm])
        else:
            gmask = wanted_rows.copy()
            skipm = fisher_df['X4Count'].fillna(0) >  0
            skips = set(prots) - set(fisher_df['Prot'][skipm])
        tdata = fisher_df[gmask].dropna()
    
        for _, row in tdata.iterrows():
            prot = row['Prot']
            Cpos = row['Pos']
            if (prot == 'gp120') and (Cpos < 302) and (Cpos > 267):
                continue
            Tpos = 20
            circ_obj.add_raw_path(Tpos+mapper['V3'],Cpos + mapper[prot], edgecolor = color_map[row['Subtype']], alpha = 0.2, lw = 3)
    
        circ_obj.add_labels(skip_labels = skips, with_text=with_text)    
        ax.set_title(subtype)
        ax.axis('off')
    fname = 'new_tropism_fisher_links'
    if with_text:
        fname += '_with_text'
    plt.savefig(fname + '.png', dpi = 200)
    
#plt.savefig('fishers_test_subtype.png')



In [173]:
from pandas import pivot_table
res = pivot_table(fisher_df, rows = ['Prot', 'Pos'], cols = 'Subtype', values = ['Pval', 'R5Count', 'X4Count'], aggfunc = np.mean)
print res.head(n = 30).to_string()


              Pval                  R5Count               X4Count             
Subtype       SubB     SubBC  SubC     SubB  SubBC  SubC     SubB  SubBC  SubC
Prot Pos                                                                      
LTR  0    0.564612  0.417293   NaN       91    102   NaN       15     16   NaN
     1    0.564612  0.417293   NaN       91    102   NaN       15     16   NaN
     2    0.564612  0.417293   NaN       91    102   NaN       15     16   NaN
     3    0.564612  0.417293   NaN       91    102   NaN       15     16   NaN
     4    1.000000  1.000000   NaN       91    102   NaN       15     16   NaN
     5    0.564612  0.417293   NaN       91    102   NaN       15     16   NaN
     6    0.564612  0.417293   NaN       91    102   NaN       15     16   NaN
     7    0.564612  0.417293   NaN       91    102   NaN       15     16   NaN
     8    1.000000  1.000000   NaN       91    102   NaN       15     16   NaN
     9    0.564612  0.417293   NaN       91    102   NaN       15     16   NaN
     10   0.564612  0.417293   NaN       91    102   NaN       15     16   NaN
     11   0.375578  0.273641   NaN       91    102   NaN       15     16   NaN
     12   0.568186  0.424841   NaN       91    102   NaN       15     16   NaN
     13   0.565074  0.419807   NaN       91    102   NaN       15     16   NaN
     14   1.000000  1.000000   NaN       91    102   NaN       15     16   NaN
     15   1.000000  0.578973   NaN       91    102   NaN       15     16   NaN
     16   1.000000  0.773317   NaN       91    102   NaN       15     16   NaN
     17   0.568186  0.424841   NaN       91    102   NaN       15     16   NaN
     18   0.568186  0.424841   NaN       91    102   NaN       15     16   NaN
     19   0.564612  0.593502   NaN       91    102   NaN       15     16   NaN
     20   0.249629  0.262332   NaN       91    102   NaN       15     16   NaN
     21   0.568186  0.424841   NaN       91    102   NaN       15     16   NaN
     22   0.755383  1.000000   NaN       91    102   NaN       15     16   NaN
     23   0.541109  0.241085   NaN       91    102   NaN       15     16   NaN
     24   1.000000  1.000000   NaN       91    102   NaN       15     16   NaN
     25   0.545040  0.405432   NaN       91    102   NaN       15     16   NaN
     26   0.401592  0.294268   NaN       91    102   NaN       15     16   NaN
     27   0.564612  0.417293   NaN       91    102   NaN       15     16   NaN
     28   0.401592  0.410725   NaN       91    102   NaN       15     16   NaN
     29   0.768808  0.585968   NaN       91    102   NaN       15     16   NaN

In [159]:
res['R5Count']['SubC'] = (res['R5Count']['SubBC']-res['R5Count']['SubB']).combine_first(res['R5Count']['SubC'])
res['X4Count']['SubC'] = (res['X4Count']['SubBC']-res['X4Count']['SubB']).combine_first(res['X4Count']['SubC'])

In [163]:
mask = (res['Pval'] < 0.01).any(axis = 1)
res[mask].to_csv('fishers_table.csv')

In [ ]: