Working on this with help from the Neighbor Joining Wikipedia page and @justin212k's code. Goal is to create a scipy linkage matrix, so the interface matches SciPy's linkage interface.


In [1]:
from __future__ import division
from skbio.core.distance import DistanceMatrix

def compute_q(dm):
    q = zeros(dm.shape)
    n = dm.shape[0]
    for i in range(n):
        for j in range(i):
            q[i, j] = q[j, i] = ((n - 2) * dm[i, j]) - dm[i].sum() - dm[j].sum()
    return DistanceMatrix(q, dm.ids)

In [2]:
data = [[0,  5,  9,  9,  8],
        [5,  0, 10, 10,  9],
        [9, 10,  0,  8,  7],
        [9, 10,  8,  0,  3],
        [8,  9,  7,  3,  0]]
ids = list('abcde')
dm = DistanceMatrix(data, ids)

In [3]:
q = compute_q(dm)
print q


5x5 distance matrix
IDs:
a, b, c, d, e
Data:
[[  0. -50. -38. -34. -34.]
 [-50.   0. -38. -34. -34.]
 [-38. -38.   0. -40. -40.]
 [-34. -34. -40.   0. -48.]
 [-34. -34. -40. -48.   0.]]

In [4]:
def pair_members_to_new_node(dm, i, j, dissallow_negative_branch_length):
    n = dm.shape[0]
    i_to_j = dm[i, j]
    i_to_u = (0.5 * i_to_j) + (1 / (2 * (n - 2))) * (dm[i].sum() - dm[j].sum())
    j_to_u = i_to_j - i_to_u
    
    if dissallow_negative_branch_length and i_to_u < 0:
        i_to_u = 0
    if dissallow_negative_branch_length and j_to_u < 0:
        j_to_u = 0
        
    return i_to_u, j_to_u

In [6]:
pair_members_to_new_node(dm, 0, 1, True)


Out[6]:
(2.0, 3.0)

In [7]:
def otu_to_new_node(dm, i, j, k, dissallow_negative_branch_length):
    k_to_u = 0.5 * (dm[i, k] + dm[j, k] - dm[i, j])
    
    if dissallow_negative_branch_length and k_to_u < 0:
        k_to_u = 0
    
    return k_to_u

In [8]:
otu_to_new_node(dm, 0, 1, 2, True)


Out[8]:
7.0

In [9]:
import numpy as np

def lowest_index(dm):
    lowest_value = np.inf
    for i in range(dm.shape[0]):
        for j in range(i):
            curr_index = i, j
            curr_value = dm[curr_index]
            if curr_value < lowest_value:
                lowest_value = curr_value
                lowest_index = curr_index
    return lowest_index

In [11]:
def compute_collapsed_dm(dm, i, j, dissallow_negative_branch_length, new_node_id=None):
    in_n = dm.shape[0]
    out_n = in_n - 1
    new_node_id = new_node_id or "(%s, %s)" % (i, j)
    out_ids = [new_node_id]
    out_ids.extend([e for e in dm.ids if e not in (i, j)])
    result = zeros((out_n, out_n))
    for idx1, out_id1 in enumerate(out_ids[1:]):
        result[0, idx1 + 1] = result[idx1 + 1, 0] = \
         otu_to_new_node(dm, i, j, out_id1, dissallow_negative_branch_length)
        for idx2, out_id2 in enumerate(out_ids[1:idx1+1]):
            result[idx1+1, idx2+1] = result[idx2+1, idx1+1] = dm[out_id1, out_id2]
    return DistanceMatrix(result, out_ids)

In [17]:
from skbio.core.tree import TreeNode

def nj(dm, dissallow_negative_branch_length=True,
       result_constructor=TreeNode.from_newick):
    while(dm.shape[0] > 3):
        q = compute_q(dm)
        idx1, idx2 = lowest_index(q)
        pair_member_1 = dm.ids[idx1]
        pair_member_2 = dm.ids[idx2]
        pair_member_1_len, pair_member_2_len = pair_members_to_new_node(dm, 
                                                                        idx1,
                                                                        idx2,
                                                                        dissallow_negative_branch_length)
        node_definition = "(%s:%d, %s:%d)" % (pair_member_1,
                                              pair_member_1_len,
                                              pair_member_2,
                                              pair_member_2_len)
        dm = compute_collapsed_dm(dm, pair_member_1, pair_member_2, 
                                  dissallow_negative_branch_length=dissallow_negative_branch_length,
                                  new_node_id=node_definition)
    
    # Define the last node - this is an internal node
    pair_member_1 = dm.ids[1]
    pair_member_2 = dm.ids[2]
    pair_member_1_len, pair_member_2_len = pair_members_to_new_node(dm,
                                                                    pair_member_1,
                                                                    pair_member_2,
                                                                    dissallow_negative_branch_length)
    internal_len = otu_to_new_node(dm, pair_member_1, pair_member_2, node_definition,
                                   dissallow_negative_branch_length=dissallow_negative_branch_length, 
                                   )
    newick = "(%s:%d, %s:%d, %s:%d);" % (pair_member_1, pair_member_1_len,
                                         node_definition, internal_len,
                                         pair_member_2, pair_member_2_len)

    return result_constructor(newick)

In [20]:
newick_str = nj(dm, result_constructor=str)
print newick_str


(d:2, (c:4, (b:3, a:2):3):2, e:1);

In [19]:



          /-d
         |
         |          /-c
         |---------|
---------|         |          /-b
         |          \--------|
         |                    \-a
         |
          \-e

In [ ]: