In [1]:
import numpy as np

In [2]:
# 2 3 1
w11 = 0.5 - np.random.rand()
w12 = 0.5 - np.random.rand()
b1 = 0.5 - np.random.rand()

w21 = 0.5 - np.random.rand()
w22 = 0.5 - np.random.rand()
b2 = 0.5 - np.random.rand()

w31 = 0.5 - np.random.rand()
w32 = 0.5 - np.random.rand()
b3 = 0.5 - np.random.rand()

w41 = 0.5 - np.random.rand()
w42 = 0.5 - np.random.rand()
w43 = 0.5 - np.random.rand()
b4 = 0.5 - np.random.rand()

loss = 0
acc = 0
total_test = 1000000
to_print = total_test / 10
eta = 0.001

for i in range(total_test):
    
    #Data
    x1 = np.random.rand()
    x2 = np.random.rand()
    z_ = float( (x1 + x2) < 0.5 or (x1 + x2) > 1.5 )
    
    #Feedforward
    y1 = w11 * x1 + w12 * x2 + b1
    z1 = 1 / (1 + np.exp(-y1))
    
    y2 = w21 * x1 + w22 * x2 + b2
    z2 = 1 / (1 + np.exp(-y2))
    
    y3 = w31 * x1 + w32 * x2 + b3
    z3 = 1 / (1 + np.exp(-y3))
    
    y4 = w41 * z1 + w42 * z2 + w43 * z3 +  b4
    z = 1 / (1 + np.exp(-y4))
    
    #Loss
    loss += -z_ * np.log(z) - (1 - z_) * np.log(1 - z) #Cross entropy
    acc += float(z_ == float(z > 0.5))
    if (i + 1) % to_print == 0:
        print loss, acc / float(to_print), (i + 1) / to_print
        loss = 0
        acc = 0
    
    #Backpropagation
    dw41 = eta * (z - z_) * z1
    dw42 = eta * (z - z_) * z2
    dw43 = eta * (z - z_) * z3
    db4  = eta * (z - z_)
    
    dw11 = eta * (z - z_) * w41 * z1 * (1 - z1) * x1
    dw12 = eta * (z - z_) * w41 * z1 * (1 - z1) * x2
    db1  = eta * (z - z_) * w41 * z1 * (1 - z1)
    
    dw21 = eta * (z - z_) * w42 * z2 * (1 - z2) * x1
    dw22 = eta * (z - z_) * w42 * z2 * (1 - z2) * x2
    db2  = eta * (z - z_) * w42 * z2 * (1 - z2)
    
    dw31 = eta * (z - z_) * w43 * z3 * (1 - z3) * x1
    dw32 = eta * (z - z_) * w43 * z3 * (1 - z3) * x2
    db3  = eta * (z - z_) * w43 * z3 * (1 - z3)
    
    w41 -= dw41
    w42 -= dw42
    w43 -= dw43
    b4  -= db4
    
    w11 -= dw11
    w12 -= dw12
    b1  -= db1
    
    w21 -= dw21
    w22 -= dw22
    b2  -= db2
    
    w31 -= dw31
    w32 -= dw32
    b3  -= db3


56497.647942 0.74853 1
56305.3045915 0.74969 2
56138.1466804 0.75114 3
56428.8370871 0.74845 4
56201.0312631 0.7505 5
56284.5143919 0.74968 6
56446.6593911 0.74821 7
56393.470542 0.74868 8
56141.7071168 0.75089 9
56356.3721604 0.74893 10