In [19]:
import bhc
import numpy as np
import matplotlib.pyplot as plt
from scipy.special import gamma
from scipy.cluster.hierarchy import dendrogram, linkage
In [58]:
import pandas as pd
In [160]:
def Bern_gen(nobs, k, theta, seed):
"""Generate Bernoulli distributed data"""
np.random.seed(seed)
obs_list = []
theta_list = (np.repeat(theta,nobs))
theta_list[:int(nobs/3)] = np.repeat(theta-0.3, int(nobs/3))
theta_list[-int(nobs/3):] = np.repeat(theta+0.3, int(nobs/3))
for i in range(nobs):
X_i = np.random.binomial(1, theta_list[i], k)
obs_list.append(X_i)
return np.matrix(obs_list)
In [164]:
def purity_score(linkage_matrix, y_test, repeats, seed):
"""Compute the expected dendrogram purity.
Sample a leaf uniformly at random. Then sample another leaf from the same
true class uniformly at random. Find their lowest common ancestor in the
tree and compute purity with respect to that class.
return purity_score
"""
np.random.seed(seed)
purity = 0
N = len(y_test)
for i in range(repeats):
class_test = np.random.choice(y_test, 1)[0]
leaf1, leaf2 = np.random.choice(np.arange(N)[np.array(y_test)==class_test], size=2, replace=None)
LL = [[item] for item in range(N)]
for j in range(linkage_matrix.shape[0]):
p, q = int(linkage_matrix[j][0]), int(linkage_matrix[j][1])
LL.append(LL[p]+LL[q])
common_ancestor = [item for item in LL if leaf1 in item and leaf2 in item][0]
predict_label = np.array(y_test)[common_ancestor]
purity += sum(predict_label==y_test[leaf1]) / len(predict_label)
return purity / repeats
In [24]:
BHC_test = np.array(bhc.bhclust_BB(X_test)[0])
single_test = linkage(X_test,method='single')
complete_test = linkage(X_test,method='complete')
average_test = linkage(X_test,method='average')
print("BHC_test:", round(purity_score(BHC_test, y_test, 'A', 5, 12),3), purity_score(BHC_test, y_test, 'B', 5, 12), purity_score(BHC_test, y_test, 'C', 5, 12))
print("Single_linkage:", purity_score(single_test, y_test, 'A', 5, 12), purity_score(single_test, y_test, 'B', 5, 12), purity_score(single_test, y_test, 'C', 5, 12))
print("Complete_linkage:", purity_score(complete_test, y_test, 'A', 5, 12), purity_score(complete_test, y_test, 'B', 5, 12), purity_score(complete_test, y_test, 'C', 5, 12))
print("Average_linkage:", purity_score(average_test, y_test, 'A', 5, 12), purity_score(average_test, y_test, 'B', 5, 12), purity_score(average_test, y_test, 'C', 5, 12))
In [170]:
mdat = np.array([[ 0.93637874, 1.61258974],
[ 1.95192875, 2.84452075],
[ 2.07671748, 3.24442548],
[ 3.122903 , 4.516753 ],
[ 3.56202194, 5.17531994],
[ 3.53211875, 5.75857675],
[ 4.65794237, 6.66995537],
[ 5.83738797, 8.46562797],
[ 6.22595817, 9.28082817],
[ 6.51552067, 9.36110867],
[ 7.24619975, 3.68958775],
[ 6.50554148, 3.69771048],
[ 6.58213752, 4.31283952],
[ 6.02279742, 4.52753342],
[ 5.83280398, 4.85751598],
[ 5.12305078, 4.76874878],
[ 5.0430706 , 5.2911986 ],
[ 2.44081699, 6.35402999]])
In [172]:
mdat_y = list(np.repeat('A',10))
mdat_y.extend(list(np.repeat('B',8)))
In [174]:
Z = bhc.bhclust(mdat, family = "multivariate", alpha = 1, r = 0.001)[0]
Z = np.array(Z)
single_test = linkage(mdat,method='single')
complete_test = linkage(mdat,method='complete')
average_test = linkage(mdat,method='average')
print("BHC_test:", round(purity_score(Z, mdat_y, 5, 12),3))
print("Single_linkage:", round(purity_score(single_test, mdat_y, 5, 12),3))
print("Complete_linkage:", round(purity_score(complete_test, mdat_y, 5, 12),3))
print("Average_linkage:", round(purity_score(average_test, mdat_y, 5, 12),3))
Reference:
Aggregation dataset: A. Gionis, H. Mannila, and P. Tsaparas, Clustering aggregation. ACM Transactions on Knowledge Discovery from Data (TKDD), 2007. 1(1): p. 1-30.
Spiral dataset: H. Chang and D.Y. Yeung, Robust path-based spectral clustering. Pattern Recognition, 2008. 41(1): p. 191-203.
In [175]:
multivariate_test = pd.read_table("/Users/lina/Downloads/Aggregation.txt", names=['X1','X2','class'])
multivariate_test_spiral = pd.read_table("/Users/lina/Downloads/spiral.txt", names=['X1','X2','class'])
In [177]:
mvn_y = multivariate_test.iloc[:,-1]
mvn_X = multivariate_test.ix[:,:2]
mvn_y_spiral = multivariate_test_spiral.iloc[:,-1]
mvn_X_spiral = multivariate_test_spiral.ix[:,:2]
In [ ]:
Z = bhc.bhclust(np.array(mvn_X), family = "multivariate", alpha = 1, r = 0.001)[0]
#Z_spiral = bhc.bhclust(np.array(mvn_X_spiral), family = "multivariate", alpha = 1, r = 0.001)[0]
In [182]:
Z_spiral = np.array([[208, 209, 0.25464380105923845, 2],
[312, 210, 0.46200922472289818, 3],
[313, 211, 0.65792654490843439, 4],
[314, 212, 0.85214870893944072, 5],
[315, 213, 1.0474374513171529, 6],
[316, 214, 1.2401314306850928, 7],
[317, 215, 1.4290644232511895, 8],
[318, 216, 1.6167719628504882, 9],
[319, 207, 1.8022719248004808, 10],
[320, 217, 1.9880160854245341, 11],
[321, 218, 2.1718080026801947, 12],
[322, 219, 2.3535387353956598, 13],
[323, 220, 2.5357948210127699, 14],
[324, 221, 2.7183498634980636, 15],
[325, 222, 2.9017376562581312, 16],
[326, 223, 3.0907544545637697, 17],
[327, 224, 3.2835407243264756, 18],
[328, 225, 3.4787545391292647, 19],
[329, 226, 3.6820322935316341, 20],
[330, 227, 3.8956402344230563, 21],
[331, 228, 4.1153078327050796, 22],
[332, 43, 4.3396699231949301, 23],
[333, 44, 4.5486624720656312, 24],
[334, 42, 4.7464192002386101, 25],
[335, 45, 4.9365581738472208, 26],
[336, 41, 5.1199929291513797, 27],
[337, 40, 5.2977606349866289, 28],
[338, 46, 5.4776216603248136, 29],
[339, 39, 5.6549063376429283, 30],
[340, 47, 5.8304496434687749, 31],
[341, 38, 6.0052000241049424, 32],
[342, 37, 6.1827191806557407, 33],
[343, 48, 6.3587118511316882, 34],
[344, 49, 6.5378099770116558, 35],
[345, 36, 6.7185324918892055, 36],
[346, 35, 6.9036638728001432, 37],
[347, 50, 7.0907838886857864, 38],
[348, 229, 7.2801763538373381, 39],
[349, 51, 7.4677772977308505, 40],
[350, 230, 7.6574399478054431, 41],
[351, 52, 7.8473568829173059, 42],
[352, 231, 8.0386787641377513, 43],
[353, 53, 8.2296605046772306, 44],
[354, 232, 8.4222441218560746, 45],
[355, 54, 8.6127260388388311, 46],
[356, 233, 8.8091782774832819, 47],
[357, 55, 9.0042168473094364, 48],
[358, 234, 9.2025562622563744, 49],
[359, 56, 9.3993575041259003, 50],
[360, 235, 9.5972882192286413, 51],
[361, 57, 9.7925316206515998, 52],
[362, 236, 9.9963483747536692, 53],
[363, 58, 10.194643338846785, 54],
[364, 59, 10.403849334855531, 55],
[365, 237, 10.606966238027683, 56],
[366, 60, 10.818012129426682, 57],
[367, 238, 11.022906901719029, 58],
[368, 34, 11.238431919832287, 59],
[369, 61, 11.455127224528686, 60],
[370, 239, 11.666447654540384, 61],
[371, 106, 11.87486178570534, 62],
[372, 240, 12.083548083852753, 63],
[373, 62, 12.286670947503465, 64],
[374, 107, 12.487047629817287, 65],
[375, 241, 12.693687445314367, 66],
[376, 108, 12.895298719628491, 67],
[377, 63, 13.0948951765226, 68],
[378, 109, 13.301127749813435, 69],
[379, 242, 13.50013297263771, 70],
[380, 64, 13.698624125572444, 71],
[381, 243, 13.904419825771392, 72],
[382, 110, 14.10436220748764, 73],
[383, 65, 14.303817818684839, 74],
[384, 111, 14.510206415919098, 75],
[385, 244, 14.711803817291187, 76],
[386, 66, 14.912086865933183, 77],
[387, 184, 15.116000418791536, 78],
[388, 185, 15.318083296125637, 79],
[389, 187, 15.519199656525453, 80],
[390, 186, 15.718407058789058, 81],
[391, 188, 15.915786318715117, 82],
[392, 189, 16.111777858209962, 83],
[393, 183, 16.306838055039179, 84],
[394, 190, 16.500657624436471, 85],
[395, 67, 16.695531436066762, 86],
[396, 191, 16.889371473122157, 87],
[397, 182, 17.082618930413844, 88],
[398, 192, 17.274546172851583, 89],
[399, 181, 17.465607029314985, 90],
[400, 245, 17.65874172242804, 91],
[401, 193, 17.852860792011416, 92],
[402, 194, 18.04556492896484, 93],
[403, 68, 18.237324653586406, 94],
[404, 180, 18.428820995289747, 95],
[405, 179, 18.621090128133329, 96],
[406, 195, 18.812707691326366, 97],
[407, 112, 19.00526495099378, 98],
[408, 196, 19.197344649457349, 99],
[409, 246, 19.38868605556981, 100],
[410, 69, 19.579676522439613, 101],
[411, 113, 19.773221402900422, 102],
[412, 197, 19.966177545304568, 103],
[413, 178, 20.159480686657847, 104],
[414, 70, 20.352871099079472, 105],
[415, 198, 20.545938699533586, 106],
[416, 247, 20.737917909306216, 107],
[417, 114, 20.933474333830262, 108],
[418, 71, 21.128129483817531, 109],
[419, 248, 21.320990525569673, 110],
[420, 199, 21.513335515675845, 111],
[421, 200, 21.708664035004347, 112],
[422, 72, 21.902794639701991, 113],
[423, 177, 22.097313195620178, 114],
[424, 201, 22.293931260843557, 115],
[425, 115, 22.489703191755133, 116],
[426, 249, 22.682138421065023, 117],
[427, 73, 22.87844806818644, 118],
[428, 202, 23.074695353794031, 119],
[429, 176, 23.270679335413632, 120],
[430, 33, 23.468202838845997, 121],
[431, 175, 23.664716556573079, 122],
[432, 203, 23.861430690340345, 123],
[433, 116, 24.059726883673378, 124],
[434, 250, 24.253849795593823, 125],
[435, 74, 24.448553669717572, 126],
[436, 204, 24.643320578970965, 127],
[437, 205, 24.841107548145519, 128],
[438, 174, 25.038993631693437, 129],
[439, 251, 25.237539981043419, 130],
[440, 117, 25.434372015899868, 131],
[441, 75, 25.629974009975534, 132],
[442, 206, 25.825407820894686, 133],
[443, 252, 26.025837308225103, 134],
[444, 76, 26.22411130657791, 135],
[445, 118, 26.423820142184372, 136],
[446, 77, 26.624804788159455, 137],
[447, 253, 26.825100976433188, 138],
[448, 32, 27.029413294965593, 139],
[449, 173, 27.231863915200705, 140],
[450, 78, 27.435912373576297, 141],
[451, 119, 27.639301834192857, 142],
[452, 254, 27.843749041491037, 143],
[453, 79, 28.04876890623898, 144],
[454, 172, 28.254698838791665, 145],
[455, 120, 28.462215434280054, 146],
[456, 80, 28.668550565806434, 147],
[457, 255, 28.873251976708247, 148],
[458, 81, 29.084243575106829, 149],
[459, 121, 29.293657193023151, 150],
[460, 256, 29.499555315834282, 151],
[461, 82, 29.709840100567661, 152],
[462, 171, 29.920206116518838, 153],
[463, 31, 30.130236022217247, 154],
[464, 257, 30.342712960718742, 155],
[465, 122, 30.551278526310849, 156],
[466, 83, 30.760614235670957, 157],
[467, 170, 30.97394087397516, 158],
[468, 258, 31.187073448612239, 159],
[469, 84, 31.39810930479965, 160],
[470, 123, 31.608187581853009, 161],
[471, 259, 31.821204378437194, 162],
[472, 85, 32.032999463374466, 163],
[473, 311, 32.244858629861767, 164],
[474, 310, 32.456648035428216, 165],
[475, 169, 32.667979651988595, 166],
[476, 30, 32.877216235397825, 167],
[477, 309, 33.087066598119783, 168],
[478, 86, 33.296334866674997, 169],
[479, 308, 33.506220396867427, 170],
[480, 87, 33.717009115693848, 171],
[481, 124, 33.926004728779127, 172],
[482, 260, 34.131687862546698, 173],
[483, 307, 34.341429676301622, 174],
[484, 168, 34.550031661046653, 175],
[485, 88, 34.75841162651291, 176],
[486, 306, 34.96694443003247, 177],
[487, 305, 35.175717614553463, 178],
[488, 29, 35.382991317069816, 179],
[489, 167, 35.589483798811209, 180],
[490, 304, 35.795630497772144, 181],
[491, 303, 36.002410706898281, 182],
[492, 89, 36.20884597478468, 183],
[493, 261, 36.413679560955643, 184],
[494, 125, 36.617577360850596, 185],
[495, 90, 36.822155077470342, 186],
[496, 262, 37.026266431662023, 187],
[497, 302, 37.230355717760318, 188],
[498, 91, 37.435007664927845, 189],
[499, 301, 37.63969155305832, 190],
[500, 28, 37.842926961961069, 191],
[501, 166, 38.043661721257003, 192],
[502, 300, 38.245884537773129, 193],
[503, 92, 38.448804390286213, 194],
[504, 299, 38.651182645166479, 195],
[505, 165, 38.853889549393536, 196],
[506, 93, 39.056278370606435, 197],
[507, 263, 39.257675951309302, 198],
[508, 126, 39.456943545411839, 199],
[509, 94, 39.657162775284533, 200],
[510, 95, 39.857214295200905, 201],
[511, 298, 40.056712910678243, 202],
[512, 264, 40.255679068872766, 203],
[513, 96, 40.454349026455674, 204],
[514, 127, 40.652692980440165, 205],
[515, 97, 40.85054749286568, 206],
[516, 265, 41.048104857529459, 207],
[517, 98, 41.24493495196365, 208],
[518, 297, 41.442016669734741, 209],
[519, 99, 41.638527768229039, 210],
[520, 100, 41.834974453255221, 211],
[521, 101, 42.031030636584596, 212],
[522, 296, 42.226799332056068, 213],
[523, 102, 42.422175541388171, 214],
[524, 128, 42.616988634375161, 215],
[525, 266, 42.810087800657335, 216],
[526, 103, 43.003701081458701, 217],
[527, 104, 43.196954505183804, 218],
[528, 105, 43.39018536326202, 219],
[529, 267, 43.583823519102815, 220],
[530, 129, 43.776699989304291, 221],
[531, 295, 43.969825586047527, 222],
[532, 164, 44.162204824221774, 223],
[533, 27, 44.354910583886046, 224],
[534, 294, 44.546706479973089, 225],
[535, 268, 44.739516015186474, 226],
[536, 293, 44.932357063254429, 227],
[537, 163, 45.124500375910586, 228],
[538, 269, 45.317695836200976, 229],
[539, 130, 45.510198390038624, 230],
[540, 292, 45.702581751565084, 231],
[541, 270, 45.894380815650763, 232],
[542, 162, 46.086517407427827, 233],
[543, 26, 46.277921106583918, 234],
[544, 291, 46.468693184072265, 235],
[545, 290, 46.660115028549853, 236],
[546, 271, 46.851242712399461, 237],
[547, 161, 47.042288960561756, 238],
[548, 289, 47.233175296980349, 239],
[549, 25, 47.423684155378766, 240],
[550, 288, 47.613978873870032, 241],
[551, 160, 47.80365108690863, 242],
[552, 272, 47.992845827780371, 243],
[553, 131, 48.181839348784287, 244],
[554, 273, 48.370382206768042, 245],
[555, 287, 48.558455400881861, 246],
[556, 274, 48.746417690467652, 247],
[557, 286, 48.934014474289981, 248],
[558, 275, 49.121212893010608, 249],
[559, 132, 49.308079600236958, 250],
[560, 276, 49.494235053228522, 251],
[561, 285, 49.680141475358589, 252],
[562, 277, 49.865898096414789, 253],
[563, 284, 50.051010903906352, 254],
[564, 159, 50.235532192535473, 255],
[565, 283, 50.419687294787863, 256],
[566, 278, 50.603217545799232, 257],
[567, 282, 50.786361995087354, 258],
[568, 280, 50.968986028803933, 259],
[569, 279, 51.151014330545891, 260],
[570, 281, 51.332584197279814, 261],
[571, 133, 51.513650594986871, 262],
[572, 158, 51.695447058635992, 263],
[573, 24, 51.876798988392814, 264],
[574, 134, 52.05883183935822, 265],
[575, 157, 52.240674452481173, 266],
[576, 156, 52.423192386470177, 267],
[577, 23, 52.604554164984499, 268],
[578, 135, 52.78670822985984, 269],
[579, 155, 52.968823994310441, 270],
[580, 136, 53.151181947673599, 271],
[581, 154, 53.332981843924486, 272],
[582, 22, 53.514271075382979, 273],
[583, 153, 53.695705388191747, 274],
[584, 137, 53.87668769553229, 275],
[585, 152, 54.057427582272808, 276],
[586, 138, 54.237631679908716, 277],
[587, 139, 54.417667541515144, 278],
[588, 151, 54.59703721784976, 279],
[589, 140, 54.775951076160361, 280],
[590, 150, 54.954394694870459, 281],
[591, 149, 55.132110128563099, 282],
[592, 141, 55.308957703757486, 283],
[593, 143, 55.485450704055999, 284],
[594, 142, 55.661006840016022, 285],
[595, 148, 55.835763444513738, 286],
[596, 147, 56.009844122898336, 287],
[597, 144, 56.183147910738306, 288],
[598, 145, 56.355918476527322, 289],
[599, 146, 56.527929600980769, 290],
[600, 21, 56.701145099604936, 291],
[601, 20, 56.874768596683069, 292],
[602, 19, 57.049473911403105, 293],
[603, 18, 57.224426757534133, 294],
[604, 17, 57.399936809324004, 295],
[605, 16, 57.575261094418813, 296],
[606, 15, 57.751331912122289, 297],
[607, 14, 57.92779189099705, 298],
[608, 13, 58.103770505036081, 299],
[609, 12, 58.279802718923897, 300],
[610, 11, 58.455721118208665, 301],
[611, 10, 58.631339581487026, 302],
[612, 9, 58.806495598194736, 303],
[613, 8, 58.981134536182289, 304],
[614, 7, 59.155628417520155, 305],
[615, 6, 59.329478833657468, 306],
[616, 5, 59.502712831224564, 307],
[617, 4, 59.675199527656936, 308],
[618, 3, 59.847384320385984, 309],
[619, 2, 60.018646697963312, 310],
[620, 1, 60.189092488116778, 311],
[621, 0, 60.35935950732884, 312]])
In [181]:
Z = np.array(Z)
single_test = linkage(mvn_X,method='single')
complete_test = linkage(mvn_X,method='complete')
average_test = linkage(mvn_X,method='average')
print("BHC_test:", round(purity_score(Z, mvn_y, 5, 12),3))
print("Single_linkage:", round(purity_score(single_test, mvn_y, 5, 12),3))
print("Complete_linkage:", round(purity_score(complete_test, mvn_y, 5, 12),3))
print("Average_linkage:", round(purity_score(average_test, mvn_y, 5, 12),3))
In [187]:
#Z_spiral = np.array(Z_spiral)
single_test = linkage(mvn_X_spiral,method='single')
complete_test = linkage(mvn_X_spiral,method='complete')
average_test = linkage(mvn_X_spiral,method='average')
print("BHC_test:", round(purity_score(Z_spiral, mvn_y_spiral, 5, 12),3))
print("Single_linkage:", round(purity_score(single_test, mvn_y_spiral, 5, 12),3))
print("Complete_linkage:", round(purity_score(complete_test, mvn_y_spiral, 5, 12),3))
print("Average_linkage:", round(purity_score(average_test, mvn_y_spiral, 5, 12),3))
In [106]:
Z = np.array(Z)
purity_score(linkage_matrix=Z, y_test=mvn_y, class_test=3, repeats=5, seed=16)
Out[106]:
In [110]:
Z_spiral = np.array(Z_spiral)
purity_score(linkage_matrix=Z_spiral, y_test=mvn_y_spiral, class_test=3, repeats=5, seed=16)
In [168]:
#SYNTHETIC binary data
X_test = Bern_gen(30, 10, 0.5, 121)
y_test = []
for i in ['A','B','C']:
y_test.extend(np.repeat(i,10))
Zb = bhc.bhclust(X_test, family = "bernoulli", alpha = 0.001)[0]
Zb = np.array(Zb)
single_test = linkage(X_test,method='single')
complete_test = linkage(X_test,method='complete')
average_test = linkage(X_test,method='average')
print("BHC_test:", round(purity_score(Zb, y_test, 5, 12),3))
print("Single_linkage:", round(purity_score(single_test, y_test, 5, 12),3))
print("Complete_linkage:", round(purity_score(complete_test, y_test, 5, 12),3))
print("Average_linkage:", round(purity_score(average_test, y_test, 5, 12),3))
In [165]:
#CEDA data from paper
multivariate_test = pd.read_csv("/Users/lina/Downloads/bindat.csv", header=-1)
bn_y = list(np.repeat('0',40))
bn_y.extend(list(np.repeat('2',40)))
bn_y.extend(list(np.repeat('4',40)))
bn_X = np.array(multivariate_test)
In [166]:
Zb_paper = bhc.bhclust(bn_X, family = "bernoulli", alpha = 0.001)[0]
Zb_paper = np.array(Zb_paper)
#BHC_test = np.array(bhc.bhclust_BB(bn_X)[0])
single_test = linkage(bn_X,method='single')
complete_test = linkage(bn_X,method='complete')
average_test = linkage(bn_X,method='average')
print("BHC_test:", round(purity_score(Zb_paper, bn_y, 5, 12),3))
print("Single_linkage:", round(purity_score(single_test, bn_y, 5, 12),3))
print("Complete_linkage:", round(purity_score(complete_test, bn_y, 5, 12),3))
print("Average_linkage:", round(purity_score(average_test, bn_y, 5, 12),3))
In [ ]: