In [13]:
import sys, os
import torch as t
from torch import nn
from torch.autograd import Variable
import tqdm
import numpy as np
import torch.nn.functional as F
from torch.utils.data import DataLoader

class Config(object):
    data_path = 'data/'  # 诗歌的文本文件存放路径
    pickle_path = 'D:/project/ml/test/rnn/input/tang.npz'  # 预处理好的二进制文件
    author = None  # 只学习某位作者的诗歌
    constrain = None  # 长度限制
    category = 'poet.tang'  # 类别,唐诗还是宋诗歌(poet.song)
    lr = 1e-3
    weight_decay = 1e-4
    use_gpu = False
    epoch = 1
    batch_size = 128
    maxlen = 125  # 超过这个长度的之后字被丢弃,小于这个长度的在前面补空格
    plot_every = 20  # 每20个batch 可视化一次
    # use_env = True # 是否使用visodm
    env = 'poetry'  # visdom env
    max_gen_len = 200  # 生成诗歌最长长度
    debug_file = '/tmp/debugp'
    model_path = 'D:/project/ml/data/tang_199.pth'  # 预训练模型路径
    prefix_words = '细雨鱼儿出,微风燕子斜。'  # 不是诗歌的组成部分,用来控制生成诗歌的意境
    start_words = '闲云潭影日悠悠'  # 诗歌开始
    acrostic = False  # 是否是藏头诗
    model_prefix = 'checkpoints/tang'  # 模型保存路径

opt = Config()

class PoetryModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super(PoetryModel, self).__init__()
        self.hidden_dim = hidden_dim
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, self.hidden_dim, num_layers=2)
        self.linear1 = nn.Linear(self.hidden_dim, vocab_size)

    def forward(self, input, hidden=None):
        seq_len, batch_size = input.size()
        if hidden is None:
            #  h_0 = 0.01*torch.Tensor(2, batch_size, self.hidden_dim).normal_().cuda()
            #  c_0 = 0.01*torch.Tensor(2, batch_size, self.hidden_dim).normal_().cuda()
            h_0 = input.data.new(2, batch_size, self.hidden_dim).fill_(0).float()
            c_0 = input.data.new(2, batch_size, self.hidden_dim).fill_(0).float()
        else:
            h_0, c_0 = hidden
        # size: (seq_len,batch_size,embeding_dim)
        embeds = self.embeddings(input)
        # output size: (seq_len,batch_size,hidden_dim)
        output, hidden = self.lstm(embeds, (h_0, c_0))

        # size: (seq_len*batch_size,vocab_size)
        output = self.linear1(output.view(seq_len * batch_size, -1))
        return output, hidden
    
def generate(model, start_words, ix2word, word2ix, prefix_words=None):
    """
    给定几个词,根据这几个词接着生成一首完整的诗歌
    start_words:u'春江潮水连海平'
    比如start_words 为 春江潮水连海平,可以生成:

    """
    
    results = list(start_words)
    start_word_len = len(start_words)
    # 手动设置第一个词为<START>
    input = t.Tensor([word2ix['<START>']]).view(1, 1).long()
    if opt.use_gpu: input = input.cuda()
    hidden = None

    if prefix_words:
        for word in prefix_words:
            output, hidden = model(input, hidden)
            input = input.data.new([word2ix[word]]).view(1, 1)

    for i in range(opt.max_gen_len):
        output, hidden = model(input, hidden)

        if i < start_word_len:
            w = results[i]
            input = input.data.new([word2ix[w]]).view(1, 1)
        else:
            top_index = output.data[0].topk(1)[1][0].item()
            w = ix2word[top_index]
            results.append(w)
            input = input.data.new([top_index]).view(1, 1)
        if w == '<EOP>':
            del results[-1]
            break
    return results


def gen_acrostic(model, start_words, ix2word, word2ix, prefix_words=None):
    """
    生成藏头诗
    start_words : u'深度学习'
    生成:
    深木通中岳,青苔半日脂。
    度山分地险,逆浪到南巴。
    学道兵犹毒,当时燕不移。
    习根通古岸,开镜出清羸。
    """
    results = []
    start_word_len = len(start_words)
    input = (t.Tensor([word2ix['<START>']]).view(1, 1).long())
    if opt.use_gpu: input = input.cuda()
    hidden = None

    index = 0  # 用来指示已经生成了多少句藏头诗
    # 上一个词
    pre_word = '<START>'

    if prefix_words:
        for word in prefix_words:
            output, hidden = model(input, hidden)
            input = (input.data.new([word2ix[word]])).view(1, 1)

    for i in range(opt.max_gen_len):
        output, hidden = model(input, hidden)
        top_index = output.data[0].topk(1)[1][0].item()
        w = ix2word[top_index]

        if (pre_word in {u'。', u'!', '<START>'}):
            # 如果遇到句号,藏头的词送进去生成

            if index == start_word_len:
                # 如果生成的诗歌已经包含全部藏头的词,则结束
                break
            else:
                # 把藏头的词作为输入送入模型
                w = start_words[index]
                index += 1
                input = (input.data.new([word2ix[w]])).view(1, 1)
        else:
            # 否则的话,把上一次预测是词作为下一个词输入
            input = (input.data.new([word2ix[w]])).view(1, 1)
        results.append(w)
        pre_word = w
    return results

def gen(**kwargs):
    """
    提供命令行接口,用以生成相应的诗
    """

    for k, v in kwargs.items():
        setattr(opt, k, v)
    data, word2ix, ix2word = get_data(opt)
    model = PoetryModel(len(word2ix), 128, 256);
    map_location = lambda s, l: s
    state_dict = t.load(opt.model_path, map_location=map_location)
    model.load_state_dict(state_dict)

    if opt.use_gpu:
        model.cuda()

    # python2和python3 字符串兼容
    if sys.version_info.major == 3:
        if opt.start_words.isprintable():
            start_words = opt.start_words
            prefix_words = opt.prefix_words if opt.prefix_words else None
        else:
            start_words = opt.start_words.encode('ascii', 'surrogateescape').decode('utf8')
            prefix_words = opt.prefix_words.encode('ascii', 'surrogateescape').decode(
                'utf8') if opt.prefix_words else None
    else:
        start_words = opt.start_words.decode('utf8')
        prefix_words = opt.prefix_words.decode('utf8') if opt.prefix_words else None

    start_words = start_words.replace(',', u',') \
        .replace('.', u'。') \
        .replace('?', u'?')

    gen_poetry = gen_acrostic if opt.acrostic else generate
    result = gen_poetry(model, start_words, ix2word, word2ix, prefix_words)
    print(''.join(result))

def get_data(opt):
    """
    @param opt 配置选项 Config对象
    @return word2ix: dict,每个字对应的序号,形如u'月'->100
    @return ix2word: dict,每个序号对应的字,形如'100'->u'月'
    @return data: numpy数组,每一行是一首诗对应的字的下标
    """
    if os.path.exists(opt.pickle_path):
        data = np.load(opt.pickle_path)
        data, word2ix, ix2word = data['data'], data['word2ix'].item(), data['ix2word'].item()
        return data, word2ix, ix2word

    # 如果没有处理好的二进制文件,则处理原始的json文件
    data = _parseRawData(opt.author, opt.constrain, opt.data_path, opt.category)
    words = {_word for _sentence in data for _word in _sentence}
    word2ix = {_word: _ix for _ix, _word in enumerate(words)}
    word2ix['<EOP>'] = len(word2ix)  # 终止标识符
    word2ix['<START>'] = len(word2ix)  # 起始标识符
    word2ix['</s>'] = len(word2ix)  # 空格
    ix2word = {_ix: _word for _word, _ix in list(word2ix.items())}

    # 为每首诗歌加上起始符和终止符
    for i in range(len(data)):
        data[i] = ["<START>"] + list(data[i]) + ["<EOP>"]

    # 将每首诗歌保存的内容由‘字’变成‘数’
    # 形如[春,江,花,月,夜]变成[1,2,3,4,5]
    new_data = [[word2ix[_word] for _word in _sentence]
                for _sentence in data]

    # 诗歌长度不够opt.maxlen的在前面补空格,超过的,删除末尾的
    pad_data = pad_sequences(new_data,
                             maxlen=opt.maxlen,
                             padding='pre',
                             truncating='post',
                             value=len(word2ix) - 1)

    # 保存成二进制文件
    np.savez_compressed(opt.pickle_path,
                        data=pad_data,
                        word2ix=word2ix,
                        ix2word=ix2word)
    return pad_data, word2ix, ix2word

In [6]:
data, word2ix, ix2word = get_data(opt)
data = t.from_numpy(data)

In [15]:
''.join([ix2word[x.item()] for x in data[10]])


Out[15]:
'</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s><START>夏景已难度,怀贤思方续。乔树落疎阴,微风散烦燠。伤离枉芳札,忻遂见心曲。蓝上舍已成,田家雨新足。讬邻素多欲,残帙犹见束。日夕上高斋,但望东原绿。<EOP>'

In [18]:
data.shape


Out[18]:
torch.Size([57580, 125])

In [22]:
index = np.random.choice(range(data.shape[0]), size=1)
''.join([ix2word[x.item()] for x in data[index.item()]])


Out[22]:
'</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s><START>紫烟楼阁碧纱亭,上界诗仙独自行。奇险驱回还寂寞,云山经用始鲜明。藕绡纹缕裁来滑,镜水波涛滤得清。昏思愿因秋露洗,幸容堦下礼先生。<EOP>'

In [7]:
len(word2ix)


Out[7]:
8293

In [58]:
dataloader = DataLoader(data,
                         batch_size=opt.batch_size,
                         shuffle=True,
                         num_workers=1)

# 模型定义
model = PoetryModel(len(word2ix), 128, 256)
optimizer = t.optim.Adam(model.parameters(), lr=opt.lr)
criterion = nn.CrossEntropyLoss()

l = []
for epoch in range(opt.epoch):
    for ii, data_ in tqdm.tqdm(enumerate(dataloader)):

        # 训练
        data_ = data_.long().transpose(1, 0).contiguous()
        optimizer.zero_grad()
        input_, target = Variable(data_[:-1, :]), Variable(data_[1:, :])
        output, _ = model(input_)
        loss = criterion(output, target.view(-1))
        loss.backward()
        optimizer.step()


        # 可视化
        if (1 + ii) % opt.plot_every == 0:

            l.append(loss.item())
            print('loss', loss.item())
            # 诗歌原文
#             poetrys = [[ix2word[_word.item()] for _word in data_[:, _iii]]
#                        for _iii in range(data_.size(1))][:16]
#             txt = '</br>'.join([''.join(poetry) for poetry in poetrys])
#             print('origin:', txt)

#             gen_poetries = []
#             # 分别以这几个字作为诗歌的第一个字,生成8首诗
#             for word in list(u'春江花月夜凉如水'):
#                 gen_poetry = ''.join(generate(model, word, ix2word, word2ix))
#                 gen_poetries.append(gen_poetry)
#             txt = '</br>'.join([''.join(poetry) for poetry in gen_poetries])
#             print('generate:', txt)





0it [00:00, ?it/s]



1it [00:04,  4.12s/it]



2it [00:07,  3.85s/it]



3it [00:11,  3.84s/it]



4it [00:15,  3.82s/it]



5it [00:19,  3.84s/it]



6it [00:24,  4.03s/it]



7it [00:28,  4.08s/it]



8it [00:32,  4.12s/it]



9it [00:38,  4.23s/it]



10it [00:42,  4.27s/it]



11it [00:47,  4.27s/it]



12it [00:51,  4.31s/it]



13it [00:56,  4.37s/it]



14it [01:01,  4.42s/it]



15it [01:06,  4.47s/it]



16it [01:11,  4.50s/it]



17it [01:16,  4.52s/it]



18it [01:21,  4.53s/it]



19it [01:26,  4.56s/it]
loss 3.632603883743286



20it [01:31,  4.57s/it]



21it [01:36,  4.57s/it]



22it [01:40,  4.57s/it]



23it [01:44,  4.56s/it]



24it [01:49,  4.56s/it]



25it [01:53,  4.56s/it]



26it [01:58,  4.56s/it]



27it [02:02,  4.55s/it]



28it [02:07,  4.55s/it]



29it [02:12,  4.55s/it]



30it [02:16,  4.55s/it]



31it [02:21,  4.58s/it]



32it [02:26,  4.59s/it]



33it [02:31,  4.58s/it]



34it [02:35,  4.58s/it]



35it [02:40,  4.58s/it]



36it [02:45,  4.58s/it]



37it [02:50,  4.60s/it]



38it [02:55,  4.61s/it]



39it [03:00,  4.62s/it]
loss 3.09531831741333



40it [03:05,  4.64s/it]



41it [03:10,  4.65s/it]



42it [03:16,  4.67s/it]



43it [03:22,  4.71s/it]



44it [03:26,  4.70s/it]



45it [03:31,  4.69s/it]



46it [03:35,  4.68s/it]



47it [03:39,  4.68s/it]



48it [03:44,  4.67s/it]



49it [03:48,  4.66s/it]



50it [03:52,  4.66s/it]



51it [03:57,  4.66s/it]



52it [04:02,  4.66s/it]



53it [04:07,  4.67s/it]



54it [04:11,  4.66s/it]



55it [04:16,  4.65s/it]



56it [04:20,  4.65s/it]



57it [04:24,  4.64s/it]



58it [04:29,  4.64s/it]



59it [04:33,  4.64s/it]
loss 2.913165330886841



60it [04:37,  4.63s/it]



61it [04:42,  4.63s/it]



62it [04:46,  4.62s/it]



63it [04:51,  4.62s/it]



64it [04:55,  4.62s/it]



65it [04:59,  4.61s/it]



66it [05:04,  4.61s/it]



67it [05:08,  4.61s/it]



68it [05:13,  4.61s/it]



69it [05:18,  4.61s/it]



70it [05:23,  4.61s/it]



71it [05:27,  4.61s/it]



72it [05:32,  4.62s/it]



73it [05:36,  4.61s/it]



74it [05:41,  4.62s/it]



75it [05:46,  4.62s/it]



76it [05:50,  4.61s/it]



77it [05:54,  4.61s/it]



78it [05:59,  4.61s/it]



79it [06:03,  4.60s/it]
loss 2.8733842372894287



80it [06:07,  4.60s/it]



81it [06:12,  4.60s/it]



82it [06:16,  4.60s/it]



83it [06:21,  4.60s/it]



84it [06:26,  4.60s/it]



85it [06:31,  4.61s/it]



86it [06:36,  4.61s/it]



87it [06:40,  4.61s/it]



88it [06:45,  4.61s/it]



89it [06:49,  4.61s/it]



90it [06:54,  4.60s/it]



91it [06:58,  4.60s/it]



92it [07:03,  4.60s/it]



93it [07:07,  4.60s/it]



94it [07:11,  4.59s/it]



95it [07:16,  4.59s/it]



96it [07:21,  4.60s/it]



97it [07:26,  4.60s/it]



98it [07:30,  4.60s/it]



99it [07:35,  4.60s/it]
loss 2.774152994155884



100it [07:40,  4.60s/it]



101it [07:44,  4.60s/it]



102it [07:49,  4.60s/it]



103it [07:53,  4.60s/it]



104it [07:57,  4.59s/it]



105it [08:02,  4.59s/it]



106it [08:06,  4.59s/it]



107it [08:10,  4.59s/it]



108it [08:15,  4.59s/it]



109it [08:20,  4.59s/it]



110it [08:24,  4.59s/it]



111it [08:29,  4.59s/it]



112it [08:33,  4.59s/it]



113it [08:38,  4.59s/it]



114it [08:43,  4.59s/it]



115it [08:48,  4.59s/it]



116it [08:53,  4.60s/it]



117it [08:58,  4.60s/it]



118it [09:03,  4.61s/it]



119it [09:08,  4.61s/it]
loss 2.8415277004241943



120it [09:12,  4.61s/it]



121it [09:16,  4.60s/it]



122it [09:21,  4.60s/it]



123it [09:26,  4.60s/it]



124it [09:30,  4.60s/it]



125it [09:35,  4.60s/it]



126it [09:39,  4.60s/it]



127it [09:44,  4.60s/it]



128it [09:48,  4.60s/it]



129it [09:53,  4.60s/it]



130it [09:58,  4.61s/it]



131it [10:03,  4.61s/it]



132it [10:08,  4.61s/it]



133it [10:13,  4.61s/it]



134it [10:18,  4.61s/it]



135it [10:22,  4.61s/it]



136it [10:27,  4.61s/it]



137it [10:31,  4.61s/it]



138it [10:35,  4.61s/it]



139it [10:40,  4.61s/it]
loss 2.8828182220458984



140it [10:44,  4.61s/it]



141it [10:49,  4.61s/it]



142it [10:53,  4.60s/it]



143it [10:58,  4.60s/it]



144it [11:02,  4.60s/it]



145it [11:06,  4.60s/it]



146it [11:11,  4.60s/it]



147it [11:15,  4.60s/it]



148it [11:20,  4.60s/it]



149it [11:24,  4.60s/it]



150it [11:29,  4.60s/it]



151it [11:34,  4.60s/it]



152it [11:38,  4.60s/it]



153it [11:43,  4.60s/it]



154it [11:47,  4.60s/it]



155it [11:52,  4.60s/it]



156it [11:56,  4.59s/it]



157it [12:01,  4.59s/it]



158it [12:05,  4.59s/it]



159it [12:10,  4.59s/it]
loss 2.982400894165039



160it [12:15,  4.60s/it]



161it [12:19,  4.60s/it]



162it [12:26,  4.61s/it]



163it [12:31,  4.61s/it]



164it [12:35,  4.61s/it]



165it [12:40,  4.61s/it]



166it [12:44,  4.61s/it]



167it [12:49,  4.61s/it]



168it [12:54,  4.61s/it]



169it [12:58,  4.61s/it]



170it [13:02,  4.61s/it]



171it [13:07,  4.61s/it]



172it [13:12,  4.61s/it]



173it [13:17,  4.61s/it]



174it [13:21,  4.61s/it]



175it [13:26,  4.61s/it]



176it [13:30,  4.61s/it]



177it [13:35,  4.61s/it]



178it [13:39,  4.61s/it]



179it [13:44,  4.60s/it]
loss 2.8406317234039307



180it [13:48,  4.60s/it]



181it [13:52,  4.60s/it]



182it [13:57,  4.60s/it]



183it [14:01,  4.60s/it]



184it [14:06,  4.60s/it]



185it [14:10,  4.60s/it]



186it [14:14,  4.60s/it]



187it [14:19,  4.59s/it]



188it [14:23,  4.59s/it]



189it [14:27,  4.59s/it]



190it [14:32,  4.59s/it]



191it [14:36,  4.59s/it]



192it [14:41,  4.59s/it]



193it [14:45,  4.59s/it]



194it [14:49,  4.59s/it]



195it [14:54,  4.59s/it]



196it [14:58,  4.58s/it]



197it [15:03,  4.58s/it]



198it [15:07,  4.58s/it]



199it [15:11,  4.58s/it]
loss 2.4145264625549316



200it [15:16,  4.58s/it]



201it [15:20,  4.58s/it]



202it [15:25,  4.58s/it]



203it [15:29,  4.58s/it]



204it [15:33,  4.58s/it]



205it [15:38,  4.58s/it]



206it [15:42,  4.57s/it]



207it [15:46,  4.57s/it]



208it [15:51,  4.57s/it]



209it [15:55,  4.57s/it]



210it [15:59,  4.57s/it]



211it [16:04,  4.57s/it]



212it [16:08,  4.57s/it]



213it [16:12,  4.57s/it]



214it [16:17,  4.57s/it]



215it [16:21,  4.57s/it]



216it [16:26,  4.57s/it]



217it [16:30,  4.56s/it]



218it [16:34,  4.56s/it]



219it [16:39,  4.56s/it]
loss 2.854132890701294



220it [16:43,  4.56s/it]



221it [16:47,  4.56s/it]



222it [16:52,  4.56s/it]



223it [16:56,  4.56s/it]



224it [17:01,  4.56s/it]



225it [17:05,  4.56s/it]



226it [17:09,  4.56s/it]



227it [17:14,  4.56s/it]



228it [17:18,  4.55s/it]



229it [17:22,  4.55s/it]



230it [17:27,  4.55s/it]



231it [17:31,  4.55s/it]



232it [17:36,  4.55s/it]



233it [17:40,  4.55s/it]



234it [17:45,  4.55s/it]



235it [17:49,  4.55s/it]



236it [17:54,  4.55s/it]



237it [17:58,  4.55s/it]



238it [18:03,  4.55s/it]



239it [18:07,  4.55s/it]
loss 2.6666290760040283



240it [18:12,  4.55s/it]



241it [18:16,  4.55s/it]



242it [18:21,  4.55s/it]



243it [18:26,  4.55s/it]



244it [18:31,  4.55s/it]
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-58-06557cff080a> in <module>()
     17         optimizer.zero_grad()
     18         input_, target = Variable(data_[:-1, :]), Variable(data_[1:, :])
---> 19         output, _ = model(input_)
     20         loss = criterion(output, target.view(-1))
     21         loss.backward()

C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\module.py in __call__(self, *input, **kwargs)
    475             result = self._slow_forward(*input, **kwargs)
    476         else:
--> 477             result = self.forward(*input, **kwargs)
    478         for hook in self._forward_hooks.values():
    479             hook_result = hook(self, input, result)

<ipython-input-54-feb8f022b63b> in forward(self, input, hidden)
     54         embeds = self.embeddings(input)
     55         # output size: (seq_len,batch_size,hidden_dim)
---> 56         output, hidden = self.lstm(embeds, (h_0, c_0))
     57 
     58         # size: (seq_len*batch_size,vocab_size)

C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\module.py in __call__(self, *input, **kwargs)
    475             result = self._slow_forward(*input, **kwargs)
    476         else:
--> 477             result = self.forward(*input, **kwargs)
    478         for hook in self._forward_hooks.values():
    479             hook_result = hook(self, input, result)

C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\rnn.py in forward(self, input, hx)
    190             flat_weight=flat_weight
    191         )
--> 192         output, hidden = func(input, self.all_weights, hx, batch_sizes)
    193         if is_packed:
    194             output = PackedSequence(output, batch_sizes)

C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\_functions\rnn.py in forward(input, *fargs, **fkwargs)
    322             func = decorator(func)
    323 
--> 324         return func(input, *fargs, **fkwargs)
    325 
    326     return forward

C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\_functions\rnn.py in forward(input, weight, hidden, batch_sizes)
    242             input = input.transpose(0, 1)
    243 
--> 244         nexth, output = func(input, hidden, weight, batch_sizes)
    245 
    246         if batch_first and not variable_length:

C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\_functions\rnn.py in forward(input, hidden, weight, batch_sizes)
     85                 l = i * num_directions + j
     86 
---> 87                 hy, output = inner(input, hidden[l], weight[l], batch_sizes)
     88                 next_hidden.append(hy)
     89                 all_output.append(output)

C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\_functions\rnn.py in forward(input, hidden, weight, batch_sizes)
    114         steps = range(input.size(0) - 1, -1, -1) if reverse else range(input.size(0))
    115         for i in steps:
--> 116             hidden = inner(input[i], hidden, *weight)
    117             # hack to handle LSTM
    118             output.append(hidden[0] if isinstance(hidden, tuple) else hidden)

C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\_functions\rnn.py in LSTMCell(input, hidden, w_ih, w_hh, b_ih, b_hh)
     32 
     33     hx, cx = hidden
---> 34     gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh)
     35 
     36     ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)

C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\functional.py in linear(input, weight, bias)
   1022     if input.dim() == 2 and bias is not None:
   1023         # fused op is marginally faster
-> 1024         return torch.addmm(bias, input, weight.t())
   1025 
   1026     output = input.matmul(weight.t())

KeyboardInterrupt: 

In [57]:
loss.item()


Out[57]:
2.6752970218658447

In [19]:
x = data_[:,0]

In [20]:
x


Out[20]:
tensor([8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292,
        8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292,
        8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292,
        8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292,
        8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292,
        8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292,
        8292, 8292, 8292, 8291, 4038, 2958, 3189, 1175, 4624, 7066, 7630, 3520,
        4710, 1161,  745, 7435, 6225, 3929, 4744, 3286, 7310, 7066, 1787, 1360,
        8010, 6787, 7914, 7435, 1719, 2808, 6016, 4782, 4414, 7066, 7854, 7377,
        6663, 3243, 1540, 7435, 8033, 2884, 4054, 3465, 8150, 7066, 3969, 5135,
        5283,   70, 7905, 7435, 8290])

In [21]:
ix2word[8292]


Out[21]:
'</s>'

In [24]:
a = t.Tensor([1])

In [25]:
a.item()


Out[25]:
1.0

In [26]:
poetrys = [[ix2word[_word.item()] for _word in data_[:, _iii]]
                       for _iii in range(data_.size(1))][:16]

In [28]:
txt = '</br>'.join([''.join(poetry) for poetry in poetrys])
txt


Out[28]:
'</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s><START>谷中春日暖,渐忆掇茶英。欲及清明火,能销醉客酲。松花飘鼎泛,兰气入瓯轻。饮罢闲无事,扪萝溪上行。<EOP></br><START>游人夜到汝阳间,夜色冥濛不解颜。谁家暗起寒山烧,因此明中得见山。山头山下须臾满,历险缘深无暂断。焦声散著羣树鸣,炎气傍林一川暖。是时西北多海风,吹上连天光更雄。浊烟熏月黑,高豔爇云红。初谓炼丹仙灶里,还疑铸劒神谿中。划为飞电来照物,乍作流星并上空。</br><START>少年落魄楚汉间,风尘萧瑟多苦颜。自言管葛竟谁许,长吁莫错还闭关。一朝君王垂拂拭,剖心输丹雪胸臆。忽蒙白日回景光,直上青云生羽翼。幸陪鸞辇出鸿都,身骑飞龙天马驹。王公大人借颜色,金璋紫绶来相趋。当时结交何纷纷,片言道合惟有君。待吾尽节报明主,然后相携</br></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s><START>天子卹疲瘵,坤灵奉其职。年年济世功,贵贱相兼植。因产众草中,所希采者识。一枝当若神,千金亦何直。生草不生药,无以彰土德。生药不生草,无以彰奇特。国忠在臣贤,民患凭药力。灵草犹如此,贤人岂多得。<EOP></br></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s><START>生长在荥阳,少小辞乡曲。迢迢四十载,复向荥阳宿。去时十一二,今年五十六。追思儿戏时,宛然犹在目。旧居失处所,故里无宗族。岂唯变市朝,兼亦迁陵谷。独有溱洧水,无情依旧绿。<EOP></br></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s><START>阻他罗网到柴扉,不奈偷仓雀转肥。赖尔林塘添景趣,剩留山果引教归。<EOP></br></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s><START>卷荷香澹浮烟渚,绿嫩擎新雨。琐窗疎透晓风清,象床珍簟冷光轻,水文平。九疑黛色屏斜掩,枕上眉心歛。不堪相望病将成,钿昏檀粉泪纵横,不胜情。<EOP></br></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s><START>彼美汉东国,川藏明月辉。宁知丧乱后,更有一珠归。<EOP></br><START>善为尔诸身,行为尔性命。祸福必可转,莫悫言前定。见人之得,如己之得。则美无不克,见人之失。如己之失,是亨贞吉。反此之徒,天鬼必诛。福先祸始,好杀灭纪。不得不止,守谦寡慾。善善恶恶,不得不作。无见贵热,谄走蹩躠。无轻贱微,上下相依。古圣著书,矻矻孳孳</br></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s><START>绰约小天仙,生来十六年。姑山半峰雪,瑶水一枝莲。晚院花留立,春窗月伴眠。回眸虽欲语,阿母在傍边。<EOP></br></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s><START>秃人今日已定,不须卜於长安。天坐住汝男津,百官大会千斤肫。<EOP></br></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s><START>肃肃雍雍义有余,九天鸞凤莫相疎。唯应静向山窗过,激发英雄夜读书。<EOP></br></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s><START>管急弦繁拍渐稠,绿腰宛转曲终头。诚知乐世声声乐,老病人听未免愁。<EOP></br></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s><START>多雨殊未已,秋云更沈沈。洛阳故人初解印,山东小吏来相寻。上卿才大名不朽,早朝至尊暮求友。豁达常推海内贤,殷勤但酌尊中酒。饮醉欲言归剡溪,门前驷马光照衣。路傍观者徒唧唧,我公不以为是非。<EOP></br></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s><START>大君膺宝历,出豫表功成。钧天金石响,洞庭弦管清。八音动繁会,九变叶希声。和云留睿赏,熏风悅圣情。盛烈光韶濩,易俗迈咸英。窃吹良无取,率舞抃羣生。<EOP></br><START>之子逍遥尘世薄,格淡於云语如鹤。相见唯谈海上山,碧侧青斜冷相沓。芒鞋竹杖寒冻时,玉霄忽去非有期。僮担亦笼密雪里,世人无人留得之。想入红霞路深邃,孤峰纵啸仙飙起。星精聚观泣海鬼,月湧薄烟花点水。送君丁宁有深旨,好寻佛窟游银地。雪眉衲僧皆正气,伊昔贞白'

In [66]:
gen_poetries = []
# 分别以这几个字作为诗歌的第一个字,生成8首诗
for word in list(u'春'):
    gen_poetry = ''.join(generate(model, word, ix2word, word2ix))
    gen_poetries.append(gen_poetry)
txt = '\n'.join([''.join(poetry) for poetry in gen_poetries])

In [67]:
txt


Out[67]:
'春来不见人不知,人家女儿弄金屋。天涯相见不可论,万里千山万丈余。青荧数点奇奇士,白日葱茏生八区。东山桃李夹城东,西陵道路人不同。顾君余笑为谁子,忆昔陈家亦相望。一从遇此学为名,十年为客无遗名。君不见东西衮衮客,一年何处无人识。君今不见东方来,今日独有江州吟。君不见邺中有美酒,江上一行何处寻。使我哀心自此乐,一夜独向江南去。'

In [35]:
model


Out[35]:
PoetryModel(
  (embeddings): Embedding(8293, 128)
  (lstm): LSTM(128, 256, num_layers=2)
  (linear1): Linear(in_features=256, out_features=8293, bias=True)
)

In [36]:
input = Variable(t.Tensor([word2ix['<START>']]).view(1, 1).long())
input


Out[36]:
tensor([[8291]])

In [38]:
output, hidden = model(input)

In [39]:
top_index = output.data[0].topk(1)[1][0]
w = ix2word[top_index.item()]

In [40]:
w


Out[40]:
'</s>'

In [63]:
model = PoetryModel(len(word2ix), 128, 256);
map_location = lambda s, l: s
state_dict = t.load('D:/project/ml/data/tang_199.pth', map_location=map_location)
model.load_state_dict(state_dict)

In [ ]:


In [14]:
gen()


闲云潭影日悠悠,千里万里无人愁。人生不及春已暮,妾家独歌长安曲。西楼月,月明中,一夜东风吹晓雨。今年花落春风起,杨柳青青鸦在水。柳条繁,柳条垂,柳条丝管,杯酒满袖,歌语声尽画蛾眉。妆成蹋,望山云,复君心。妾心明月月如珠翠,绣户垂珰。陵上一行春半,月明双燕燕来。一曲弦歌,女笛歌回。愿言君王镇魏王女,传歌金马不成回。第一斗斛三五,三十六宫千万里。金龙雄劒佩金梭,玉关银汉照金绳。帐前珠佩芙蓉幕,蜡炬光辉

In [ ]: