Learning how to define a custom metric in sklearn so it can be used with KDTree which in turn for correlation function


In [23]:
from sklearn.neighbors import *
import numpy as np
import math as m

In [10]:


In [ ]:


In [2]:
dist = DistanceMetric.get_metric('euclidean')
X = [[0, 1, 2],[3, 4, 5]]

dist.pairwise(X)


Out[2]:
array([[ 0.        ,  5.19615242],
       [ 5.19615242,  0.        ]])

In [3]:
dist.dist_to_rdist(3)


Out[3]:
9

In [26]:
def myDistance(x,y):
    return np.sqrt(x**2+y**2)

In [27]:
distc=DistanceMetric.get_metric("pyfunc",func=myDistance)

In [28]:
distc.pairwise(X)


---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-28-056fb7651e76> in <module>()
----> 1 distc.pairwise(X)

sklearn/neighbors/dist_metrics.pyx in sklearn.neighbors.dist_metrics.DistanceMetric.pairwise (sklearn/neighbors/dist_metrics.c:5720)()

sklearn/neighbors/dist_metrics.pyx in sklearn.neighbors.dist_metrics.DistanceMetric.pdist (sklearn/neighbors/dist_metrics.c:5088)()

sklearn/neighbors/dist_metrics.pyx in sklearn.neighbors.dist_metrics.PyFuncDistance.dist (sklearn/neighbors/dist_metrics.c:11445)()

TypeError: Custom distance function must accept two vectors and return a float.

In [ ]:


In [41]:
class LCDMDistance(DistanceMetric):
    """LCDM Distance metric
       D(x, y) = \sqrt{ x^2+y^2-2xycos(theta) }
    """
    def __init__(self):
        return None
    
    def dist(x,y):
        return np.sqrt(x**2+y**2-2*x*y)
    
    def pairwise(x):
        return dist(x,x)

    def rdist_to_dist(self, rdist):
        return np.sqrt(rdist)

    def dist_to_rdist(self, dist):
        return dist ** 2

In [48]:
distlcdm=LCDMDistance.pairwise(X)


---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-48-1260f425983b> in <module>()
----> 1 distlcdm=LCDMDistance.pairwise(X)

TypeError: unbound method pairwise() must be called with LCDMDistance instance as first argument (got list instance instead)

In [47]:
distlcdm.pairwise(X)


---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-47-8b2bafa960d3> in <module>()
----> 1 distlcdm.pairwise(X)

AttributeError: 'function' object has no attribute 'pairwise'

In [62]:
def dist(x,y):
    return np.sqrt(np.sum(x**2+y**2-2*x*y))

In [63]:
x=[1,2,3]

In [51]:
KDTree.valid_metrics


Out[51]:
['chebyshev',
 'euclidean',
 'cityblock',
 'manhattan',
 'infinity',
 'minkowski',
 'p',
 'l2',
 'l1']

In [54]:
KDTree.data


Out[54]:
<attribute 'data' of 'sklearn.neighbors.kd_tree.BinaryTree' objects>

In [56]:
np.random.seed(0)
X = np.random.random((30, 3))
r = np.linspace(0, 1, 5)
tree = KDTree(X,metric="pyfunc",func=dist)                
tree.two_point_correlation(X, r)


---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-56-8477bf55e9bb> in <module>()
      2 X = np.random.random((30, 3))
      3 r = np.linspace(0, 1, 5)
----> 4 tree = KDTree(X,metric=lambda x,y:dist)
      5 tree.two_point_correlation(X, r)

sklearn/neighbors/binary_tree.pxi in sklearn.neighbors.kd_tree.BinaryTree.__init__ (sklearn/neighbors/kd_tree.c:9328)()

ValueError: metric PyFuncDistance is not valid for KDTree

In [58]:
BallTree.valid_metrics


Out[58]:
['chebyshev',
 'sokalmichener',
 'canberra',
 'haversine',
 'rogerstanimoto',
 'matching',
 'dice',
 'euclidean',
 'braycurtis',
 'russellrao',
 'cityblock',
 'manhattan',
 'infinity',
 'jaccard',
 'seuclidean',
 'sokalsneath',
 'kulsinski',
 'minkowski',
 'mahalanobis',
 'p',
 'l2',
 'hamming',
 'l1',
 'wminkowski',
 'pyfunc']

In [64]:
np.random.seed(0)
X = np.random.random((30, 3))
r = np.linspace(0, 1, 5)
tree = BallTree(X,metric="pyfunc",func=dist)                
tree.two_point_correlation(X, r)


Out[64]:
array([ 30,  62, 278, 580, 820])

In [ ]: