In [2]:
import tensorflow as tf
from tensorflow.models.rnn.ptb import reader


---------------------------------------------------------------------------
ImportError                               Traceback (most recent call last)
<ipython-input-2-23fc9fab46d8> in <module>()
      1 import tensorflow as tf
----> 2 from tensorflow.models.rnn.ptb import reader

ImportError: No module named models.rnn.ptb

In [5]:
help(tf.nn.rnn_cell)


Help on module tensorflow.python.ops.rnn_cell in tensorflow.python.ops:

NAME
    tensorflow.python.ops.rnn_cell - Module for constructing RNN Cells.

FILE
    /usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn_cell.py

DESCRIPTION
    ## Base interface for all RNN Cells
    
    @@RNNCell
    
    ## RNN Cells for use with TensorFlow's core RNN methods
    
    @@BasicRNNCell
    @@BasicLSTMCell
    @@GRUCell
    @@LSTMCell
    
    ## Classes storing split `RNNCell` state
    
    @@LSTMStateTuple
    
    ## RNN Cell wrappers (RNNCells that wrap other RNNCells)
    
    @@MultiRNNCell
    @@DropoutWrapper
    @@DeviceWrapper
    @@ResidualWrapper


1. 读取数据并打印长度及前100位数据。


In [6]:
DATA_PATH = "../../datasets/PTB_data"
train_data, valid_data, test_data, _ = reader.ptb_raw_data(DATA_PATH)
print len(train_data)
print train_data[:100]


---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
<ipython-input-6-6e2ffd8e0a41> in <module>()
      1 DATA_PATH = "../../datasets/PTB_data"
----> 2 train_data, valid_data, test_data, _ = reader.ptb_raw_data(DATA_PATH)
      3 print len(train_data)
      4 print train_data[:100]

NameError: name 'reader' is not defined

2. 将训练数据组织成batch大小为4、截断长度为5的数据组。并使用队列读取前3个batch。


In [3]:
# ptb_producer返回的为一个二维的tuple数据。
result = reader.ptb_producer(train_data, 4, 5)

# 通过队列依次读取batch。
with tf.Session() as sess:
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    for i in range(3):
        x, y = sess.run(result)
        print "X%d: "%i, x
        print "Y%d: "%i, y
    coord.request_stop()
    coord.join(threads)


X0:  [[9970 9971 9972 9974 9975]
 [ 332 7147  328 1452 8595]
 [1969    0   98   89 2254]
 [   3    3    2   14   24]]
Y0:  [[9971 9972 9974 9975 9976]
 [7147  328 1452 8595   59]
 [   0   98   89 2254    0]
 [   3    2   14   24  198]]
X1:  [[9976 9980 9981 9982 9983]
 [  59 1569  105 2231    1]
 [   0  312 1641    4 1063]
 [ 198  150 2262   10    0]]
Y1:  [[9980 9981 9982 9983 9984]
 [1569  105 2231    1  895]
 [ 312 1641    4 1063    8]
 [ 150 2262   10    0  507]]
X2:  [[9984 9986 9987 9988 9989]
 [ 895    1 5574    4  618]
 [   8  713    0  264  820]
 [ 507   74 2619    0    1]]
Y2:  [[9986 9987 9988 9989 9991]
 [   1 5574    4  618    2]
 [ 713    0  264  820    2]
 [  74 2619    0    1    8]]