机器也懂情感:文本分类

在本课中,我们通过抓取京东商场上面的评论形成我们的语料。然后,我们训练一个模型来对这些评论进行分类,分为正面的评论和负面的评论。

首先,我们展示了如何写一个简单的爬虫程序从网上扒文章;

其次,我们设计了一个简单的前馈网络,利用词袋模型获得每一个评论的向量表示,并输入进前馈网络得到很好的分类效果,我们还对这个网络进行了简单的剖析。

然后,我们尝试了两种RNN网络,一种是普通的RNN,另一种是LSTM。本程序展示了它们在处理同样的文本分类问题上的用法

本文件是集智AI学园http://campus.swarma.org 出品的“火炬上的深度学习”第VI课的配套源代码


In [2]:
# 导入程序所需要的程序包

#抓取网页内容用的程序包
import json
import requests

#PyTorch用的包
import torch
import torch.nn as nn
import torch.optim
from torch.autograd import Variable

# 自然语言处理相关的包
import re #正则表达式的包
import jieba #结巴分词包
from collections import Counter #搜集器,可以让统计词频更简单

#绘图、计算用的程序包
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline


/Users/tradeshift/Workspace/deep_learning_tutorial/p3ml-venv/lib/python3.6/site-packages/matplotlib/font_manager.py:279: UserWarning: Matplotlib is building the font cache using fc-list. This may take a moment.
  'Matplotlib is building the font cache using fc-list. '

一、数据处理

我们的数据来源于京东上的商品评论,每一条评论都会配合有一个评分。我们通过调用接口将相应的参数传入进去就可以得到评论。

根据评分的高低,我们可以划分成正向和负向两组标签

1、从京东上抓取评论数据


In [2]:
# 在指定的url处获得评论
def get_comments(url):
    comments = []
    # 打开指定页面
    resp = requests.get(url)
    resp.encoding = 'gbk'
    
    #如果200秒没有打开则失败
    if resp.status_code != 200:
        return []
    
    #获得内容
    content = resp.text
    if content:
        #获得()括号中的内容
        ind = content.find('(')
        s1 = content[ind+1:-2]
        try:
            #尝试利用jason接口来读取内容,并做jason的解析
            js = json.loads(s1)
            #提取出comments字段的内容
            comment_infos = js['comments']
        except:
            print('error')
            return([])
        
        #对每一条评论进行内容部分的抽取
        for comment_info in comment_infos:
            comment_content = comment_info['content']
            str1 = comment_content + '\n'
            comments.append(str1)
    return comments

good_comments = []

#评论抓取的来源地址,其中参数包括:
#productId为商品的id,score为评分,page为对应的评论翻页的页码,pageSize为总页数
#这里,我们设定score=3表示好的评分。
good_comment_url_templates = [
    'https://club.jd.com/comment/productPageComments.action?callback=fetchJSON_comment98vv8914&productId=10359162198&score=3&sortType=5&page={}&pageSize=10&isShadowSku=0',
    'https://club.jd.com/comment/productPageComments.action?callback=fetchJSON_comment98vv73&productId=10968941641&score=3&sortType=5&page={}&pageSize=10&isShadowSku=0',
    'https://club.jd.com/comment/productPageComments.action?callback=fetchJSON_comment98vv4653&productId=10335204102&score=3&sortType=5&page={}&pageSize=10&isShadowSku=0',
    'https://club.jd.com/comment/productPageComments.action?callback=fetchJSON_comment98vv1&productId=1269194114&score=3&sortType=5&page={}&pageSize=10&isShadowSku=0',
    'https://club.jd.com/comment/productPageComments.action?callback=fetchJSON_comment98vv2777&productId=1409704820&score=3&sortType=5&page={}&pageSize=10&isShadowSku=0',
    'https://club.jd.com/comment/productPageComments.action?callback=fetchJSON_comment98vv174&productId=10103790891&score=3&sortType=5&page={}&pageSize=10&isShadowSku=0',
    'https://club.jd.com/comment/productPageComments.action?callback=fetchJSON_comment98vv9447&productId=1708318938&score=3&sortType=5&page={}&pageSize=10&isShadowSku=0',
    'https://club.jd.com/comment/productPageComments.action?callback=fetchJSON_comment98vv111&productId=10849803616&score=3&sortType=5&page={}&pageSize=10&isShadowSku=0'
]

# 对上述网址进行循环,并模拟翻页100次
j=0
for good_comment_url_template in good_comment_url_templates:
    for i in range(100):
        url = good_comment_url_template.format(i)
        good_comments += get_comments(url)
        print('第{}条纪录,总文本长度{}'.format(j, len(good_comments)))
        j += 1
#将结果存储到good.txt文件中
fw = open('data/good.txt', 'w')
fw.writelines(good_comments)


第0条纪录,总文本长度10
第1条纪录,总文本长度20
第2条纪录,总文本长度30
第3条纪录,总文本长度40
第4条纪录,总文本长度50
第5条纪录,总文本长度60
第6条纪录,总文本长度70
第7条纪录,总文本长度80
第8条纪录,总文本长度90
第9条纪录,总文本长度100
第10条纪录,总文本长度110
第11条纪录,总文本长度120
第12条纪录,总文本长度130
第13条纪录,总文本长度140
第14条纪录,总文本长度150
第15条纪录,总文本长度160
第16条纪录,总文本长度170
第17条纪录,总文本长度180
第18条纪录,总文本长度190
第19条纪录,总文本长度200
第20条纪录,总文本长度210
第21条纪录,总文本长度220
第22条纪录,总文本长度230
第23条纪录,总文本长度240
第24条纪录,总文本长度250
第25条纪录,总文本长度260
第26条纪录,总文本长度270
第27条纪录,总文本长度280
第28条纪录,总文本长度290
第29条纪录,总文本长度300
第30条纪录,总文本长度310
第31条纪录,总文本长度320
第32条纪录,总文本长度330
第33条纪录,总文本长度340
第34条纪录,总文本长度350
第35条纪录,总文本长度360
第36条纪录,总文本长度370
第37条纪录,总文本长度380
第38条纪录,总文本长度390
第39条纪录,总文本长度400
第40条纪录,总文本长度410
第41条纪录,总文本长度420
第42条纪录,总文本长度430
第43条纪录,总文本长度440
第44条纪录,总文本长度450
第45条纪录,总文本长度460
第46条纪录,总文本长度470
第47条纪录,总文本长度480
第48条纪录,总文本长度490
第49条纪录,总文本长度500
第50条纪录,总文本长度510
第51条纪录,总文本长度520
第52条纪录,总文本长度530
第53条纪录,总文本长度540
第54条纪录,总文本长度550
第55条纪录,总文本长度560
第56条纪录,总文本长度570
第57条纪录,总文本长度580
第58条纪录,总文本长度590
第59条纪录,总文本长度600
第60条纪录,总文本长度610
第61条纪录,总文本长度620
第62条纪录,总文本长度630
第63条纪录,总文本长度640
第64条纪录,总文本长度650
第65条纪录,总文本长度660
第66条纪录,总文本长度670
第67条纪录,总文本长度680
第68条纪录,总文本长度690
第69条纪录,总文本长度700
第70条纪录,总文本长度710
第71条纪录,总文本长度720
第72条纪录,总文本长度730
第73条纪录,总文本长度740
第74条纪录,总文本长度750
第75条纪录,总文本长度760
第76条纪录,总文本长度770
第77条纪录,总文本长度780
第78条纪录,总文本长度790
第79条纪录,总文本长度800
第80条纪录,总文本长度810
第81条纪录,总文本长度820
第82条纪录,总文本长度830
第83条纪录,总文本长度840
第84条纪录,总文本长度850
第85条纪录,总文本长度860
第86条纪录,总文本长度870
第87条纪录,总文本长度880
第88条纪录,总文本长度890
第89条纪录,总文本长度900
第90条纪录,总文本长度910
第91条纪录,总文本长度920
第92条纪录,总文本长度930
第93条纪录,总文本长度940
第94条纪录,总文本长度950
第95条纪录,总文本长度960
第96条纪录,总文本长度970
第97条纪录,总文本长度980
第98条纪录,总文本长度990
第99条纪录,总文本长度1000
第100条纪录,总文本长度1010
第101条纪录,总文本长度1020
第102条纪录,总文本长度1030
第103条纪录,总文本长度1040
第104条纪录,总文本长度1050
第105条纪录,总文本长度1060
第106条纪录,总文本长度1070
第107条纪录,总文本长度1080
第108条纪录,总文本长度1090
第109条纪录,总文本长度1100
第110条纪录,总文本长度1110
第111条纪录,总文本长度1120
第112条纪录,总文本长度1130
第113条纪录,总文本长度1140
第114条纪录,总文本长度1150
第115条纪录,总文本长度1160
第116条纪录,总文本长度1170
第117条纪录,总文本长度1180
第118条纪录,总文本长度1190
第119条纪录,总文本长度1200
第120条纪录,总文本长度1210
第121条纪录,总文本长度1220
第122条纪录,总文本长度1230
第123条纪录,总文本长度1240
第124条纪录,总文本长度1250
第125条纪录,总文本长度1260
第126条纪录,总文本长度1270
第127条纪录,总文本长度1280
第128条纪录,总文本长度1290
第129条纪录,总文本长度1300
第130条纪录,总文本长度1310
第131条纪录,总文本长度1320
第132条纪录,总文本长度1330
第133条纪录,总文本长度1340
第134条纪录,总文本长度1350
第135条纪录,总文本长度1360
第136条纪录,总文本长度1370
第137条纪录,总文本长度1380
第138条纪录,总文本长度1390
第139条纪录,总文本长度1400
第140条纪录,总文本长度1410
第141条纪录,总文本长度1420
第142条纪录,总文本长度1430
第143条纪录,总文本长度1440
第144条纪录,总文本长度1450
第145条纪录,总文本长度1460
第146条纪录,总文本长度1470
第147条纪录,总文本长度1480
第148条纪录,总文本长度1490
第149条纪录,总文本长度1500
第150条纪录,总文本长度1510
第151条纪录,总文本长度1520
第152条纪录,总文本长度1530
第153条纪录,总文本长度1540
第154条纪录,总文本长度1550
第155条纪录,总文本长度1560
第156条纪录,总文本长度1570
第157条纪录,总文本长度1580
第158条纪录,总文本长度1590
第159条纪录,总文本长度1600
第160条纪录,总文本长度1610
第161条纪录,总文本长度1620
第162条纪录,总文本长度1630
第163条纪录,总文本长度1640
第164条纪录,总文本长度1650
第165条纪录,总文本长度1660
第166条纪录,总文本长度1670
第167条纪录,总文本长度1680
第168条纪录,总文本长度1690
第169条纪录,总文本长度1700
第170条纪录,总文本长度1710
第171条纪录,总文本长度1720
第172条纪录,总文本长度1730
第173条纪录,总文本长度1740
第174条纪录,总文本长度1750
第175条纪录,总文本长度1760
第176条纪录,总文本长度1770
第177条纪录,总文本长度1780
第178条纪录,总文本长度1790
第179条纪录,总文本长度1800
第180条纪录,总文本长度1810
第181条纪录,总文本长度1820
第182条纪录,总文本长度1830
第183条纪录,总文本长度1840
第184条纪录,总文本长度1850
第185条纪录,总文本长度1860
第186条纪录,总文本长度1870
第187条纪录,总文本长度1880
第188条纪录,总文本长度1890
第189条纪录,总文本长度1900
第190条纪录,总文本长度1910
第191条纪录,总文本长度1920
第192条纪录,总文本长度1930
第193条纪录,总文本长度1940
第194条纪录,总文本长度1950
第195条纪录,总文本长度1960
第196条纪录,总文本长度1970
第197条纪录,总文本长度1980
第198条纪录,总文本长度1990
第199条纪录,总文本长度2000
第200条纪录,总文本长度2010
第201条纪录,总文本长度2020
第202条纪录,总文本长度2030
第203条纪录,总文本长度2040
第204条纪录,总文本长度2050
第205条纪录,总文本长度2060
第206条纪录,总文本长度2070
第207条纪录,总文本长度2080
第208条纪录,总文本长度2090
第209条纪录,总文本长度2100
第210条纪录,总文本长度2110
第211条纪录,总文本长度2120
第212条纪录,总文本长度2130
第213条纪录,总文本长度2140
第214条纪录,总文本长度2150
第215条纪录,总文本长度2160
第216条纪录,总文本长度2170
第217条纪录,总文本长度2180
第218条纪录,总文本长度2190
第219条纪录,总文本长度2200
第220条纪录,总文本长度2210
第221条纪录,总文本长度2220
第222条纪录,总文本长度2230
第223条纪录,总文本长度2240
第224条纪录,总文本长度2250
第225条纪录,总文本长度2260
第226条纪录,总文本长度2270
第227条纪录,总文本长度2280
第228条纪录,总文本长度2290
第229条纪录,总文本长度2300
第230条纪录,总文本长度2310
第231条纪录,总文本长度2320
第232条纪录,总文本长度2330
第233条纪录,总文本长度2340
第234条纪录,总文本长度2350
第235条纪录,总文本长度2360
第236条纪录,总文本长度2370
第237条纪录,总文本长度2380
第238条纪录,总文本长度2390
第239条纪录,总文本长度2400
第240条纪录,总文本长度2410
第241条纪录,总文本长度2420
第242条纪录,总文本长度2430
第243条纪录,总文本长度2440
第244条纪录,总文本长度2450
第245条纪录,总文本长度2460
第246条纪录,总文本长度2470
第247条纪录,总文本长度2480
第248条纪录,总文本长度2490
第249条纪录,总文本长度2500
第250条纪录,总文本长度2510
第251条纪录,总文本长度2520
第252条纪录,总文本长度2530
第253条纪录,总文本长度2540
第254条纪录,总文本长度2550
第255条纪录,总文本长度2560
第256条纪录,总文本长度2570
第257条纪录,总文本长度2580
第258条纪录,总文本长度2590
第259条纪录,总文本长度2600
第260条纪录,总文本长度2610
第261条纪录,总文本长度2620
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-2-2c40157c1d83> in <module>()
     53     for i in range(100):
     54         url = good_comment_url_template.format(i)
---> 55         good_comments += get_comments(url)
     56         print('第{}条纪录,总文本长度{}'.format(j, len(good_comments)))
     57         j += 1

<ipython-input-2-2c40157c1d83> in get_comments(url)
     26 
     27         #对每一条评论进行内容部分的抽取
---> 28         for comment_info in comment_infos:
     29             comment_content = comment_info['content']
     30             str1 = comment_content + '\n'

TypeError: 'NoneType' object is not iterable

In [3]:
# 负向评论如法炮制
bad_comments = []
bad_comment_url_templates = [
    'https://club.jd.com/comment/productPageComments.action?callback=fetchJSON_comment98vv8914&productId=10359162198&score=1&sortType=5&page={}&pageSize=10&isShadowSku=0',
    'https://club.jd.com/comment/productPageComments.action?callback=fetchJSON_comment98vv73&productId=10968941641&score=1&sortType=5&page={}&pageSize=10&isShadowSku=0',
    'http://club.jd.com/comment/productPageComments.action?callback=fetchJSON_comment98vv4653&productId=10335204102&score=1&sortType=5&page={}&pageSize=10&isShadowSku=0',
    'https://club.jd.com/comment/productPageComments.action?callback=fetchJSON_comment98vv1&productId=1269194114&score=1&sortType=5&page={}&pageSize=10&isShadowSku=0',
    'https://club.jd.com/comment/productPageComments.action?callback=fetchJSON_comment98vv2777&productId=1409704820&score=1&sortType=5&page={}&pageSize=10&isShadowSku=0',
    'https://club.jd.com/comment/productPageComments.action?callback=fetchJSON_comment98vv174&productId=10103790891&score=1&sortType=5&page={}&pageSize=10&isShadowSku=0',
    'https://club.jd.com/comment/productPageComments.action?callback=fetchJSON_comment98vv9447&productId=1708318938&score=1&sortType=5&page={}&pageSize=10&isShadowSku=0',
    'https://club.jd.com/comment/productPageComments.action?callback=fetchJSON_comment98vv111&productId=10849803616&score=1&sortType=5&page={}&pageSize=10&isShadowSku=0'
]

j = 0
for bad_comment_url_template in bad_comment_url_templates:
    for i in range(100):
        url = bad_comment_url_template.format(i)
        bad_comments += get_comments(url)
        print('第{}条纪录,总文本长度{}'.format(j, len(bad_comments)))
        j += 1

fw = open('data/bad.txt', 'w')
fw.writelines(bad_comments)


第0条纪录,总文本长度10
第1条纪录,总文本长度20
第2条纪录,总文本长度30
第3条纪录,总文本长度40
第4条纪录,总文本长度50
第5条纪录,总文本长度60
第6条纪录,总文本长度70
第7条纪录,总文本长度80
第8条纪录,总文本长度90
第9条纪录,总文本长度100
第10条纪录,总文本长度110
第11条纪录,总文本长度120
第12条纪录,总文本长度130
第13条纪录,总文本长度140
第14条纪录,总文本长度150
第15条纪录,总文本长度160
第16条纪录,总文本长度170
第17条纪录,总文本长度180
第18条纪录,总文本长度190
第19条纪录,总文本长度200
第20条纪录,总文本长度210
第21条纪录,总文本长度220
第22条纪录,总文本长度230
第23条纪录,总文本长度240
第24条纪录,总文本长度250
第25条纪录,总文本长度260
第26条纪录,总文本长度270
第27条纪录,总文本长度280
第28条纪录,总文本长度290
第29条纪录,总文本长度300
第30条纪录,总文本长度310
第31条纪录,总文本长度320
第32条纪录,总文本长度330
第33条纪录,总文本长度340
第34条纪录,总文本长度350
第35条纪录,总文本长度360
第36条纪录,总文本长度368
第37条纪录,总文本长度368
第38条纪录,总文本长度368
第39条纪录,总文本长度368
第40条纪录,总文本长度368
第41条纪录,总文本长度368
第42条纪录,总文本长度368
第43条纪录,总文本长度368
第44条纪录,总文本长度368
第45条纪录,总文本长度368
第46条纪录,总文本长度368
第47条纪录,总文本长度368
第48条纪录,总文本长度368
第49条纪录,总文本长度368
第50条纪录,总文本长度368
第51条纪录,总文本长度368
第52条纪录,总文本长度368
第53条纪录,总文本长度368
第54条纪录,总文本长度368
第55条纪录,总文本长度368
第56条纪录,总文本长度368
第57条纪录,总文本长度368
第58条纪录,总文本长度368
第59条纪录,总文本长度368
第60条纪录,总文本长度368
第61条纪录,总文本长度368
第62条纪录,总文本长度368
第63条纪录,总文本长度368
第64条纪录,总文本长度368
第65条纪录,总文本长度368
第66条纪录,总文本长度368
第67条纪录,总文本长度368
第68条纪录,总文本长度368
第69条纪录,总文本长度368
第70条纪录,总文本长度368
第71条纪录,总文本长度368
第72条纪录,总文本长度368
第73条纪录,总文本长度368
第74条纪录,总文本长度368
第75条纪录,总文本长度368
第76条纪录,总文本长度368
第77条纪录,总文本长度368
第78条纪录,总文本长度368
第79条纪录,总文本长度368
第80条纪录,总文本长度368
第81条纪录,总文本长度368
第82条纪录,总文本长度368
第83条纪录,总文本长度368
第84条纪录,总文本长度368
第85条纪录,总文本长度368
第86条纪录,总文本长度368
第87条纪录,总文本长度368
第88条纪录,总文本长度368
第89条纪录,总文本长度368
第90条纪录,总文本长度368
第91条纪录,总文本长度368
第92条纪录,总文本长度368
第93条纪录,总文本长度368
第94条纪录,总文本长度368
第95条纪录,总文本长度368
第96条纪录,总文本长度368
第97条纪录,总文本长度368
第98条纪录,总文本长度368
第99条纪录,总文本长度368
第100条纪录,总文本长度378
第101条纪录,总文本长度388
第102条纪录,总文本长度398
第103条纪录,总文本长度408
第104条纪录,总文本长度418
第105条纪录,总文本长度428
第106条纪录,总文本长度438
第107条纪录,总文本长度448
第108条纪录,总文本长度458
第109条纪录,总文本长度468
第110条纪录,总文本长度478
第111条纪录,总文本长度488
第112条纪录,总文本长度498
第113条纪录,总文本长度508
第114条纪录,总文本长度518
第115条纪录,总文本长度528
第116条纪录,总文本长度538
第117条纪录,总文本长度548
第118条纪录,总文本长度558
第119条纪录,总文本长度568
第120条纪录,总文本长度578
第121条纪录,总文本长度588
第122条纪录,总文本长度598
第123条纪录,总文本长度608
第124条纪录,总文本长度618
第125条纪录,总文本长度628
第126条纪录,总文本长度638
第127条纪录,总文本长度648
第128条纪录,总文本长度658
第129条纪录,总文本长度668
第130条纪录,总文本长度678
第131条纪录,总文本长度688
第132条纪录,总文本长度698
第133条纪录,总文本长度708
第134条纪录,总文本长度718
第135条纪录,总文本长度728
第136条纪录,总文本长度738
第137条纪录,总文本长度748
第138条纪录,总文本长度758
第139条纪录,总文本长度768
第140条纪录,总文本长度778
第141条纪录,总文本长度788
第142条纪录,总文本长度798
第143条纪录,总文本长度808
第144条纪录,总文本长度818
第145条纪录,总文本长度828
第146条纪录,总文本长度838
第147条纪录,总文本长度848
第148条纪录,总文本长度858
第149条纪录,总文本长度868
第150条纪录,总文本长度878
第151条纪录,总文本长度888
第152条纪录,总文本长度898
第153条纪录,总文本长度908
第154条纪录,总文本长度918
第155条纪录,总文本长度928
第156条纪录,总文本长度938
第157条纪录,总文本长度948
第158条纪录,总文本长度958
第159条纪录,总文本长度968
第160条纪录,总文本长度978
第161条纪录,总文本长度988
第162条纪录,总文本长度998
第163条纪录,总文本长度1008
第164条纪录,总文本长度1018
第165条纪录,总文本长度1028
第166条纪录,总文本长度1038
第167条纪录,总文本长度1048
第168条纪录,总文本长度1058
第169条纪录,总文本长度1068
第170条纪录,总文本长度1078
第171条纪录,总文本长度1088
第172条纪录,总文本长度1098
第173条纪录,总文本长度1108
第174条纪录,总文本长度1118
第175条纪录,总文本长度1128
第176条纪录,总文本长度1138
第177条纪录,总文本长度1148
第178条纪录,总文本长度1158
第179条纪录,总文本长度1168
第180条纪录,总文本长度1178
第181条纪录,总文本长度1188
第182条纪录,总文本长度1198
第183条纪录,总文本长度1208
第184条纪录,总文本长度1218
第185条纪录,总文本长度1228
第186条纪录,总文本长度1238
第187条纪录,总文本长度1248
第188条纪录,总文本长度1258
第189条纪录,总文本长度1268
第190条纪录,总文本长度1278
第191条纪录,总文本长度1288
第192条纪录,总文本长度1298
第193条纪录,总文本长度1308
第194条纪录,总文本长度1318
第195条纪录,总文本长度1328
第196条纪录,总文本长度1338
第197条纪录,总文本长度1348
第198条纪录,总文本长度1358
第199条纪录,总文本长度1368
第200条纪录,总文本长度1378
第201条纪录,总文本长度1388
第202条纪录,总文本长度1398
第203条纪录,总文本长度1408
第204条纪录,总文本长度1418
第205条纪录,总文本长度1428
第206条纪录,总文本长度1438
第207条纪录,总文本长度1448
第208条纪录,总文本长度1458
第209条纪录,总文本长度1468
第210条纪录,总文本长度1478
第211条纪录,总文本长度1488
第212条纪录,总文本长度1498
第213条纪录,总文本长度1508
第214条纪录,总文本长度1518
第215条纪录,总文本长度1528
第216条纪录,总文本长度1538
第217条纪录,总文本长度1548
第218条纪录,总文本长度1558
第219条纪录,总文本长度1568
第220条纪录,总文本长度1578
第221条纪录,总文本长度1588
第222条纪录,总文本长度1598
第223条纪录,总文本长度1608
第224条纪录,总文本长度1618
第225条纪录,总文本长度1628
第226条纪录,总文本长度1638
第227条纪录,总文本长度1648
第228条纪录,总文本长度1658
第229条纪录,总文本长度1668
第230条纪录,总文本长度1678
第231条纪录,总文本长度1688
第232条纪录,总文本长度1698
第233条纪录,总文本长度1708
第234条纪录,总文本长度1718
第235条纪录,总文本长度1728
第236条纪录,总文本长度1738
第237条纪录,总文本长度1748
第238条纪录,总文本长度1758
第239条纪录,总文本长度1768
第240条纪录,总文本长度1778
第241条纪录,总文本长度1788
第242条纪录,总文本长度1798
第243条纪录,总文本长度1808
第244条纪录,总文本长度1818
第245条纪录,总文本长度1828
第246条纪录,总文本长度1838
第247条纪录,总文本长度1848
第248条纪录,总文本长度1858
第249条纪录,总文本长度1868
第250条纪录,总文本长度1878
第251条纪录,总文本长度1888
第252条纪录,总文本长度1898
第253条纪录,总文本长度1908
第254条纪录,总文本长度1918
第255条纪录,总文本长度1928
第256条纪录,总文本长度1938
第257条纪录,总文本长度1948
第258条纪录,总文本长度1958
第259条纪录,总文本长度1968
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-3-58c9313c3f8e> in <module>()
     16     for i in range(100):
     17         url = bad_comment_url_template.format(i)
---> 18         bad_comments += get_comments(url)
     19         print('第{}条纪录,总文本长度{}'.format(j, len(bad_comments)))
     20         j += 1

<ipython-input-2-2c40157c1d83> in get_comments(url)
     26 
     27         #对每一条评论进行内容部分的抽取
---> 28         for comment_info in comment_infos:
     29             comment_content = comment_info['content']
     30             str1 = comment_content + '\n'

TypeError: 'NoneType' object is not iterable

2.数据预处理


In [3]:
# 数据来源文件
good_file = 'data/good.txt'
bad_file  = 'data/bad.txt'

# 将文本中的标点符号过滤掉
def filter_punc(sentence):
    sentence = re.sub("[\s+\.\!\/_,$%^*(+\"\'“”《》?“]+|[+——!,。?、~@#¥%……&*():]+", "", sentence)  
    return(sentence)

#扫描所有的文本,分词、建立词典,分出正向还是负向的评论,is_filter可以过滤是否筛选掉标点符号
def Prepare_data(good_file, bad_file, is_filter = True):
    all_words = [] #存储所有的单词
    pos_sentences = [] #存储正向的评论
    neg_sentences = [] #存储负向的评论
    with open(good_file, 'r') as fr:
        for idx, line in enumerate(fr):
            if is_filter:
                #过滤标点符号
                line = filter_punc(line)
            #分词
            words = jieba.lcut(line)
            if len(words) > 0:
                all_words += words
                pos_sentences.append(words)
    print('{0} 包含 {1} 行, {2} 个词.'.format(good_file, idx+1, len(all_words)))

    count = len(all_words)
    with open(bad_file, 'r') as fr:
        for idx, line in enumerate(fr):
            if is_filter:
                line = filter_punc(line)
            words = jieba.lcut(line)
            if len(words) > 0:
                all_words += words
                neg_sentences.append(words)
    print('{0} 包含 {1} 行, {2} 个词.'.format(bad_file, idx+1, len(all_words)-count))

    #建立词典,diction的每一项为{w:[id, 单词出现次数]}
    diction = {}
    cnt = Counter(all_words)
    for word, freq in cnt.items():
        diction[word] = [len(diction), freq]
    print('字典大小:{}'.format(len(diction)))
    return(pos_sentences, neg_sentences, diction)

#根据单词返还单词的编码
def word2index(word, diction):
    if word in diction:
        value = diction[word][0]
    else:
        value = -1
    return(value)

#根据编码获得单词
def index2word(index, diction):
    for w,v in diction.items():
        if v[0] == index:
            return(w)
    return(None)

pos_sentences, neg_sentences, diction = Prepare_data(good_file, bad_file, True)
st = sorted([(v[1], w) for w, v in diction.items()])
st


Building prefix dict from the default dictionary ...
Dumping model to file cache /var/folders/tg/tk9z3txd6gx6135zwd7n5fjc0000gn/T/jieba.cache
Loading model cost 1.553 seconds.
Prefix dict has been built succesfully.
data/good.txt 包含 8089 行, 100844 个词.
data/bad.txt 包含 5076 行, 56073 个词.
字典大小:7133
Out[3]:
[(1, '000000000'),
 (1, '0000000000000000'),
 (1, '00000000000000000'),
 (1, '0000000000000000000000'),
 (1, '000000000000000000000000000'),
 (1, '00000000000000000000000000000000000000000000000'),
 (1, '1006'),
 (1, '111'),
 (1, '1111111'),
 (1, '11111111111'),
 (1, '11111111111111'),
 (1, '111111111111111111111111'),
 (1, '11111111111111111111111111111'),
 (1, '111111111111111111111111111111111'),
 (1, '1111111111111111111111111111111111111111111'),
 (1, '115'),
 (1, '122'),
 (1, '12315'),
 (1, '123456'),
 (1, '1255888'),
 (1, '128'),
 (1, '130'),
 (1, '136'),
 (1, '138'),
 (1, '15548'),
 (1, '158105'),
 (1, '160'),
 (1, '16067cm'),
 (1, '165140'),
 (1, '165cm'),
 (1, '166cm68kgxl'),
 (1, '16860'),
 (1, '170CM55'),
 (1, '170M'),
 (1, '170cm75kg'),
 (1, '171819202122'),
 (1, '172'),
 (1, '172153'),
 (1, '175cm'),
 (1, '1805'),
 (1, '180xxl'),
 (1, '183'),
 (1, '185'),
 (1, '1852XL'),
 (1, '185XXL'),
 (1, '185mm'),
 (1, '200000000'),
 (1, '20161218'),
 (1, '210'),
 (1, '211'),
 (1, '211111111111111111111111111'),
 (1, '22'),
 (1, '233333'),
 (1, '27'),
 (1, '28'),
 (1, '2XL'),
 (1, '34'),
 (1, '36'),
 (1, '3Q'),
 (1, '3X'),
 (1, '3o'),
 (1, '44'),
 (1, '45'),
 (1, '502'),
 (1, '60'),
 (1, '618'),
 (1, '65665556666666'),
 (1, '66'),
 (1, '666'),
 (1, '666666'),
 (1, '73120'),
 (1, '75'),
 (1, '78'),
 (1, '800'),
 (1, '881860881860881860'),
 (1, '88888888'),
 (1, '89'),
 (1, '91'),
 (1, '924'),
 (1, 'B'),
 (1, 'CARTELO'),
 (1, 'DAORD'),
 (1, 'Daord'),
 (1,
  'E5B0B1E698AFE5A4A7E4BA86E782B92020202020E58F91E8B4A7E4B99FE5A4AAE685A2202020202020E99E8BE58FAAE883BDE8AFB4E8BF98E58FAFE4BBA5'),
 (1, 'Fashion'),
 (1, 'GG'),
 (1, 'GM'),
 (1, 'GOODGOODGOOD'),
 (1, 'GOODGOODGOODGOOD'),
 (1, 'Itisnotgood'),
 (1, 'KTJF'),
 (1, 'LJ'),
 (1, 'LOGO'),
 (1, 'MLGB'),
 (1, 'MM'),
 (1, 'MiGo'),
 (1, 'NICCE'),
 (1, 'NICE'),
 (1, 'Omega'),
 (1, 'P'),
 (1, 'PLC'),
 (1, 'T1'),
 (1, 'X'),
 (1, 'XP1'),
 (1, 'acute'),
 (1, 'app'),
 (1, 'asdasdas'),
 (1, 'bb'),
 (1, 'bdhdiifjdjdhdhdueiejf'),
 (1, 'bhhbbfdfchh'),
 (1, 'da'),
 (1, 'de'),
 (1, 'dfgsdfgsvdfgsdfgsdf'),
 (1, 'die'),
 (1, 'epsilon'),
 (1, 'fashion'),
 (1, 'gel'),
 (1, 'gggfcvfg'),
 (1, 'gghhhh'),
 (1, 'gsdfgsdfgsdfvdfgsdfgsdfgsdfg'),
 (1, 'hao'),
 (1, 'hhhhh'),
 (1, 'i'),
 (1, 'ing'),
 (1, 'j'),
 (1, 'j8'),
 (1, 'jb'),
 (1, 'jd'),
 (1, 'jia'),
 (1, 'lj'),
 (1, 'lo'),
 (1, 'long'),
 (1, 'mmmmmm'),
 (1, 'n'),
 (1, 'nm'),
 (1, 'okok'),
 (1, 'omega'),
 (1, 'playboy'),
 (1, 'q'),
 (1, 'ri'),
 (1, 's'),
 (1, 'sb'),
 (1, 'sdrgdfsgdfsgdsfgsdfdfff'),
 (1, 'shi'),
 (1, 'si'),
 (1, 'tM'),
 (1, 'thesameasthewhite'),
 (1, 'tmd'),
 (1, 'tryreygjhgjghf'),
 (1, 'u'),
 (1, 'uoyo6'),
 (1, 'vip'),
 (1, 'wo'),
 (1, 'wqnmgb'),
 (1, 'xx'),
 (1, 'xxxL'),
 (1, 'ybb'),
 (1, '|'),
 (1, 'ˊ'),
 (1, 'ˋ'),
 (1, '∵'),
 (1, '≦'),
 (1, '≧'),
 (1, '⑧'),
 (1, '╭'),
 (1, '╮'),
 (1, '一一'),
 (1, '一万个'),
 (1, '一上午'),
 (1, '一不小心'),
 (1, '一个个'),
 (1, '一个包'),
 (1, '一个半月'),
 (1, '一中'),
 (1, '一事'),
 (1, '一二次'),
 (1, '一亮'),
 (1, '一侧'),
 (1, '一出'),
 (1, '一切顺利'),
 (1, '一办'),
 (1, '一包'),
 (1, '一匹'),
 (1, '一口'),
 (1, '一口气'),
 (1, '一向'),
 (1, '一坨'),
 (1, '一夜'),
 (1, '一大堆'),
 (1, '一对'),
 (1, '一巴掌'),
 (1, '一开'),
 (1, '一律'),
 (1, '一心'),
 (1, '一截'),
 (1, '一打'),
 (1, '一批'),
 (1, '一把'),
 (1, '一抹'),
 (1, '一拉'),
 (1, '一换'),
 (1, '一探'),
 (1, '一提'),
 (1, '一晚'),
 (1, '一月'),
 (1, '一有'),
 (1, '一条线'),
 (1, '一架'),
 (1, '一次次'),
 (1, '一段'),
 (1, '一比'),
 (1, '一泡'),
 (1, '一灰一'),
 (1, '一点两点'),
 (1, '一点儿'),
 (1, '一点半点'),
 (1, '一班'),
 (1, '一瓶'),
 (1, '一百'),
 (1, '一百块'),
 (1, '一盆'),
 (1, '一等品'),
 (1, '一肚'),
 (1, '一肚子气'),
 (1, '一肚子火'),
 (1, '一至'),
 (1, '一般化'),
 (1, '一般见识'),
 (1, '一装'),
 (1, '一见钟情'),
 (1, '一觉'),
 (1, '一说'),
 (1, '一货'),
 (1, '七五'),
 (1, '七八天'),
 (1, '七十'),
 (1, '七十八'),
 (1, '七家'),
 (1, '七彩'),
 (1, '七毛'),
 (1, '丅'),
 (1, '万恶'),
 (1, '三三三'),
 (1, '三五十块'),
 (1, '三倍'),
 (1, '三十'),
 (1, '三十五'),
 (1, '三十几块'),
 (1, '三十岁'),
 (1, '三厘米'),
 (1, '三块'),
 (1, '三处'),
 (1, '三无'),
 (1, '三星'),
 (1, '三月份'),
 (1, '三番五次'),
 (1, '三种'),
 (1, '三级'),
 (1, '三道'),
 (1, '上乘'),
 (1, '上亮'),
 (1, '上佳'),
 (1, '上内'),
 (1, '上去'),
 (1, '上吊'),
 (1, '上图'),
 (1, '上天'),
 (1, '上帅帅'),
 (1, '上年'),
 (1, '上来'),
 (1, '上海'),
 (1, '上等货'),
 (1, '上网'),
 (1, '上能'),
 (1, '上课'),
 (1, '上长'),
 (1, '上门'),
 (1, '下不为例'),
 (1, '下会'),
 (1, '下边带'),
 (1, '不久'),
 (1, '不乍样'),
 (1, '不付'),
 (1, '不住'),
 (1, '不作'),
 (1, '不佳'),
 (1, '不俗'),
 (1, '不信'),
 (1, '不值一提'),
 (1, '不值钱'),
 (1, '不全'),
 (1, '不冷脚'),
 (1, '不划算'),
 (1, '不加绒'),
 (1, '不勒脚'),
 (1, '不可多得'),
 (1, '不合格品'),
 (1, '不同于'),
 (1, '不吭声'),
 (1, '不周到'),
 (1, '不和身'),
 (1, '不复'),
 (1, '不够意思'),
 (1, '不大一样'),
 (1, '不太爽'),
 (1, '不失'),
 (1, '不学'),
 (1, '不容'),
 (1, '不寄到'),
 (1, '不对头'),
 (1, '不对板'),
 (1, '不帅'),
 (1, '不带'),
 (1, '不平'),
 (1, '不开'),
 (1, '不得好死'),
 (1, '不忍'),
 (1, '不快'),
 (1, '不怪'),
 (1, '不恳'),
 (1, '不愧'),
 (1, '不愿'),
 (1, '不戴'),
 (1, '不扎人'),
 (1, '不扎肉'),
 (1, '不拉得'),
 (1, '不挤'),
 (1, '不掉'),
 (1, '不搁'),
 (1, '不料'),
 (1, '不易'),
 (1, '不服气'),
 (1, '不杂'),
 (1, '不正'),
 (1, '不比大'),
 (1, '不气'),
 (1, '不满'),
 (1, '不知去向'),
 (1, '不算数'),
 (1, '不素'),
 (1, '不经脏'),
 (1, '不缝'),
 (1, '不耐烦'),
 (1, '不聊'),
 (1, '不脏'),
 (1, '不至于'),
 (1, '不良'),
 (1, '不要紧'),
 (1, '不見'),
 (1, '不见踪影'),
 (1, '不负'),
 (1, '不赖'),
 (1, '不赶'),
 (1, '不足'),
 (1, '不软'),
 (1, '不轻'),
 (1, '不过如此'),
 (1, '不进'),
 (1, '不送'),
 (1, '不适'),
 (1, '不选'),
 (1, '不透'),
 (1, '不配'),
 (1, '不重'),
 (1, '不长不短'),
 (1, '不问'),
 (1, '与众不同'),
 (1, '丑女'),
 (1, '丑好'),
 (1, '丑爆'),
 (1, '丑花'),
 (1, '专用'),
 (1, '专门'),
 (1, '东京'),
 (1, '东哥'),
 (1, '东方不败'),
 (1, '东非'),
 (1, '丝丝'),
 (1, '丝线'),
 (1, '丝绵'),
 (1, '丝袜'),
 (1, '丢人'),
 (1, '丢件'),
 (1, '丢掉'),
 (1, '丢脸'),
 (1, '两三根'),
 (1, '两号'),
 (1, '两周'),
 (1, '两回'),
 (1, '两支'),
 (1, '两月'),
 (1, '两步'),
 (1, '两水'),
 (1, '两种'),
 (1, '两遍'),
 (1, '严密'),
 (1, '严得'),
 (1, '严紧'),
 (1, '严重错误'),
 (1, '个别'),
 (1, '个子'),
 (1, '个把月'),
 (1, '个星'),
 (1, '丫'),
 (1, '中下部'),
 (1, '中到'),
 (1, '中奖'),
 (1, '中差'),
 (1, '中招'),
 (1, '中码'),
 (1, '中肯'),
 (1, '中评'),
 (1, '中途'),
 (1, '中通够'),
 (1, '丰满'),
 (1, '串'),
 (1, '串风'),
 (1, '为之动容'),
 (1, '主'),
 (1, '主体'),
 (1, '主图'),
 (1, '主意'),
 (1, '乃是'),
 (1, '乃至'),
 (1, '久些'),
 (1, '之一'),
 (1, '之余'),
 (1, '之作'),
 (1, '之家'),
 (1, '之慨且'),
 (1, '之气'),
 (1, '之清莲'),
 (1, '之选'),
 (1, '乌龟'),
 (1, '九十块'),
 (1, '九成'),
 (1, '也许'),
 (1, '乡镇'),
 (1, '买一送一'),
 (1, '买不到'),
 (1, '买京'),
 (1, '买会'),
 (1, '买假'),
 (1, '买回去'),
 (1, '买大一'),
 (1, '买太坑'),
 (1, '买方'),
 (1, '买狮港'),
 (1, '买祝'),
 (1, '买药'),
 (1, '买要'),
 (1, '买贵'),
 (1, '乱七八糟'),
 (1, '乱搞'),
 (1, '乱支'),
 (1, '乱放'),
 (1, '了不起'),
 (1, '了解'),
 (1, '争论'),
 (1, '事事如意'),
 (1, '二三十'),
 (1, '二个月'),
 (1, '二十一'),
 (1, '二十元'),
 (1, '二十多天'),
 (1, '二号'),
 (1, '二维'),
 (1, '二维码'),
 (1, '二边'),
 (1, '于'),
 (1, '于仁泰'),
 (1, '于是乎'),
 (1, '亏大'),
 (1, '互相理解'),
 (1, '五件'),
 (1, '五六次'),
 (1, '五十几'),
 (1, '五度'),
 (1, '五心'),
 (1, '五棵星'),
 (1, '五点'),
 (1, '五角星'),
 (1, '些小'),
 (1, '交'),
 (1, '交接处'),
 (1, '交班'),
 (1, '交给'),
 (1, '产品描述'),
 (1, '产生'),
 (1, '京客隆'),
 (1, '京豆京'),
 (1, '亮丽'),
 (1, '亮眼'),
 (1, '亲切'),
 (1, '亲已'),
 (1, '亲戚'),
 (1, '亲肤'),
 (1, '亲自'),
 (1, '亲要'),
 (1, '人一扯'),
 (1, '人为'),
 (1, '人事局'),
 (1, '人人'),
 (1, '人共赏'),
 (1, '人品'),
 (1, '人嘛'),
 (1, '人太差'),
 (1, '人心'),
 (1, '人情'),
 (1, '人格担保'),
 (1, '人生'),
 (1, '人神'),
 (1, '人类'),
 (1, '人群'),
 (1, '人间'),
 (1, '今天上午'),
 (1, '介意'),
 (1, '从未'),
 (1, '从来不'),
 (1, '从此'),
 (1, '从此以后'),
 (1, '从没'),
 (1, '从线'),
 (1, '从购'),
 (1, '从里到外'),
 (1, '仓管'),
 (1, '仔细检查'),
 (1, '付出'),
 (1, '仙人'),
 (1, '代购'),
 (1, '以免'),
 (1, '以及'),
 (1, '以往'),
 (1, '以至'),
 (1, '价一洗'),
 (1, '价不'),
 (1, '价亏'),
 (1, '价吧'),
 (1, '价小贵'),
 (1, '价打个'),
 (1, '价挺值'),
 (1, '价是'),
 (1, '价格下降'),
 (1, '价比高'),
 (1, '价没试'),
 (1, '价穿'),
 (1, '价算'),
 (1, '价线'),
 (1, '价谦'),
 (1, '价鞋'),
 (1, '任些'),
 (1, '仿冒品'),
 (1, '休闲服'),
 (1, '众'),
 (1, '众多'),
 (1, '优惠价'),
 (1, '优惠券'),
 (1, '优赞'),
 (1, '伙'),
 (1, '会先'),
 (1, '会卡脚'),
 (1, '会哈'),
 (1, '会多'),
 (1, '会太花'),
 (1, '会常'),
 (1, '会往'),
 (1, '传不上'),
 (1, '伤太深'),
 (1, '伤透'),
 (1, '伪劣'),
 (1, '伪劣产品'),
 (1, '佈'),
 (1, '位'),
 (1, '低档'),
 (1, '低过'),
 (1, '住手'),
 (1, '体会'),
 (1, '体恤'),
 (1, '体谅'),
 (1, '体贴'),
 (1, '体贴入微'),
 (1, '何来'),
 (1, '何用'),
 (1, '作业'),
 (1, '作用'),
 (1, '你們'),
 (1, '佷'),
 (1, '例如'),
 (1, '供'),
 (1, '供参考'),
 (1, '供货商'),
 (1, '侧边'),
 (1, '侧面'),
 (1, '便易'),
 (1, '便靓'),
 (1, '俗气'),
 (1, '保不'),
 (1, '保养'),
 (1, '保守'),
 (1, '保护'),
 (1, '保温'),
 (1, '保证质量'),
 (1, '保障'),
 (1, '信一星'),
 (1, '信不信'),
 (1, '信不过'),
 (1, '俩个'),
 (1, '修个'),
 (1, '修差'),
 (1, '修没见'),
 (1, '修理'),
 (1, '俯看'),
 (1, '俯首'),
 (1, '俺'),
 (1, '倒大霉'),
 (1, '借'),
 (1, '债任'),
 (1, '值钱'),
 (1, '倾吾之'),
 (1, '假期'),
 (1, '假话'),
 (1, '假鞋'),
 (1, '偏大买'),
 (1, '偏大号'),
 (1, '偏小一星'),
 (1, '偏小一码'),
 (1, '偏小好'),
 (1, '偏小得'),
 (1, '偏小换'),
 (1, '偏小本'),
 (1, '偏小需'),
 (1, '偏棕'),
 (1, '偏短'),
 (1, '偏紧'),
 (1, '偏红'),
 (1, '偏色'),
 (1, '偏贵'),
 (1, '偏高'),
 (1, '做不了'),
 (1, '做个'),
 (1, '做事'),
 (1, '做到'),
 (1, '做广告'),
 (1, '做成'),
 (1, '做错事'),
 (1, '停'),
 (1, '停停'),
 (1, '停好'),
 (1, '停机'),
 (1, '偷工减料'),
 (1, '催单'),
 (1, '儿'),
 (1, '儿响'),
 (1, '元件'),
 (1, '元宵节'),
 (1, '充满'),
 (1, '先买'),
 (1, '先寄'),
 (1, '先放'),
 (1, '先看'),
 (1, '光凭'),
 (1, '光四射'),
 (1, '光料'),
 (1, '光泽'),
 (1, '光面'),
 (1, '克'),
 (1, '免'),
 (1, '免包'),
 (1, '免强'),
 (1, '免得'),
 (1, '免郵費'),
 (1, '党'),
 (1, '入'),
 (1, '入住'),
 (1, '入秋'),
 (1, '全他'),
 (1, '全会'),
 (1, '全开'),
 (1, '全粘在'),
 (1, '全错'),
 (1, '全黑'),
 (1, '八'),
 (1, '八十九'),
 (1, '公'),
 (1, '公交车'),
 (1, '公愤'),
 (1, '公民'),
 (1, '公里'),
 (1, '六'),
 (1, '六个'),
 (1, '六件'),
 (1, '六八'),
 (1, '六天'),
 (1, '六点'),
 (1, '共同'),
 (1, '关怀'),
 (1, '关故'),
 (1, '兴趣'),
 (1, '其'),
 (1, '具体来说'),
 (1, '具备'),
 (1, '内地'),
 (1, '内好'),
 (1, '内存'),
 (1, '内敛'),
 (1, '冇'),
 (1, '再多添'),
 (1, '再寄'),
 (1, '再慢'),
 (1, '再烂'),
 (1, '再者'),
 (1, '再脏'),
 (1, '冒牌货'),
 (1, '冒领'),
 (1, '写全'),
 (1, '冤枉'),
 (1, '冬装'),
 (1, '冰棍'),
 (1, '冲动'),
 (1, '冻僵'),
 (1, '冻成'),
 (1, '准穿'),
 (1, '凌乱'),
 (1, '减价'),
 (1, '减点'),
 (1, '减肥'),
 (1, '凑合着'),
 (1, '凑齐'),
 (1, '几十倍'),
 (1, '几单'),
 (1, '几好'),
 (1, '几成'),
 (1, '几时'),
 (1, '几桶'),
 (1, '几滴'),
 (1, '凡客'),
 (1, '凭良心'),
 (1, '凸点'),
 (1, '出动'),
 (1, '出差'),
 (1, '出库'),
 (1, '出手'),
 (1, '出点'),
 (1, '出线'),
 (1, '分家'),
 (1, '分开'),
 (1, '分袖'),
 (1, '分要'),
 (1, '分赞'),
 (1, '划手'),
 (1, '划破'),
 (1, '刚取'),
 (1, '刚回来'),
 (1, '刚来时'),
 (1, '刚穿'),
 (1, '刚试'),
 (1, '刚购'),
 (1, '删除'),
 (1, '別上當'),
 (1, '別買'),
 (1, '利润'),
 (1, '利润率'),
 (1, '利索'),
 (1, '别个'),
 (1, '别信'),
 (1, '别家'),
 (1, '别总想'),
 (1, '别提'),
 (1, '别问'),
 (1, '刮'),
 (1, '刮花'),
 (1, '到哪去'),
 (1, '到家试'),
 (1, '到底'),
 (1, '到时候'),
 (1, '刷子'),
 (1, '刺痒'),
 (1, '剌'),
 (1, '前两天'),
 (1, '前卫'),
 (1, '前来'),
 (1, '前端'),
 (1, '剩'),
 (1, '剪线'),
 (1, '剪裁'),
 (1, '力到'),
 (1, '力给'),
 (1, '办公室'),
 (1, '办越'),
 (1, '功'),
 (1, '加个'),
 (1, '加件'),
 (1, '加呀'),
 (1, '加大'),
 (1, '加工厂'),
 (1, '加工资'),
 (1, '加有'),
 (1, '加棉'),
 (1, '加点'),
 (1, '加都'),
 (1, '动物'),
 (1, '助人为乐'),
 (1, '劲儿'),
 (1, '劳民伤财'),
 (1, '勉勉强强'),
 (1, '勒'),
 (1, '勒死'),
 (1, '匀称'),
 (1, '包包'),
 (1, '包括'),
 (1, '包身'),
 (1, '化学'),
 (1, '化学品'),
 (1, '化学纤维'),
 (1, '化工原料'),
 (1, '北'),
 (1, '北京'),
 (1, '北方'),
 (1, '区'),
 (1, '医药费'),
 (1, '十'),
 (1, '十一下'),
 (1, '十一分'),
 (1, '十七'),
 (1, '十八号'),
 (1, '十六号'),
 (1, '十几件'),
 (1, '十几公里'),
 (1, '十分钟'),
 (1, '十块钱'),
 (1, '十岁'),
 (1, '十部'),
 (1, '千万千万'),
 (1, '千千万万'),
 (1, '千里'),
 (1, '升'),
 (1, '升起'),
 (1, '半'),
 (1, '半夜三更'),
 (1, '半截'),
 (1, '半袖'),
 (1, '协议'),
 (1, '单一'),
 (1, '单上'),
 (1, '单为'),
 (1, '单价'),
 (1, '单党'),
 (1, '单品'),
 (1, '单坑'),
 (1, '单是'),
 (1, '单词'),
 (1, '单调'),
 (1, '单鞋'),
 (1, '卖出'),
 (1, '卖出去'),
 (1, '卖到'),
 (1, '卖无语'),
 (1, '卖货'),
 (1, '卖鞋'),
 (1, '南京'),
 (1, '南北'),
 (1, '南宁'),
 (1, '占便宜'),
 (1, '卡住'),
 (1, '卡帝'),
 (1, '卡拉'),
 (1, '卡片'),
 (1, '卫衣款'),
 (1, '印个'),
 (1, '印完'),
 (1, '印洗'),
 (1, '印渍'),
 (1, '印痕'),
 (1, '印纹'),
 (1, '印象'),
 (1, '即'),
 (1, '即使'),
 (1, '即刻'),
 (1, '厂商'),
 (1, '厂服'),
 (1, '历经'),
 (1, '压成'),
 (1, '压脚'),
 (1, '压迫感'),
 (1, '厕所'),
 (1, '厚厚的'),
 (1, '厚够'),
 (1, '原先'),
 (1, '原则'),
 (1, '原本'),
 (1, '厲害'),
 (1, '参加'),
 (1, '又糙'),
 (1, '叉口'),
 (1, '及其'),
 (1, '友友'),
 (1, '双卡'),
 (1, '双板'),
 (1, '双白鞋'),
 (1, '双肩'),
 (1, '反博'),
 (1, '反映'),
 (1, '反馈'),
 (1, '发图片'),
 (1, '发挥'),
 (1, '发毛'),
 (1, '发汗'),
 (1, '发火'),
 (1, '发痒'),
 (1, '发硬'),
 (1, '发索'),
 (1, '发贷'),
 (1, '发顺丰'),
 (1, '取出'),
 (1, '取取'),
 (1, '取回来'),
 (1, '取笑'),
 (1, '受害者'),
 (1, '受得了'),
 (1, '受欢迎'),
 (1, '受骗'),
 (1, '变大'),
 (1, '变大码'),
 (1, '变红'),
 (1, '变通'),
 (1, '变长'),
 (1, '口'),
 (1, '口口'),
 (1, '口气'),
 (1, '口罩'),
 (1, '句句'),
 (1, '另一边'),
 (1, '只前'),
 (1, '只发'),
 (1, '只用'),
 (1, '只能靠'),
 (1, '只顾'),
 (1, '叫花子'),
 (1, '叮'),
 (1, '可不'),
 (1, '可不可以'),
 (1, '可可'),
 (1, '可帅'),
 (1, '可怜'),
 (1, '可气'),
 (1, '可算松'),
 (1, '可累'),
 (1, '可谓'),
 (1, '台布'),
 (1, '史无前例'),
 (1, '右肩'),
 (1, '叽'),
 (1, '吃不下'),
 (1, '各位朋友'),
 (1, '各大'),
 (1, '各庄'),
 (1, '各来'),
 (1, '合算'),
 (1, '吉利'),
 (1, '吉林省'),
 (1, '吊儿郎当'),
 (1, '吊及'),
 (1, '吊草'),
 (1, '同'),
 (1, '同前'),
 (1, '同码'),
 (1, '名字'),
 (1, '名牌'),
 (1, '名符其实'),
 (1, '后会'),
 (1, '后发'),
 (1, '后期'),
 (1, '后果'),
 (1, '后边'),
 (1, '后退'),
 (1, '后门'),
 (1, '吐'),
 (1, '吐血'),
 (1, '吓死'),
 (1, '吓死人'),
 (1, '否定词'),
 (1, '否认'),
 (1, '含'),
 (1, '含量'),
 (1, '听君'),
 (1, '听说'),
 (1, '吱声'),
 (1, '吵'),
 (1, '吵吵'),
 (1, '吸毛刚'),
 (1, '吸水'),
 (1, '吸灰'),
 (1, '吹牛'),
 (1, '告知'),
 (1, '员到'),
 (1, '员吵'),
 (1, '员多'),
 (1, '员太度'),
 (1, '员妈'),
 (1, '呛'),
 ...]

二、词袋模型

词袋模型实际上是一种对文本进行向量化的手段,通过统计出词表上的每个单词出现频率,从而将一篇文章向量化

1. 训练数据准备


In [4]:
# 输入一个句子和相应的词典,得到这个句子的向量化表示
# 向量的尺寸为词典中词汇的个数,i位置上面的数值为第i个单词在sentence中出现的频率
def sentence2vec(sentence, dictionary):
    vector = np.zeros(len(dictionary))
    for l in sentence:
        vector[l] += 1
    return(1.0 * vector / len(sentence))

# 遍历所有句子,将每一个词映射成编码
dataset = [] #数据集
labels = [] #标签
sentences = [] #原始句子,调试用
# 处理正向评论
for sentence in pos_sentences:
    new_sentence = []
    for l in sentence:
        if l in diction:
            new_sentence.append(word2index(l, diction))
    dataset.append(sentence2vec(new_sentence, diction))
    labels.append(0) #正标签为0
    sentences.append(sentence)

# 处理负向评论
for sentence in neg_sentences:
    new_sentence = []
    for l in sentence:
        if l in diction:
            new_sentence.append(word2index(l, diction))
    dataset.append(sentence2vec(new_sentence, diction))
    labels.append(1) #负标签为1
    sentences.append(sentence)

#打乱所有的数据顺序,形成数据集
# indices为所有数据下标的一个全排列
indices = np.random.permutation(len(dataset))

#重新根据打乱的下标生成数据集dataset,标签集labels,以及对应的原始句子sentences
dataset = [dataset[i] for i in indices]
labels = [labels[i] for i in indices]
sentences = [sentences[i] for i in indices]

#对整个数据集进行划分,分为:训练集、校准集和测试集,其中校准和测试集合的长度都是整个数据集的10分之一
test_size = len(dataset) // 10
train_data = dataset[2 * test_size :]
train_label = labels[2 * test_size :]

valid_data = dataset[: test_size]
valid_label = labels[: test_size]

test_data = dataset[test_size : 2 * test_size]
test_label = labels[test_size : 2 * test_size]

2. 模型定义


In [5]:
# 一个简单的前馈神经网络,三层,第一层线性层,加一个非线性ReLU,第二层线性层,中间有10个隐含层神经元

# 输入维度为词典的大小:每一段评论的词袋模型
model = nn.Sequential(
    nn.Linear(len(diction), 10),
    nn.ReLU(),
    nn.Linear(10, 2),
    nn.LogSoftmax(),
)

def rightness(predictions, labels):
    """计算预测错误率的函数,其中predictions是模型给出的一组预测结果,batch_size行num_classes列的矩阵,labels是数据之中的正确答案"""
    pred = torch.max(predictions.data, 1)[1] # 对于任意一行(一个样本)的输出值的第1个维度,求最大,得到每一行的最大元素的下标
    rights = pred.eq(labels.data.view_as(pred)).sum() #将下标与labels中包含的类别进行比较,并累计得到比较正确的数量
    return rights, len(labels) #返回正确的数量和这一次一共比较了多少元素

3. 训练模型


In [10]:
# 损失函数为交叉熵
cost = torch.nn.NLLLoss()
# 优化算法为Adam,可以自动调节学习率
optimizer = torch.optim.Adam(model.parameters(), lr = 0.01)
records = []

#循环10个Epoch
for epoch in range(10):
    losses = []
    for i, data in enumerate(zip(train_data, train_label)):
        x, y = data
        
        # 需要将输入的数据进行适当的变形,主要是要多出一个batch_size的维度,也即第一个为1的维度
        x = Variable(torch.FloatTensor(x).view(1,-1))
        # x的尺寸:batch_size=1, len_dictionary
        # 标签也要加一层外衣以变成1*1的张量
        # y = Variable(torch.LongTensor(np.array([y])))
        y = Variable(torch.LongTensor([y]))
        # y的尺寸:batch_size=1, 1
        
        # 清空梯度
        optimizer.zero_grad()
        # 模型预测
        predict = model(x)
        # 计算损失函数
        loss = cost(predict, y)
        # 将损失函数数值加入到列表中
        losses.append(loss.data.numpy()[0])
        # 开始进行梯度反传
        loss.backward()
        # 开始对参数进行一步优化
        optimizer.step()
        
        # 每隔3000步,跑一下校验数据集的数据,输出临时结果
        if i % 3000 == 0:
            val_losses = []
            rights = []
            # 在所有校验数据集上实验
            for j, val in enumerate(zip(valid_data, valid_label)):
                x, y = val
                x = Variable(torch.FloatTensor(x).view(1,-1))
                y = Variable(torch.LongTensor(np.array([y])))
                predict = model(x)
                # 调用rightness函数计算准确度
                right = rightness(predict, y)
                rights.append(right)
                loss = cost(predict, y)
                val_losses.append(loss.data.numpy()[0])
                
            # 将校验集合上面的平均准确度计算出来
            right_ratio = 1.0 * np.sum([i[0] for i in rights]) / np.sum([i[1] for i in rights])
            print('第{}轮,训练损失:{:.2f}, 校验损失:{:.2f}, 校验准确率: {:.2f}'.format(epoch, np.mean(losses),
                                                                        np.mean(val_losses), right_ratio))
            records.append([np.mean(losses), np.mean(val_losses), right_ratio])


第0轮,训练损失:0.01, 校验损失:0.18, 校验准确率: 0.95
第0轮,训练损失:0.17, 校验损失:0.17, 校验准确率: 0.94
第0轮,训练损失:0.18, 校验损失:0.20, 校验准确率: 0.94
第0轮,训练损失:0.18, 校验损失:0.17, 校验准确率: 0.94
第1轮,训练损失:0.01, 校验损失:0.18, 校验准确率: 0.94
第1轮,训练损失:0.17, 校验损失:0.17, 校验准确率: 0.94
第1轮,训练损失:0.18, 校验损失:0.20, 校验准确率: 0.94
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-10-ed068039d1b6> in <module>()
     39             for j, val in enumerate(zip(train_data, train_label)):
     40                 x, y = val
---> 41                 x = Variable(torch.FloatTensor(x).view(1,-1))
     42                 y = Variable(torch.LongTensor(np.array([y])))
     43                 predict = model(x)

KeyboardInterrupt: 

In [10]:
# 绘制误差曲线
a = [i[0] for i in records]
b = [i[1] for i in records]
c = [i[2] for i in records]
plt.plot(a, label = 'Train Loss')
plt.plot(b, label = 'Valid Loss')
plt.plot(c, label = 'Valid Accuracy')
plt.xlabel('Steps')
plt.ylabel('Loss & Accuracy')
plt.legend()


Out[10]:
<matplotlib.legend.Legend at 0x124195d30>

In [107]:
# 保存、提取模型(为展示用)
#torch.save(model,'bow.mdl')
#model = torch.load('bow.mdl')

In [14]:
#在测试集上分批运行,并计算总的正确率
vals = [] #记录准确率所用列表

#对测试数据集进行循环
for data, target in zip(test_data, test_label):
    data, target = Variable(torch.FloatTensor(data).view(1,-1)), Variable(torch.LongTensor(np.array([target])))
    output = model(data) #将特征数据喂入网络,得到分类的输出
    val = rightness(output, target) #获得正确样本数以及总样本数
    vals.append(val) #记录结果

#计算准确率
rights = (sum([tup[0] for tup in vals]), sum([tup[1] for tup in vals]))
right_rate = 1.0 * rights[0] / rights[1]
right_rate


Out[14]:
0.9009976976208749

4. 解剖神经网络

接下来,我们对训练好的神经网络进行解剖分析。

我们看一看每一个神经元都在检测什么模式;

我们也希望看到神经网络在测试集上判断错误的数据上出错的原因

1). 查看每一层的模式


In [17]:
# 将神经网络的架构打印出来,方便后面的访问
model.named_parameters


Out[17]:
<bound method Module.named_parameters of Sequential (
  (0): Linear (7139 -> 10)
  (1): ReLU ()
  (2): Linear (10 -> 2)
  (3): LogSoftmax ()
)>

In [108]:
# 绘制出第二个全链接层的权重大小
# model[2]即提取第2层,网络一共4层,第0层为线性神经元,第1层为ReLU,第2层为第二层神经原链接,第3层为logsoftmax
plt.figure(figsize = (10, 7))
for i in range(model[2].weight.size()[0]):
    #if i == 1:
        weights = model[2].weight[i].data.numpy()
        plt.plot(weights, 'o-', label = i)
plt.legend()
plt.xlabel('Neuron in Hidden Layer')
plt.ylabel('Weights')


Out[108]:
<matplotlib.text.Text at 0x1194aef60>

In [109]:
# 将第一层神经元的权重都打印出来,一条曲线表示一个隐含层神经元。横坐标为输入层神经元编号,纵坐标为权重值大小
plt.figure(figsize = (10, 7))
for i in range(model[0].weight.size()[0]):
    #if i == 1:
        weights = model[0].weight[i].data.numpy()
        plt.plot(weights, alpha = 0.5, label = i)
plt.legend()
plt.xlabel('Neuron in Input Layer')
plt.ylabel('Weights')


Out[109]:
<matplotlib.text.Text at 0x12008d2e8>

In [110]:
# 将第二层的各个神经元与输入层的链接权重,挑出来最大的权重和最小的权重,并考察每一个权重所对应的单词是什么,把单词打印出来
# model[0]是取出第一层的神经元

for i in range(len(model[0].weight)):
    print('\n')
    print('第{}个神经元'.format(i))
    print('max:')
    st = sorted([(w,i) for i,w in enumerate(model[0].weight[i].data.numpy())])
    for i in range(20):
        word = index2word(st[-i][1],diction)
        print(word)
    print('min:')
    for i in range(20):
        word = index2word(st[i][1],diction)
        print(word)



第0个神经元
max:
垃圾
相当
耐心
划算
YY
nbsp
期望值
完全一致
quot
棒棒
超出
老朋友
超值
放心
亲身
试
款
努力
仔细
我用
min:
垃圾
差评
退
差
差劲
不好
无语
难看
承担
最差
我要
很差
发错
一星
半个
千万别
坑人
地摊货
不想
一股


第1个神经元
max:
快快乐乐
真
味道
这鞋
有没有
差劲
买
没有
老大
还
骗人
他
说
问
鳄鱼
含
妥妥
最坑
狂爱
茄色
min:
快快乐乐
健健康康
很
包装
也
~
拼
好
还来
仔细
amp
描述
支持
这次
网购
一定
不容
真的
一样
发货


第2个神经元
max:
说
距
锁边
简直
先是
白白净净
好几个
穿得酷
麻花
发财
别看
很窄
嗎
员太度
衣衣
一货
物流配送
体重
拒收
唯一
min:
说
有没有
买
骗人
啦
味道
老大
还
问
这鞋
的
没有
他
差劲
真
大力
方
物理
心急
同样


第3个神经元
max:
完全一致
退
地摊货
找
坑人
别
发错
丢
严重
要死
却
没收
破
差评
很差
一股
不如
骗人
千万别
差劲
min:
完全一致
谢谢
超值
力
没得说
试试
精细
惊喜
老公
广西南宁
很漂亮
还会来
昨天
物有所值
抱
还行
放心
没话说
合脚
试


第4个神经元
max:
完全一致
退
地摊货
找
严重
坑人
别
发错
丢
要死
很差
却
差评
骗人
差劲
签收
没收
竟然
一股
不如
min:
完全一致
不错
很棒
力
谢谢
没得说
广西南宁
惊喜
精细
昨天
实惠
很漂亮
舒适
漂亮
超值
挺不错
托
试试
刚好
物超所值


第5个神经元
max:
包装
邮寄
拉丝
Itisnotgood
费用
星星
分不多
稀饭
码须
时候
u
体恤
太肥
肯
趣味
皮好
眼瞎
显瘦
每周
剩
min:
包装
呢
速度
这次
宝贝
网购
店家
光顾
的
仔细
一定
快
服务
发货
很
描述
一样
支持
还来
也


第6个神经元
max:
有没有
超小
挺乱
断底
回來
相机
加呀
好久好久
字数
最好
粘胶
说句
那么
手一看
质量
脚耐折性
能
化学品
洞差
鬼刚
min:
有没有
还
说
差劲
这鞋
啦
的
他
味道
真
老大
买
没有
骗人
问
哄
╭
脱毛
温暖
绿色


第7个神经元
max:
折磨
蓝色
xxl
11111111111111
超棒会
害得
姐姐
买不起
揍
差是
回后
店主人
这是
吊儿郎当
喝架
再别
很轻
介绍
级
丝丝
min:
折磨
能
没想到
好
有
太薄
了
既
况且
品
一律
很
胸口
吓死
允许
网点
斤斤计较
挑不出
都还没
灰黄


第8个神经元
max:
不值
太高兴
做个
吃亏
五十几
全家
不脏
贪污犯
難得
腰
一一
偏大号
今天下午
药店
原以为
看不清
味闻
太水
贴切
破个
min:
不值
颜值
两样
M
相机
情侣装
变长
小鸡鸡
气球
老客户
点赞
不管
一共
够用
实体
中通
失去
表扬
懒得
遇到


第9个神经元
max:
质量
议
不让
颗心
一双
花瓣
邮回
粘毛
特步
孑买
沒有
看点
郁闷
返寄眥大
美国
随后
着实
本店
因为
;
min:
质量
衣服
不错
jia
配送
简述
有误
体贴
可行
拥有
不如
不得好死
篮子
太短
此类
尺码
重要
照片
三
留错

2. 寻找判断错误的原因


In [111]:
# 收集到在测试集中判断错误的句子
wrong_sentences = []
targets = []
j = 0
sent_indices = []
for data, target in zip(test_data, test_label):
    predictions = model(Variable(torch.FloatTensor(data).view(1,-1)))
    pred = torch.max(predictions.data, 1)[1]
    target = torch.LongTensor(np.array([target])).view_as(pred)
    rights = pred.eq(target)
    indices = np.where(rights.numpy() == 0)[0]
    for i in indices:
        wrong_sentences.append(data)
        targets.append(target[i])
        sent_indices.append(test_size + j + i)
    j += len(target)

In [178]:
# 逐个查看出错的句子是什么
idx = 65
print(sentences[sent_indices[idx]], targets[idx].numpy()[0])
lst = list(np.where(wrong_sentences[idx]>0)[0])
mm = list(map(lambda x:index2word(x, diction), lst))
print(mm)


['面料', '不是', '很', '好', '样式', '还', '可以'] 1
['不是', '还', '好', '面料', '样式', '可以', '很']

In [179]:
# 观察第一层的权重与输入向量的内积结果,也就是对隐含层神经元的输入,其中最大数值对应的项就是被激活的神经元
# 负值最小的神经元就是被抑制的神经元
model[0].weight.data.numpy().dot(wrong_sentences[idx].reshape(-1, 1))


Out[179]:
array([[  3.68143035e+00],
       [ -2.26595165e-02],
       [ -7.10971770e-03],
       [  2.42715367e-01],
       [  1.30977717e+00],
       [ -5.97434116e-03],
       [ -9.14752623e-03],
       [ -3.05731981e-02],
       [ -9.80353615e-04],
       [ -4.06029967e-04]])

In [180]:
# 显示输入句子的非零项,即对应单词不为空的项,看它们到隐含层指定神经元的权重是多少
model[0].weight[0].data.numpy()[np.where(wrong_sentences[idx]>0)[0]]


Out[180]:
array([ 1.09417236, -0.34580594,  4.98125982,  3.20058918,  6.02985764,
        6.62689352,  4.18304586], dtype=float32)

三、RNN模型

我们分别比较了两种RNN模型,一个是普通的RNN模型,另一个是LSTM。

本单元的主要目的是了解RNN模型如何实现,以及考察它们在测试数据集上的分类准确度

1. 普通RNN模型


In [181]:
# 需要重新数据预处理,主要是要加上标点符号,它对于RNN起到重要作用
# 数据来源文件
good_file = 'data/good.txt'
bad_file  = 'data/bad.txt'
# 生成正样例和反样例,以及词典,很有趣的是,词典中的词语竟然比不考虑标点符号的时候少了(要知道标点也是被当作一个单词的),
# 主要原因应该是总的分词出来的数量变少了。当去掉标点符号以后,有很多字的组合被当作了单词处理了。
pos_sentences, neg_sentences, diction = Prepare_data(good_file, bad_file, False)


good.txt 包含 8089 行, 136364 个词.
bad.txt 包含 5076 行, 75669 个词.
字典大小:7024

In [182]:
# 重新准备数据,输入给RNN
# 与词袋模型不同的是。每一个句子在词袋模型中都被表示为了固定长度的向量,其中长度为字典的尺寸
# 在RNN中,每一个句子就是被单独当成词语的序列来处理的,因此序列的长度是与句子等长的

dataset = []
labels = []
sentences = []

# 正例集合
for sentence in pos_sentences:
    new_sentence = []
    for l in sentence:
        if l in diction:
            # 注意将每个词编码
            new_sentence.append(word2index(l, diction))
    #每一个句子都是一个不等长的整数序列
    dataset.append(new_sentence)
    labels.append(0)
    sentences.append(sentence)

# 反例集合
for sentence in neg_sentences:
    new_sentence = []
    for l in sentence:
        if l in diction:
            new_sentence.append(word2index(l, diction))
    dataset.append(new_sentence)
    labels.append(1)
    sentences.append(sentence)

# 重新对数据洗牌,构造数据集合
indices = np.random.permutation(len(dataset))
dataset = [dataset[i] for i in indices]
labels = [labels[i] for i in indices]
sentences = [sentences[i] for i in indices]

test_size = len(dataset) // 10

# 训练集
train_data = dataset[2 * test_size :]
train_label = labels[2 * test_size :]

# 校验集
valid_data = dataset[: test_size]
valid_label = labels[: test_size]

# 测试集
test_data = dataset[test_size : 2 * test_size]
test_label = labels[test_size : 2 * test_size]

In [183]:
# 一个手动实现的RNN模型

class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNN, self).__init__()

        self.hidden_size = hidden_size
        # 一个embedding层
        self.embed = nn.Embedding(input_size, hidden_size)
        # 隐含层内部的相互链接
        self.i2h = nn.Linear(2 * hidden_size, hidden_size)
        # 隐含层到输出层的链接
        self.i2o = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax()

    def forward(self, input, hidden):
        
        # 先进行embedding层的计算,它可以把一个数或者数列,映射成一个向量或一组向量
        # input尺寸:seq_length, 1
        x = self.embed(input)
        # x尺寸:hidden_size
        
        # 将输入和隐含层的输出(hidden)耦合在一起构成了后续的输入
        combined = torch.cat((x, hidden), 1)
        # combined尺寸:2*hidden_size
        #
        # 从输入到隐含层的计算
        hidden = self.i2h(combined)
        # combined尺寸:hidden_size
        
        # 从隐含层到输出层的运算
        output = self.i2o(hidden)
        # output尺寸:output_size
        
        # softmax函数
        output = self.softmax(output)
        return output, hidden

    def initHidden(self):
        # 对隐含单元的初始化
        # 注意尺寸是:batch_size, hidden_size
        return Variable(torch.zeros(1, self.hidden_size))

In [184]:
# 开始训练这个RNN,10个隐含层单元
rnn = RNN(len(diction), 10, 2)

# 交叉熵评价函数
cost = torch.nn.NLLLoss()

# Adam优化器
optimizer = torch.optim.Adam(rnn.parameters(), lr = 0.0001)
records = []

# 学习周期10次
for epoch in range(10):
    losses = []
    for i, data in enumerate(zip(train_data, train_label)):
        x, y = data
        x = Variable(torch.LongTensor(x))
        #x尺寸:seq_length(序列的长度)
        y = Variable(torch.LongTensor(np.array([y])))
        #x尺寸:batch_size = 1,1
        optimizer.zero_grad()
        
        #初始化隐含层单元全为0
        hidden = rnn.initHidden()
        # hidden尺寸:batch_size = 1, hidden_size
        
        #手动实现RNN的时间步循环,x的长度就是总的循环时间步,因为要把x中的输入句子全部读取完毕
        for s in range(x.size()[0]):
            output, hidden = rnn(x[s], hidden)
        
        #校验函数
        loss = cost(output, y)
        losses.append(loss.data.numpy()[0])
        loss.backward()
        # 开始优化
        optimizer.step()
        if i % 3000 == 0:
            # 每间隔3000步来一次校验集上面的计算
            val_losses = []
            rights = []
            for j, val in enumerate(zip(valid_data, valid_label)):
                x, y = val
                x = Variable(torch.LongTensor(x))
                y = Variable(torch.LongTensor(np.array([y])))
                hidden = rnn.initHidden()
                for s in range(x.size()[0]):
                    output, hidden = rnn(x[s], hidden)
                right = rightness(output, y)
                rights.append(right)
                loss = cost(output, y)
                val_losses.append(loss.data.numpy()[0])
            # 计算准确度
            right_ratio = 1.0 * np.sum([i[0] for i in rights]) / np.sum([i[1] for i in rights])
            print('第{}轮,训练损失:{:.2f}, 测试损失:{:.2f}, 测试准确率: {:.2f}'.format(epoch, np.mean(losses),
                                                                        np.mean(val_losses), right_ratio))
            records.append([np.mean(losses), np.mean(val_losses), right_ratio])


第0轮,训练损失:0.51, 测试损失:0.78, 测试准确率: 0.40
第0轮,训练损失:0.68, 测试损失:0.67, 测试准确率: 0.60
第0轮,训练损失:0.66, 测试损失:0.66, 测试准确率: 0.61
第0轮,训练损失:0.65, 测试损失:0.64, 测试准确率: 0.65
第1轮,训练损失:1.49, 测试损失:0.64, 测试准确率: 0.65
第1轮,训练损失:0.62, 测试损失:0.62, 测试准确率: 0.66
第1轮,训练损失:0.60, 测试损失:0.59, 测试准确率: 0.68
第1轮,训练损失:0.58, 测试损失:0.52, 测试准确率: 0.74
第2轮,训练损失:1.09, 测试损失:0.51, 测试准确率: 0.75
第2轮,训练损失:0.49, 测试损失:0.49, 测试准确率: 0.77
第2轮,训练损失:0.46, 测试损失:0.47, 测试准确率: 0.78
第2轮,训练损失:0.46, 测试损失:0.45, 测试准确率: 0.79
第3轮,训练损失:0.56, 测试损失:0.45, 测试准确率: 0.80
第3轮,训练损失:0.44, 测试损失:0.44, 测试准确率: 0.81
第3轮,训练损失:0.41, 测试损失:0.43, 测试准确率: 0.81
第3轮,训练损失:0.41, 测试损失:0.42, 测试准确率: 0.82
第4轮,训练损失:0.41, 测试损失:0.42, 测试准确率: 0.82
第4轮,训练损失:0.40, 测试损失:0.41, 测试准确率: 0.82
第4轮,训练损失:0.38, 测试损失:0.41, 测试准确率: 0.83
第4轮,训练损失:0.38, 测试损失:0.40, 测试准确率: 0.83
第5轮,训练损失:0.33, 测试损失:0.40, 测试准确率: 0.83
第5轮,训练损失:0.38, 测试损失:0.39, 测试准确率: 0.83
第5轮,训练损失:0.36, 测试损失:0.40, 测试准确率: 0.84
第5轮,训练损失:0.36, 测试损失:0.38, 测试准确率: 0.84
第6轮,训练损失:0.28, 测试损失:0.38, 测试准确率: 0.85
第6轮,训练损失:0.37, 测试损失:0.38, 测试准确率: 0.84
第6轮,训练损失:0.34, 测试损失:0.38, 测试准确率: 0.86
第6轮,训练损失:0.34, 测试损失:0.37, 测试准确率: 0.85
第7轮,训练损失:0.24, 测试损失:0.37, 测试准确率: 0.86
第7轮,训练损失:0.35, 测试损失:0.37, 测试准确率: 0.86
第7轮,训练损失:0.32, 测试损失:0.37, 测试准确率: 0.86
第7轮,训练损失:0.33, 测试损失:0.36, 测试准确率: 0.86
第8轮,训练损失:0.21, 测试损失:0.36, 测试准确率: 0.87
第8轮,训练损失:0.34, 测试损失:0.36, 测试准确率: 0.86
第8轮,训练损失:0.31, 测试损失:0.37, 测试准确率: 0.87
第8轮,训练损失:0.32, 测试损失:0.35, 测试准确率: 0.87
第9轮,训练损失:0.18, 测试损失:0.35, 测试准确率: 0.87
第9轮,训练损失:0.32, 测试损失:0.35, 测试准确率: 0.87
第9轮,训练损失:0.30, 测试损失:0.36, 测试准确率: 0.87
第9轮,训练损失:0.31, 测试损失:0.35, 测试准确率: 0.87

In [185]:
# 绘制误差曲线
a = [i[0] for i in records]
b = [i[1] for i in records]
c = [i[2] for i in records]
plt.plot(a, label = 'Train Loss')
plt.plot(b, label = 'Valid Loss')
plt.plot(c, label = 'Valid Accuracy')
plt.xlabel('Steps')
plt.ylabel('Loss & Accuracy')
plt.legend()


Out[185]:
<matplotlib.legend.Legend at 0x124c8cb70>

In [186]:
#在测试集上运行,并计算准确率
vals = [] #记录准确率所用列表

#对测试数据集进行循环
for j, test in enumerate(zip(test_data, test_label)):
    x, y = test
    x = Variable(torch.LongTensor(x))
    y = Variable(torch.LongTensor(np.array([y])))
    hidden = rnn.initHidden()
    for s in range(x.size()[0]):
        output, hidden = rnn(x[s], hidden)
    right = rightness(output, y)
    rights.append(right)
    val = rightness(output, y) #获得正确样本数以及总样本数
    vals.append(val) #记录结果

#计算准确率
rights = (sum([tup[0] for tup in vals]), sum([tup[1] for tup in vals]))
right_rate = 1.0 * rights[0] / rights[1]
right_rate


Out[186]:
0.8860182370820668

In [76]:
#保存、加载模型(为讲解用)
#torch.save(rnn, 'rnn.mdl')
#rnn = torch.load('rnn.mdl')


/Users/jake/anaconda/envs/learning_pytorch/lib/python3.5/site-packages/torch/serialization.py:147: UserWarning: Couldn't retrieve source code for container of type RNN. It won't be checked for correctness upon loading.
  "type " + obj.__name__ + ". It won't be checked "

LSTM网络

普通RNN的效果并不好,我们尝试利用改进型的RNN,即LSTM。LSTM与RNN最大的区别就是在于每个神经元中多增加了3个控制门:遗忘门、输入门和输出门. 另外,在每个隐含层神经元中,LSTM多了一个cell的状态,起到了记忆的作用。

这就使得LSTM可以记忆更长时间的Pattern


In [199]:
class LSTMNetwork(nn.Module):
    def __init__(self, input_size, hidden_size, n_layers=1):
        super(LSTMNetwork, self).__init__()
        self.n_layers = n_layers
        self.hidden_size = hidden_size

        # LSTM的构造如下:一个embedding层,将输入的任意一个单词映射为一个向量
        # 一个LSTM隐含层,共有hidden_size个LSTM神经元
        # 一个全链接层,外接一个softmax输出
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.lstm = nn.LSTM(hidden_size, hidden_size, n_layers)
        self.fc = nn.Linear(hidden_size, 2)
        self.logsoftmax = nn.LogSoftmax()

    def forward(self, input, hidden=None):
        
        #input尺寸: seq_length
        #词向量嵌入
        embedded = self.embedding(input)
        #embedded尺寸: seq_length, hidden_size
        
        #PyTorch设计的LSTM层有一个特别别扭的地方是,输入张量的第一个维度需要是时间步,
        #第二个维度才是batch_size,所以需要对embedded变形
        embedded = embedded.view(input.data.size()[0], 1, self.hidden_size)
        #embedded尺寸: seq_length, batch_size = 1, hidden_size
    
        #调用PyTorch自带的LSTM层函数,注意有两个输入,一个是输入层的输入,另一个是隐含层自身的输入
        # 输出output是所有步的隐含神经元的输出结果,hidden是隐含层在最后一个时间步的状态。
        # 注意hidden是一个tuple,包含了最后时间步的隐含层神经元的输出,以及每一个隐含层神经元的cell的状态
        
        output, hidden = self.lstm(embedded, hidden)
        #output尺寸: seq_length, batch_size = 1, hidden_size
        #hidden尺寸: 二元组(n_layer = 1 * batch_size = 1 * hidden_size, n_layer = 1 * batch_size = 1 * hidden_size)
        
        #我们要把最后一个时间步的隐含神经元输出结果拿出来,送给全连接层
        output = output[-1,...]
        #output尺寸: batch_size = 1, hidden_size

        #全链接层
        out = self.fc(output)
        #out尺寸: batch_size = 1, output_size
        # softmax
        out = self.logsoftmax(out)
        return out

    def initHidden(self):
        # 对隐单元的初始化
        
        # 对隐单元输出的初始化,全0.
        # 注意hidden和cell的维度都是layers,batch_size,hidden_size
        hidden = Variable(torch.zeros(self.n_layers, 1, self.hidden_size))
        # 对隐单元内部的状态cell的初始化,全0
        cell = Variable(torch.zeros(self.n_layers, 1, self.hidden_size))
        return (hidden, cell)

In [200]:
# 开始训练LSTM网络

# 构造一个LSTM网络的实例
lstm = LSTMNetwork(len(diction), 10, 2)

#定义损失函数
cost = torch.nn.NLLLoss()

#定义优化器
optimizer = torch.optim.Adam(lstm.parameters(), lr = 0.001)
records = []

# 开始训练,一共5个epoch,否则容易过拟合
for epoch in range(5):
    losses = []
    for i, data in enumerate(zip(train_data, train_label)):
        x, y = data
        x = Variable(torch.LongTensor(x))
        #x尺寸:seq_length,序列的长度
        y = Variable(torch.LongTensor(np.array([y])))
        #y尺寸:batch_size = 1, 1
        optimizer.zero_grad()
        
        #初始化LSTM隐含层单元的状态
        hidden = lstm.initHidden()
        #hidden: 二元组(n_layer = 1 * batch_size = 1 * hidden_size, n_layer = 1 * batch_size = 1 * hidden_size)
        
        #让LSTM开始做运算,注意,不需要手工编写对时间步的循环,而是直接交给PyTorch的LSTM层。
        #它自动会根据数据的维度计算若干时间步
        output = lstm(x, hidden)
        #output尺寸: batch_size = 1, output_size
        
        #损失函数
        loss = cost(output, y)
        losses.append(loss.data.numpy()[0])
        
        #反向传播
        loss.backward()
        optimizer.step()
        
        #每隔3000步,跑一次校验集,并打印结果
        if i % 3000 == 0:
            val_losses = []
            rights = []
            for j, val in enumerate(zip(valid_data, valid_label)):
                x, y = val
                x = Variable(torch.LongTensor(x))
                y = Variable(torch.LongTensor(np.array([y])))
                hidden = lstm.initHidden()
                output = lstm(x, hidden)
                #计算校验数据集上的分类准确度
                right = rightness(output, y)
                rights.append(right)
                loss = cost(output, y)
                val_losses.append(loss.data.numpy()[0])
            right_ratio = 1.0 * np.sum([i[0] for i in rights]) / np.sum([i[1] for i in rights])
            print('第{}轮,训练损失:{:.2f}, 测试损失:{:.2f}, 测试准确率: {:.2f}'.format(epoch, np.mean(losses),
                                                                        np.mean(val_losses), right_ratio))
            records.append([np.mean(losses), np.mean(val_losses), right_ratio])


第0轮,训练损失:0.51, 测试损失:0.76, 测试准确率: 0.40
第0轮,训练损失:0.51, 测试损失:0.44, 测试准确率: 0.81
第0轮,训练损失:0.42, 测试损失:0.39, 测试准确率: 0.84
第0轮,训练损失:0.39, 测试损失:0.37, 测试准确率: 0.86
第1轮,训练损失:0.12, 测试损失:0.35, 测试准确率: 0.87
第1轮,训练损失:0.30, 测试损失:0.35, 测试准确率: 0.87
第1轮,训练损失:0.28, 测试损失:0.34, 测试准确率: 0.87
第1轮,训练损失:0.28, 测试损失:0.33, 测试准确率: 0.88
第2轮,训练损失:0.05, 测试损失:0.33, 测试准确率: 0.88
第2轮,训练损失:0.26, 测试损失:0.33, 测试准确率: 0.88
第2轮,训练损失:0.25, 测试损失:0.33, 测试准确率: 0.88
第2轮,训练损失:0.25, 测试损失:0.33, 测试准确率: 0.89
第3轮,训练损失:0.04, 测试损失:0.32, 测试准确率: 0.89
第3轮,训练损失:0.24, 测试损失:0.33, 测试准确率: 0.89
第3轮,训练损失:0.23, 测试损失:0.32, 测试准确率: 0.88
第3轮,训练损失:0.23, 测试损失:0.32, 测试准确率: 0.89
第4轮,训练损失:0.04, 测试损失:0.32, 测试准确率: 0.89
第4轮,训练损失:0.22, 测试损失:0.33, 测试准确率: 0.88
第4轮,训练损失:0.21, 测试损失:0.33, 测试准确率: 0.89
第4轮,训练损失:0.21, 测试损失:0.33, 测试准确率: 0.89

In [197]:
# 绘制误差曲线
a = [i[0] for i in records]
b = [i[1] for i in records]
c = [i[2] for i in records]
plt.plot(a, label = 'Train Loss')
plt.plot(b, label = 'Valid Loss')
plt.plot(c, label = 'Valid Accuracy')
plt.xlabel('Steps')
plt.ylabel('Loss & Accuracy')
plt.legend()


Out[197]:
<matplotlib.legend.Legend at 0x123c22fd0>

In [198]:
#在测试集上计算总的正确率
vals = [] #记录准确率所用列表

#对测试数据集进行循环
for j, test in enumerate(zip(test_data, test_label)):
    x, y = test
    x = Variable(torch.LongTensor(x))
    y = Variable(torch.LongTensor(np.array([y])))
    hidden = lstm.initHidden()
    output = lstm(x, hidden)
    right = rightness(output, y)
    rights.append(right)
    val = rightness(output, y) #获得正确样本数以及总样本数
    vals.append(val) #记录结果

#计算准确率
rights = (sum([tup[0] for tup in vals]), sum([tup[1] for tup in vals]))
right_rate = 1.0 * rights[0] / rights[1]
right_rate


Out[198]:
0.9019756838905775

In [99]:
#保存、加载模型(为讲解用)
#torch.save(lstm, 'lstm.mdl')
#rnn = torch.load('rnn.mdl')


/Users/jake/anaconda/envs/learning_pytorch/lib/python3.5/site-packages/torch/serialization.py:147: UserWarning: Couldn't retrieve source code for container of type LSTMNetwork. It won't be checked for correctness upon loading.
  "type " + obj.__name__ + ". It won't be checked "

本文件是集智AI学园http://campus.swarma.org 出品的“火炬上的深度学习”第III课的配套源代码


In [ ]: