In [1]:
import numpy as np

#定義常用函數
def sigmoid(z):
    """The sigmoid function."""
    return 1.0/(1.0+np.exp(-z))

def sigmoid_prime(z):
    """Derivative of the sigmoid function."""
    return sigmoid(z)*(1-sigmoid(z))

In [2]:
# 2 3 1 採用右乘

w1 = np.random.rand(2, 3) - 0.5
b1 = np.zeros((1, 3))

w2 = np.random.rand(3, 1) - 0.5
b2 = np.zeros((1, 1))

loss = 0
acc = 0
total_test = 10000000
to_print = 100000
eta = 0.01

for i in range(total_test):
    
    #Data
    x = np.random.rand(1, 2)
    z_ = float( np.sum(x) < 0.5 or np.sum(x) > 1.5 )
    
    #Feedforward
    s1 = np.dot(x, w1) + b1
    z1 = sigmoid(s1)
    
    s2 = np.dot(z1, w2) + b2
    z2 = sigmoid(s2)
    
    #Loss
    loss += -z_ * np.log(z2) - (1 - z_) * np.log(1 - z2) #Cross entropy
    acc += float(z_ == float(z2 > 0.5))
    if (i + 1) % to_print == 0:
        score = acc / float(to_print)
        print loss, score, (i + 1) / to_print
        loss = 0
        acc = 0
        if score > 0.99:
            break
    
    #Backpropagation
    #delta2_out = -z_ / z2 + (1 - z_) / (1 - z2)
    #delta2_in = delta2_out * sigmoid_prime(z2)
    
    # -z_ * (1 - z2) + (1 - z_) * z2  
    # -z_ + z_z2 + z2 - z_z2 
    # z2 - z_
    delta2_in = z2 - z_
    
    dw2 = eta * np.dot(z1.T, delta2_in) 
    db2 = eta * np.sum(delta2_in, axis = 0)
    
    delta1_out = np.dot(delta2_in, w2.T)
    delta1_in = delta1_out * sigmoid_prime(s1)
    
    dw1 = eta * np.dot(x.T, delta1_in)
    db1 = eta * np.sum(delta1_in, axis = 0)
    
    w1 -= dw1
    b1 -= db1
    w2 -= dw2
    b2 -= db2


[[ 56304.19468505]] 0.75018 1
[[ 56361.22368372]] 0.74961 2
[[ 56486.88388642]] 0.74835 3
[[ 56326.74536639]] 0.74976 4
[[ 56634.6653972]] 0.74684 5
[[ 56343.9921785]] 0.74876 6
[[ 52583.5111166]] 0.74762 7
[[ 45576.18841399]] 0.74907 8
[[ 42346.44592707]] 0.78842 9
[[ 23473.38387846]] 0.91561 10
[[ 11145.3373047]] 0.97108 11
[[ 6277.23528982]] 0.99067 12

In [3]:
print w1, b1, w2, b2


[[ -6.6289409  -11.99595684  -3.39362292]
 [ -2.77804889 -11.91187781  -6.89972204]] [[ 6.54707894  6.42226942  7.33172589]] [[ -9.78164921]
 [ 19.76103982]
 [-10.66773685]] [[ 8.10960182]]