In [1]:
import time
import numpy as np
import tensorflow as tf
from tensorflow.python.ops.rnn_cell import BasicLSTMCell, MultiRNNCell, DropoutWrapper

from tensorflow.models.rnn.ptb import reader

Small Config 정보를 사용합니다.


In [2]:
class SmallConfig(object):
    """Small config."""
    init_scale = 0.1
    learning_rate = 1.0
    max_grad_norm = 5
    num_layers = 2
    num_steps = 20
    hidden_size = 200
    max_epoch = 4
    max_max_epoch = 13
    keep_prob = 1.0
    lr_decay = 0.5
    batch_size = 20
    vocab_size = 10000

트레이닝과 테스트에 사용할 두개의 config 오브젝트를 만듭니다.


In [3]:
config = SmallConfig()
eval_config = SmallConfig()
eval_config.batch_size = 1
eval_config.num_steps = 1

PTB 모델을 만들어 주는 클래스를 작성합니다.


In [4]:
class PTBModel(object):
    """The PTB model."""

    def __init__(self, config, is_training=False):
        self.batch_size = config.batch_size
        self.num_steps = config.num_steps
        input_size = [config.batch_size, config.num_steps]
        self.input_data = tf.placeholder(tf.int32, input_size)
        self.targets = tf.placeholder(tf.int32, input_size)

        lstm_cell = BasicLSTMCell(config.hidden_size, forget_bias=0.0, state_is_tuple=True)
        if is_training and config.keep_prob < 1:
            lstm_cell = DropoutWrapper(lstm_cell, config.keep_prob)
        cell = MultiRNNCell([lstm_cell] * config.num_layers, state_is_tuple=True)

        self.initial_state = cell.zero_state(config.batch_size, tf.float32)

        with tf.device("/cpu:0"):
            embedding_size = [config.vocab_size, config.hidden_size]
            embedding = tf.get_variable("embedding", embedding_size)
            inputs = tf.nn.embedding_lookup(embedding, self.input_data)

        if is_training and config.keep_prob < 1:
            inputs = tf.nn.dropout(inputs, config.keep_prob)

        outputs = []
        state = self.initial_state
        with tf.variable_scope("RNN"):
            for time_step in range(config.num_steps):
                if time_step > 0: tf.get_variable_scope().reuse_variables()
                (cell_output, state) = cell(inputs[:, time_step, :], state)
                outputs.append(cell_output)

        output = tf.reshape(tf.concat(1, outputs), [-1, config.hidden_size])
        softmax_w_size = [config.hidden_size, config.vocab_size]
        softmax_w = tf.get_variable("softmax_w", softmax_w_size)
        softmax_b = tf.get_variable("softmax_b", [config.vocab_size])
        logits = tf.matmul(output, softmax_w) + softmax_b
    
        loss = tf.nn.seq2seq.sequence_loss_by_example(
            [logits],
            [tf.reshape(self.targets, [-1])],
            [tf.ones([config.batch_size * config.num_steps])])
        self.cost = tf.reduce_sum(loss) / config.batch_size
        self.final_state = state

        if not is_training:
            return

        self.lr = tf.Variable(0.0, trainable=False)
        tvars = tf.trainable_variables()
        grads, _ = tf.clip_by_global_norm(tf.gradients(self.cost, tvars),
                                          config.max_grad_norm)
        optimizer = tf.train.GradientDescentOptimizer(self.lr)
        self.train_op = optimizer.apply_gradients(zip(grads, tvars))

    def assign_lr(self, session, lr_value):
        session.run(tf.assign(self.lr, lr_value))

에포크를 처리할 함수를 만듭니다.


In [5]:
def run_epoch(session, m, data, is_training=False):
    """Runs the model on the given data."""
    epoch_size = ((len(data) // m.batch_size) - 1) // m.num_steps
    start_time = time.time()
    costs = 0.0
    iters = 0
    
    eval_op = m.train_op if is_training else tf.no_op()
    
    state_list = []
    for c, h in m.initial_state:
        state_list.extend([c.eval(), h.eval()])
    
    ptb_iter = reader.ptb_iterator(data, m.batch_size, m.num_steps)
    for step, (x, y) in enumerate(ptb_iter):
        fetch_list = [m.cost]
        for c, h in m.final_state:
            fetch_list.extend([c, h])
        fetch_list.append(eval_op)
        
        feed_dict = {m.input_data: x, m.targets: y}
        for i in range(len(m.initial_state)):
            c, h = m.initial_state[i]
            feed_dict[c], feed_dict[h] = state_list[i*2:(i+1)*2]
        
        cost, *state_list, _ = session.run(fetch_list, feed_dict)

        costs += cost
        iters += m.num_steps

        if is_training and step % (epoch_size // 10) == 10:
            print("%.3f perplexity: %.3f speed: %.0f wps" %
                    (step * 1.0 / epoch_size, np.exp(costs / iters),
                     iters * m.batch_size / (time.time() - start_time)))

    return np.exp(costs / iters)

In [6]:
raw_data = reader.ptb_raw_data('simple-examples/data')
train_data, valid_data, test_data, _ = raw_data

train_data, valid_data, test_data 는 단어를 숫자로 바꾼 리스트입니다.
가장 많이 나온 단어 순으로 0번 부터 시작하여 10000번 까지의 번호를 가지고 있습니다.


In [7]:
with tf.Graph().as_default(), tf.Session() as session:
    initializer = tf.random_uniform_initializer(-config.init_scale, config.init_scale)

    with tf.variable_scope("model", reuse=None, initializer=initializer):
        m = PTBModel(config, is_training=True)
    with tf.variable_scope("model", reuse=True, initializer=initializer):
        mvalid = PTBModel(config)
        mtest = PTBModel(eval_config)
        
    tf.initialize_all_variables().run()
    
    for i in range(config.max_max_epoch):
        lr_decay = config.lr_decay ** max(i - config.max_epoch, 0.0)
        m.assign_lr(session, config.learning_rate * lr_decay)
        print("Epoch: %d Learning rate: %.3f" % (i + 1, session.run(m.lr)))
        
        perplexity = run_epoch(session, m, train_data, is_training=True)
        print("Epoch: %d Train Perplexity: %.3f" % (i + 1, perplexity))

        perplexity = run_epoch(session, mvalid, valid_data)
        print("Epoch: %d Valid Perplexity: %.3f" % (i + 1, perplexity))

    perplexity = run_epoch(session, mtest, test_data)
    print("Test Perplexity: %.3f" % perplexity)


Epoch: 1 Learning rate: 1.000
0.004 perplexity: 5623.590 speed: 1776 wps
0.104 perplexity: 837.693 speed: 1834 wps
0.204 perplexity: 621.171 speed: 1819 wps
0.304 perplexity: 502.748 speed: 1811 wps
0.404 perplexity: 433.921 speed: 1813 wps
0.504 perplexity: 388.876 speed: 1812 wps
0.604 perplexity: 350.280 speed: 1811 wps
0.703 perplexity: 323.886 speed: 1809 wps
0.803 perplexity: 303.027 speed: 1803 wps
0.903 perplexity: 283.861 speed: 1800 wps
Epoch: 1 Train Perplexity: 269.491
Epoch: 1 Valid Perplexity: 178.930
Epoch: 2 Learning rate: 1.000
0.004 perplexity: 212.452 speed: 1776 wps
0.104 perplexity: 151.408 speed: 1746 wps
0.204 perplexity: 158.838 speed: 1739 wps
0.304 perplexity: 153.748 speed: 1731 wps
0.404 perplexity: 150.825 speed: 1719 wps
0.504 perplexity: 148.415 speed: 1698 wps
0.604 perplexity: 143.767 speed: 1669 wps
0.703 perplexity: 141.649 speed: 1647 wps
0.803 perplexity: 139.754 speed: 1630 wps
0.903 perplexity: 136.156 speed: 1614 wps
Epoch: 2 Train Perplexity: 134.051
Epoch: 2 Valid Perplexity: 144.795
Epoch: 3 Learning rate: 1.000
0.004 perplexity: 144.703 speed: 1482 wps
0.104 perplexity: 106.289 speed: 1454 wps
0.204 perplexity: 115.377 speed: 1434 wps
0.304 perplexity: 112.380 speed: 1426 wps
0.404 perplexity: 111.319 speed: 1426 wps
0.504 perplexity: 110.437 speed: 1426 wps
0.604 perplexity: 107.767 speed: 1426 wps
0.703 perplexity: 107.108 speed: 1428 wps
0.803 perplexity: 106.567 speed: 1429 wps
0.903 perplexity: 104.342 speed: 1431 wps
Epoch: 3 Train Perplexity: 103.330
Epoch: 3 Valid Perplexity: 133.454
Epoch: 4 Learning rate: 1.000
0.004 perplexity: 116.389 speed: 1426 wps
0.104 perplexity: 86.294 speed: 1457 wps
0.204 perplexity: 94.770 speed: 1457 wps
0.304 perplexity: 92.424 speed: 1454 wps
0.404 perplexity: 91.953 speed: 1456 wps
0.504 perplexity: 91.528 speed: 1457 wps
0.604 perplexity: 89.621 speed: 1462 wps
0.703 perplexity: 89.398 speed: 1463 wps
0.803 perplexity: 89.187 speed: 1463 wps
0.903 perplexity: 87.572 speed: 1464 wps
Epoch: 4 Train Perplexity: 86.974
Epoch: 4 Valid Perplexity: 127.934
Epoch: 5 Learning rate: 1.000
0.004 perplexity: 99.530 speed: 1467 wps
0.104 perplexity: 74.640 speed: 1498 wps
0.204 perplexity: 82.279 speed: 1490 wps
0.304 perplexity: 80.323 speed: 1492 wps
0.404 perplexity: 80.090 speed: 1493 wps
0.504 perplexity: 79.886 speed: 1495 wps
0.604 perplexity: 78.470 speed: 1495 wps
0.703 perplexity: 78.460 speed: 1496 wps
0.803 perplexity: 78.449 speed: 1497 wps
0.903 perplexity: 77.164 speed: 1497 wps
Epoch: 5 Train Perplexity: 76.807
Epoch: 5 Valid Perplexity: 128.210
Epoch: 6 Learning rate: 0.500
0.004 perplexity: 88.207 speed: 1551 wps
0.104 perplexity: 65.028 speed: 1504 wps
0.204 perplexity: 70.426 speed: 1499 wps
0.304 perplexity: 67.778 speed: 1497 wps
0.404 perplexity: 66.742 speed: 1496 wps
0.504 perplexity: 65.888 speed: 1499 wps
0.604 perplexity: 63.992 speed: 1502 wps
0.703 perplexity: 63.296 speed: 1504 wps
0.803 perplexity: 62.584 speed: 1505 wps
0.903 perplexity: 60.883 speed: 1504 wps
Epoch: 6 Train Perplexity: 59.966
Epoch: 6 Valid Perplexity: 122.017
Epoch: 7 Learning rate: 0.250
0.004 perplexity: 71.439 speed: 1465 wps
0.104 perplexity: 53.845 speed: 1427 wps
0.204 perplexity: 58.533 speed: 1415 wps
0.304 perplexity: 56.148 speed: 1411 wps
0.404 perplexity: 55.253 speed: 1414 wps
0.504 perplexity: 54.451 speed: 1437 wps
0.604 perplexity: 52.781 speed: 1451 wps
0.703 perplexity: 52.078 speed: 1456 wps
0.803 perplexity: 51.342 speed: 1460 wps
0.903 perplexity: 49.769 speed: 1462 wps
Epoch: 7 Train Perplexity: 48.860
Epoch: 7 Valid Perplexity: 122.090
Epoch: 8 Learning rate: 0.125
0.004 perplexity: 63.181 speed: 1454 wps
0.104 perplexity: 47.797 speed: 1505 wps
0.204 perplexity: 52.037 speed: 1499 wps
0.304 perplexity: 49.854 speed: 1457 wps
0.404 perplexity: 49.044 speed: 1465 wps
0.504 perplexity: 48.303 speed: 1466 wps
0.604 perplexity: 46.791 speed: 1472 wps
0.703 perplexity: 46.123 speed: 1471 wps
0.803 perplexity: 45.409 speed: 1472 wps
0.903 perplexity: 43.942 speed: 1478 wps
Epoch: 8 Train Perplexity: 43.073
Epoch: 8 Valid Perplexity: 123.251
Epoch: 9 Learning rate: 0.062
0.004 perplexity: 59.041 speed: 1574 wps
0.104 perplexity: 44.766 speed: 1565 wps
0.204 perplexity: 48.769 speed: 1569 wps
0.304 perplexity: 46.702 speed: 1574 wps
0.404 perplexity: 45.921 speed: 1575 wps
0.504 perplexity: 45.215 speed: 1577 wps
0.604 perplexity: 43.784 speed: 1578 wps
0.703 perplexity: 43.131 speed: 1579 wps
0.803 perplexity: 42.433 speed: 1581 wps
0.903 perplexity: 41.024 speed: 1582 wps
Epoch: 9 Train Perplexity: 40.176
Epoch: 9 Valid Perplexity: 123.995
Epoch: 10 Learning rate: 0.031
0.004 perplexity: 56.746 speed: 1545 wps
0.104 perplexity: 43.123 speed: 1546 wps
0.204 perplexity: 47.044 speed: 1526 wps
0.304 perplexity: 45.051 speed: 1521 wps
0.404 perplexity: 44.293 speed: 1521 wps
0.504 perplexity: 43.604 speed: 1519 wps
0.604 perplexity: 42.216 speed: 1519 wps
0.703 perplexity: 41.571 speed: 1519 wps
0.803 perplexity: 40.874 speed: 1519 wps
0.903 perplexity: 39.494 speed: 1519 wps
Epoch: 10 Train Perplexity: 38.660
Epoch: 10 Valid Perplexity: 124.196
Epoch: 11 Learning rate: 0.016
0.004 perplexity: 55.580 speed: 1552 wps
0.104 perplexity: 42.193 speed: 1520 wps
0.204 perplexity: 46.077 speed: 1515 wps
0.304 perplexity: 44.133 speed: 1514 wps
0.404 perplexity: 43.392 speed: 1516 wps
0.504 perplexity: 42.716 speed: 1516 wps
0.604 perplexity: 41.353 speed: 1517 wps
0.703 perplexity: 40.719 speed: 1518 wps
0.803 perplexity: 40.023 speed: 1518 wps
0.903 perplexity: 38.660 speed: 1519 wps
Epoch: 11 Train Perplexity: 37.835
Epoch: 11 Valid Perplexity: 123.893
Epoch: 12 Learning rate: 0.008
0.004 perplexity: 54.922 speed: 1507 wps
0.104 perplexity: 41.660 speed: 1517 wps
0.204 perplexity: 45.516 speed: 1518 wps
0.304 perplexity: 43.603 speed: 1518 wps
0.404 perplexity: 42.872 speed: 1518 wps
0.504 perplexity: 42.204 speed: 1518 wps
0.604 perplexity: 40.857 speed: 1515 wps
0.703 perplexity: 40.235 speed: 1512 wps
0.803 perplexity: 39.543 speed: 1512 wps
0.903 perplexity: 38.191 speed: 1513 wps
Epoch: 12 Train Perplexity: 37.372
Epoch: 12 Valid Perplexity: 123.478
Epoch: 13 Learning rate: 0.004
0.004 perplexity: 54.537 speed: 1521 wps
0.104 perplexity: 41.354 speed: 1515 wps
0.204 perplexity: 45.195 speed: 1511 wps
0.304 perplexity: 43.301 speed: 1513 wps
0.404 perplexity: 42.580 speed: 1513 wps
0.504 perplexity: 41.916 speed: 1514 wps
0.604 perplexity: 40.579 speed: 1514 wps
0.703 perplexity: 39.964 speed: 1515 wps
0.803 perplexity: 39.277 speed: 1514 wps
0.903 perplexity: 37.932 speed: 1515 wps
Epoch: 13 Train Perplexity: 37.117
Epoch: 13 Valid Perplexity: 123.184
Test Perplexity: 118.012

In [ ]: