In [1]:
cd ../chapter5/models-master/tutorials/rnn/ptb


/home/wjj/TFbook/chapter5/models-master/tutorials/rnn/ptb

In [2]:
# -*- coding:utf-8 -*- 
import time
import numpy as np
import tensorflow as tf
import reader

In [3]:
class PTBInput(object):
    
    def __init__(self,config,data,name=None):
        self.batch_size = batch_size = config.batch_size    #从config中读取参数存到本地变量
        self.num_steps = num_steps = config.num_steps
        self.epoch_size = (len(data) // batch_size - 1) // num_steps
        self.input_data, self.targets = reader.ptb_producer(data,batch_size,num_steps,name=name)

In [4]:
class PTBModel(object):
    def __init__(self,is_training,config,input_):
        self._input = input_
        
        batch_size = input_.batch_size    #从input_中读取参数存到本地变量
        num_steps = input_.num_steps
        size = config.hidden_size    #从config中读取参数存到本地变量,隐含节点个数
        vocab_size = config.vocab_size
        def lstm_cell():    #使用tf.contrib.rnn.BasicLSTMCell函数设置默认的LSTM单元
            return tf.contrib.rnn.BasicLSTMCell(size,forget_bias=0.0,state_is_tuple=True)    #size是隐含节点个数,forgets_bias是forget gate的bias,
        attn_cell = lstm_cell
        if is_training and config.keep_prob < 1:    #如果是在训练状态且keepprob<1则在前面的lstm_cell之后接一个Dropout层
            def attn_cell():    #使用tf.contrib.rnn.DropoutWrapper函数设置一个dropout层
                return tf.contrib.rnn.DropoutWrapper(lstm_cell(),output_keep_prob=config.keep_prob)
        cell = tf.contrib.rnn.MultiRNNCell([attn_cell() for _ in range(config.num_layers)],state_is_tuple=True)#用tf.contrib.rnn.MultiRNNCell函数堆叠前面构造的lstm_cell

        self._initial_state = cell.zero_state(batch_size, tf.float32)    #设置LSTM单元的初始化状态为0    

        with tf.device("/cpu:0"):
            embedding = tf.get_variable(      #embedding是一个向量, 将one-hot编码格式的单词转化为向量表达形式
                "embedding", [vocab_size,size],dtype=tf.float32)    #vocab_size是词汇表数,每个单词向量表达所需的维数为size   分别构成embedding的行和列
            inputs = tf.nn.embedding_lookup(embedding,input_.input_data)    #查询单词对应的向量表达获得inputs
        if is_training and config.keep_prob <1:    #如果是训练状态,还要在后面加上一层dropout层
            inputs = tf.nn.dropout(inputs,config.keep_prob)

        outputs = []
        state = self._initial_state
        with tf.variable_scope("RNN"):     #将下面的操作设为RNN
            for time_step in range(num_steps):    #设置步数,用来限制梯度在反向传播过程步数
                if time_step > 0: tf.get_variable_scope().reuse_variables()    #第二次循环开始设置复用变量
                (cell_output,state) = cell(inputs[:,time_step,:],state) #inputs的三个维度,第1个代表batch中的第几个样本,第2个代表样本中的第几个单词,第三个代表单词的
                outputs.append(cell_output)

        output = tf.reshape(tf.concat(outputs,1),[-1,size]) 
        softmax_w = tf.get_variable(
                    "softmax_w",[size,vocab_size],dtype=tf.float32)
        softmax_b = tf.get_variable("softmax_b",[vocab_size],dtype=tf.float32)
        logits = tf.matmul(output,softmax_w) + softmax_b
        loss = tf.contrib.legacy_seq2seq.sequence_loss_by_example(   #用这个函数来计算targets和logits的偏差
                [logits],
                [tf.reshape(input_.targets,[-1])],
                [tf.ones([batch_size * num_steps],dtype=tf.float32)])
        self._cost = cost = tf.reduce_sum(loss) / batch_size    #汇总batch的误差,在计算每个样本的误差
        self._final_state = state    #保留最终的状态

        if not is_training:
            return
        self._lr = tf.Variable(0.0,trainable=False)   #定义学习率。且设置为不可训练
        tvars = tf.trainable_variables()    
        grads,_ = tf.clip_by_global_norm(tf.gradients(cost,tvars),  #计算tvars的梯度,设置梯度的最大范数,
                                        config.max_grad_norm)    #这个Gradient Cliping方法,控制梯度的最大范数,某种程度上有正则化的效果,防止梯度爆炸问题
        optimizer = tf.train.GradientDescentOptimizer(self._lr)
        self._train_op = optimizer.apply_gradients(zip(grads,tvars),    #定义一个训练操作,将clip过的梯度应用到所有了训练的参数上
                        global_step=tf.contrib.framework.get_or_create_global_step())    #生成全局统一的训练步数

        self._new_lr = tf.placeholder(
                    tf.float32,shape=[],name="new_learning_rate")   #控制学习速率
        self._lr_update = tf.assign(self._lr,self._new_lr)  #将新的学习速率赋值给当前的学习速率_lr
        
    def assign_lr(self,session,lr_value):
        session.run(self._lr_update,feed_dict={self._new_lr:lr_value})
    
    @property
    def input(self):
        return self._input
    @property
    def initial_state(self):
        return self._initial_state
    
    @property
    def cost(self):
        return self._cost
    
    @property
    def final_state(self):
        return self._final_state
    
    @property
    def lr(self):
        return self._lr
    
    @property
    def train_op(self):
        return self._train_op

In [5]:
class SmallConfig(object):
    init_scale = 0.1
    learning_rate = 1.0
    max_grad_norm = 5
    num_layers = 2
    num_steps = 20
    hidden_size = 200
    max_epoch = 4 
    max_max_epoch = 13
    keep_prob = 1.0
    lr_decay = 0.5
    batch_size = 20
    vocab_size = 10000

In [6]:
class MediumConfig(object):
    init_scale = 0.05
    learning_rate = 1.0
    max_grad_norm = 5
    num_layers = 2
    num_steps = 35
    hidden_size = 650
    max_epoch = 6
    max_max_epoch = 39
    keep_prob = 0.5
    lr_decay = 0.8
    batch_size = 20
    vocab_size = 10000

In [7]:
class LargeConfig(object):
    init_scale = 0.04
    learning_rate = 1.0
    max_grad_norm = 10
    num_layers = 2
    num_steps = 35
    hidden_size = 1500
    max_epoch = 14
    max_max_epoch = 55
    keep_prob = 0.35
    lr_decay = 1 / 1.15
    batch_size = 20
    vocab_size = 10000

In [8]:
class TestConfig(object):
    init_scale = 0.04
    learning_rate = 1.0
    max_grad_norm = 1
    num_layers = 1
    num_steps = 2
    hidden_size = 2
    max_epoch = 1
    max_max_epoch = 1
    keep_prob = 1.0
    lr_decay = 0.5
    batch_size = 20
    vocab_size = 10000

In [9]:
def run_epoch(session,model,eval_op=None,verbose=False):
    start_time = time.time()
    costs = 0.0
    iters = 0
    state = session.run(model.initial_state)    #初始化获得初始状态
    
    fetches = {
        "cost":model.cost,
        "final_state":model.final_state,
    }
    if eval_op is not None:
        fetches["eval_op"] = eval_op    #创建结果的字典表
    
    for step in range(model.input.epoch_size):    #训练epoch_size
        feed_dict = {}
        for i,(c,h) in enumerate(model.initial_state):    #每次把state装入feed_dict
            feed_dict[c] = state[i].c
            feed_dict[h] = state[i].h
        vals = session.run(fetches,feed_dict)   #跑起
        cost = vals["cost"]     #得到cost
        state = vals["final_state"]   #得到state
        
        costs += costs    #累加cost
        iters += model.input.num_steps   #累加迭代次数,
        
        if verbose and step % (model.input.epoch_size // 10) == 10:    #每隔10次做一次展示
            print("%.3f perplexity:%.3f speed: %.0f wps" %
                 (step * 1.0 / model.input.epoch_size,np.exp(costs/iters),
                 iters * model.input.batch_size / (time.time()-start_time)))
                  
    return np.exp(costs / iters)

In [10]:
data_path='/home/wjj/TFbook/chapter7/simple-examples/data/'
raw_data = reader.ptb_raw_data(data_path)     #直接读取解压后的数据
train_data,valid_data,test_data,_ = raw_data    #将解压后的数据分别存为训练数据和验证数据以及测试数据
config = SmallConfig()    #使用SmallConfig的配置
eval_config = SmallConfig()  #测试配置eval_config需和训练配置一致
eval_config.batch_size = 1
eval_config.num_steps = 1

In [12]:



Out[12]:
[102,
 14,
 24,
 32,
 752,
 381,
 2,
 29,
 120,
 0,
 35,
 92,
 60,
 111,
 143,
 32,
 616,
 3148,
 282,
 19,
 0,
 447,
 459,
 438,
 196,
 1621,
 3,
 394,
 90,
 4,
 14,
 7,
 0,
 1106,
 1471,
 14,
 3152,
 1858,
 5,
 1337,
 39,
 1079,
 4,
 6803,
 2,
 57,
 2162,
 4857,
 3845,
 78,
 0,
 522,
 3,
 1037,
 779,
 51,
 74,
 901,
 280,
 117,
 2283,
 5,
 4102,
 0,
 399,
 3866,
 7,
 179,
 149,
 8,
 288,
 2,
 0,
 3,
 60,
 2569,
 365,
 16,
 0,
 129,
 146,
 1022,
 0,
 838,
 8,
 2888,
 4,
 69,
 3222,
 56,
 46,
 3000,
 78,
 0,
 3,
 1037,
 496,
 554,
 79,
 32,
 2444,
 0,
 399,
 844,
 2,
 129,
 145,
 247,
 2066,
 5,
 1213,
 52,
 5,
 0,
 5491,
 5,
 414,
 0,
 8654,
 1022,
 280,
 17,
 360,
 129,
 4234,
 4,
 60,
 280,
 117,
 2,
 772,
 399,
 4,
 739,
 82,
 890,
 9,
 3919,
 216,
 288,
 7,
 482,
 1,
 2770,
 149,
 3006,
 2,
 914,
 129,
 146,
 149,
 408,
 2981,
 3564,
 6705,
 3953,
 180,
 7930,
 1684,
 1973,
 8,
 691,
 5852,
 96,
 1898,
 77,
 8,
 576,
 5807,
 2,
 0,
 1,
 30,
 292,
 2334,
 2,
 0,
 621,
 47,
 24,
 1,
 2,
 496,
 554,
 0,
 2198,
 46,
 63,
 583,
 5,
 2444,
 0,
 6470,
 16,
 0,
 1022,
 4,
 0,
 35,
 92,
 60,
 111,
 15,
 2994,
 1,
 444,
 232,
 70,
 18,
 1,
 125,
 584,
 2,
 1,
 557,
 1,
 141,
 4,
 2198,
 5415,
 997,
 80,
 14,
 13,
 1547,
 5,
 117,
 0,
 2569,
 13,
 32,
 839,
 49,
 789,
 2,
 67,
 0,
 332,
 13,
 7,
 6,
 1,
 113,
 876,
 247,
 562,
 32,
 1154,
 14,
 2,
 3228,
 26,
 2374,
 11,
 6,
 3465,
 4,
 1856,
 10,
 13,
 63,
 83,
 7,
 0,
 47,
 2,
 97,
 161,
 586,
 8,
 57,
 280,
 50,
 292,
 696,
 51,
 1150,
 266,
 282,
 1730,
 16,
 6,
 5044,
 6302,
 272,
 76,
 0,
 60,
 47,
 24,
 105,
 2748,
 2,
 223,
 7,
 6,
 1,
 1227,
 0,
 447,
 459,
 2614,
 7,
 3152,
 31,
 1471,
 8448,
 43,
 6,
 277,
 4,
 51,
 552,
 39,
 40,
 1,
 52,
 6,
 8570,
 36,
 3,
 3,
 252,
 16,
 0,
 272,
 7,
 1,
 77,
 395,
 2,
 1,
 77,
 7240,
 5,
 3,
 21,
 71,
 6,
 439,
 11,
 0,
 129,
 146,
 2,
 18,
 0,
 235,
 4,
 0,
 272,
 3,
 21,
 71,
 46,
 926,
 2,
 0,
 447,
 459,
 2614,
 248,
 18,
 3,
 2,
 0,
 447,
 9,
 460,
 24,
 334,
 7,
 374,
 520,
 86,
 5,
 0,
 1,
 752,
 381,
 1037,
 10,
 2868,
 313,
 3,
 3,
 2,
 7,
 760,
 520,
 227,
 0,
 447,
 9,
 8018,
 24,
 0,
 1,
 821,
 8,
 0,
 9912,
 155,
 0,
 47,
 202,
 3,
 36,
 3,
 3,
 6,
 123,
 78,
 752,
 381,
 2,
 0,
 447,
 202,
 3,
 3,
 16,
 752,
 381,
 2,
 71,
 4,
 408,
 0,
 729,
 4,
 561,
 698,
 46,
 1965,
 1024,
 73,
 272,
 282,
 7089,
 5,
 309,
 8,
 1918,
 43,
 0,
 497,
 12,
 3,
 48,
 588,
 4,
 0,
 763,
 17,
 31,
 1,
 96,
 2,
 330,
 321,
 9,
 7757,
 3228,
 36,
 645,
 5903,
 50,
 1686,
 4052,
 380,
 9608,
 10,
 6,
 420,
 42,
 2280,
 8,
 408,
 60,
 42,
 499,
 2,
 18,
 3,
 2187,
 4686,
 544,
 0,
 1,
 309,
 0,
 129,
 146,
 24,
 1,
 77,
 7,
 408,
 1499,
 309,
 2,
 16,
 0,
 111,
 1022,
 19,
 640,
 19,
 408,
 1898,
 77,
 64,
 1,
 11,
 6,
 3866,
 15,
 54,
 524,
 1022,
 1281,
 2,
 249,
 280,
 79,
 25,
 919,
 7138,
 51,
 2540,
 67,
 0,
 309,
 1,
 2,
 11,
 422,
 0,
 47,
 50,
 58,
 3106,
 43,
 3328,
 78,
 1672,
 65,
 9,
 267,
 4662,
 6219,
 287,
 43,
 0,
 2192,
 11,
 506,
 1327,
 1148,
 3328,
 2,
 8,
 3,
 1495,
 78,
 0,
 408,
 77,
 2671,
 544,
 309,
 10,
 0,
 408,
 96,
 79,
 32,
 188,
 604,
 11,
 27,
 257,
 2,
 18,
 39,
 374,
 0,
 447,
 24,
 118,
 43,
 3,
 394,
 2,
 0,
 47,
 1,
 2,
 5903,
 79,
 32,
 5110,
 51,
 408,
 60,
 29,
 38,
 5812,
 873,
 4,
 440,
 538,
 6202,
 60,
 38,
 50,
 2,
 11,
 471,
 51,
 399,
 988,
 77,
 9888,
 5,
 25,
 1993,
 7,
 2044,
 96,
 41,
 248,
 118,
 3,
 3,
 5,
 3,
 3,
 2790,
 450,
 805,
 41,
 202,
 3,
 3,
 5,
 3,
 3,
 8,
 1,
 553,
 41,
 4998,
 3,
 5,
 3,
 3,
 2,
 144,
 149,
 1413,
 8361,
 2,
 29,
 19,
 3866,
 1439,
 3228,
 542,
 5,
 184,
 2719,
 149,
 88,
 19,
 1684,
 1973,
 8,
 197,
 85,
 753,
 5,
 1195,
 51,
 392,
 2,
 67,
 77,
 24,
 4124,
 7,
 1684,
 1973,
 0,
 60,
 24,
 77,
 18,
 3,
 118,
 3,
 3,
 120,
 688,
 248,
 3,
 3,
 230,
 18,
 3,
 2,
 399,
 1,
 76,
 4,
 5575,
 4,
 4625,
 7733,
 455,
 41,
 26,
 2363,
 17,
 224,
 67,
 112,
 616,
 5,
 425,
 790,
 2,
 90,
 4,
 0,
 60,
 399,
 844,
 544,
 20,
 330,
 321,
 3623,
 210,
 5958,
 156,
 280,
 2,
 280,
 15,
 90,
 4,
 51,
 137,
 1255,
 116,
 16,
 0,
 61,
 1051,
 3889,
 3134,
 2,
 99,
 18,
 3,
 54,
 4,
 0,
 47,
 9,
 9268,
 2320,
 453,
 841,
 19,
 0,
 833,
 3,
 288,
 315,
 50,
 1621,
 3,
 394,
 2438,
 5,
 384,
 6,
 1,
 470,
 7,
 0,
 447,
 2614,
 2,
 124,
 31,
 331,
 1198,
 17,
 0,
 129,
 146,
 8,
 0,
 482,
 3201,
 111,
 77,
 24,
 2152,
 4124,
 7,
 482,
 2,
 78,
 0,
 77,
 2671,
 7,
 0,
 833,
 3,
 6160,
 7,
 482,
 5575,
 4,
 399,
 608,
 5,
 775,
 149,
 873,
 16,
 0,
 129,
 146,
 8,
 2198,
 608,
 5,
 1,
 112,
 118,
 2,
 19,
 6,
 410,
 0,
 3288,
 178,
 0,
 288,
 8,
 60,
 190,
 1,
 3148,
 2,
 349,
 0,
 1,
 4,
 974,
 288,
 0,
 6766,
 4,
 251,
 280,
 316,
 0,
 1267,
 60,
 47,
 13,
 1654,
 97,
 280,
 46,
 4224,
 5,
 457,
 60,
 112,
 1185,
 16,
 0,
 129,
 146,
 2,
 0,
 288,
 2671,
 24,
 113,
 1,
 17,
 129,
 146,
 1022,
 280,
 2,
 14,
 1,
 607,
 52,
 15,
 54,
 137,
 2569,
 2,
 39,
 4246,
 2853,
 4124,
 54,
 942,
 4,
 156,
 77,
 60,
 216,
 975,
 10,
 1101,
 4312,
 0,
 288,
 8,
 60,
 190,
 8,
 30,
 58,
 3065,
 17,
 57,
 11,
 0,
 47,
 9,
 129,
 2042,
 2,
 7,
 6,
 974,
 975,
 184,
 156,
 280,
 182,
 36,
 184,
 129,
 3795,
 4,
 149,
 8,
 1195,
 0,
 218,
 7,
 288,
 5,
 4938,
 7,
 6,
 115,
 1568,
 2,
 67,
 0,
 763,
 529,
 544,
 148,
 14,
 1,
 538,
 1794,
 64,
 50,
 11,
 0,
 3029,
 15,
 6,
 1275,
 276,
 18,
 54,
 4,
 0,
 413,
 2875,
 ...]

In [ ]:
with tf.Graph().as_default():
    initializer = tf.random_uniform_initializer(-config.init_scale,config.init_scale)     #
    
    with tf.name_scope("Train"):
        train_input = PTBInput(config=config,data=train_data,name="TrainInput")
        with tf.variable_scope("Model",reuse = None, initializer=initializer):
            m = PTBModel(is_training=True,config=config,input_=train_input)
    
    with tf.name_scope("Valid"):
        valid_input = PTBInput(config=config,data=valid_data,name="ValidInput")
        with tf.variable_scope("Model",reuse = True, initializer=initializer):
            mvalid = PTBModel(is_training=False,config=config,input_=valid_input)
            
    with tf.name_scope("Test"):
        test_input = PTBInput(config=config,data=test_data,name="TestInput")
        with tf.variable_scope("Model",reuse = True, initializer=initializer):
            mtest = PTBModel(is_training=False,config=config,input_=test_input)
    sv = tf.train.Supervisor()
    with sv.managed_session() as session:
        for i in range(config.max_max_epoch):
            lr_decay = config.lr_decay ** max(i+1-config.max_epoch,0.0)
            m.assign_lr(session,config.learning_rate * lr_decay)
            
            print("Epoch: %d Learning rate: %.3f" % (i+1,session.run(m.lr)))
            
            train_perplexity = run_epoch(session,m,eval_op=m.train_op,verbose=True)
            print("Epoch: %d Train Perplexity: %.3f" % (i+1,train_perplexity))
            
            valid_perplexity = run_epoch(session,mvalid)
            print("Epoch: %d Valid Perplexity: %.3f" % (i+1,valid_perplexity))
            
        test_perplexity = run_epoch(session,mtest)
        print("Test Perplexity: %.3f" % test_perplexity)


WARNING:tensorflow:Standard services need a 'logdir' passed to the SessionManager
Epoch: 1 Learning rate: 1.000
0.004 perplexity:1.000 speed: 11516 wps
0.104 perplexity:1.000 speed: 31775 wps
0.204 perplexity:1.000 speed: 32989 wps
0.304 perplexity:1.000 speed: 33488 wps
0.404 perplexity:1.000 speed: 33744 wps
0.504 perplexity:1.000 speed: 33889 wps
0.604 perplexity:1.000 speed: 34009 wps
0.703 perplexity:1.000 speed: 34088 wps
0.803 perplexity:1.000 speed: 34140 wps
0.903 perplexity:1.000 speed: 34185 wps
Epoch: 1 Train Perplexity: 1.000
Epoch: 1 Valid Perplexity: 1.000
Epoch: 2 Learning rate: 1.000
0.004 perplexity:1.000 speed: 34678 wps
0.104 perplexity:1.000 speed: 34634 wps
0.204 perplexity:1.000 speed: 34644 wps
0.304 perplexity:1.000 speed: 34668 wps
0.404 perplexity:1.000 speed: 34646 wps
0.504 perplexity:1.000 speed: 34649 wps
0.604 perplexity:1.000 speed: 34659 wps
0.703 perplexity:1.000 speed: 34641 wps
0.803 perplexity:1.000 speed: 34634 wps
0.903 perplexity:1.000 speed: 34636 wps
Epoch: 2 Train Perplexity: 1.000
Epoch: 2 Valid Perplexity: 1.000
Epoch: 3 Learning rate: 1.000
0.004 perplexity:1.000 speed: 34733 wps
0.104 perplexity:1.000 speed: 34615 wps
0.204 perplexity:1.000 speed: 34540 wps
0.304 perplexity:1.000 speed: 34582 wps
0.404 perplexity:1.000 speed: 34597 wps
0.504 perplexity:1.000 speed: 34559 wps
0.604 perplexity:1.000 speed: 34579 wps
0.703 perplexity:1.000 speed: 34594 wps
0.803 perplexity:1.000 speed: 34621 wps
0.903 perplexity:1.000 speed: 34637 wps
Epoch: 3 Train Perplexity: 1.000
Epoch: 3 Valid Perplexity: 1.000
Epoch: 4 Learning rate: 1.000
0.004 perplexity:1.000 speed: 34932 wps
0.104 perplexity:1.000 speed: 34754 wps
0.204 perplexity:1.000 speed: 34693 wps
0.304 perplexity:1.000 speed: 34627 wps
0.404 perplexity:1.000 speed: 34663 wps
0.504 perplexity:1.000 speed: 34685 wps
0.604 perplexity:1.000 speed: 34690 wps
0.703 perplexity:1.000 speed: 34698 wps
0.803 perplexity:1.000 speed: 34704 wps

In [ ]:


In [ ]:


In [ ]:


In [ ]:


In [ ]: