In [216]:
%matplotlib inline
import numpy as np

In [ ]:


In [6]:
N = 100
D = 2

In [8]:
X = np.random.randn(N, D)

In [11]:
np.shape(X)


Out[11]:
(100, 2)

In [37]:
# 常數項
ones = np.array([[1]*N]).T

In [39]:
ones.shape


Out[39]:
(100, 1)

In [40]:
Xb = np.concatenate([ones, X], axis=1)

In [41]:
Xb


Out[41]:
array([[  1.00000000e+00,   9.93751891e-01,   6.66752005e-01],
       [  1.00000000e+00,   3.77411084e-01,  -2.75911808e-01],
       [  1.00000000e+00,  -1.06539594e+00,   6.78236568e-01],
       [  1.00000000e+00,   5.85289875e-01,   3.50537352e-01],
       [  1.00000000e+00,   3.73202539e-04,  -7.85370598e-01],
       [  1.00000000e+00,  -5.17690595e-01,  -6.00503426e-01],
       [  1.00000000e+00,   2.50419729e-01,  -1.97391658e-02],
       [  1.00000000e+00,   9.35649812e-01,   1.58260119e+00],
       [  1.00000000e+00,  -1.39902635e+00,   4.40801514e-01],
       [  1.00000000e+00,  -1.56059368e+00,   2.58841172e-01],
       [  1.00000000e+00,   1.20245547e+00,  -5.27449027e-01],
       [  1.00000000e+00,   5.68798233e-01,   1.27989234e+00],
       [  1.00000000e+00,  -1.94083376e+00,   7.85364277e-01],
       [  1.00000000e+00,  -5.47674725e-01,  -4.61496228e-01],
       [  1.00000000e+00,   1.15682927e+00,  -6.11882423e-01],
       [  1.00000000e+00,  -1.04944631e+00,   1.53769888e-01],
       [  1.00000000e+00,  -4.77018041e-01,   7.19466067e-01],
       [  1.00000000e+00,   1.26937437e+00,  -4.23274840e-01],
       [  1.00000000e+00,   2.71184117e-01,  -6.02906394e-01],
       [  1.00000000e+00,  -1.03980268e+00,  -2.33020826e-01],
       [  1.00000000e+00,  -3.93102497e-01,   8.40331277e-01],
       [  1.00000000e+00,   1.69474533e+00,   1.81684274e-01],
       [  1.00000000e+00,   2.87085602e-01,  -6.39326765e-01],
       [  1.00000000e+00,  -1.97205899e+00,  -8.98187299e-03],
       [  1.00000000e+00,  -1.06484606e-01,  -5.27449639e-01],
       [  1.00000000e+00,   1.29049731e+00,  -6.09356382e-01],
       [  1.00000000e+00,  -2.09556086e+00,  -1.53734837e+00],
       [  1.00000000e+00,   1.40293339e+00,   4.71114919e-01],
       [  1.00000000e+00,   2.07364324e-01,   1.72590826e-01],
       [  1.00000000e+00,   6.68958704e-01,   1.61005999e+00],
       [  1.00000000e+00,  -1.39334915e+00,  -9.80094489e-01],
       [  1.00000000e+00,   4.38058072e-01,   3.09098192e-01],
       [  1.00000000e+00,   1.19328161e+00,  -5.50545672e-01],
       [  1.00000000e+00,  -1.04477407e+00,   4.41009468e-01],
       [  1.00000000e+00,  -8.26631709e-01,  -8.71048482e-02],
       [  1.00000000e+00,   7.23647697e-01,  -1.97469679e+00],
       [  1.00000000e+00,   2.35537501e-01,  -1.94855293e-01],
       [  1.00000000e+00,  -2.04840787e-01,   4.53355742e-01],
       [  1.00000000e+00,   7.14459922e-01,  -4.46881885e-01],
       [  1.00000000e+00,   1.16163344e+00,   2.46020187e+00],
       [  1.00000000e+00,   1.62742686e+00,   1.49110416e-01],
       [  1.00000000e+00,  -3.69667321e-01,   4.63082629e-02],
       [  1.00000000e+00,   7.34940255e-02,   7.84390166e-01],
       [  1.00000000e+00,   8.44453382e-01,   6.29314468e-01],
       [  1.00000000e+00,   4.81639249e-01,  -1.84492148e+00],
       [  1.00000000e+00,   1.61459927e-01,   1.09457716e+00],
       [  1.00000000e+00,  -5.77991967e-01,  -1.75670163e-02],
       [  1.00000000e+00,  -8.30498749e-01,  -3.05755189e-02],
       [  1.00000000e+00,  -9.34201380e-01,   2.30659251e-01],
       [  1.00000000e+00,   3.59243755e-01,   3.56971079e-01],
       [  1.00000000e+00,  -7.04499931e-01,   9.10354536e-01],
       [  1.00000000e+00,   4.69353356e-01,   7.75829082e-01],
       [  1.00000000e+00,   8.42463835e-01,  -1.88246530e+00],
       [  1.00000000e+00,  -1.30264707e+00,   1.98955219e+00],
       [  1.00000000e+00,   7.53053406e-01,  -3.81312294e-01],
       [  1.00000000e+00,   4.80865962e-01,   1.88697106e-01],
       [  1.00000000e+00,  -1.09187528e+00,  -2.24780582e-01],
       [  1.00000000e+00,   1.71850556e+00,  -2.70083157e-01],
       [  1.00000000e+00,   2.67825011e-01,  -6.61211872e-01],
       [  1.00000000e+00,   5.30410888e-01,   1.98198847e+00],
       [  1.00000000e+00,   1.13892402e+00,  -2.32774731e-01],
       [  1.00000000e+00,   1.62078348e+00,   4.10962368e-02],
       [  1.00000000e+00,  -2.47627439e-01,   3.16044037e+00],
       [  1.00000000e+00,  -9.30059204e-01,   1.37822706e-01],
       [  1.00000000e+00,  -1.08090636e+00,  -5.63022686e-01],
       [  1.00000000e+00,   5.92377694e-01,   1.43091049e+00],
       [  1.00000000e+00,  -8.11947294e-02,  -2.06003489e+00],
       [  1.00000000e+00,  -7.47181445e-01,   1.95489944e-02],
       [  1.00000000e+00,  -2.59982526e+00,  -6.30983856e-01],
       [  1.00000000e+00,  -1.02557646e+00,   1.33550772e+00],
       [  1.00000000e+00,   1.07860066e+00,   1.52006850e-01],
       [  1.00000000e+00,   8.59796431e-01,  -6.90396527e-01],
       [  1.00000000e+00,  -7.25776213e-01,   9.41723094e-02],
       [  1.00000000e+00,  -1.17364972e+00,   6.30329682e-01],
       [  1.00000000e+00,  -7.75667191e-01,   5.93624838e-01],
       [  1.00000000e+00,  -5.22434025e-01,  -3.56227952e-01],
       [  1.00000000e+00,  -9.96928142e-01,   1.41539496e+00],
       [  1.00000000e+00,  -8.44561553e-01,   1.08040433e-01],
       [  1.00000000e+00,  -1.23422657e+00,   1.36375229e-01],
       [  1.00000000e+00,   1.92496166e+00,  -2.33586287e-01],
       [  1.00000000e+00,   2.23880422e+00,  -1.07687394e+00],
       [  1.00000000e+00,  -7.67752188e-01,   4.28636063e-01],
       [  1.00000000e+00,  -6.89343806e-01,  -4.23820005e-01],
       [  1.00000000e+00,  -1.27159286e+00,  -2.77033854e-01],
       [  1.00000000e+00,   1.23528301e+00,   6.43215198e-02],
       [  1.00000000e+00,  -5.88349632e-01,  -1.87040815e-01],
       [  1.00000000e+00,  -1.44964972e+00,  -1.27403505e+00],
       [  1.00000000e+00,   8.37419957e-01,  -5.15683598e-01],
       [  1.00000000e+00,  -8.80778731e-01,   3.11395856e-01],
       [  1.00000000e+00,  -1.84565752e+00,   1.25246782e+00],
       [  1.00000000e+00,   6.15132985e-01,  -1.20056945e+00],
       [  1.00000000e+00,  -1.16037411e+00,   9.26997800e-01],
       [  1.00000000e+00,   3.95071546e-01,   5.94187922e-03],
       [  1.00000000e+00,   2.72626150e-01,   1.79268726e-01],
       [  1.00000000e+00,   1.95708066e-01,  -1.08264302e+00],
       [  1.00000000e+00,  -4.56233643e-01,  -1.05387002e+00],
       [  1.00000000e+00,   7.60951572e-01,   9.63607708e-01],
       [  1.00000000e+00,  -1.11011819e-01,  -2.18758941e+00],
       [  1.00000000e+00,  -4.86687475e-01,   9.93042320e-01],
       [  1.00000000e+00,  -5.68746299e-01,   3.13975195e-01]])

In [42]:
w = np.random.randn(D + 1) # 變項數 + 常數項

In [43]:
z = Xb.dot(w)

In [46]:
def sigmoid(z):
    return 1/(1+ np.exp(-z))

In [47]:
print sigmoid(z)


[ 0.61328574  0.74383751  0.04007962  0.57682934  0.78847139  0.51872257
  0.59493879  0.21750332  0.03432538  0.03574312  0.95098277  0.20056885
  0.00735268  0.44360092  0.95418845  0.09892986  0.0982883   0.94766742
  0.8122598   0.18244376  0.09234318  0.9282826   0.82604925  0.02824273
  0.66069155  0.963264    0.26514425  0.82200617  0.49113316  0.14211567
  0.31355269  0.53129842  0.95214559  0.06208669  0.19994046  0.99111928
  0.66189279  0.22062725  0.87696746  0.07912619  0.92418918  0.30515048
  0.20331284  0.56609563  0.98300348  0.14598389  0.2545591   0.1832254
  0.10485619  0.47533004  0.04937774  0.34183509  0.99149483  0.00262529
  0.87151654  0.60254037  0.16713285  0.96804652  0.82681612  0.06256633
  0.91102971  0.9359773   0.00206027  0.12228776  0.27262094  0.16635935
  0.96935402  0.19184977  0.02854101  0.01361918  0.82223505  0.93432916
  0.17740679  0.03625115  0.07476735  0.40838161  0.01242864  0.14585746
  0.075678    0.9760671   0.99688678  0.09917438  0.36753169  0.13845317
  0.8769338   0.3123242   0.41185782  0.90913646  0.10019273  0.00377992
  0.95846005  0.02213785  0.64394071  0.51679308  0.89943202  0.72995467
  0.38246348  0.97417083  0.0616254   0.16088391]

