In [2]:
print(23 * 2.32)
53.36
In [3]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = (8,8)
plt.rcParams['font.size'] = 14
In [4]:
from sklearn.datasets import make_blobs
labels = ['b', 'r']
X, y = make_blobs(n_samples=400, centers=23, random_state=23)
y = np.take(labels, (y<10))
In [7]:
X[:5]
Out[7]:
array([[-4.07558626, -5.17157144],
[ 5.03328786, 2.21139413],
[ 2.68857184, 6.70464571],
[ 6.56078922, -8.17265192],
[-9.67208284, 6.59074933]])
In [11]:
X_, y_ = make_blobs(n_samples=400, centers=23, random_state=23)
y_
Out[11]:
array([ 9, 10, 20, 8, 5, 22, 1, 6, 10, 12, 16, 7, 13, 5, 17, 7, 1,
4, 0, 0, 12, 10, 13, 22, 13, 15, 3, 11, 13, 0, 11, 8, 11, 14,
2, 4, 5, 13, 4, 12, 15, 8, 0, 5, 9, 1, 4, 17, 0, 7, 20,
22, 19, 10, 2, 12, 10, 18, 14, 1, 2, 9, 19, 2, 17, 19, 12, 5,
0, 10, 1, 18, 4, 16, 8, 9, 13, 7, 11, 20, 3, 15, 16, 14, 11,
18, 3, 17, 18, 0, 15, 14, 19, 7, 13, 7, 12, 20, 9, 14, 17, 14,
11, 3, 2, 6, 21, 17, 14, 10, 20, 10, 15, 4, 16, 2, 11, 9, 9,
18, 0, 8, 17, 12, 5, 20, 21, 5, 19, 0, 18, 21, 1, 7, 21, 7,
10, 1, 22, 20, 22, 7, 17, 3, 16, 12, 11, 4, 8, 2, 14, 7, 19,
22, 3, 1, 0, 12, 9, 5, 18, 19, 1, 9, 11, 16, 2, 5, 17, 18,
3, 15, 5, 21, 21, 8, 1, 18, 9, 17, 22, 1, 15, 6, 17, 4, 15,
2, 5, 0, 5, 0, 8, 16, 18, 4, 2, 8, 15, 7, 4, 9, 10, 21,
10, 10, 3, 10, 22, 8, 18, 13, 10, 20, 3, 1, 16, 2, 6, 12, 17,
8, 13, 3, 19, 22, 11, 20, 19, 4, 17, 13, 6, 15, 2, 20, 8, 12,
5, 13, 13, 19, 2, 2, 15, 12, 21, 5, 19, 17, 14, 4, 14, 6, 0,
15, 5, 13, 19, 22, 7, 18, 12, 19, 7, 6, 14, 12, 14, 18, 4, 9,
4, 16, 7, 15, 16, 8, 0, 14, 11, 22, 21, 10, 4, 14, 20, 16, 18,
5, 21, 17, 6, 6, 11, 20, 7, 6, 21, 13, 22, 3, 2, 21, 3, 21,
6, 22, 0, 11, 6, 9, 1, 3, 1, 21, 9, 8, 3, 3, 13, 9, 19,
0, 13, 6, 18, 20, 5, 6, 22, 14, 8, 15, 16, 18, 12, 7, 1, 8,
6, 15, 16, 4, 14, 9, 2, 20, 3, 6, 14, 3, 19, 21, 0, 0, 21,
6, 16, 15, 20, 20, 2, 4, 16, 16, 10, 6, 22, 11, 20, 10, 22, 17,
18, 1, 7, 8, 1, 5, 1, 11, 22, 11, 21, 19, 7, 11, 16, 2, 12,
8, 13, 19, 12, 3, 17, 4, 15, 9])
In [14]:
plt.scatter(X[:,0], X[:,1], c=y)
Out[14]:
<matplotlib.collections.PathCollection at 0x11ac07c88>
In [15]:
from sklearn.neighbors import KNeighborsClassifier
clf = KNeighborsClassifier()
clf.fit(X, y)
Out[15]:
KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
metric_params=None, n_jobs=1, n_neighbors=5, p=2,
weights='uniform')
In [16]:
clf.predict([[3., 2.5]])
Out[16]:
array(['b'],
dtype='<U1')
In [17]:
from utils import plot_surface
plot_surface(clf, X, y)
In [18]:
X
Out[18]:
array([[ -4.07558626e+00, -5.17157144e+00],
[ 5.03328786e+00, 2.21139413e+00],
[ 2.68857184e+00, 6.70464571e+00],
[ 6.56078922e+00, -8.17265192e+00],
[ -9.67208284e+00, 6.59074933e+00],
[ -5.87915382e+00, -7.05053029e+00],
[ 5.05805016e+00, -2.75897249e+00],
[ 6.51526080e+00, -4.62761111e+00],
[ 6.81492966e+00, 3.21927649e+00],
[ 8.24960093e+00, -8.67572688e+00],
[ -7.50651779e+00, 4.21642244e+00],
[ 2.64444155e+00, 7.21418104e+00],
[ -4.91684350e-01, -3.27623858e+00],
[ -1.12565634e+01, 8.91173102e+00],
[ -6.87676031e+00, -8.66747449e+00],
[ 1.92856933e+00, 7.85179347e+00],
[ 4.22382841e+00, -2.52870489e+00],
[ 3.09701318e+00, -2.09633039e+00],
[ 1.02281779e+00, 1.21267548e+01],
[ 2.74633752e-01, 8.89381450e+00],
[ 9.11019111e+00, -7.14042558e+00],
[ 8.15460584e+00, 3.53613479e+00],
[ -1.35387999e+00, -4.31398380e+00],
[ -5.67125327e+00, -5.62356261e+00],
[ -2.31243172e+00, -3.36472862e+00],
[ 5.78612306e+00, 4.26372527e+00],
[ -7.19885472e+00, -1.35735691e+00],
[ -7.87895033e+00, -8.55792394e+00],
[ 9.94888232e-01, -1.83666314e+00],
[ 1.38684365e+00, 8.84521734e+00],
[ -7.90005688e+00, -9.35823449e+00],
[ 7.31535975e+00, -9.54046967e+00],
[ -8.57859716e+00, -1.00063879e+01],
[ 5.39466551e+00, -1.87854465e+00],
[ -4.92547675e+00, 2.97862891e+00],
[ 4.57531840e-01, -2.00502093e+00],
[ -8.57472881e+00, 7.73366723e+00],
[ -2.80651163e+00, -3.84739904e+00],
[ 1.03649892e+00, -2.38331829e+00],
[ 9.63829615e+00, -7.16519481e+00],
[ 6.84239874e+00, 5.15734759e+00],
[ 6.80720194e+00, -6.36065581e+00],
[ 7.83787357e-01, 1.02670397e+01],
[ -9.22124277e+00, 8.50920405e+00],
[ -2.98875493e+00, -5.09986963e+00],
[ 6.23019102e+00, -2.41899329e+00],
[ 3.24743124e+00, -2.00177788e+00],
[ -7.69041421e+00, -1.02135520e+01],
[ 1.43405527e-02, 7.30686580e+00],
[ 1.76645611e+00, 9.79660242e+00],
[ 1.51086608e+00, 5.28758800e+00],
[ -6.99960201e+00, -4.47882425e+00],
[ -6.90359509e+00, -1.49611964e-01],
[ 6.44678614e+00, 2.01607950e+00],
[ -5.48899879e+00, 3.63049712e+00],
[ 8.09377664e+00, -7.49314785e+00],
[ 5.86609670e+00, 1.23378945e+00],
[ 6.44658501e+00, 4.64241112e-01],
[ 6.83538916e+00, -2.19795714e+00],
[ 5.74708767e+00, -4.25235984e+00],
[ -5.44767643e+00, 5.32767582e+00],
[ -2.99150288e+00, -3.63495185e+00],
[ -7.17213754e+00, 1.59287531e+00],
[ -7.59390088e+00, 5.59533568e+00],
[ -7.62340165e+00, -8.84728513e+00],
[ -8.62544829e+00, 1.36953762e+00],
[ 9.95798633e+00, -6.35866363e+00],
[ -1.02482099e+01, 7.34665806e+00],
[ 3.59379934e-01, 9.87519697e+00],
[ 5.91027277e+00, 2.05813969e+00],
[ 7.45997513e+00, -3.92694304e+00],
[ 7.01619954e+00, -3.46521809e-01],
[ 1.11254334e+00, -2.11459072e+00],
[ -7.67934969e+00, 2.79172457e+00],
[ 5.58107115e+00, -9.28071180e+00],
[ -3.49460476e+00, -5.08084938e+00],
[ 1.47776558e-01, -4.75116069e+00],
[ 1.11360603e+00, 9.07503570e+00],
[ -8.85380666e+00, -9.95476994e+00],
[ 3.83470745e+00, 6.25883319e+00],
[ -6.84726906e+00, -2.06986033e+00],
[ 5.29237060e+00, 4.82205220e+00],
[ -8.20695857e+00, 2.08346572e+00],
[ 7.37810112e+00, -2.67531340e+00],
[ -8.89673913e+00, -9.46437816e+00],
[ 8.31547306e+00, -1.77681956e+00],
[ -6.66095736e+00, -1.87918141e+00],
[ -8.08854236e+00, -8.41542282e+00],
[ 5.51728291e+00, -2.14544090e+00],
[ 8.58095072e-01, 9.07079054e+00],
[ 6.46404819e+00, 5.06226181e+00],
[ 6.99386040e+00, -4.72807666e-01],
[ -6.88864357e+00, 1.31612426e+00],
[ 2.35384252e+00, 8.35141460e+00],
[ -3.27576433e+00, -2.66104844e+00],
[ 2.18027095e+00, 1.03359682e+01],
[ 9.97281298e+00, -5.35430951e+00],
[ 1.00559437e+00, 5.44751042e+00],
[ -3.32002566e+00, -4.99045579e+00],
[ 6.79577536e+00, -4.47762958e-01],
[ -8.24755980e+00, -7.98510941e+00],
[ 5.97055758e+00, -2.21020369e+00],
[ -7.76674306e+00, -1.01412212e+01],
[ -8.42607129e+00, -3.59498200e+00],
[ -5.25554710e+00, 2.62124824e+00],
[ 8.36792594e+00, -4.59088761e+00],
[ 2.85109786e+00, -5.98316678e-01],
[ -7.64531537e+00, -7.87143414e+00],
[ 4.81017719e+00, -9.08676019e-01],
[ 6.76040231e+00, 2.42857374e+00],
[ 2.74001361e+00, 7.45297552e+00],
[ 7.12894087e+00, 2.85451692e+00],
[ 5.46391056e+00, 4.43263121e+00],
[ 3.55562818e+00, -2.16174499e+00],
[ -9.03621260e+00, 1.96822623e+00],
[ -5.23557597e+00, 1.69658089e+00],
[ -8.68780864e+00, -8.54636281e+00],
[ -3.37504045e+00, -4.19960446e+00],
[ -4.52553886e+00, -4.42717886e+00],
[ 7.48701072e+00, -9.51075654e-01],
[ 9.65071750e-01, 7.94667830e+00],
[ 5.22412798e+00, -9.24250463e+00],
[ -7.73707897e+00, -7.10870437e+00],
[ 9.00886828e+00, -7.32850123e+00],
[ -1.08229580e+01, 6.41813263e+00],
[ 1.84060995e+00, 5.04787171e+00],
[ 2.75780398e+00, 3.31082243e-01],
[ -1.00768084e+01, 8.04795852e+00],
[ -6.45717432e+00, 1.18600090e+00],
[ -7.34607078e-01, 8.94948138e+00],
[ 7.02961604e+00, -6.74268077e-01],
[ 8.97244866e-01, -2.43331261e-01],
[ 5.37460126e+00, -3.12765245e+00],
[ 1.12390889e+00, 7.10609827e+00],
[ 2.54236423e+00, 6.90928846e-01],
[ 2.05212360e+00, 8.98242498e+00],
[ 6.98578746e+00, 2.35003259e+00],
[ 5.75034670e+00, -5.16952245e+00],
[ -3.95126251e+00, -6.08277976e+00],
[ 2.47698218e+00, 6.59028103e+00],
[ -7.57625101e+00, -4.61972934e+00],
[ 3.00112828e+00, 9.94547756e+00],
[ -7.41382328e+00, -8.38682412e+00],
[ -6.56376569e+00, -3.58147617e-01],
[ -7.04927578e+00, 5.49638814e-01],
[ 7.29967511e+00, -7.39125635e+00],
[ -8.32442402e+00, -1.05885936e+01],
[ 4.05163364e+00, -9.40768894e-01],
[ 6.73539076e+00, -8.30026635e+00],
[ -6.19127335e+00, 2.98759959e+00],
[ 5.04729104e+00, -8.94058810e-01],
[ 1.47805126e+00, 9.52796177e+00],
[ -8.50535269e+00, -4.10391408e-01],
[ -5.60277163e+00, -6.59116228e+00],
[ -7.07426214e+00, -4.24839704e+00],
[ 5.53570693e+00, -4.11818513e+00],
[ 1.04958131e+00, 8.67110994e+00],
[ 1.07928549e+01, -5.37099051e+00],
[ -4.54406593e+00, -3.69410481e+00],
[ -1.08022380e+01, 7.63419393e+00],
[ 7.40469325e+00, -1.21675975e+00],
[ -7.19051186e+00, 1.00037273e+00],
[ 4.51383197e+00, -5.36361832e+00],
[ -4.10909510e+00, -3.48339631e+00],
[ -7.95298055e+00, -9.39354684e+00],
[ -6.30495691e+00, 2.62235288e-01],
[ -7.34340638e+00, 4.29094563e+00],
[ -9.91174554e+00, 8.07057326e+00],
[ -7.40059499e+00, -7.64642839e+00],
[ 7.29930572e+00, 3.79632837e-01],
[ -6.50120306e+00, -2.45574826e+00],
[ 6.06332381e+00, 3.92547486e+00],
[ -1.02146401e+01, 6.48960176e+00],
[ 2.52165071e+00, 1.55700256e+00],
[ 3.80606776e+00, 3.12462043e-01],
[ 6.63797230e+00, -9.28409939e+00],
[ 4.48625651e+00, -5.95019223e+00],
[ 6.85406009e+00, 3.51688916e-02],
[ -5.87024844e+00, -3.93950193e+00],
[ -8.62306840e+00, -7.97335378e+00],
[ -7.70436410e+00, -6.41956409e+00],
[ 5.17617060e+00, -4.36944006e+00],
[ 6.76279578e+00, 3.72449550e+00],
[ 8.95904894e+00, -2.57668171e+00],
[ -6.79755738e+00, -9.59859614e+00],
[ 2.48478837e+00, -2.63903831e+00],
[ 6.00427228e+00, 4.31012728e+00],
[ -6.77445248e+00, 5.40648718e+00],
[ -1.07622843e+01, 7.89177485e+00],
[ 5.96694229e-01, 8.12563583e+00],
[ -1.03813278e+01, 8.62906890e+00],
[ 2.24860129e-01, 9.12939344e+00],
[ 8.29056288e+00, -7.38639647e+00],
[ -8.58019471e+00, 2.00296133e+00],
[ 6.80970919e+00, -2.84728132e+00],
[ 3.27841559e+00, 5.93349289e-01],
[ -3.70589464e+00, 3.22298573e+00],
[ 6.86583377e+00, -9.09974076e+00],
[ 6.29033526e+00, 3.69435225e+00],
[ 3.81985165e+00, 9.44080256e+00],
[ 3.22478702e+00, -1.41173158e+00],
[ -2.62364147e+00, -3.85861153e+00],
[ 5.90142127e+00, 2.52504375e+00],
[ 2.88571939e+00, 2.46456024e+00],
[ 7.01737919e+00, 2.47182704e+00],
[ 4.51444052e+00, 1.50895356e+00],
[ -5.48541983e+00, -2.77973146e+00],
[ 4.88773201e+00, 4.61579903e+00],
[ -7.27668618e+00, -6.37690080e+00],
[ 6.70027920e+00, -8.45715394e+00],
[ 6.87968806e+00, -6.68927723e-01],
[ -2.90673776e+00, -3.63377745e+00],
[ 6.15635782e+00, 1.43228507e+00],
[ 1.55971996e+00, 6.98614568e+00],
[ -7.99077007e+00, -1.65509293e+00],
[ 5.48001290e+00, -5.06133702e+00],
[ -7.27294815e+00, 1.44414864e+00],
[ -4.27230949e+00, 4.27375179e+00],
[ 8.69243068e+00, -3.56290185e+00],
[ 8.40767648e+00, -8.35285411e+00],
[ -7.51459129e+00, -6.88046902e+00],
[ 7.40092879e+00, -9.79166112e+00],
[ -2.08282733e+00, -3.58887409e+00],
[ -6.68377223e+00, -1.52793139e+00],
[ -6.50756090e+00, 1.05624195e+00],
[ -7.66002448e+00, -5.99067625e+00],
[ -6.62917646e+00, -1.24435936e+01],
[ 1.32125540e+00, 6.22206227e+00],
[ -6.55343079e+00, -4.47435184e-02],
[ 2.47061646e+00, -2.52238846e+00],
[ -7.16610607e+00, -8.23687522e+00],
[ -1.29476626e+00, -3.17588893e+00],
[ 7.55723528e+00, -4.51672707e+00],
[ 6.97148190e+00, 4.84342391e+00],
[ -4.51366994e+00, 3.11887810e+00],
[ 1.11531695e+00, 3.65653462e+00],
[ 6.80199177e+00, -6.69286861e+00],
[ 7.54709775e+00, -7.01861166e+00],
[ -1.08113163e+01, 7.42201164e+00],
[ -2.38410713e+00, -3.37069811e+00],
[ -1.90469140e+00, -3.44932740e+00],
[ -6.15004425e+00, 7.72304184e-01],
[ -4.90268958e+00, 1.46855509e+00],
[ -5.12849677e+00, 5.34189373e+00],
[ 6.67617342e+00, 5.07812070e+00],
[ 9.09234071e+00, -6.16612598e+00],
[ 4.02640891e+00, 1.49072771e+00],
[ -9.97497492e+00, 8.43601963e+00],
[ -6.50715204e+00, 6.94536374e-01],
[ -8.43449556e+00, -8.21662976e+00],
[ 8.04794538e+00, -1.36593373e+00],
[ 1.52676949e+00, -2.48362944e+00],
[ 8.75576478e+00, -8.33389322e-02],
[ 8.26137307e+00, -2.29203333e+00],
[ -2.16027264e+00, 9.08621257e+00],
[ 7.83799820e+00, 4.98391615e+00],
[ -9.13875222e+00, 8.21104251e+00],
[ -1.91938262e+00, -3.89771045e+00],
[ -7.45777045e+00, 1.00573859e+00],
[ -6.37825099e+00, -5.06326540e+00],
[ 1.52858005e+00, 7.82656411e+00],
[ 5.52772242e+00, -2.24980029e+00],
[ 8.60861793e+00, -6.85452904e+00],
[ -7.03039270e+00, 2.70351666e+00],
[ 2.79862401e+00, 1.03668877e+01],
[ 7.21809501e+00, -3.17079669e+00],
[ 6.78416689e+00, -1.37124699e+00],
[ 7.19692234e+00, -7.61053887e+00],
[ 8.60683470e+00, -1.39909996e+00],
[ 6.17139687e+00, -1.93911101e-02],
[ 3.88343741e+00, -4.52580404e-01],
[ -5.41573219e+00, -5.42339937e+00],
[ 3.49918314e+00, -1.45885602e+00],
[ -7.59228660e+00, 1.23103076e+00],
[ 5.64753649e-01, 9.82217325e+00],
[ 5.96829746e+00, 4.15522939e+00],
[ -6.88512339e+00, 7.35597921e-01],
[ 7.81700007e+00, -9.54109836e+00],
[ 9.52153167e-01, 8.91671319e+00],
[ 6.77078119e+00, -2.17821403e+00],
[ -6.78863617e+00, -1.23075413e+01],
[ -6.64922681e+00, -5.70181605e+00],
[ 4.06732570e+00, 5.46201497e-02],
[ 7.44480324e+00, 3.19725724e+00],
[ 1.59228096e+00, -1.57224715e+00],
[ 7.29114402e+00, -1.95361277e+00],
[ 2.86618328e+00, 6.33410232e+00],
[ -6.93232068e+00, 1.23049895e+00],
[ 8.37726509e+00, 7.56618596e-01],
[ -1.01625068e+01, 6.21514191e+00],
[ 3.72969694e+00, 1.75824908e+00],
[ -6.98605962e+00, -9.00618453e+00],
[ 6.24959377e+00, -6.06011699e+00],
[ 7.49606728e+00, -3.77792031e+00],
[ -6.91552849e+00, -1.07981853e+01],
[ 9.76829567e-01, 5.64326258e+00],
[ 1.49353067e+00, 9.54614197e+00],
[ 8.46193037e+00, -3.20517678e+00],
[ 2.49659695e+00, 1.95372239e+00],
[ -3.11879493e+00, -5.03029886e+00],
[ -5.75930193e+00, -5.30079895e+00],
[ -5.41829681e+00, -1.97823361e+00],
[ -4.21860791e+00, 2.15459916e+00],
[ 3.64822383e+00, 2.28870306e+00],
[ -6.89613923e+00, -2.68855015e+00],
[ 3.82232744e+00, 2.73909527e+00],
[ 6.12951167e+00, -4.93914500e+00],
[ -7.00981174e+00, -7.05659501e+00],
[ -7.48862763e-02, 8.38726351e+00],
[ -7.29747089e+00, -1.01990782e+01],
[ 8.00270797e+00, -4.24220301e+00],
[ -4.71975344e+00, -2.59020797e+00],
[ 4.53959614e+00, -4.15996344e+00],
[ -5.78415720e+00, -3.78211254e+00],
[ 4.36197274e+00, -5.61365933e+00],
[ 4.24935708e+00, -8.39789570e-01],
[ -4.43578958e+00, -5.55216909e+00],
[ 6.59937181e+00, -7.83501253e+00],
[ -6.46736025e+00, -2.38321644e+00],
[ -6.64402212e+00, -2.63636072e+00],
[ -2.97614133e+00, -1.86351778e+00],
[ -3.72750751e+00, -3.91923064e+00],
[ -7.41659393e+00, 9.38085966e-01],
[ -1.36601802e-01, 1.01838626e+01],
[ -1.88578814e+00, -2.88875464e+00],
[ 7.44024585e+00, -4.93082582e+00],
[ 6.99658655e+00, -1.15209531e+00],
[ 1.88314370e+00, 3.34786567e+00],
[ -1.04233834e+01, 7.79768893e+00],
[ 7.63737383e+00, -4.61929993e+00],
[ -6.33386945e+00, -4.53289421e+00],
[ 9.11724770e+00, -1.77006031e+00],
[ 7.65334031e+00, -9.59241352e+00],
[ 8.04514430e+00, 4.82963911e+00],
[ -6.21265254e+00, 3.37912026e+00],
[ 6.52102783e+00, 3.77152997e-01],
[ 7.58219482e+00, -7.25297405e+00],
[ -6.59284857e-02, 9.36998737e+00],
[ 5.20810227e+00, -4.91612275e+00],
[ 5.70411406e+00, -9.08437710e+00],
[ 7.59496230e+00, -4.82259907e+00],
[ 6.39413699e+00, 3.71893914e+00],
[ -8.60943171e+00, 4.70900745e+00],
[ 3.01304178e+00, -1.19253789e+00],
[ 5.96446547e+00, -2.76980078e-01],
[ -4.50242393e+00, -4.61792362e+00],
[ -5.03030466e+00, 3.10930672e+00],
[ 7.19921894e-01, 4.06635417e+00],
[ -6.84757417e+00, -1.62576663e+00],
[ 6.34558525e+00, -4.81932122e+00],
[ 6.53673662e+00, -2.68709503e+00],
[ -6.54138506e+00, -1.48605545e+00],
[ -6.64850702e+00, 2.33387879e+00],
[ 2.05785879e+00, 2.14435415e-02],
[ 1.84611290e-01, 1.01316564e+01],
[ 7.66580337e-01, 9.35087172e+00],
[ 2.99450051e+00, -1.11481006e-02],
[ 5.55557573e+00, -3.47960024e+00],
[ -8.21445376e+00, 2.98632528e+00],
[ 7.50891611e+00, 4.11876919e+00],
[ 2.18924013e+00, 4.46414699e+00],
[ 1.29226496e+00, 5.39390328e+00],
[ -7.96895851e+00, 3.39116772e+00],
[ 3.17134767e+00, -2.60492674e+00],
[ -8.18109333e+00, 1.71765180e+00],
[ -1.01230222e+01, 3.05568190e+00],
[ 5.88239791e+00, 8.52444471e-01],
[ 7.59555978e+00, -2.60687120e+00],
[ -7.09100461e+00, -4.39430763e+00],
[ -9.34769922e+00, -9.87207838e+00],
[ 1.18061600e+00, 6.09488594e+00],
[ 5.29335797e+00, 1.70672345e+00],
[ -6.22932898e+00, -4.55338523e+00],
[ -8.09025804e+00, -9.11833727e+00],
[ 6.12036619e+00, 1.29799457e-02],
[ 3.97633169e+00, -2.71469190e+00],
[ 2.25878099e+00, 1.03076866e+01],
[ 5.05488570e+00, -9.12018569e+00],
[ 3.58594213e+00, -3.89082392e+00],
[ -1.00064863e+01, 6.68858068e+00],
[ 5.13015959e+00, -3.34102429e+00],
[ -1.00178080e+01, -9.78292480e+00],
[ -6.43509241e+00, -4.61036734e+00],
[ -7.44224108e+00, -1.01894261e+01],
[ 2.75608832e+00, 2.87980300e+00],
[ -8.01539166e+00, 7.43341277e-01],
[ 1.38248729e+00, 9.78870411e+00],
[ -7.93559405e+00, -8.77181820e+00],
[ -7.68646185e+00, 9.34062449e-01],
[ -5.92145768e+00, 3.10095532e+00],
[ 7.88129264e+00, -7.42939077e+00],
[ 6.98411421e+00, -9.47413578e+00],
[ -1.17216442e+00, -3.93972901e+00],
[ -6.42729653e+00, 4.84304639e-01],
[ 9.61843374e+00, -8.01210116e+00],
[ -7.77705812e+00, -1.71379313e+00],
[ -5.79257343e+00, -9.74351328e+00],
[ 2.93399670e+00, -2.41433573e+00],
[ 5.40257356e+00, 5.72028359e+00],
[ -3.23082407e+00, -1.99852141e+00]])
In [22]:
def f(x, beta0=2.19, beta1=3.141):
return beta0 + beta1 * x + np.random.randn(x.shape[0]) * 2
x = np.linspace(-5, 5, 100
)
np.random.shuffle(x)
X = np.sort(x[:40])
y = f(X)
In [23]:
X = X.reshape(-1, 1)
plt.plot(X, y, 'ob')
Out[23]:
[<matplotlib.lines.Line2D at 0x11bd6c7f0>]
In [24]:
n_steps = 20
beta0s, beta1s = np.meshgrid(np.linspace(0,6,n_steps), np.linspace(0, 6, n_steps))
coefficients = np.c_[beta0s.ravel(), beta1s.ravel()]
print(coefficients)
[[ 0. 0. ]
[ 0.31578947 0. ]
[ 0.63157895 0. ]
[ 0.94736842 0. ]
[ 1.26315789 0. ]
[ 1.57894737 0. ]
[ 1.89473684 0. ]
[ 2.21052632 0. ]
[ 2.52631579 0. ]
[ 2.84210526 0. ]
[ 3.15789474 0. ]
[ 3.47368421 0. ]
[ 3.78947368 0. ]
[ 4.10526316 0. ]
[ 4.42105263 0. ]
[ 4.73684211 0. ]
[ 5.05263158 0. ]
[ 5.36842105 0. ]
[ 5.68421053 0. ]
[ 6. 0. ]
[ 0. 0.31578947]
[ 0.31578947 0.31578947]
[ 0.63157895 0.31578947]
[ 0.94736842 0.31578947]
[ 1.26315789 0.31578947]
[ 1.57894737 0.31578947]
[ 1.89473684 0.31578947]
[ 2.21052632 0.31578947]
[ 2.52631579 0.31578947]
[ 2.84210526 0.31578947]
[ 3.15789474 0.31578947]
[ 3.47368421 0.31578947]
[ 3.78947368 0.31578947]
[ 4.10526316 0.31578947]
[ 4.42105263 0.31578947]
[ 4.73684211 0.31578947]
[ 5.05263158 0.31578947]
[ 5.36842105 0.31578947]
[ 5.68421053 0.31578947]
[ 6. 0.31578947]
[ 0. 0.63157895]
[ 0.31578947 0.63157895]
[ 0.63157895 0.63157895]
[ 0.94736842 0.63157895]
[ 1.26315789 0.63157895]
[ 1.57894737 0.63157895]
[ 1.89473684 0.63157895]
[ 2.21052632 0.63157895]
[ 2.52631579 0.63157895]
[ 2.84210526 0.63157895]
[ 3.15789474 0.63157895]
[ 3.47368421 0.63157895]
[ 3.78947368 0.63157895]
[ 4.10526316 0.63157895]
[ 4.42105263 0.63157895]
[ 4.73684211 0.63157895]
[ 5.05263158 0.63157895]
[ 5.36842105 0.63157895]
[ 5.68421053 0.63157895]
[ 6. 0.63157895]
[ 0. 0.94736842]
[ 0.31578947 0.94736842]
[ 0.63157895 0.94736842]
[ 0.94736842 0.94736842]
[ 1.26315789 0.94736842]
[ 1.57894737 0.94736842]
[ 1.89473684 0.94736842]
[ 2.21052632 0.94736842]
[ 2.52631579 0.94736842]
[ 2.84210526 0.94736842]
[ 3.15789474 0.94736842]
[ 3.47368421 0.94736842]
[ 3.78947368 0.94736842]
[ 4.10526316 0.94736842]
[ 4.42105263 0.94736842]
[ 4.73684211 0.94736842]
[ 5.05263158 0.94736842]
[ 5.36842105 0.94736842]
[ 5.68421053 0.94736842]
[ 6. 0.94736842]
[ 0. 1.26315789]
[ 0.31578947 1.26315789]
[ 0.63157895 1.26315789]
[ 0.94736842 1.26315789]
[ 1.26315789 1.26315789]
[ 1.57894737 1.26315789]
[ 1.89473684 1.26315789]
[ 2.21052632 1.26315789]
[ 2.52631579 1.26315789]
[ 2.84210526 1.26315789]
[ 3.15789474 1.26315789]
[ 3.47368421 1.26315789]
[ 3.78947368 1.26315789]
[ 4.10526316 1.26315789]
[ 4.42105263 1.26315789]
[ 4.73684211 1.26315789]
[ 5.05263158 1.26315789]
[ 5.36842105 1.26315789]
[ 5.68421053 1.26315789]
[ 6. 1.26315789]
[ 0. 1.57894737]
[ 0.31578947 1.57894737]
[ 0.63157895 1.57894737]
[ 0.94736842 1.57894737]
[ 1.26315789 1.57894737]
[ 1.57894737 1.57894737]
[ 1.89473684 1.57894737]
[ 2.21052632 1.57894737]
[ 2.52631579 1.57894737]
[ 2.84210526 1.57894737]
[ 3.15789474 1.57894737]
[ 3.47368421 1.57894737]
[ 3.78947368 1.57894737]
[ 4.10526316 1.57894737]
[ 4.42105263 1.57894737]
[ 4.73684211 1.57894737]
[ 5.05263158 1.57894737]
[ 5.36842105 1.57894737]
[ 5.68421053 1.57894737]
[ 6. 1.57894737]
[ 0. 1.89473684]
[ 0.31578947 1.89473684]
[ 0.63157895 1.89473684]
[ 0.94736842 1.89473684]
[ 1.26315789 1.89473684]
[ 1.57894737 1.89473684]
[ 1.89473684 1.89473684]
[ 2.21052632 1.89473684]
[ 2.52631579 1.89473684]
[ 2.84210526 1.89473684]
[ 3.15789474 1.89473684]
[ 3.47368421 1.89473684]
[ 3.78947368 1.89473684]
[ 4.10526316 1.89473684]
[ 4.42105263 1.89473684]
[ 4.73684211 1.89473684]
[ 5.05263158 1.89473684]
[ 5.36842105 1.89473684]
[ 5.68421053 1.89473684]
[ 6. 1.89473684]
[ 0. 2.21052632]
[ 0.31578947 2.21052632]
[ 0.63157895 2.21052632]
[ 0.94736842 2.21052632]
[ 1.26315789 2.21052632]
[ 1.57894737 2.21052632]
[ 1.89473684 2.21052632]
[ 2.21052632 2.21052632]
[ 2.52631579 2.21052632]
[ 2.84210526 2.21052632]
[ 3.15789474 2.21052632]
[ 3.47368421 2.21052632]
[ 3.78947368 2.21052632]
[ 4.10526316 2.21052632]
[ 4.42105263 2.21052632]
[ 4.73684211 2.21052632]
[ 5.05263158 2.21052632]
[ 5.36842105 2.21052632]
[ 5.68421053 2.21052632]
[ 6. 2.21052632]
[ 0. 2.52631579]
[ 0.31578947 2.52631579]
[ 0.63157895 2.52631579]
[ 0.94736842 2.52631579]
[ 1.26315789 2.52631579]
[ 1.57894737 2.52631579]
[ 1.89473684 2.52631579]
[ 2.21052632 2.52631579]
[ 2.52631579 2.52631579]
[ 2.84210526 2.52631579]
[ 3.15789474 2.52631579]
[ 3.47368421 2.52631579]
[ 3.78947368 2.52631579]
[ 4.10526316 2.52631579]
[ 4.42105263 2.52631579]
[ 4.73684211 2.52631579]
[ 5.05263158 2.52631579]
[ 5.36842105 2.52631579]
[ 5.68421053 2.52631579]
[ 6. 2.52631579]
[ 0. 2.84210526]
[ 0.31578947 2.84210526]
[ 0.63157895 2.84210526]
[ 0.94736842 2.84210526]
[ 1.26315789 2.84210526]
[ 1.57894737 2.84210526]
[ 1.89473684 2.84210526]
[ 2.21052632 2.84210526]
[ 2.52631579 2.84210526]
[ 2.84210526 2.84210526]
[ 3.15789474 2.84210526]
[ 3.47368421 2.84210526]
[ 3.78947368 2.84210526]
[ 4.10526316 2.84210526]
[ 4.42105263 2.84210526]
[ 4.73684211 2.84210526]
[ 5.05263158 2.84210526]
[ 5.36842105 2.84210526]
[ 5.68421053 2.84210526]
[ 6. 2.84210526]
[ 0. 3.15789474]
[ 0.31578947 3.15789474]
[ 0.63157895 3.15789474]
[ 0.94736842 3.15789474]
[ 1.26315789 3.15789474]
[ 1.57894737 3.15789474]
[ 1.89473684 3.15789474]
[ 2.21052632 3.15789474]
[ 2.52631579 3.15789474]
[ 2.84210526 3.15789474]
[ 3.15789474 3.15789474]
[ 3.47368421 3.15789474]
[ 3.78947368 3.15789474]
[ 4.10526316 3.15789474]
[ 4.42105263 3.15789474]
[ 4.73684211 3.15789474]
[ 5.05263158 3.15789474]
[ 5.36842105 3.15789474]
[ 5.68421053 3.15789474]
[ 6. 3.15789474]
[ 0. 3.47368421]
[ 0.31578947 3.47368421]
[ 0.63157895 3.47368421]
[ 0.94736842 3.47368421]
[ 1.26315789 3.47368421]
[ 1.57894737 3.47368421]
[ 1.89473684 3.47368421]
[ 2.21052632 3.47368421]
[ 2.52631579 3.47368421]
[ 2.84210526 3.47368421]
[ 3.15789474 3.47368421]
[ 3.47368421 3.47368421]
[ 3.78947368 3.47368421]
[ 4.10526316 3.47368421]
[ 4.42105263 3.47368421]
[ 4.73684211 3.47368421]
[ 5.05263158 3.47368421]
[ 5.36842105 3.47368421]
[ 5.68421053 3.47368421]
[ 6. 3.47368421]
[ 0. 3.78947368]
[ 0.31578947 3.78947368]
[ 0.63157895 3.78947368]
[ 0.94736842 3.78947368]
[ 1.26315789 3.78947368]
[ 1.57894737 3.78947368]
[ 1.89473684 3.78947368]
[ 2.21052632 3.78947368]
[ 2.52631579 3.78947368]
[ 2.84210526 3.78947368]
[ 3.15789474 3.78947368]
[ 3.47368421 3.78947368]
[ 3.78947368 3.78947368]
[ 4.10526316 3.78947368]
[ 4.42105263 3.78947368]
[ 4.73684211 3.78947368]
[ 5.05263158 3.78947368]
[ 5.36842105 3.78947368]
[ 5.68421053 3.78947368]
[ 6. 3.78947368]
[ 0. 4.10526316]
[ 0.31578947 4.10526316]
[ 0.63157895 4.10526316]
[ 0.94736842 4.10526316]
[ 1.26315789 4.10526316]
[ 1.57894737 4.10526316]
[ 1.89473684 4.10526316]
[ 2.21052632 4.10526316]
[ 2.52631579 4.10526316]
[ 2.84210526 4.10526316]
[ 3.15789474 4.10526316]
[ 3.47368421 4.10526316]
[ 3.78947368 4.10526316]
[ 4.10526316 4.10526316]
[ 4.42105263 4.10526316]
[ 4.73684211 4.10526316]
[ 5.05263158 4.10526316]
[ 5.36842105 4.10526316]
[ 5.68421053 4.10526316]
[ 6. 4.10526316]
[ 0. 4.42105263]
[ 0.31578947 4.42105263]
[ 0.63157895 4.42105263]
[ 0.94736842 4.42105263]
[ 1.26315789 4.42105263]
[ 1.57894737 4.42105263]
[ 1.89473684 4.42105263]
[ 2.21052632 4.42105263]
[ 2.52631579 4.42105263]
[ 2.84210526 4.42105263]
[ 3.15789474 4.42105263]
[ 3.47368421 4.42105263]
[ 3.78947368 4.42105263]
[ 4.10526316 4.42105263]
[ 4.42105263 4.42105263]
[ 4.73684211 4.42105263]
[ 5.05263158 4.42105263]
[ 5.36842105 4.42105263]
[ 5.68421053 4.42105263]
[ 6. 4.42105263]
[ 0. 4.73684211]
[ 0.31578947 4.73684211]
[ 0.63157895 4.73684211]
[ 0.94736842 4.73684211]
[ 1.26315789 4.73684211]
[ 1.57894737 4.73684211]
[ 1.89473684 4.73684211]
[ 2.21052632 4.73684211]
[ 2.52631579 4.73684211]
[ 2.84210526 4.73684211]
[ 3.15789474 4.73684211]
[ 3.47368421 4.73684211]
[ 3.78947368 4.73684211]
[ 4.10526316 4.73684211]
[ 4.42105263 4.73684211]
[ 4.73684211 4.73684211]
[ 5.05263158 4.73684211]
[ 5.36842105 4.73684211]
[ 5.68421053 4.73684211]
[ 6. 4.73684211]
[ 0. 5.05263158]
[ 0.31578947 5.05263158]
[ 0.63157895 5.05263158]
[ 0.94736842 5.05263158]
[ 1.26315789 5.05263158]
[ 1.57894737 5.05263158]
[ 1.89473684 5.05263158]
[ 2.21052632 5.05263158]
[ 2.52631579 5.05263158]
[ 2.84210526 5.05263158]
[ 3.15789474 5.05263158]
[ 3.47368421 5.05263158]
[ 3.78947368 5.05263158]
[ 4.10526316 5.05263158]
[ 4.42105263 5.05263158]
[ 4.73684211 5.05263158]
[ 5.05263158 5.05263158]
[ 5.36842105 5.05263158]
[ 5.68421053 5.05263158]
[ 6. 5.05263158]
[ 0. 5.36842105]
[ 0.31578947 5.36842105]
[ 0.63157895 5.36842105]
[ 0.94736842 5.36842105]
[ 1.26315789 5.36842105]
[ 1.57894737 5.36842105]
[ 1.89473684 5.36842105]
[ 2.21052632 5.36842105]
[ 2.52631579 5.36842105]
[ 2.84210526 5.36842105]
[ 3.15789474 5.36842105]
[ 3.47368421 5.36842105]
[ 3.78947368 5.36842105]
[ 4.10526316 5.36842105]
[ 4.42105263 5.36842105]
[ 4.73684211 5.36842105]
[ 5.05263158 5.36842105]
[ 5.36842105 5.36842105]
[ 5.68421053 5.36842105]
[ 6. 5.36842105]
[ 0. 5.68421053]
[ 0.31578947 5.68421053]
[ 0.63157895 5.68421053]
[ 0.94736842 5.68421053]
[ 1.26315789 5.68421053]
[ 1.57894737 5.68421053]
[ 1.89473684 5.68421053]
[ 2.21052632 5.68421053]
[ 2.52631579 5.68421053]
[ 2.84210526 5.68421053]
[ 3.15789474 5.68421053]
[ 3.47368421 5.68421053]
[ 3.78947368 5.68421053]
[ 4.10526316 5.68421053]
[ 4.42105263 5.68421053]
[ 4.73684211 5.68421053]
[ 5.05263158 5.68421053]
[ 5.36842105 5.68421053]
[ 5.68421053 5.68421053]
[ 6. 5.68421053]
[ 0. 6. ]
[ 0.31578947 6. ]
[ 0.63157895 6. ]
[ 0.94736842 6. ]
[ 1.26315789 6. ]
[ 1.57894737 6. ]
[ 1.89473684 6. ]
[ 2.21052632 6. ]
[ 2.52631579 6. ]
[ 2.84210526 6. ]
[ 3.15789474 6. ]
[ 3.47368421 6. ]
[ 3.78947368 6. ]
[ 4.10526316 6. ]
[ 4.42105263 6. ]
[ 4.73684211 6. ]
[ 5.05263158 6. ]
[ 5.36842105 6. ]
[ 5.68421053 6. ]
[ 6. 6. ]]
In [25]:
def MSE(X, y, coefficients):
y = y.reshape(-1, 1)
losses = []
for beta0, beta1 in coefficients:
model = beta0 + beta1*X
residual = y - model
L = np.sum(residual**2) / X.shape[0]
losses.append(L)
return np.array(losses)
MSEs = MSE(X, y, coefficients)
idx = np.argmin(MSEs)
print('best coefficients:', coefficients[idx])
best coefficients: [ 1.89473684 3.15789474]
In [26]:
from sklearn.linear_model import LinearRegression
rgr = LinearRegression()
rgr.fit(X, y)
print(rgr.coef_)
print(rgr.intercept_)
[ 3.05581942]
1.97952091908
In [28]:
plt.plot(X, y, 'ob')
line = np.linspace(-5, 5, 100).reshape(-1, 1)
plt.plot(line, rgr.predict(line), '-r')
Out[28]:
[<matplotlib.lines.Line2D at 0x11c3375c0>]
In [31]:
from sklearn.model_selection import train_test_split
X_train,X_test, y_train, y_test = train_test_split(X, y, test_size=0.5)
In [34]:
plt.plot(X_train, y_train, 'ob')
plt.plot(X_test, y_test, '^r')
plt.plot(line, rgr.predict(line), '-r')
Out[34]:
[<matplotlib.lines.Line2D at 0x11ca9b9b0>]
In [37]:
from sklearn.metrics import mean_squared_error
rgr.fit(X_train, y_train)
print(mean_squared_error(y_train, rgr.predict(X_train)))
print(mean_squared_error(y_test, rgr.predict(X_test)))
6.15054347101
5.34654762328
In [39]:
acc_test = []
acc_train = []
for n in range(50):
X, y = make_blobs(n_samples=400, centers=23, random_state=42+n)
y = np.take(labels, (y<10))
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5)
train_scores = []
test_scores= []
for k in range(1,24):
clf = KNeighborsClassifier(n_neighbors=k)
clf.fit(X_train, y_train)
train_scores.append(clf.score(X_train, y_train))
test_scores.append(clf.score(X_test, y_test))
acc_test.append(test_scores)
acc_train.append(train_scores)
plt.plot(range(1,24), train_scores, '-r', alpha=0.1)
plt.plot(range(1,24), test_scores, '-b', alpha=0.1)
ks = range(1,24)
plt.plot(ks, np.array(acc_test).mean(axis=0), '-b')
plt.plot(ks, np.array(acc_train).mean(axis=0), '-r')
Out[39]:
[<matplotlib.lines.Line2D at 0x11caeeba8>]
Content source: andre-martini/advanced-comp-2017
Similar notebooks: