In [1]:
%matplotlib inline

In [2]:
import networkx as nx
from skbio import TreeNode
from skbio.diversity.beta._unifrac import _validate
from skbio.diversity._base import _validate_counts_vectors


1.10

In [181]:
def _get_mst(graph, tips, reference_node):
    """Get the MST for an environment
    
    This method is algorithmically stupid, will not scale, and needs to be
    revisited.
    """
    # determine a plausible set of nodes that correspond to the environment
    nodes = set()
    for tip in tips:
        nodes.update(set(nx.shortest_path(graph, source=tip, target=reference_node)))
    
    # compute the MST (nx.minimum_spanning_tree is undefined for directed graphs)
    directed = graph.subgraph(nodes)
    mst = nx.minimum_spanning_tree(directed.to_undirected())
    
    # coerce the undirected MST back to a digraph
    m_e = mst.edges()
    E = [(i,o,w) for i,o,w in directed.edges_iter(data=True) if ((i,o) in m_e or (o,i) in m_e)]
    directed_mst = nx.DiGraph(E)
    
    return directed_mst

def _treenode_to_graph(tree):
    """Take a skbio TreeNode and convert it into a DiGraph"""
    g = nx.DiGraph()

    label_base = '_int-%d'
    tree.assign_ids()
    for node in tree.preorder(include_self=False):
        if node.name is None:
            node.name = label_base % node.id
        
        if node.parent.is_root():
            g.add_edge(node.name, 'root', weight=node.length)
        else:
            g.add_edge(node.name, node.parent.name, weight=node.length)
    return g

def unweighted_unifrac_uag(u_counts, v_counts, otu_ids, graph, validate=True):
    """Compute unweighted unifrac over a directed acyclic graph
    
    The computation relies on the presence of a root. But, in the general case of a DAG,
    calling this node a "root" does not make sense. Instead, the code labels this node as
    the reference_node. The intuition being that all comparisons are relative to this
    node. In the case of a skbio TreeNode being passed in, the method will use the root
    of the tree as the reference. The determination of a reference for a DAG is open, and
    right now, an arbitary node is picked. This is _not_ a good idea as it is not assured
    to be stable on pairwise operations and how this choice is made needs to be revisited.
    """
    # do the validation and conversion to a DAG as needed. This likely should be 
    # decomposed.
    if isinstance(graph, TreeNode):
        if validate:
            _validate(u_counts=u_counts, v_counts=v_counts,
                      otu_ids=otu_ids, tree=graph)
        graph = _treenode_to_graph(graph)
    else:
        if validate:
            _validate_counts_vectors(u_counts, v_counts, suppress_cast=True)
    
    # handle boundary cases
    if sum(u_counts) == 0 or sum(v_counts) == 0 or len(u_counts) == 0 or len(v_counts) == 0:
        if sum(u_counts) + sum(v_counts) == 0:
            return 0.0
        else:
            return 1.0
    
    # determine a reference point
    if graph.has_node('root'):
        reference_node = 'root'
    else:        
        # pick an arbitrary point of reference that is not an observation of interest.
        # this is not stable across method calls and needs to be revisited.
        reference_node = sorted([n for n in graph.nodes() if n not in otu_ids])[-1]
    
    # determine what IDs are represented in each environment
    u_ids = [i for u, i in zip(u_counts, otu_ids) if u]
    v_ids = [i for v, i in zip(v_counts, otu_ids) if v]

    # get all the minimum spanning tree for each environment
    u_sg = _get_mst(graph, u_ids, reference_node)
    v_sg = _get_mst(graph, v_ids, reference_node)
    
    # determine what edges are shared between environments and merge those 
    # networks
    shared_edges = set(u_sg.edges()).intersection(set(v_sg.edges()))
    shared_sg = graph.subgraph(set(flatten(shared_edges)))

    # determine the total length observed in the network 
    total_sg = graph.subgraph(set(u_sg.nodes()) | set(v_sg.nodes()))
    
    # compute unifrac
    return (total_sg.size(weight='weight') - shared_sg.size(weight='weight')) / total_sg.size(weight='weight')

In [180]:
from skbio.diversity.beta.tests.test_unifrac_base import StatsTests
from unittest import TestCase, TestLoader, TextTestRunner

class NetworkUniFracTests(StatsTests, TestCase):
    _method = {'unweighted_unifrac': unweighted_unifrac_uag,
               'weighted_unifrac': lambda w, x, y, z: 0.0}

suite = TestLoader().loadTestsFromTestCase(NetworkUniFracTests)
TextTestRunner(verbosity=1,stream=sys.stderr).run( suite )


F.........FEFF.EFEE.EFE
======================================================================
ERROR: test_weighted_normalized_root_not_observed (__main__.NetworkUniFracTests)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/Users/mcdonadt/ResearchWork/software/scikit-bio/skbio/diversity/beta/tests/test_unifrac_base.py", line 142, in test_weighted_normalized_root_not_observed
    self.oids2, self.t2, normalized=True)
TypeError: <lambda>() got an unexpected keyword argument 'normalized'

======================================================================
ERROR: test_weighted_unifrac_identity_normalized (__main__.NetworkUniFracTests)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/Users/mcdonadt/ResearchWork/software/scikit-bio/skbio/diversity/beta/tests/test_unifrac_base.py", line 499, in test_weighted_unifrac_identity_normalized
    self.b1[i], self.b1[i], self.oids1, self.t1, normalized=True)
TypeError: <lambda>() got an unexpected keyword argument 'normalized'

======================================================================
ERROR: test_weighted_unifrac_non_overlapping_normalized (__main__.NetworkUniFracTests)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/Users/mcdonadt/ResearchWork/software/scikit-bio/skbio/diversity/beta/tests/test_unifrac_base.py", line 517, in test_weighted_unifrac_non_overlapping_normalized
    self.b1[4], self.b1[5], self.oids1, self.t1, normalized=True)
TypeError: <lambda>() got an unexpected keyword argument 'normalized'

======================================================================
ERROR: test_weighted_unifrac_normalized (__main__.NetworkUniFracTests)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/Users/mcdonadt/ResearchWork/software/scikit-bio/skbio/diversity/beta/tests/test_unifrac_base.py", line 550, in test_weighted_unifrac_normalized
    self.b1[0], self.b1[1], self.oids1, self.t1, normalized=True)
TypeError: <lambda>() got an unexpected keyword argument 'normalized'

======================================================================
ERROR: test_weighted_unifrac_symmetry_normalized (__main__.NetworkUniFracTests)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/Users/mcdonadt/ResearchWork/software/scikit-bio/skbio/diversity/beta/tests/test_unifrac_base.py", line 508, in test_weighted_unifrac_symmetry_normalized
    normalized=True)
TypeError: <lambda>() got an unexpected keyword argument 'normalized'

======================================================================
ERROR: test_weighted_unifrac_zero_counts_normalized (__main__.NetworkUniFracTests)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/Users/mcdonadt/ResearchWork/software/scikit-bio/skbio/diversity/beta/tests/test_unifrac_base.py", line 532, in test_weighted_unifrac_zero_counts_normalized
    normalized=True)
TypeError: <lambda>() got an unexpected keyword argument 'normalized'

======================================================================
FAIL: test_invalid_input (__main__.NetworkUniFracTests)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/Users/mcdonadt/ResearchWork/software/scikit-bio/skbio/diversity/beta/tests/test_unifrac_base.py", line 184, in test_invalid_input
    v_counts, otu_ids, t)
AssertionError: DuplicateNodeError not raised by <lambda>

======================================================================
FAIL: test_weighted_minimal_trees (__main__.NetworkUniFracTests)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/Users/mcdonadt/ResearchWork/software/scikit-bio/skbio/diversity/beta/tests/test_unifrac_base.py", line 95, in test_weighted_minimal_trees
    self.assertEqual(actual, expected)
AssertionError: 0.0 != 0.25

======================================================================
FAIL: test_weighted_root_not_observed (__main__.NetworkUniFracTests)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/Users/mcdonadt/ResearchWork/software/scikit-bio/skbio/diversity/beta/tests/test_unifrac_base.py", line 128, in test_weighted_root_not_observed
    self.assertAlmostEqual(actual, expected)
AssertionError: 0.0 != 0.15 within 7 places

======================================================================
FAIL: test_weighted_unifrac (__main__.NetworkUniFracTests)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/Users/mcdonadt/ResearchWork/software/scikit-bio/skbio/diversity/beta/tests/test_unifrac_base.py", line 434, in test_weighted_unifrac
    self.assertAlmostEqual(actual, expected)
AssertionError: 0.0 != 2.4 within 7 places

======================================================================
FAIL: test_weighted_unifrac_non_overlapping (__main__.NetworkUniFracTests)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/Users/mcdonadt/ResearchWork/software/scikit-bio/skbio/diversity/beta/tests/test_unifrac_base.py", line 408, in test_weighted_unifrac_non_overlapping
    self.assertAlmostEqual(actual, expected)
AssertionError: 0.0 != 4.0 within 7 places

======================================================================
FAIL: test_weighted_unifrac_zero_counts (__main__.NetworkUniFracTests)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/Users/mcdonadt/ResearchWork/software/scikit-bio/skbio/diversity/beta/tests/test_unifrac_base.py", line 421, in test_weighted_unifrac_zero_counts
    self.assertAlmostEqual(actual, expected)
AssertionError: 0.0 != 2.0 within 7 places

----------------------------------------------------------------------
Ran 23 tests in 0.370s

FAILED (failures=6, errors=6)
Out[180]:
<unittest.runner.TextTestResult run=23 errors=6 failures=6>

In [7]:
def _get_all_nodes_weighted(graph, ids, weights):
    nodes = []
    weights = []
    for i in zip(range(len) - 1):
        for node in nx.shortest_path(graph, source=ids[i], target=ids[i+1]):
            nodes.append(node)
            weights.append(weights[])
        

def weighted_unifrac_uag(u_counts, v_counts, otu_ids, graph):
    u_ids = [i for u, i in zip(u_counts, otu_ids) if u]
    u_w   = [u for u, i in zip(u_counts, otu_ids) if u]
    v_ids = [i for v, i in zip(v_counts, otu_ids) if v]
    v_w   = [v for v, i in zip(v_counts, otu_ids) if v]
    
    u_nodes, u_weights = _get_all_nodes_weighted(graph, u_ids, u_w)
    v_nodes, v_weights = _get_all_nodes_weighted(graph, v_ids, v_w)

In [144]:
import networkx as nx
G1 = nx.Graph()
G1.add_path([0,1,2,3])
G2 = nx.Graph()
G2.add_path([1,2,3,4])
set(G1.edges()).intersection(G2.edges())


Out[144]:
{(1, 2), (2, 3)}

In [164]:
?nx.minimum_spanning_tree

In [ ]: