Experiment on softmax

$$ softmax(z_j) = {e^{z_j} \over {\sum_{i=1}^m e^{z_i} }} , j=1, ... m $$

反向传播计算步骤:

  • 计算输出层,计算最后一层softmax输出层的下列值:
$$ \delta_i^L = ({a_i^L} - y_i) $$$$ \frac {\partial C}{\partial w_{ij}^L} = ({a_i^L}-{y_i}){a_j^{L-1}} = \delta_i^L \cdot a_j^{L-1} $$$$ \frac {\partial C}{\partial b_{i}^L} = ({a_i^L} - y_i) =\delta_i^L $$
  • 计算隐藏层,反向一层一层计算sigmoid隐藏层的下列值:
$$ \delta_j^{l-1} = (\sum_{k=1}^m {\delta_k^l \cdot w_{kj}}) \cdot { a_j^{l-1} (1 - a_j^{l-1}) } $$$$ \frac {\partial C}{\partial w_{ij}^l} = \delta_i^{l} \cdot a_j^{l-1} $$$$ \frac {\partial C}{\partial b_{i}^l} = \delta_i^{l} $$

In [1]:
import numpy as np

def softmax(z):
    """
    z: is a array of input
    return: softmax correspond to each of the element
    """
    return np.exp(z)/np.sum(np.exp(z), 1, keepdims=True)

def sigmoid(z):
    return 1/(1+np.exp(-z))

def cost(activations, expectations):
    return np.sum(-expectations*np.log(activations))

In [2]:
# prepare the training data

# training input data
input = np.array([[0, 0, 0, 1],
                 [0, 0, 1, 0],
                 [0, 0, 1, 1],
                 [0, 1, 0, 0],
                 [0, 1, 0, 1],
                 [0, 1, 1, 0],
                 [0, 1, 1, 1]])

# training output expectation
output = np.array([[0, 1],
                   [1, 0],
                   [0, 1],
                   [1, 0],
                   [0, 1],
                   [1, 0],
                   [0, 1]])

网络拓扑


In [3]:
# construct the network

# input layer: 4 inputs
# hidden layer: 5 neurons with sigmoid as activate function
# * weight: 4x5 matrices
# * bias: 1x5 matrices
# output layer: 2 neurons with softmax as activate function
# * weight: 5x2 matrices
# * bias: 1x2 matrices

# initialize the weight/bias of the hidden layer (2nd layer)
w2 = np.random.rand(4, 5)
b2 = np.random.rand(1, 5)

# initialize the weight/bias of the output layer (3rd layer) 
w3 = np.random.rand(5, 2)
b3 = np.random.rand(1, 2)

In [4]:
num_epochs = 10000
eta = 0.1

x=[]
y=[]

for i in xrange(num_epochs):
    # feed forward
    z2 = np.dot(input, w2) + b2
    a2 = sigmoid(z2) # 7x5

    z3 = np.dot(a2, w3) + b3
    #z3 = np.dot(a2, w3)
    a3 = softmax(z3) # 7x2
    
    if i%1000 == 0:
        print "Perception", a3
        print "W2", w2
        print "B2", b2
        print "W3", w3
        print "B3", b3

    x.append(i)
    y.append(cost(a3, output))

    # 7x2
    delta_l3 = a3 - output
    deriv_w3 = np.dot(a2.T, delta_l3)
    deriv_b3 = delta_l3
    w3 -= eta*deriv_w3
    b3 -= eta*np.mean(deriv_b3, 0)
    
    delta_l2 = np.dot(delta_l3, w3.T)*(a2*(1-a2)) # 7x5
    deriv_w2 = np.dot(input.T, delta_l2)
    deriv_b2 = delta_l2
    w2 -= eta*deriv_w2
    b2 -= eta*np.mean(deriv_b2, 0)


Perception [[ 0.61447444  0.38552556]
 [ 0.63411271  0.36588729]
 [ 0.63291354  0.36708646]
 [ 0.62566476  0.37433524]
 [ 0.62461907  0.37538093]
 [ 0.64090918  0.35909082]
 [ 0.64021094  0.35978906]]
W2 [[ 0.98581723  0.77097799  0.43174681  0.44405765  0.11745462]
 [ 0.56285789  0.68637715  0.11755575  0.40969895  0.10522042]
 [ 0.29521794  0.49557402  0.52698973  0.65697418  0.29259944]
 [ 0.05643833  0.79178668  0.0175052   0.00110209  0.07388702]]
B2 [[ 0.55055101  0.92322898  0.29658861  0.61080862  0.67286559]]
W3 [[ 0.99459228  0.98659168]
 [ 0.85772062  0.91232175]
 [ 0.25425736  0.09411044]
 [ 0.59848548  0.1101079 ]
 [ 0.50518666  0.55335981]]
B3 [[ 0.38276182  0.2520216 ]]
Perception [[  6.25341793e-04   9.99374658e-01]
 [  9.96828319e-01   3.17168117e-03]
 [  1.19162234e-03   9.98808378e-01]
 [  9.96818146e-01   3.18185410e-03]
 [  1.14727780e-03   9.98852722e-01]
 [  9.99647489e-01   3.52510850e-04]
 [  3.72190201e-03   9.96278098e-01]]
W2 [[ 0.98581723  0.77097799  0.43174681  0.44405765  0.11745462]
 [ 0.49771115 -0.79860825  0.18782909  0.8705957  -0.61426333]
 [ 0.23741226 -0.87430177  0.38484818  0.83707038 -0.32818166]
 [ 0.49839353  2.75873145 -2.33041413 -4.71851063  2.06724532]]
B2 [[ 0.54081659  0.55062913  0.32309859  0.77071942  0.49830286]]
W3 [[ 0.65598394  1.32520003]
 [-1.44500689  3.21504926]
 [ 1.97492856 -1.62656075]
 [ 4.81887415 -4.11028077]
 [-1.14798614  2.20653262]]
B3 [[ 0.38388758  0.25089584]]
Perception [[  2.57275418e-04   9.99742725e-01]
 [  9.98582643e-01   1.41735744e-03]
 [  4.97163566e-04   9.99502836e-01]
 [  9.98585891e-01   1.41410926e-03]
 [  4.80930572e-04   9.99519069e-01]
 [  9.99878104e-01   1.21895906e-04]
 [  1.72070793e-03   9.98279292e-01]]
W2 [[ 0.98581723  0.77097799  0.43174681  0.44405765  0.11745462]
 [ 0.49569429 -0.87595565  0.21733309  0.91961076 -0.66701011]
 [ 0.23364907 -0.94614176  0.39818114  0.89125948 -0.38589029]
 [ 0.54216801  2.9771831  -2.55257369 -4.96848567  2.24333072]]
B2 [[ 0.53800459  0.50938393  0.34482138  0.8100604   0.46922705]]
W3 [[ 0.60655446  1.3746295 ]
 [-1.65082651  3.42086888]
 [ 2.19036225 -1.84199444]
 [ 5.15769498 -4.4491016 ]
 [-1.31077282  2.3693193 ]]
B3 [[ 0.38395479  0.25082863]]
Perception [[  1.55891584e-04   9.99844108e-01]
 [  9.99104571e-01   8.95429034e-04]
 [  3.03152065e-04   9.99696848e-01]
 [  9.99108526e-01   8.91474194e-04]
 [  2.93737272e-04   9.99706263e-01]
 [  9.99932909e-01   6.70913242e-05]
 [  1.10718113e-03   9.98892819e-01]]
W2 [[ 0.98581723  0.77097799  0.43174681  0.44405765  0.11745462]
 [ 0.49456328 -0.91599953  0.2374615   0.94521842 -0.6956467 ]
 [ 0.23134239 -0.98347664  0.40993345  0.91986118 -0.41850177]
 [ 0.5675332   3.09969579 -2.67213753 -5.10124937  2.34200825]]
B2 [[ 0.53624255  0.48632259  0.35844775  0.83154599  0.45233378]]
W3 [[ 0.57795823  1.40322574]
 [-1.76650088  3.53654325]
 [ 2.31057269 -1.96220488]
 [ 5.34142752 -4.63283414]
 [-1.40281116  2.46135763]]
B3 [[ 0.38388649  0.25089693]]
Perception [[  1.09924016e-04   9.99890076e-01]
 [  9.99351175e-01   6.48824503e-04]
 [  2.14564756e-04   9.99785435e-01]
 [  9.99354784e-01   6.45215807e-04]
 [  2.08064381e-04   9.99791936e-01]
 [  9.99955716e-01   4.42843746e-05]
 [  8.12234773e-04   9.99187765e-01]]
W2 [[ 0.98581723  0.77097799  0.43174681  0.44405765  0.11745462]
 [ 0.49375447 -0.94266092  0.25266099  0.96233182 -0.71519321]
 [ 0.22964026 -1.00836598  0.41954568  0.93901554 -0.44126966]
 [ 0.58548482  3.18465938 -2.75340136 -5.19063646  2.4105747 ]]
B2 [[ 0.53493706  0.47033295  0.36842358  0.84621823  0.44036765]]
W3 [[ 0.55773433  1.42344963]
 [-1.84701171  3.61705408]
 [ 2.39381318 -2.04544537]
 [ 5.46666765 -4.75807427]
 [-1.46711707  2.52566354]]
B3 [[ 0.38379634  0.25098708]]
Perception [[  8.40755128e-05   9.99915924e-01]
 [  9.99493767e-01   5.06233345e-04]
 [  1.64525007e-04   9.99835475e-01]
 [  9.99496942e-01   5.03057995e-04]
 [  1.59608046e-04   9.99840392e-01]
 [  9.99967782e-01   3.22175198e-05]
 [  6.39608816e-04   9.99360391e-01]]
W2 [[ 0.98581723  0.77097799  0.43174681  0.44405765  0.11745462]
 [ 0.49311455 -0.96249391  0.2648657   0.97508723 -0.72997126]
 [ 0.22827659 -1.02688895  0.4275859   0.95329188 -0.45875115]
 [ 0.59939744  3.24953945 -2.81468425 -5.25754223  2.4630563 ]]
B2 [[ 0.53389201  0.45811784  0.37630543  0.85730223  0.43109131]]
W3 [[ 0.54206819  1.43911578]
 [-1.90870637  3.67874874]
 [ 2.45735415 -2.10898634]
 [ 5.56124213 -4.85264875]
 [-1.51653477  2.57508125]]
B3 [[ 0.38370475  0.25107868]]
Perception [[  6.76511633e-05   9.99932349e-01]
 [  9.99586281e-01   4.13719124e-04]
 [  1.32626763e-04   9.99867373e-01]
 [  9.99589073e-01   4.10926957e-04]
 [  1.28693779e-04   9.99871306e-01]
 [  9.99975096e-01   2.49040695e-05]
 [  5.26561104e-04   9.99473439e-01]]
W2 [[ 0.98581723  0.77097799  0.43174681  0.44405765  0.11745462]
 [ 0.49257954 -0.97820985  0.27506062  0.98520644 -0.74181973]
 [ 0.22713156 -1.04156783  0.43447468  0.96460867 -0.47293021]
 [ 0.61076213  3.30192374 -2.8637354  -5.31076113  2.5055286 ]]
B2 [[ 0.53301676  0.44824938  0.38282499  0.86617835  0.42351376]]
W3 [[ 0.52927595  1.45190801]
 [-1.95867767  3.72872004]
 [ 2.50866398 -2.16029617]
 [ 5.63699046 -4.92839708]
 [-1.5566533   2.61519977]]
B3 [[ 0.38361686  0.25116656]]
Perception [[  5.63553741e-05   9.99943645e-01]
 [  9.99650970e-01   3.49030245e-04]
 [  1.10635082e-04   9.99889365e-01]
 [  9.99653442e-01   3.46557742e-04]
 [  1.07368583e-04   9.99892631e-01]
 [  9.99979935e-01   2.00645253e-05]
 [  4.46918643e-04   9.99553081e-01]]
W2 [[ 0.98581723  0.77097799  0.43174681  0.44405765  0.11745462]
 [ 0.49211655 -0.99118392  0.28381339  0.9935662  -0.75168978]
 [ 0.2261404  -1.0536843   0.44049345  0.97394733 -0.48485018]
 [ 0.62037099  3.34579159 -2.90454691 -5.35480412  2.54117434]]
B2 [[ 0.53226158  0.4399799   0.38838598  0.87356265  0.41710775]]
W3 [[ 0.5184636   1.46272036]
 [-2.00064579  3.77068816]
 [ 2.55164825 -2.20328044]
 [ 5.70003402 -4.99144064]
 [-1.59041067  2.64895715]]
B3 [[ 0.38353395  0.25124948]]
Perception [[  4.81429184e-05   9.99951857e-01]
 [  9.99698648e-01   3.01352283e-04]
 [  9.46152793e-05   9.99905385e-01]
 [  9.99700856e-01   2.99144358e-04]
 [  9.18281175e-05   9.99908172e-01]
 [  9.99983341e-01   1.66585854e-05]
 [  3.87849700e-04   9.99612150e-01]]
W2 [[ 0.98581723  0.77097799  0.43174681  0.44405765  0.11745462]
 [ 0.49170632 -1.00220652  0.29148077  1.00067181 -0.76013623]
 [ 0.22526388 -1.06397592  0.44583442  0.98187504 -0.4951276 ]
 [ 0.6286957   3.38348801 -2.93944065 -5.39228478  2.57186856]]
B2 [[ 0.53159607  0.43286915  0.39323519  0.8798735   0.41155898]]
W3 [[ 0.50909895  1.47208501]
 [-2.03680449  3.80684686]
 [ 2.58860409 -2.24023628]
 [ 5.7539419  -5.04534853]
 [-1.61954266  2.67808914]]
B3 [[ 0.38345607  0.25132735]]
Perception [[  4.19211493e-05   9.99958079e-01]
 [  9.99735189e-01   2.64810879e-04]
 [  8.24594314e-05   9.99917541e-01]
 [  9.99737177e-01   2.62823176e-04]
 [  8.00325678e-05   9.99919967e-01]
 [  9.99985850e-01   1.41496966e-05]
 [  3.42333815e-04   9.99657666e-01]]
W2 [[ 0.98581723  0.77097799  0.43174681  0.44405765  0.11745462]
 [ 0.49133661 -1.01177242  0.29830159  1.00684018 -0.7675105 ]
 [ 0.22447637 -1.07290483  0.45063339  0.98874815 -0.50415713]
 [ 0.63604003  3.41651012 -2.9698839  -5.42484897  2.59880848]]
B2 [[ 0.53100026  0.42663619  0.39753467  0.88537608  0.40666489]]
W3 [[ 0.50083933  1.48034463]
 [-2.06855564  3.83859801]
 [ 2.62099604 -2.27262823]
 [ 5.80097356 -5.09238018]
 [-1.64516019  2.70370666]]
B3 [[ 0.38338293  0.25140049]]

In [5]:
import matplotlib.pyplot as plt
plt.plot(x, y)
plt.xlabel("Epoch")
plt.ylabel("Cost")
plt.show()