In [ ]:

Process data


In [48]:
import numpy as np
import pandas as pd

In [49]:
df = pd.read_csv("../machine_learning_examples/ann_logistic_extra/ecommerce_data.csv")

In [81]:
df.head()


Out[81]:
is_mobile n_products_viewed visit_duration is_returning_visitor time_of_day user_action
0 1 0 0.657510 0 3 0
1 1 1 0.568571 0 2 1
2 1 0 0.042246 1 1 0
3 1 1 1.659793 1 1 2
4 0 1 2.014745 1 1 2

In [82]:
df.describe()


Out[82]:
is_mobile n_products_viewed visit_duration is_returning_visitor time_of_day user_action
count 500.000000 500.000000 500.000000 500.000000 500.000000 500.00000
mean 0.486000 0.854000 1.055880 0.518000 1.588000 0.74800
std 0.500305 1.046362 0.976711 0.500176 1.121057 0.89336
min 0.000000 0.000000 0.000141 0.000000 0.000000 0.00000
25% 0.000000 0.000000 0.328550 0.000000 1.000000 0.00000
50% 0.000000 1.000000 0.804717 1.000000 2.000000 0.00000
75% 1.000000 1.000000 1.499518 1.000000 3.000000 1.00000
max 1.000000 4.000000 6.368775 1.000000 3.000000 3.00000

In [ ]:


In [135]:
def get_data():
    df = pd.read_csv("../machine_learning_examples/ann_logistic_extra/ecommerce_data.csv")
    data = df.as_matrix()
    
    X = data[:, :-1]
    Y = data[:, -1]
    
    X[:,1] =  (X[:,1] - X[:,1].mean())/ X[:,1].std()
    X[:,2] =  (X[:,2] - X[:,2].mean())/ X[:,2].std()
    
    N, D = X.shape
    # create new matrix including one-hot encoding
    X2 = np.zeros((N,D+3))
    # 只有 time_of_the_day需要做one-hot encoding
    # time_of_the_day  (1~3)
    X2[:,0:(D-1)] = X[:, 0:(D-1)]
    
    for n in xrange(N):
        t = int(X[n, D-1])
        X2[n, D-1+t] = 1 
        
    Z = np.zeros((N, 4))
    Z[np.arange(N), X[:,D-1].astype(np.int32)]=1
    assert(np.abs(X2[:,-4:] - Z).sum() < 10e-10)
    
    return X2, Y

In [137]:
X2, Y = get_data()

In [139]:
Y


Out[139]:
array([ 0.,  1.,  0.,  2.,  2.,  2.,  0.,  0.,  1.,  0.,  3.,  0.,  0.,
        1.,  0.,  3.,  1.,  1.,  1.,  0.,  2.,  0.,  0.,  3.,  0.,  1.,
        0.,  0.,  2.,  2.,  0.,  1.,  2.,  2.,  1.,  0.,  0.,  1.,  1.,
        0.,  2.,  1.,  0.,  1.,  3.,  0.,  1.,  0.,  0.,  1.,  2.,  3.,
        1.,  2.,  1.,  1.,  0.,  0.,  1.,  1.,  0.,  0.,  1.,  0.,  0.,
        1.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  2.,
        0.,  2.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  2.,  3.,  0.,
        2.,  1.,  1.,  3.,  1.,  0.,  3.,  0.,  2.,  0.,  2.,  0.,  1.,
        0.,  1.,  0.,  0.,  1.,  0.,  2.,  3.,  1.,  0.,  0.,  1.,  0.,
        0.,  1.,  0.,  2.,  0.,  1.,  0.,  1.,  1.,  2.,  0.,  0.,  0.,
        0.,  0.,  1.,  0.,  1.,  0.,  0.,  2.,  2.,  0.,  0.,  2.,  0.,
        0.,  0.,  1.,  2.,  1.,  2.,  0.,  2.,  2.,  0.,  2.,  2.,  0.,
        0.,  0.,  3.,  1.,  1.,  1.,  0.,  0.,  1.,  1.,  0.,  2.,  0.,
        0.,  0.,  0.,  0.,  2.,  0.,  1.,  0.,  1.,  0.,  0.,  0.,  2.,
        0.,  1.,  0.,  2.,  1.,  1.,  0.,  1.,  2.,  0.,  0.,  0.,  0.,
        3.,  0.,  1.,  0.,  0.,  0.,  0.,  1.,  0.,  3.,  0.,  2.,  0.,
        0.,  1.,  0.,  1.,  0.,  0.,  0.,  0.,  1.,  1.,  3.,  3.,  1.,
        0.,  1.,  1.,  1.,  0.,  1.,  0.,  1.,  1.,  1.,  0.,  0.,  0.,
        0.,  0.,  1.,  0.,  1.,  1.,  2.,  0.,  0.,  1.,  2.,  0.,  2.,
        0.,  0.,  1.,  3.,  1.,  2.,  0.,  2.,  0.,  0.,  0.,  1.,  0.,
        0.,  0.,  0.,  0.,  0.,  1.,  3.,  0.,  0.,  1.,  0.,  2.,  2.,
        1.,  0.,  2.,  0.,  0.,  3.,  2.,  1.,  0.,  3.,  0.,  0.,  1.,
        0.,  2.,  0.,  1.,  0.,  1.,  0.,  1.,  0.,  0.,  1.,  0.,  1.,
        1.,  0.,  2.,  1.,  1.,  0.,  0.,  1.,  0.,  2.,  0.,  0.,  2.,
        0.,  1.,  1.,  1.,  1.,  2.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,
        0.,  2.,  1.,  1.,  0.,  1.,  1.,  0.,  0.,  0.,  0.,  1.,  0.,
        0.,  1.,  0.,  2.,  0.,  0.,  0.,  1.,  0.,  1.,  0.,  0.,  1.,
        1.,  1.,  0.,  0.,  0.,  0.,  1.,  1.,  1.,  0.,  0.,  0.,  1.,
        0.,  1.,  2.,  1.,  0.,  0.,  0.,  2.,  3.,  2.,  0.,  2.,  0.,
        2.,  0.,  0.,  0.,  1.,  2.,  3.,  1.,  1.,  1.,  2.,  3.,  0.,
        1.,  0.,  0.,  0.,  0.,  1.,  0.,  1.,  1.,  0.,  1.,  0.,  1.,
        0.,  2.,  0.,  1.,  0.,  2.,  2.,  1.,  0.,  2.,  0.,  1.,  0.,
        0.,  0.,  0.,  2.,  1.,  0.,  0.,  0.,  1.,  2.,  0.,  1.,  0.,
        0.,  0.,  0.,  2.,  2.,  0.,  2.,  2.,  3.,  1.,  0.,  0.,  2.,
        0.,  0.,  3.,  0.,  0.,  2.,  0.,  0.,  1.,  0.,  2.,  1.,  1.,
        0.,  1.,  0.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  1.,  1.,  1.,
        0.,  1.,  0.,  0.,  2.,  0.,  2.,  0.,  3.,  0.,  0.,  0.,  0.,
        0.,  2.,  2.,  1.,  1.,  1.,  0.,  1.,  1.,  3.,  0.,  0.,  0.,
        2.,  2.,  0.,  0.,  0.,  0.])

In [141]:
def get_binary_data():
    X, Y = get_data()
    X2 = X[Y <= 1]
    Y2 = Y[Y <= 1]
    return X2, Y2

Prediction


In [142]:
X , Y = get_binary_data()

In [143]:
N, D = X.shape

In [144]:
W = np.random.randn(D)

In [145]:
b = 0

In [149]:
def sigmoid(a):
    return 1 / (1 + np.exp(-a))

In [150]:
def forward(X, W, b):
    return sigmoid(X.dot(W) + b)

In [151]:
P_Y_given_X = forward(X, W ,b)

In [154]:
prediction = np.round(P_Y_given_X)

In [155]:
def classfication_rate(Y, P):
    return np.mean(Y == P)

In [157]:
print ("score: %s" % classfication_rate(Y,prediction))


score: 0.683417085427

In [ ]:


In [ ]:


In [ ]:


In [ ]:
##################

In [85]:
df = pd.read_csv("../machine_learning_examples/ann_logistic_extra/ecommerce_data.csv")
data = df.as_matrix()
    
X = data[:, :-1]
Y = data[:, -1]
    
X[:,1] =  (X[:,1] - X[:,1].mean())/ X[:,1].std()
X[:,2] =  (X[:,2] - X[:,2].mean())/ X[:,2].std()
    
N, D = X.shape
    # create new matrix including one-hot encoding
X2 = np.zeros((N,D+3))
# 只有 time_of_the_day需要做one-hot encoding
    # time_of_the_day  (1~3)
X2[:,0:(D-1)] = X[:, 0:(D-1)]
    
for n in xrange(N):
    t = int(X[n, D-1])
    X2[n, D-1+t] = 1 
        
Z = np.zeros((N, 4))

In [114]:
Z.shape


Out[114]:
(500, 4)

In [120]:
Z[np.arange(N), X[:,D-1].astype(np.int32)] = 1

In [125]:
type(X[:,D-1])


Out[125]:
numpy.ndarray

In [158]:
np.arrange()


---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-158-8ba384b6b80e> in <module>()
----> 1 np.arrange()

AttributeError: 'module' object has no attribute 'arrange'

Gaussian and log(Gaussian) plot


In [211]:
import seaborn as sns
import scipy

In [213]:
scipy.stats.norm.pdf(2)


Out[213]:
0.053990966513188063

In [214]:
x = np.linspace(-3, 3, 100)

In [215]:
y = scipy.stats.norm.pdf(x)

In [217]:
sns.plt.plot(x,y)


Out[217]:
[<matplotlib.lines.Line2D at 0x112769b10>]

In [219]:
log_y = np.log(y)

In [220]:
sns.plt.plot(x, log_y)


Out[220]:
[<matplotlib.lines.Line2D at 0x112984810>]

In [221]:
squar_y = y * y

In [222]:
sns.plt.plot(x, squar_y)


Out[222]:
[<matplotlib.lines.Line2D at 0x112a9b350>]

logistic


In [1]:
import numpy as np

In [2]:
N = 100
D = 2

In [3]:
X = np.random.randn(N, D)

In [4]:
X[0]


Out[4]:
array([-1.09460797,  0.77611705])

In [5]:
X[:50, :] = X[:50, :] - 2*np.ones((50,D))

In [6]:
X[0]


Out[6]:
array([-3.09460797, -1.22388295])

In [7]:
X[50:, :] = X[50:, :] + 2*np.ones((50, D))

In [32]:
X[0]


Out[32]:
array([-3.09460797, -1.22388295])

In [33]:
X


Out[33]:
array([[-3.09460797, -1.22388295],
       [-1.85694746, -2.96521385],
       [-3.19746189, -1.34916859],
       [-4.21130155, -2.21003089],
       [-2.96862735,  0.17410949],
       [-1.71653279, -1.26340179],
       [-3.40668005,  0.13175914],
       [-0.55510237, -0.90495622],
       [-3.58553397, -1.94775111],
       [-1.35235069, -0.59871431],
       [-0.69381335, -1.03667929],
       [-2.34996084, -3.19194379],
       [-4.46878418, -2.42648397],
       [-2.98985864, -2.24585402],
       [-2.97964461, -1.74981333],
       [-1.98102993, -2.09782076],
       [-2.52146891, -2.4886781 ],
       [-2.99636112, -3.76110225],
       [-1.57057041, -3.17323462],
       [-1.88585935, -2.63767718],
       [-1.66675761, -0.865735  ],
       [-1.9713574 , -1.05412651],
       [-0.76331653, -0.939845  ],
       [-0.64318313, -2.72340746],
       [-1.6611602 , -0.39123348],
       [-2.78180514, -0.42406616],
       [-1.85112236, -1.03075703],
       [ 0.53736747, -2.12648289],
       [-2.60045166, -1.61751198],
       [-0.58875122, -2.53701914],
       [-1.89377229, -2.97099694],
       [-2.32840367, -0.49823545],
       [-1.7100886 , -1.11034142],
       [-0.18846388, -3.26571878],
       [-4.06340448, -2.77143177],
       [-1.61941685, -2.17862437],
       [-3.57072848, -0.03620329],
       [-1.35760031, -2.06768856],
       [-2.84681194, -2.4960743 ],
       [-2.46284409, -1.00364446],
       [-2.02885893, -2.3315675 ],
       [-2.70396514, -1.04080066],
       [-1.94665975, -2.61373084],
       [-0.04518893, -3.0538181 ],
       [-0.49782364, -3.52838834],
       [-2.42191977, -1.28813663],
       [-0.04034495, -3.11514386],
       [-2.54299923, -2.34469755],
       [-1.48463301, -2.74685699],
       [-3.69120921, -1.77016267],
       [ 1.59893642,  0.20315069],
       [ 2.70803706,  2.13297713],
       [ 2.17935246,  0.85013127],
       [ 5.17129582,  2.25067801],
       [ 0.80281244,  2.42469083],
       [ 1.82146062,  1.09775641],
       [ 1.1216339 ,  3.01148772],
       [ 1.67343377,  2.98743498],
       [ 1.56697155,  2.06747475],
       [ 1.80464855,  2.4035561 ],
       [ 1.86182917,  0.95765451],
       [ 1.66087608,  2.78537549],
       [ 0.81957532,  0.34304047],
       [ 1.58552   ,  2.06965295],
       [ 1.73411595,  1.83919562],
       [ 1.95753357,  2.57836315],
       [ 1.14510803,  0.92266429],
       [ 1.38025295,  2.48653409],
       [ 3.42649282,  0.96316835],
       [ 1.95646414,  0.61565452],
       [ 2.12854567,  0.62844801],
       [ 2.58202201,  0.05578524],
       [ 1.40518746,  0.04272369],
       [ 1.6819604 ,  1.98935882],
       [ 3.03335236,  1.17808246],
       [ 1.97965832,  4.62737264],
       [ 1.14658626,  2.89787647],
       [ 2.03043623,  2.53094135],
       [ 1.15837894,  1.52070535],
       [ 1.28071655,  2.22764824],
       [ 0.60705701,  1.44094261],
       [ 4.7944334 ,  2.55363817],
       [ 1.9656908 ,  0.74700123],
       [ 1.68719019,  1.89540318],
       [ 2.50819613,  2.86314961],
       [ 1.39820852,  1.98150769],
       [ 0.38871083,  3.67721795],
       [ 2.47415706,  3.2324861 ],
       [ 1.84759635,  2.64037839],
       [ 2.26655992,  1.72769021],
       [ 2.95878749,  1.93937911],
       [ 1.93479571,  2.02961329],
       [ 1.21492926,  1.18146562],
       [ 3.3882496 ,  1.56201497],
       [ 1.25303289,  1.06883205],
       [ 2.11694264,  2.6036041 ],
       [ 1.6171156 ,  2.31342671],
       [ 2.01996033,  2.92560347],
       [ 3.44371585,  1.93177231],
       [ 3.59869074,  1.96535034]])

In [ ]:


In [9]:
T = np.array([0]*50+[1]*50)

In [10]:
ones = np.array([[1]*N]).T

In [11]:
Xb = np.concatenate((ones,X), axis=1)

In [12]:
w = np.random.randn(D+1)

In [13]:
z = Xb.dot(w)

In [14]:
def sigmoid(z):
    return 1 / (1+np.exp(-z))

In [15]:
Y = sigmoid(z)

In [16]:
# T: 實際值
# Y: 預測值
def cross_entropy(T, Y):
    return -sum(T*np.log(Y)+(1-T)*np.log(1-Y))

In [17]:
cross_entropy(T,Y)


Out[17]:
82.248483442891001

In [18]:
T[0]


Out[18]:
0

In [19]:
Y[0]


Out[19]:
0.69894979687397873

In [20]:
np.log(1-Y[0])


Out[20]:
-1.2004782403467085

In [21]:
np.log(1)


Out[21]:
0.0

In [22]:
w2 = np.array([0, 4, 4])

In [23]:
z2 = Xb.dot(w2)

In [24]:
Y2 = sigmoid(z2)

In [25]:
cross_entropy(T, Y2)


Out[25]:
0.021672366099269647

visulization


In [26]:
%matplotlib inline
import seaborn as sns

In [31]:
sns.plt.scatter(X[:,0], X[:,1], c=T, s=100, alpha=0.5)


Out[31]:
<matplotlib.collections.PathCollection at 0x11122ce90>

In [ ]:


In [ ]:


In [ ]:


In [ ]:


In [ ]: