In [1]:
import numpy as np
import pyGPs
import sklearn
from sklearn.model_selection import train_test_split
from sklearn.metrics.pairwise import rbf_kernel, laplacian_kernel

In [80]:
def compute_gp_regression(X_train, y_train, X_test):
    model = pyGPs.GPR()
    m = pyGPs.mean.Const()
    k = pyGPs.cov.RBF()
    model.setPrior(mean=m, kernel=k)
    model.optimize(X_train, y_train)
    print('Optimized negative log marginal likelihood:', round(model.nlZ,3))
    y_pred, _, _, _, _ = model.predict(X_test)
    return y_pred

def HSIC_d(X, Y, kernel='exponential'):
    n = len(X)

    if kernel == 'exponential':
        apply_kernel = rbf_kernel
    elif kernel == 'laplacian':
        apply_kernel = laplacian_kernel
    K = apply_kernel(X.reshape(-1, 1))
    L = apply_kernel(Y.reshape(-1, 1))
    
    H = np.eye(n) - np.ones((n, n)) * (1.0 / n)
    return ((n - 1) ** -2) * np.trace(np.dot(np.dot(np.dot(K, H), L), H))


def ANM_algorithm(X,y):
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)
    leakage_prob = dict()
    
    for col in range(X_train.shape[1]):
        
        x_train_column = X_train[:,col]
        x_test_column = X_test[:,col]
        
        print x_train_column.shape, y_train.shape, x_test_column.shape
        
        y_pred = compute_gp_regression(x_train_column, y_train, x_test_column)
        x_pred = compute_gp_regression(y_train, x_train_column, y_test)
        print 'y_pred shape', y_pred.shape
        print 'x_pred shape', x_pred.shape
        
        y_residuals = y_test - y_pred.ravel() # esto no deberia ser absolute value?
        x_residuals = x_test_column - x_pred.ravel()
        
        print y_residuals.shape
        print x_test_column.shape
        
        print x_residuals.shape
        print y_test.shape
        
        
        HSIC_x_to_y = HSIC_d(x_test_column, y_residuals)
        HSIC_y_to_x = HSIC_d(y_test, x_residuals)
        
        
        diff_HSIC = HSIC_x_to_y - HSIC_y_to_x
    
        
        leakage_prob[diff_HSIC] = col
    
    keys = leakage_prob.keys()
    keys.sort(reverse=True)
    for key in keys:
        print "The probability of column: " + str(leakage_prob[key]) + " is: " + str(key)

In [60]:
import pandas as pd
pairs = pd.read_csv('data/pair0039.txt', sep=' ', header=None)
pairs.columns = ['X', 'Y']

In [68]:
x = np.array(pairs)[:,0].reshape(-1,1)
y = np.array(pairs)[:,1]

In [69]:
ANM_algorithm(x,y)


Number of line searches 40
('Optimized negative log marginal likelihood:', 2435.659)
Number of line searches 40
('Optimized negative log marginal likelihood:', 2435.659)
y_pred shape (131, 1)
x_pred shape (131, 1)
(131,)
(131,)
(131,)
(131,)
(131,)
The probability of column: 0 is: -0.00114255125667

In [51]:
from sklearn.datasets import load_boston
boston = load_boston()

In [91]:
boston.feature_names


Out[91]:
array(['CRIM', 'ZN', 'INDUS', 'CHAS', 'NOX', 'RM', 'AGE', 'DIS', 'RAD',
       'TAX', 'PTRATIO', 'B', 'LSTAT'], 
      dtype='|S7')

In [90]:
boston.DESCR


