In [1]:
import tensorflow as tf
from numpy.random import RandomState

1. 定义神经网络的参数,输入和输出节点。


In [2]:
batch_size = 8
w1= tf.Variable(tf.random_normal([2, 3], stddev=1, seed=1))
w2= tf.Variable(tf.random_normal([3, 1], stddev=1, seed=1))
x = tf.placeholder(tf.float32, shape=(None, 2), name="x-input")
y_= tf.placeholder(tf.float32, shape=(None, 1), name='y-input')

2. 定义前向传播过程,损失函数及反向传播算法。


In [3]:
a = tf.matmul(x, w1)
y = tf.matmul(a, w2)
cross_entropy = -tf.reduce_mean(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0))) 
train_step = tf.train.AdamOptimizer(0.001).minimize(cross_entropy)

3. 生成模拟数据集。


In [4]:
rdm = RandomState(1)
X = rdm.rand(128,2)
Y = [[int(x1+x2 < 1)] for (x1, x2) in X]

4. 创建一个会话来运行TensorFlow程序。


In [5]:
with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    
    # 输出目前(未经训练)的参数取值。
    print "w1:", sess.run(w1)
    print "w2:", sess.run(w2)
    print "\n"
    
    # 训练模型。
    STEPS = 5000
    for i in range(STEPS):
        start = (i*batch_size) % 128
        end = (i*batch_size) % 128 + batch_size
        sess.run(train_step, feed_dict={x: X[start:end], y_: Y[start:end]})
        if i % 1000 == 0:
            total_cross_entropy = sess.run(cross_entropy, feed_dict={x: X, y_: Y})
            print("After %d training step(s), cross entropy on all data is %g" % (i, total_cross_entropy))
    
    # 输出训练后的参数取值。
    print "\n"
    print "w1:", sess.run(w1)
    print "w2:", sess.run(w2)


w1: [[-0.81131822  1.48459876  0.06532937]
 [-2.44270396  0.0992484   0.59122431]]
w2: [[-0.81131822]
 [ 1.48459876]
 [ 0.06532937]]


After 0 training step(s), cross entropy on all data is 0.0674925
After 1000 training step(s), cross entropy on all data is 0.0163385
After 2000 training step(s), cross entropy on all data is 0.00907547
After 3000 training step(s), cross entropy on all data is 0.00714436
After 4000 training step(s), cross entropy on all data is 0.00578471


w1: [[-1.9618274   2.58235407  1.68203783]
 [-3.4681716   1.06982327  2.11788988]]
w2: [[-1.8247149 ]
 [ 2.68546653]
 [ 1.41819501]]