Out[90]:
"Boston House Prices dataset\n===========================\n\nNotes\n------\nData Set Characteristics:  \n\n    :Number of Instances: 506 \n\n    :Number of Attributes: 13 numeric/categorical predictive\n    \n    :Median Value (attribute 14) is usually the target\n\n    :Attribute Information (in order):\n        - CRIM     per capita crime rate by town\n        - ZN       proportion of residential land zoned for lots over 25,000 sq.ft.\n        - INDUS    proportion of non-retail business acres per town\n        - CHAS     Charles River dummy variable (= 1 if tract bounds river; 0 otherwise)\n        - NOX      nitric oxides concentration (parts per 10 million)\n        - RM       average number of rooms per dwelling\n        - AGE      proportion of owner-occupied units built prior to 1940\n        - DIS      weighted distances to five Boston employment centres\n        - RAD      index of accessibility to radial highways\n        - TAX      full-value property-tax rate per $10,000\n        - PTRATIO  pupil-teacher ratio by town\n        - B        1000(Bk - 0.63)^2 where Bk is the proportion of blacks by town\n        - LSTAT    % lower status of the population\n        - MEDV     Median value of owner-occupied homes in $1000's\n\n    :Missing Attribute Values: None\n\n    :Creator: Harrison, D. and Rubinfeld, D.L.\n\nThis is a copy of UCI ML housing dataset.\nhttp://archive.ics.uci.edu/ml/datasets/Housing\n\n\nThis dataset was taken from the StatLib library which is maintained at Carnegie Mellon University.\n\nThe Boston house-price data of Harrison, D. and Rubinfeld, D.L. 'Hedonic\nprices and the demand for clean air', J. Environ. Economics & Management,\nvol.5, 81-102, 1978.   Used in Belsley, Kuh & Welsch, 'Regression diagnostics\n...', Wiley, 1980.   N.B. Various transformations are used in the table on\npages 244-261 of the latter.\n\nThe Boston house-price data has been used in many machine learning papers that address regression\nproblems.   \n     \n**References**\n\n   - Belsley, Kuh & Welsch, 'Regression diagnostics: Identifying Influential Data and Sources of Collinearity', Wiley, 1980. 244-261.\n   - Quinlan,R. (1993). Combining Instance-Based and Model-Based Learning. In Proceedings on the Tenth International Conference of Machine Learning, 236-243, University of Massachusetts, Amherst. Morgan Kaufmann.\n   - many more! (see http://archive.ics.uci.edu/ml/datasets/Housing)\n"

In [74]:
X = boston.data
y = boston.target

In [75]:
X.shape


Out[75]:
(506, 13)

In [76]:
y.shape


Out[76]:
(506,)

In [81]:
ANM_algorithm(X, y)


(339,) (339,) (167,)
Number of line searches 40
('Optimized negative log marginal likelihood:', 1208.048)
Number of line searches 40
('Optimized negative log marginal likelihood:', 1111.66)
y_pred shape (167, 1)
x_pred shape (167, 1)
(167,)
(167,)
(167,)
(167,)
(339,) (339,) (167,)
Number of line searches 40
('Optimized negative log marginal likelihood:', 1227.439)
Number of line searches 40
('Optimized negative log marginal likelihood:', 1510.418)
y_pred shape (167, 1)
x_pred shape (167, 1)
(167,)
(167,)
(167,)
(167,)
(339,) (339,) (167,)
Number of line searches 40
('Optimized negative log marginal likelihood:', 1189.604)
Number of line searches 40
('Optimized negative log marginal likelihood:', 1067.432)
y_pred shape (167, 1)
x_pred shape (167, 1)
(167,)
(167,)
(167,)
(167,)
(339,) (339,) (167,)
Number of line searches 40
('Optimized negative log marginal likelihood:', 1236.236)
Number of line searches 40
('Optimized negative log marginal likelihood:', 25.543)
y_pred shape (167, 1)
x_pred shape (167, 1)
(167,)
(167,)
(167,)
(167,)
(339,) (339,) (167,)
Number of line searches 40
('Optimized negative log marginal likelihood:', 1212.917)
Number of line searches 40
('Optimized negative log marginal likelihood:', -309.415)
y_pred shape (167, 1)
x_pred shape (167, 1)
(167,)
(167,)
(167,)
(167,)
(339,) (339,) (167,)
Number of line searches 40
('Optimized negative log marginal likelihood:', 1107.227)
Number of line searches 34
('Optimized negative log marginal likelihood:', 253.242)
y_pred shape (167, 1)
x_pred shape (167, 1)
(167,)
(167,)
(167,)
(167,)
(339,) (339,) (167,)
Number of line searches 40
('Optimized negative log marginal likelihood:', 1222.879)
Number of line searches 40
('Optimized negative log marginal likelihood:', 1557.402)
y_pred shape (167, 1)
x_pred shape (167, 1)
(167,)
(167,)
(167,)
(167,)
(339,) (339,) (167,)
Number of line searches 40
('Optimized negative log marginal likelihood:', 1231.582)
Number of line searches 40
('Optimized negative log marginal likelihood:', 702.502)
y_pred shape (167, 1)
x_pred shape (167, 1)
(167,)
(167,)
(167,)
(167,)
(339,) (339,) (167,)
Number of line searches 40
('Optimized negative log marginal likelihood:', 1207.465)
Number of line searches 40
('Optimized negative log marginal likelihood:', 1155.245)
y_pred shape (167, 1)
x_pred shape (167, 1)
(167,)
(167,)
(167,)
(167,)
(339,) (339,) (167,)
Number of line searches 40
('Optimized negative log marginal likelihood:', 1209.047)
Number of line searches 40
('Optimized negative log marginal likelihood:', 2141.443)
y_pred shape (167, 1)
x_pred shape (167, 1)
(167,)
(167,)
(167,)
(167,)
(339,) (339,) (167,)
Warning: adding jitter of 9.7612971734e+05 to diagnol of kernel matrix for numerical stability
Warning: adding jitter of 9.7612971734e+06 to diagnol of kernel matrix for numerical stability
Warning: adding jitter of 9.7612971734e+07 to diagnol of kernel matrix for numerical stability
Warning: adding jitter of 9.7612971734e+08 to diagnol of kernel matrix for numerical stability
Warning: adding jitter of 9.7612971734e+09 to diagnol of kernel matrix for numerical stability
Number of line searches 40
('Optimized negative log marginal likelihood:', 1194.508)
Number of line searches 40
('Optimized negative log marginal likelihood:', 707.207)
y_pred shape (167, 1)
x_pred shape (167, 1)
(167,)
(167,)
(167,)
(167,)
(339,) (339,) (167,)
Number of line searches 40
('Optimized negative log marginal likelihood:', 1219.053)
Warning: adding jitter of 4.7916475229e+12 to diagnol of kernel matrix for numerical stability
Warning: adding jitter of 4.7916475229e+13 to diagnol of kernel matrix for numerical stability
Warning: adding jitter of 4.7916475229e+14 to diagnol of kernel matrix for numerical stability
Warning: adding jitter of 4.7916475229e+15 to diagnol of kernel matrix for numerical stability
Warning: adding jitter of 4.7916475229e+16 to diagnol of kernel matrix for numerical stability
Number of line searches 40
('Optimized negative log marginal likelihood:', 1971.665)
y_pred shape (167, 1)
x_pred shape (167, 1)
(167,)
(167,)
(167,)
(167,)
(339,) (339,) (167,)
Number of line searches 40
('Optimized negative log marginal likelihood:', 1053.072)
Number of line searches 40
('Optimized negative log marginal likelihood:', 959.063)
y_pred shape (167, 1)
x_pred shape (167, 1)
(167,)
(167,)
(167,)
(167,)
The probability of column: 12 is: 0.00149305048352
The probability of column: 5 is: 9.52529158381e-05
The probability of column: 4 is: -5.05522560742e-06
The probability of column: 3 is: -1.16502517361e-05
The probability of column: 9 is: -0.000599836560431
The probability of column: 6 is: -0.000834835429529
The probability of column: 7 is: -0.000926073969679
The probability of column: 10 is: -0.00174040422681
The probability of column: 11 is: -0.00262653264801
The probability of column: 2 is: -0.00430590121497
The probability of column: 8 is: -0.00967054417892
The probability of column: 0 is: -0.0109522976416
The probability of column: 1 is: -0.0289051424721

In [85]:
X[:,12]


Out[85]:
array([  4.98,   9.14,   4.03,   2.94,   5.33,   5.21,  12.43,  19.15,
        29.93,  17.1 ,  20.45,  13.27,  15.71,   8.26,  10.26,   8.47,
         6.58,  14.67,  11.69,  11.28,  21.02,  13.83,  18.72,  19.88,
        16.3 ,  16.51,  14.81,  17.28,  12.8 ,  11.98,  22.6 ,  13.04,
        27.71,  18.35,  20.34,   9.68,  11.41,   8.77,  10.13,   4.32,
         1.98,   4.84,   5.81,   7.44,   9.55,  10.21,  14.15,  18.8 ,
        30.81,  16.2 ,  13.45,   9.43,   5.28,   8.43,  14.8 ,   4.81,
         5.77,   3.95,   6.86,   9.22,  13.15,  14.44,   6.73,   9.5 ,
         8.05,   4.67,  10.24,   8.1 ,  13.09,   8.79,   6.72,   9.88,
         5.52,   7.54,   6.78,   8.94,  11.97,  10.27,  12.34,   9.1 ,
         5.29,   7.22,   6.72,   7.51,   9.62,   6.53,  12.86,   8.44,
         5.5 ,   5.7 ,   8.81,   8.2 ,   8.16,   6.21,  10.59,   6.65,
        11.34,   4.21,   3.57,   6.19,   9.42,   7.67,  10.63,  13.44,
        12.33,  16.47,  18.66,  14.09,  12.27,  15.55,  13.  ,  10.16,
        16.21,  17.09,  10.45,  15.76,  12.04,  10.3 ,  15.37,  13.61,
        14.37,  14.27,  17.93,  25.41,  17.58,  14.81,  27.26,  17.19,
        15.39,  18.34,  12.6 ,  12.26,  11.12,  15.03,  17.31,  16.96,
        16.9 ,  14.59,  21.32,  18.46,  24.16,  34.41,  26.82,  26.42,
        29.29,  27.8 ,  16.65,  29.53,  28.32,  21.45,  14.1 ,  13.28,
        12.12,  15.79,  15.12,  15.02,  16.14,   4.59,   6.43,   7.39,
         5.5 ,   1.73,   1.92,   3.32,  11.64,   9.81,   3.7 ,  12.14,
        11.1 ,  11.32,  14.43,  12.03,  14.69,   9.04,   9.64,   5.33,
        10.11,   6.29,   6.92,   5.04,   7.56,   9.45,   4.82,   5.68,
        13.98,  13.15,   4.45,   6.68,   4.56,   5.39,   5.1 ,   4.69,
         2.87,   5.03,   4.38,   2.97,   4.08,   8.61,   6.62,   4.56,
         4.45,   7.43,   3.11,   3.81,   2.88,  10.87,  10.97,  18.06,
        14.66,  23.09,  17.27,  23.98,  16.03,   9.38,  29.55,   9.47,
        13.51,   9.69,  17.92,  10.5 ,   9.71,  21.46,   9.93,   7.6 ,
         4.14,   4.63,   3.13,   6.36,   3.92,   3.76,  11.65,   5.25,
         2.47,   3.95,   8.05,  10.88,   9.54,   4.73,   6.36,   7.37,
        11.38,  12.4 ,  11.22,   5.19,  12.5 ,  18.46,   9.16,  10.15,
         9.52,   6.56,   5.9 ,   3.59,   3.53,   3.54,   6.57,   9.25,
         3.11,   5.12,   7.79,   6.9 ,   9.59,   7.26,   5.91,  11.25,
         8.1 ,  10.45,  14.79,   7.44,   3.16,  13.65,  13.  ,   6.59,
         7.73,   6.58,   3.53,   2.98,   6.05,   4.16,   7.19,   4.85,
         3.76,   4.59,   3.01,   3.16,   7.85,   8.23,  12.93,   7.14,
         7.6 ,   9.51,   3.33,   3.56,   4.7 ,   8.58,  10.4 ,   6.27,
         7.39,  15.84,   4.97,   4.74,   6.07,   9.5 ,   8.67,   4.86,
         6.93,   8.93,   6.47,   7.53,   4.54,   9.97,  12.64,   5.98,
        11.72,   7.9 ,   9.28,  11.5 ,  18.33,  15.94,  10.36,  12.73,
         7.2 ,   6.87,   7.7 ,  11.74,   6.12,   5.08,   6.15,  12.79,
         9.97,   7.34,   9.09,  12.43,   7.83,   5.68,   6.75,   8.01,
         9.8 ,  10.56,   8.51,   9.74,   9.29,   5.49,   8.65,   7.18,
         4.61,  10.53,  12.67,   6.36,   5.99,   5.89,   5.98,   5.49,
         7.79,   4.5 ,   8.05,   5.57,  17.6 ,  13.27,  11.48,  12.67,
         7.79,  14.19,  10.19,  14.64,   5.29,   7.12,  14.  ,  13.33,
         3.26,   3.73,   2.96,   9.53,   8.88,  34.77,  37.97,  13.44,
        23.24,  21.24,  23.69,  21.78,  17.21,  21.08,  23.6 ,  24.56,
        30.63,  30.81,  28.28,  31.99,  30.62,  20.85,  17.11,  18.76,
        25.68,  15.17,  16.35,  17.12,  19.37,  19.92,  30.59,  29.97,
        26.77,  20.32,  20.31,  19.77,  27.38,  22.98,  23.34,  12.13,
        26.4 ,  19.78,  10.11,  21.22,  34.37,  20.08,  36.98,  29.05,
        25.79,  26.64,  20.62,  22.74,  15.02,  15.7 ,  14.1 ,  23.29,
        17.16,  24.39,  15.69,  14.52,  21.52,  24.08,  17.64,  19.69,
        12.03,  16.22,  15.17,  23.27,  18.05,  26.45,  34.02,  22.88,
        22.11,  19.52,  16.59,  18.85,  23.79,  23.98,  17.79,  16.44,
        18.13,  19.31,  17.44,  17.73,  17.27,  16.74,  18.71,  18.13,
        19.01,  16.94,  16.23,  14.7 ,  16.42,  14.65,  13.99,  10.29,
        13.22,  14.13,  17.15,  21.32,  18.13,  14.76,  16.29,  12.87,
        14.36,  11.66,  18.14,  24.1 ,  18.68,  24.91,  18.03,  13.11,
        10.74,   7.74,   7.01,  10.42,  13.34,  10.58,  14.98,  11.45,
        18.06,  23.97,  29.68,  18.07,  13.35,  12.01,  13.59,  17.6 ,
        21.14,  14.1 ,  12.92,  15.1 ,  14.33,   9.67,   9.08,   5.64,
         6.48,   7.88])

In [ ]: