Skip-gram word2vec

In this notebook, I'll lead you through using TensorFlow to implement the word2vec algorithm using the skip-gram architecture. By implementing this, you'll learn about embedding words for use in natural language processing. This will come in handy when dealing with things like machine translation.

本文使用skip-gram来实现word2vec算法,word2vec主要是在NLP中使用

Readings

Here are the resources I used to build this notebook. I suggest reading these either beforehand or while you're working on this material.

Word embeddings

When you're dealing with words in text, you end up with tens of thousands of classes to predict, one for each word. Trying to one-hot encode these words is massively inefficient, you'll have one element set to 1 and the other 50,000 set to 0. The matrix multiplication going into the first hidden layer will have almost all of the resulting values be zero. This a huge waste of computation.

To solve this problem and greatly increase the efficiency of our networks, we use what are called embeddings. Embeddings are just a fully connected layer like you've seen before. We call this layer the embedding layer and the weights are embedding weights. We skip the multiplication into the embedding layer by instead directly grabbing the hidden layer values from the weight matrix. We can do this because the multiplication of a one-hot encoded vector with a matrix returns the row of the matrix corresponding the index of the "on" input unit.

Instead of doing the matrix multiplication, we use the weight matrix as a lookup table. We encode the words as integers, for example "heart" is encoded as 958, "mind" as 18094. Then to get hidden layer values for "heart", you just take the 958th row of the embedding matrix. This process is called an embedding lookup and the number of hidden units is the embedding dimension.

There is nothing magical going on here. The embedding lookup table is just a weight matrix. The embedding layer is just a hidden layer. The lookup is just a shortcut for the matrix multiplication. The lookup table is trained just like any weight matrix as well.

Embeddings aren't only used for words of course. You can use them for any model where you have a massive number of classes. A particular type of model called Word2Vec uses the embedding layer to find vector representations of words that contain semantic meaning.

Word2Vec

The word2vec algorithm finds much more efficient representations by finding vectors that represent the words. These vectors also contain semantic information about the words. Words that show up in similar contexts, such as "black", "white", and "red" will have vectors near each other. There are two architectures for implementing word2vec, CBOW (Continuous Bag-Of-Words) and Skip-gram.

当我们处理语言的时候,我们如果使用one-hot encoding的方法,可能这个向量会有上万,千万,非常的没有效率

word2vec试图找到一个更有效的表示方法,能够使用更低的维度来对单词进行编码

有两种模型:CBOW (Continuous Bag-Of-Words) 和 Skip-gram

In this implementation, we'll be using the skip-gram architecture because it performs better than CBOW. Here, we pass in a word and try to predict the words surrounding it in the text. In this way, we can train the network to learn representations for words that show up in similar contexts.

First up, importing packages.


In [1]:
import time

import numpy as np
import tensorflow as tf

import utils

Load the text8 dataset, a file of cleaned up Wikipedia articles from Matt Mahoney. The next cell will download the data set to the data folder. Then you can extract it and delete the archive file to save storage space.


In [2]:
from urllib.request import urlretrieve
from os.path import isfile, isdir
from tqdm import tqdm
import zipfile

dataset_folder_path = 'data'
dataset_filename = 'text8.zip'
dataset_name = 'Text8 Dataset'

class DLProgress(tqdm):
    last_block = 0

    def hook(self, block_num=1, block_size=1, total_size=None):
        self.total = total_size
        self.update((block_num - self.last_block) * block_size)
        self.last_block = block_num

if not isfile(dataset_filename):
    with DLProgress(unit='B', unit_scale=True, miniters=1, desc=dataset_name) as pbar:
        urlretrieve(
            'http://mattmahoney.net/dc/text8.zip',
            dataset_filename,
            pbar.hook)

if not isdir(dataset_folder_path):
    with zipfile.ZipFile(dataset_filename) as zip_ref:
        zip_ref.extractall(dataset_folder_path)
        
with open('data/text8') as f:
    text = f.read()

Preprocessing

Here I'm fixing up the text to make training easier. This comes from the utils module I wrote. The preprocess function coverts any punctuation【标点符号】 into tokens, so a period is changed to <PERIOD>. In this data set, there aren't any periods, but it will help in other NLP problems. I'm also removing all words that show up five or fewer times in the dataset【去除出现少于5次的单词】. This will greatly reduce issues due to noise in the data and improve the quality of the vector representations. If you want to write your own functions for this stuff, go for it.


In [3]:
words = utils.preprocess(text) # 预处理中将标点符号都替换为了 <标点符号> 这种形式,并且去除了少于5的单词
print(words[:30])


['anarchism', 'originated', 'as', 'a', 'term', 'of', 'abuse', 'first', 'used', 'against', 'early', 'working', 'class', 'radicals', 'including', 'the', 'diggers', 'of', 'the', 'english', 'revolution', 'and', 'the', 'sans', 'culottes', 'of', 'the', 'french', 'revolution', 'whilst']

In [4]:
print("Total words: {}".format(len(words)))
print("Unique words: {}".format(len(set(words))))


Total words: 16680599
Unique words: 63641

And here I'm creating dictionaries to covert words to integers and backwards, integers to words. The integers are assigned in descending frequency order, so the most frequent word ("the") is given the integer 0 and the next most frequent is 1 and so on. The words are converted to integers and stored in the list int_words.

创建两个关系,index => word, word => index, word按照出现的频率降序排列


In [5]:
vocab_to_int, int_to_vocab = utils.create_lookup_tables(words)
int_words = [vocab_to_int[word] for word in words]
# 此处我们将words变为了 integers 的数组

Subsampling【抽样】

由于一些高频词(如the,of),并没有提供什么有效的信息,因此我们要将这些词过滤掉

Words that show up often such as "the", "of", and "for" don't provide much context to the nearby words. If we discard some of them, we can remove some of the noise from our data and in return get faster training and better representations. This process is called subsampling by Mikolov. For each word $w_i$ in the training set, we'll discard it with probability given by

$$ P(w_i) = 1 - \sqrt{\frac{t}{f(w_i)}} $$

where $t$ is a threshold parameter and $f(w_i)$ is the frequency of word $w_i$ in the total dataset.

I'm going to leave this up to you as an exercise. This is more of a programming challenge, than about deep learning specifically. But, being able to prepare your data for your network is an important skill to have. Check out my solution to see how I did it.

Exercise: Implement subsampling for the words in int_words. That is, go through int_words and discard each word given the probablility $P(w_i)$ shown above. Note that $P(w_i)$ is the probability that a word is discarded. Assign the subsampled data to train_words.


In [6]:
from collections import Counter
#int_words[:10] # 每个出现的次数
word_counts = Counter(int_words)

In [7]:
word_counts.most_common()


Out[7]:
[(0, 1061396),
 (1, 593677),
 (2, 416629),
 (3, 411764),
 (4, 372201),
 (5, 325873),
 (6, 316376),
 (7, 264975),
 (8, 250430),
 (9, 192644),
 (10, 183153),
 (11, 131815),
 (12, 125285),
 (13, 118445),
 (14, 116710),
 (15, 115789),
 (16, 114775),
 (17, 112807),
 (18, 111831),
 (19, 109510),
 (20, 108182),
 (21, 102145),
 (22, 99683),
 (23, 95603),
 (24, 91250),
 (25, 76527),
 (26, 73334),
 (27, 72871),
 (28, 68945),
 (29, 62603),
 (30, 61925),
 (31, 61281),
 (32, 58832),
 (33, 54788),
 (34, 54576),
 (35, 53573),
 (36, 44358),
 (37, 44033),
 (38, 39712),
 (39, 39086),
 (40, 37866),
 (41, 35358),
 (42, 32433),
 (43, 31523),
 (44, 29567),
 (45, 28810),
 (46, 28553),
 (47, 28161),
 (48, 28100),
 (49, 26229),
 (50, 26223),
 (51, 25563),
 (52, 25519),
 (53, 25383),
 (54, 24413),
 (55, 24096),
 (56, 23997),
 (57, 23770),
 (58, 22737),
 (59, 22707),
 (60, 21125),
 (61, 20623),
 (62, 20484),
 (63, 20477),
 (64, 20412),
 (65, 19864),
 (66, 19463),
 (67, 19206),
 (68, 19115),
 (69, 18807),
 (70, 17949),
 (71, 17581),
 (72, 17516),
 (73, 17377),
 (74, 17236),
 (75, 16155),
 (76, 15861),
 (77, 15737),
 (78, 15574),
 (79, 15122),
 (80, 14935),
 (81, 14916),
 (82, 14696),
 (83, 14629),
 (84, 14578),
 (85, 14494),
 (86, 14437),
 (87, 14420),
 (88, 14151),
 (89, 14011),
 (90, 13380),
 (91, 13296),
 (92, 12987),
 (93, 12904),
 (94, 12722),
 (95, 12623),
 (96, 12560),
 (97, 12445),
 (98, 12363),
 (99, 12347),
 (100, 12275),
 (101, 11931),
 (102, 11868),
 (103, 11847),
 (104, 11803),
 (105, 11755),
 (106, 11753),
 (107, 11721),
 (108, 11701),
 (109, 11537),
 (110, 11536),
 (111, 11426),
 (112, 11399),
 (113, 11323),
 (114, 11285),
 (115, 10976),
 (116, 10971),
 (117, 10691),
 (118, 10629),
 (119, 10627),
 (120, 10572),
 (121, 10561),
 (122, 10550),
 (123, 10371),
 (124, 10332),
 (125, 10307),
 (126, 10195),
 (127, 10172),
 (128, 10096),
 (129, 10056),
 (130, 9858),
 (131, 9854),
 (132, 9774),
 (133, 9633),
 (134, 9591),
 (135, 9539),
 (136, 9534),
 (137, 9500),
 (138, 9412),
 (139, 9388),
 (140, 9375),
 (141, 9286),
 (142, 9168),
 (143, 9133),
 (144, 8956),
 (145, 8928),
 (146, 8904),
 (147, 8861),
 (148, 8822),
 (149, 8773),
 (150, 8736),
 (151, 8700),
 (152, 8659),
 (153, 8581),
 (154, 8491),
 (155, 8432),
 (156, 8372),
 (157, 8337),
 (158, 8312),
 (159, 8304),
 (160, 8244),
 (161, 8209),
 (162, 8206),
 (163, 8110),
 (164, 8002),
 (165, 7987),
 (166, 7933),
 (167, 7898),
 (168, 7896),
 (169, 7878),
 (170, 7862),
 (171, 7858),
 (172, 7797),
 (173, 7790),
 (174, 7770),
 (175, 7763),
 (176, 7698),
 (177, 7682),
 (178, 7672),
 (179, 7628),
 (180, 7611),
 (181, 7585),
 (182, 7553),
 (183, 7522),
 (184, 7488),
 (185, 7481),
 (186, 7456),
 (187, 7435),
 (188, 7417),
 (189, 7400),
 (190, 7378),
 (191, 7368),
 (192, 7328),
 (193, 7262),
 (194, 7219),
 (195, 7118),
 (196, 7113),
 (197, 7110),
 (198, 7043),
 (199, 7034),
 (200, 6970),
 (201, 6943),
 (202, 6908),
 (203, 6872),
 (204, 6865),
 (205, 6786),
 (206, 6690),
 (207, 6684),
 (208, 6604),
 (209, 6591),
 (210, 6576),
 (211, 6531),
 (212, 6518),
 (213, 6494),
 (214, 6433),
 (215, 6419),
 (216, 6410),
 (217, 6400),
 (218, 6384),
 (219, 6362),
 (220, 6361),
 (221, 6300),
 (222, 6238),
 (223, 6220),
 (224, 6184),
 (225, 6184),
 (226, 6166),
 (227, 6146),
 (228, 6141),
 (229, 6064),
 (230, 6056),
 (231, 6001),
 (232, 5985),
 (233, 5976),
 (234, 5967),
 (235, 5874),
 (236, 5862),
 (237, 5857),
 (238, 5854),
 (239, 5844),
 (240, 5843),
 (241, 5806),
 (242, 5778),
 (243, 5761),
 (244, 5758),
 (245, 5711),
 (246, 5684),
 (247, 5678),
 (248, 5661),
 (249, 5654),
 (250, 5652),
 (251, 5634),
 (252, 5605),
 (253, 5582),
 (254, 5468),
 (255, 5450),
 (256, 5404),
 (257, 5350),
 (258, 5345),
 (259, 5343),
 (260, 5319),
 (261, 5314),
 (262, 5289),
 (263, 5268),
 (264, 5262),
 (265, 5248),
 (266, 5246),
 (267, 5243),
 (268, 5236),
 (269, 5231),
 (270, 5221),
 (271, 5216),
 (272, 5207),
 (273, 5196),
 (274, 5163),
 (275, 5160),
 (276, 5157),
 (277, 5143),
 (278, 5124),
 (279, 5113),
 (280, 5108),
 (281, 5105),
 (282, 5099),
 (283, 5096),
 (284, 5094),
 (285, 5087),
 (286, 5070),
 (287, 5025),
 (288, 5013),
 (289, 5010),
 (290, 5007),
 (291, 5004),
 (292, 4979),
 (293, 4974),
 (294, 4970),
 (295, 4939),
 (296, 4916),
 (297, 4910),
 (298, 4902),
 (299, 4898),
 (300, 4884),
 (301, 4872),
 (302, 4813),
 (303, 4802),
 (304, 4801),
 (305, 4776),
 (306, 4762),
 (307, 4755),
 (308, 4731),
 (309, 4728),
 (310, 4727),
 (311, 4724),
 (312, 4721),
 (313, 4720),
 (314, 4686),
 (315, 4641),
 (316, 4612),
 (317, 4605),
 (318, 4586),
 (319, 4586),
 (320, 4577),
 (321, 4574),
 (322, 4551),
 (323, 4539),
 (324, 4528),
 (325, 4527),
 (326, 4483),
 (327, 4472),
 (328, 4471),
 (329, 4457),
 (330, 4448),
 (331, 4446),
 (332, 4435),
 (333, 4433),
 (334, 4430),
 (335, 4429),
 (336, 4418),
 (337, 4414),
 (338, 4399),
 (339, 4386),
 (340, 4383),
 (341, 4352),
 (342, 4347),
 (343, 4330),
 (344, 4307),
 (345, 4307),
 (346, 4299),
 (347, 4270),
 (348, 4269),
 (349, 4268),
 (350, 4260),
 (351, 4257),
 (352, 4241),
 (353, 4236),
 (354, 4231),
 (355, 4212),
 (356, 4202),
 (357, 4198),
 (358, 4193),
 (359, 4178),
 (360, 4176),
 (361, 4172),
 (362, 4165),
 (363, 4127),
 (364, 4125),
 (365, 4121),
 (366, 4118),
 (367, 4114),
 (368, 4114),
 (369, 4105),
 (370, 4078),
 (371, 4067),
 (372, 4064),
 (373, 4063),
 (374, 4051),
 (375, 4042),
 (376, 4034),
 (377, 4013),
 (378, 4007),
 (379, 3992),
 (380, 3985),
 (381, 3984),
 (382, 3978),
 (383, 3964),
 (384, 3962),
 (385, 3960),
 (386, 3955),
 (387, 3955),
 (388, 3949),
 (389, 3947),
 (390, 3922),
 (391, 3919),
 (392, 3903),
 (393, 3891),
 (394, 3874),
 (395, 3866),
 (396, 3863),
 (397, 3855),
 (398, 3843),
 (399, 3831),
 (400, 3809),
 (401, 3805),
 (402, 3782),
 (403, 3773),
 (404, 3772),
 (405, 3771),
 (406, 3768),
 (407, 3767),
 (408, 3762),
 (409, 3760),
 (410, 3760),
 (411, 3755),
 (412, 3747),
 (413, 3745),
 (414, 3734),
 (415, 3726),
 (416, 3717),
 (417, 3714),
 (418, 3714),
 (419, 3708),
 (420, 3702),
 (421, 3698),
 (422, 3691),
 (423, 3686),
 (424, 3675),
 (425, 3663),
 (426, 3662),
 (427, 3653),
 (428, 3650),
 (429, 3639),
 (430, 3616),
 (431, 3611),
 (432, 3611),
 (433, 3601),
 (434, 3599),
 (435, 3595),
 (436, 3588),
 (437, 3588),
 (438, 3583),
 (439, 3581),
 (440, 3573),
 (441, 3572),
 (442, 3570),
 (443, 3555),
 (444, 3548),
 (445, 3546),
 (446, 3545),
 (447, 3536),
 (448, 3528),
 (449, 3524),
 (450, 3523),
 (451, 3519),
 (452, 3519),
 (453, 3503),
 (454, 3502),
 (455, 3493),
 (456, 3483),
 (457, 3464),
 (458, 3457),
 (459, 3452),
 (460, 3451),
 (461, 3448),
 (462, 3443),
 (463, 3443),
 (464, 3442),
 (465, 3441),
 (466, 3438),
 (467, 3437),
 (468, 3437),
 (469, 3437),
 (470, 3435),
 (471, 3430),
 (472, 3429),
 (473, 3428),
 (474, 3423),
 (475, 3418),
 (476, 3412),
 (477, 3410),
 (478, 3406),
 (479, 3404),
 (480, 3403),
 (481, 3401),
 (482, 3395),
 (483, 3376),
 (484, 3368),
 (485, 3365),
 (486, 3363),
 (487, 3360),
 (488, 3356),
 (489, 3356),
 (490, 3352),
 (491, 3333),
 (492, 3330),
 (493, 3318),
 (494, 3311),
 (495, 3310),
 (496, 3307),
 (497, 3304),
 (498, 3303),
 (499, 3301),
 (500, 3290),
 (501, 3277),
 (502, 3276),
 (503, 3268),
 (504, 3266),
 (505, 3259),
 (506, 3248),
 (507, 3242),
 (508, 3242),
 (509, 3240),
 (510, 3213),
 (511, 3213),
 (512, 3207),
 (513, 3206),
 (514, 3195),
 (515, 3185),
 (516, 3176),
 (517, 3141),
 (518, 3136),
 (519, 3135),
 (520, 3133),
 (521, 3131),
 (522, 3129),
 (523, 3125),
 (524, 3119),
 (525, 3119),
 (526, 3109),
 (527, 3103),
 (528, 3102),
 (529, 3098),
 (530, 3088),
 (531, 3079),
 (532, 3069),
 (533, 3063),
 (534, 3055),
 (535, 3050),
 (536, 3050),
 (537, 3046),
 (538, 3044),
 (539, 3043),
 (540, 3043),
 (541, 3037),
 (542, 3035),
 (543, 3027),
 (544, 3022),
 (545, 3017),
 (546, 3011),
 (547, 3009),
 (548, 3002),
 (549, 3002),
 (550, 2989),
 (551, 2965),
 (552, 2964),
 (553, 2963),
 (554, 2961),
 (555, 2959),
 (556, 2935),
 (557, 2933),
 (558, 2932),
 (559, 2927),
 (560, 2923),
 (561, 2914),
 (562, 2913),
 (563, 2899),
 (564, 2895),
 (565, 2888),
 (566, 2882),
 (567, 2879),
 (568, 2876),
 (569, 2863),
 (570, 2857),
 (571, 2852),
 (572, 2852),
 (573, 2845),
 (574, 2843),
 (575, 2830),
 (576, 2829),
 (577, 2822),
 (578, 2821),
 (579, 2819),
 (580, 2819),
 (581, 2817),
 (582, 2810),
 (583, 2809),
 (584, 2806),
 (585, 2806),
 (586, 2797),
 (587, 2796),
 (588, 2795),
 (589, 2795),
 (590, 2785),
 (591, 2784),
 (592, 2777),
 (593, 2775),
 (594, 2773),
 (595, 2771),
 (596, 2771),
 (597, 2769),
 (598, 2767),
 (599, 2761),
 (600, 2759),
 (601, 2759),
 (602, 2758),
 (603, 2755),
 (604, 2753),
 (605, 2752),
 (606, 2748),
 (607, 2746),
 (608, 2745),
 (609, 2741),
 (610, 2740),
 (611, 2736),
 (612, 2735),
 (613, 2731),
 (614, 2731),
 (615, 2727),
 (616, 2723),
 (617, 2722),
 (618, 2719),
 (619, 2716),
 (620, 2714),
 (621, 2709),
 (622, 2708),
 (623, 2699),
 (624, 2697),
 (625, 2693),
 (626, 2693),
 (627, 2691),
 (628, 2685),
 (629, 2681),
 (630, 2674),
 (631, 2664),
 (632, 2663),
 (633, 2660),
 (634, 2650),
 (635, 2644),
 (636, 2638),
 (637, 2633),
 (638, 2632),
 (639, 2628),
 (640, 2628),
 (641, 2623),
 (642, 2620),
 (643, 2619),
 (644, 2615),
 (645, 2595),
 (646, 2595),
 (647, 2593),
 (648, 2583),
 (649, 2576),
 (650, 2574),
 (651, 2562),
 (652, 2561),
 (653, 2554),
 (654, 2554),
 (655, 2553),
 (656, 2542),
 (657, 2540),
 (658, 2539),
 (659, 2537),
 (660, 2530),
 (661, 2527),
 (662, 2525),
 (663, 2525),
 (664, 2520),
 (665, 2519),
 (666, 2518),
 (667, 2494),
 (668, 2493),
 (669, 2493),
 (670, 2489),
 (671, 2480),
 (672, 2478),
 (673, 2475),
 (674, 2471),
 (675, 2471),
 (676, 2470),
 (677, 2468),
 (678, 2464),
 (679, 2462),
 (680, 2461),
 (681, 2458),
 (682, 2457),
 (683, 2456),
 (684, 2454),
 (685, 2451),
 (686, 2449),
 (687, 2444),
 (688, 2440),
 (689, 2439),
 (690, 2434),
 (691, 2432),
 (692, 2431),
 (693, 2404),
 (694, 2404),
 (695, 2398),
 (696, 2397),
 (697, 2397),
 (698, 2396),
 (699, 2395),
 (700, 2392),
 (701, 2389),
 (702, 2384),
 (703, 2380),
 (704, 2379),
 (705, 2379),
 (706, 2376),
 (707, 2375),
 (708, 2374),
 (709, 2373),
 (710, 2364),
 (711, 2362),
 (712, 2361),
 (713, 2359),
 (714, 2357),
 (715, 2356),
 (716, 2353),
 (717, 2349),
 (718, 2349),
 (719, 2343),
 (720, 2337),
 (721, 2336),
 (722, 2336),
 (723, 2331),
 (724, 2319),
 (725, 2313),
 (726, 2313),
 (727, 2308),
 (728, 2304),
 (729, 2303),
 (730, 2300),
 (731, 2299),
 (732, 2298),
 (733, 2296),
 (734, 2290),
 (735, 2289),
 (736, 2288),
 (737, 2282),
 (738, 2278),
 (739, 2275),
 (740, 2272),
 (741, 2271),
 (742, 2267),
 (743, 2266),
 (744, 2260),
 (745, 2256),
 (746, 2255),
 (747, 2252),
 (748, 2251),
 (749, 2249),
 (750, 2249),
 (751, 2248),
 (752, 2243),
 (753, 2241),
 (754, 2236),
 (755, 2231),
 (756, 2231),
 (757, 2227),
 (758, 2223),
 (759, 2222),
 (760, 2222),
 (761, 2218),
 (762, 2217),
 (763, 2216),
 (764, 2215),
 (765, 2206),
 (766, 2200),
 (767, 2195),
 (768, 2194),
 (769, 2193),
 (770, 2192),
 (771, 2189),
 (772, 2183),
 (773, 2182),
 (774, 2178),
 (775, 2178),
 (776, 2177),
 (777, 2175),
 (778, 2174),
 (779, 2171),
 (780, 2168),
 (781, 2164),
 (782, 2163),
 (783, 2162),
 (784, 2162),
 (785, 2159),
 (786, 2156),
 (787, 2147),
 (788, 2139),
 (789, 2137),
 (790, 2133),
 (791, 2133),
 (792, 2130),
 (793, 2128),
 (794, 2126),
 (795, 2125),
 (796, 2124),
 (797, 2119),
 (798, 2117),
 (799, 2116),
 (800, 2115),
 (801, 2114),
 (802, 2111),
 (803, 2109),
 (804, 2109),
 (805, 2108),
 (806, 2108),
 (807, 2108),
 (808, 2105),
 (809, 2103),
 (810, 2102),
 (811, 2102),
 (812, 2102),
 (813, 2101),
 (814, 2101),
 (815, 2099),
 (816, 2099),
 (817, 2098),
 (818, 2098),
 (819, 2093),
 (820, 2093),
 (821, 2090),
 (822, 2090),
 (823, 2089),
 (824, 2086),
 (825, 2084),
 (826, 2080),
 (827, 2079),
 (828, 2074),
 (829, 2070),
 (830, 2069),
 (831, 2069),
 (832, 2068),
 (833, 2067),
 (834, 2066),
 (835, 2065),
 (836, 2064),
 (837, 2060),
 (838, 2059),
 (839, 2057),
 (840, 2054),
 (841, 2052),
 (842, 2050),
 (843, 2049),
 (844, 2048),
 (845, 2045),
 (846, 2042),
 (847, 2041),
 (848, 2039),
 (849, 2034),
 (850, 2033),
 (851, 2033),
 (852, 2031),
 (853, 2030),
 (854, 2029),
 (855, 2026),
 (856, 2025),
 (857, 2024),
 (858, 2022),
 (859, 2021),
 (860, 2020),
 (861, 2018),
 (862, 2017),
 (863, 2013),
 (864, 2013),
 (865, 2013),
 (866, 2012),
 (867, 2009),
 (868, 2007),
 (869, 2005),
 (870, 2000),
 (871, 1995),
 (872, 1994),
 (873, 1992),
 (874, 1990),
 (875, 1989),
 (876, 1981),
 (877, 1980),
 (878, 1977),
 (879, 1977),
 (880, 1977),
 (881, 1973),
 (882, 1971),
 (883, 1967),
 (884, 1966),
 (885, 1959),
 (886, 1959),
 (887, 1958),
 (888, 1958),
 (889, 1957),
 (890, 1954),
 (891, 1952),
 (892, 1950),
 (893, 1948),
 (894, 1945),
 (895, 1945),
 (896, 1945),
 (897, 1944),
 (898, 1944),
 (899, 1944),
 (900, 1941),
 (901, 1941),
 (902, 1940),
 (903, 1938),
 (904, 1938),
 (905, 1938),
 (906, 1936),
 (907, 1935),
 (908, 1934),
 (909, 1930),
 (910, 1927),
 (911, 1926),
 (912, 1924),
 (913, 1924),
 (914, 1922),
 (915, 1921),
 (916, 1919),
 (917, 1918),
 (918, 1916),
 (919, 1915),
 (920, 1914),
 (921, 1913),
 (922, 1912),
 (923, 1911),
 (924, 1907),
 (925, 1905),
 (926, 1904),
 (927, 1904),
 (928, 1903),
 (929, 1903),
 (930, 1902),
 (931, 1900),
 (932, 1899),
 (933, 1897),
 (934, 1896),
 (935, 1896),
 (936, 1895),
 (937, 1895),
 (938, 1892),
 (939, 1890),
 (940, 1888),
 (941, 1887),
 (942, 1887),
 (943, 1883),
 (944, 1881),
 (945, 1880),
 (946, 1872),
 (947, 1871),
 (948, 1870),
 (949, 1868),
 (950, 1866),
 (951, 1865),
 (952, 1864),
 (953, 1862),
 (954, 1861),
 (955, 1860),
 (956, 1859),
 (957, 1858),
 (958, 1858),
 (959, 1857),
 (960, 1856),
 (961, 1854),
 (962, 1852),
 (963, 1852),
 (964, 1851),
 (965, 1849),
 (966, 1843),
 (967, 1843),
 (968, 1842),
 (969, 1841),
 (970, 1840),
 (971, 1836),
 (972, 1830),
 (973, 1828),
 (974, 1823),
 (975, 1823),
 (976, 1819),
 (977, 1819),
 (978, 1813),
 (979, 1812),
 (980, 1811),
 (981, 1810),
 (982, 1809),
 (983, 1808),
 (984, 1807),
 (985, 1807),
 (986, 1806),
 (987, 1805),
 (988, 1802),
 (989, 1801),
 (990, 1801),
 (991, 1800),
 (992, 1797),
 (993, 1794),
 (994, 1790),
 (995, 1789),
 (996, 1787),
 (997, 1786),
 (998, 1783),
 (999, 1783),
 ...]

In [8]:
total_count = len(int_words) # 总单词书

In [9]:
freqs = {word: count/total_count for word, count in word_counts.items()}

In [10]:
threshold = 1e-5 # t

p_drop = {word: 1 - np.sqrt(threshold/freqs[word]) for word in word_counts}

In [11]:
## Your code here
import random
train_words = [word for word in int_words if p_drop[word] < random.random()]

#train_words = # The final subsampled word list

Making batches

Now that our data is in good shape, we need to get it into the proper form to pass it into our network. With the skip-gram architecture, for each word in the text, we want to grab all the words in a window around that word, with size $C$.

From Mikolov et al.:

"Since the more distant words are usually less related to the current word than those close to it, we give less weight to the distant words by sampling less from those words in our training examples... If we choose $C = 5$, for each training word we will select randomly a number $R$ in range $< 1; C >$, and then use $R$ words from history and $R$ words from the future of the current word as correct labels."

Exercise: Implement a function get_target that receives a list of words, an index, and a window size, then returns a list of words in the window around the index. Make sure to use the algorithm described above, where you choose a random number of words from the window.


In [12]:
np.random.randint(1, 5)
words[0:2] # 这是一个前逼后开的区间


Out[12]:
['anarchism', 'originated']

In [13]:
def get_target(words, idx, window_size=5):
    ''' Get a list of words in a window around an index. '''
    
    R = np.random.randint(1, window_size+1)
    start = idx - R if (idx - R) > 0 else 0
    stop = idx + R
    target_words = set(words[start:idx] + words[idx+1:stop+1])
    
    return list(target_words)

In [29]:
small_words = words[:30]
print(small_words)
# for small in small_words:
#     print(int_to_vocab[small])


['anarchism', 'originated', 'as', 'a', 'term', 'of', 'abuse', 'first', 'used', 'against', 'early', 'working', 'class', 'radicals', 'including', 'the', 'diggers', 'of', 'the', 'english', 'revolution', 'and', 'the', 'sans', 'culottes', 'of', 'the', 'french', 'revolution', 'whilst']

In [28]:
get_target(words,5,2)


Out[28]:
['term', 'abuse']

Here's a function that returns batches for our network. The idea is that it grabs batch_size words from a words list. Then for each of those words, it gets the target words in the window. I haven't found a way to pass in a random number of target words and get it to work with the architecture, so I make one row per input-target pair. This is a generator function by the way, helps save memory.


In [28]:
#10 // 3 # 整除
#range(0,10,3)
#words[0:3]
# get_batches 返回 x,x,x,x,x => word1,word2,word3,...


Out[28]:
['anarchism', 'originated', 'as']

In [14]:
def get_batches(words, batch_size, window_size=5):
    ''' Create a generator of word batches as a tuple (inputs, targets) '''
    
    n_batches = len(words)//batch_size
    
    # only full batches
    words = words[:n_batches*batch_size] # 只取整数的,去除不满batch_size的部分
    
    for idx in range(0, len(words), batch_size):
        x, y = [], []
        batch = words[idx:idx+batch_size]
        for ii in range(len(batch)):
            batch_x = batch[ii]
            batch_y = get_target(batch, ii, window_size)
            y.extend(batch_y)
            x.extend([batch_x]*len(batch_y))
        yield x, y

In [39]:
batch_size = 5
words = train_words[:30]
n_batches = len(words)//5
words = words[:n_batches*batch_size]
print(words,n_batches)


[5233, 3080, 741, 476, 10571, 133, 27349, 15067, 58112, 854, 10712, 3672, 2757, 7088, 44611, 2877, 792, 8983, 4147, 6437, 32, 5233, 1818, 19, 4860, 6753, 11064, 51, 7088, 270] 6

In [40]:
# for idx in range(0, len(words), batch_size):
#     print(idx)


0
5
10
15
20
25

In [44]:
x, y = [], []
batch = words[0:0+5]
for ii in range(len(batch)):
    batch_x = batch[ii]
    batch_y = get_target(batch, ii, 2)
    y.extend(batch_y)
    x.extend([batch_x]*len(batch_y))

print(x,y)


[5233, 5233, 3080, 3080, 741, 741, 476, 476, 476, 10571, 10571] [3080, 741, 5233, 741, 3080, 476, 3080, 10571, 741, 476, 741]

In [34]:
# batches = get_batches(train_words[:30],5)
# for x,y in batches:
#     print(x,y)


[5233, 5233, 5233, 5233, 3080, 3080, 3080, 3080, 741, 741, 741, 741, 476, 476, 476, 476, 10571, 10571, 10571, 10571] [3080, 10571, 476, 741, 5233, 10571, 476, 741, 3080, 5233, 10571, 476, 3080, 5233, 10571, 741, 3080, 5233, 476, 741]
[133, 133, 133, 27349, 27349, 15067, 15067, 15067, 15067, 58112, 58112, 58112, 58112, 854, 854, 854] [58112, 15067, 27349, 15067, 133, 27349, 58112, 133, 854, 27349, 15067, 133, 854, 58112, 15067, 27349]
[10712, 10712, 10712, 10712, 3672, 3672, 2757, 2757, 2757, 2757, 7088, 7088, 7088, 7088, 44611, 44611, 44611] [3672, 44611, 2757, 7088, 10712, 2757, 10712, 44611, 3672, 7088, 10712, 44611, 3672, 2757, 3672, 2757, 7088]
[2877, 2877, 2877, 792, 792, 792, 792, 8983, 8983, 8983, 8983, 4147, 4147, 4147, 4147, 6437, 6437, 6437, 6437] [792, 4147, 8983, 4147, 2877, 6437, 8983, 792, 4147, 2877, 6437, 792, 6437, 2877, 8983, 792, 4147, 2877, 8983]
[32, 32, 32, 32, 5233, 5233, 5233, 1818, 1818, 1818, 1818, 19, 19, 19, 4860, 4860, 4860, 4860] [5233, 1818, 19, 4860, 32, 1818, 19, 32, 5233, 19, 4860, 5233, 1818, 4860, 32, 5233, 1818, 19]
[6753, 6753, 6753, 6753, 11064, 11064, 11064, 51, 51, 51, 51, 7088, 7088, 7088, 7088, 270] [11064, 51, 7088, 270, 7088, 6753, 51, 11064, 6753, 7088, 270, 11064, 6753, 51, 270, 7088]

Building the graph

这个跟我之前做的一篇神经网络实践之情感分类相似

From Chris McCormick's blog, we can see the general structure of our network.

The input words are passed in as integers. This will go into a hidden layer of linear units, then into a softmax layer. We'll use the softmax layer to make a prediction like normal.

The idea here is to train the hidden layer weight matrix to find efficient representations for our words. We can discard the softmax layer becuase we don't really care about making predictions with this network. We just want the embedding matrix so we can use it in other networks we build from the dataset.

I'm going to have you build the graph in stages now. First off, creating the inputs and labels placeholders like normal.

Exercise: Assign inputs and labels using tf.placeholder. We're going to be passing in integers, so set the data types to tf.int32. The batches we're passing in will have varying sizes, so set the batch sizes to [None]. To make things work later, you'll need to set the second dimension of labels to None or 1.


In [15]:
train_graph = tf.Graph()
with train_graph.as_default():
    inputs = tf.placeholder(tf.int32, shape=[None], name='inputs')
    #labels = tf.placeholder(tf.int32, [None, 1], name='labels')
    labels = tf.placeholder(tf.int32, shape=[None, 1], name='labels')

In [175]:
inputs.get_shape()


Out[175]:
TensorShape([Dimension(None)])

In [176]:
labels.get_shape()


Out[176]:
TensorShape([Dimension(None), Dimension(1)])

Embedding

The embedding matrix has a size of the number of words by the number of units in the hidden layer. So, if you have 10,000 words and 300 hidden units, the matrix will have size $10,000 \times 300$. Remember that we're using tokenized data for our inputs, usually as integers, where the number of tokens is the number of words in our vocabulary.

Exercise: Tensorflow provides a convenient function tf.nn.embedding_lookup that does this lookup for us. You pass in the embedding matrix and a tensor of integers, then it returns rows in the matrix corresponding to those integers. Below, set the number of embedding features you'll use (200 is a good start), create the embedding matrix variable, and use tf.nn.embedding_lookup to get the embedding tensors. For the embedding matrix, I suggest you initialize it with a uniform random numbers between -1 and 1 using tf.random_uniform.


In [159]:
# tf.random_uniform #均匀分布

In [16]:
n_vocab = len(int_to_vocab)
n_embedding = 200 # Number of embedding features 
with train_graph.as_default():
    embedding = tf.Variable(tf.random_uniform((n_vocab, n_embedding), -1.0, 1.0))
    embed = tf.nn.embedding_lookup(embedding, inputs) # use tf.nn.embedding_lookup to get the hidden layer output

In [178]:
embed.get_shape()


Out[178]:
TensorShape([Dimension(None), Dimension(200)])

In [179]:
embedding.get_shape()


Out[179]:
TensorShape([Dimension(63641), Dimension(200)])

Negative sampling 【负采样】

For every example we give the network, we train it using the output from the softmax layer. That means for each input, we're making very small changes to millions of weights even though we only have one true example. This makes training the network very inefficient. We can approximate the loss from the softmax layer by only updating a small subset of all the weights at once. We'll update the weights for the correct label, but only a small number of incorrect labels. This is called "negative sampling". Tensorflow has a convenient function to do this, tf.nn.sampled_softmax_loss.

Exercise: Below, create weights and biases for the softmax layer. Then, use tf.nn.sampled_softmax_loss to calculate the loss. Be sure to read the documentation to figure out how it works.


In [163]:
# tf.transpose 进行转置
## tf.nn.sampled_softmax_loss ~~ tf.nn.softmax(tf.matmul(inputs, tf.transpose(weights)) + biases).

In [17]:
# Number of negative labels to sample
# 该方法只是在train的时候使用
n_sampled = 100
with train_graph.as_default():
    softmax_w = tf.Variable(tf.truncated_normal((n_vocab, n_embedding), stddev=0.1))
#     softmax_w = tf.Variable(tf.truncated_normal((n_embedding, n_vocab), stddev=0.1))
    softmax_b = tf.Variable(tf.zeros([n_vocab]))
    
    # Calculate the loss using negative sampling
    loss = tf.nn.sampled_softmax_loss(softmax_w, softmax_b, 
                                      labels, embed,
                                      n_sampled, n_vocab)
#     loss = tf.nn.nce_loss(weights=softmax_w,
#                      biases=softmax_b,
#                      labels=labels,
#                      inputs=embed,
#                      num_sampled=n_sampled,
#                      num_classes=n_vocab)
    
    cost = tf.reduce_mean(loss)
    optimizer = tf.train.AdamOptimizer().minimize(cost)

Validation

This code is from Thushan Ganegedara's implementation. Here we're going to choose a few common words and few uncommon words. Then, we'll print out the closest words to them. It's a nice way to check that our embedding table is grouping together words with similar semantic meanings.


In [18]:
with train_graph.as_default():
    ## From Thushan Ganegedara's implementation
    valid_size = 16 # Random set of words to evaluate similarity on.
    valid_window = 100
    # pick 8 samples from (0,100) and (1000,1100) each ranges. lower id implies more frequent 
    valid_examples = np.array(random.sample(range(valid_window), valid_size//2))
    valid_examples = np.append(valid_examples, 
                               random.sample(range(1000,1000+valid_window), valid_size//2))

    valid_dataset = tf.constant(valid_examples, dtype=tf.int32)
    
    # We use the cosine distance:
    norm = tf.sqrt(tf.reduce_sum(tf.square(embedding), 1, keep_dims=True))
    normalized_embedding = embedding / norm
    # valid_embedding shape (None,200)
    valid_embedding = tf.nn.embedding_lookup(normalized_embedding, valid_dataset)
    # 计算相似性,similarity 每一行都是表示一个单词,然后列是该单词和其他单词的相似性
    similarity = tf.matmul(valid_embedding, tf.transpose(normalized_embedding))

In [46]:
random.sample(range(100),16//2) # 从range中随机选出 16 // 2个数
tf.reduce_sum(np.array([[1,2,3],[2,4,6]]),1,keep_dims=True) # 2行一列,我们按行进行了sum


Out[46]:
<tf.Tensor 'Sum:0' shape=(2, 1) dtype=int64>

In [19]:
# If the checkpoints directory doesn't exist:
!mkdir checkpoints


mkdir: checkpoints: File exists

Training

Below is the code to train the network. Every 100 batches it reports the training loss. Every 1000 batches, it'll print out the validation words.


In [60]:
#np.array([1,2,3])[:,None]

In [20]:
epochs = 10
batch_size = 1000
window_size = 10

with train_graph.as_default():
    saver = tf.train.Saver()

with tf.Session(graph=train_graph) as sess:
    iteration = 1
    loss = 0
    sess.run(tf.global_variables_initializer())

    for e in range(1, epochs+1):
        batches = get_batches(train_words, batch_size, window_size)
        start = time.time()
        for x, y in batches:
            
            feed = {inputs: x,
                    labels: np.array(y)[:, None]}
            train_loss, _ = sess.run([cost, optimizer], feed_dict=feed)
            
            loss += train_loss
            
            if iteration % 100 == 0: 
                end = time.time()
                print("Epoch {}/{}".format(e, epochs),
                      "Iteration: {}".format(iteration),
                      "Avg. Training loss: {:.4f}".format(loss/100),
                      "{:.4f} sec/batch".format((end-start)/100))
                loss = 0
                start = time.time()
            
            if iteration % 1000 == 0:
                ## From Thushan Ganegedara's implementation
                # note that this is expensive (~20% slowdown if computed every 500 steps)
                sim = similarity.eval()
                for i in range(valid_size):
                    valid_word = int_to_vocab[valid_examples[i]]
                    top_k = 8 # number of nearest neighbors
                    nearest = (-sim[i, :]).argsort()[1:top_k+1]
                    log = 'Nearest to %s:' % valid_word
                    for k in range(top_k):
                        close_word = int_to_vocab[nearest[k]]
                        log = '%s %s,' % (log, close_word)
                    print(log)
            
            iteration += 1
    save_path = saver.save(sess, "checkpoints/text8.ckpt")
    embed_mat = sess.run(normalized_embedding)


Epoch 1/10 Iteration: 100 Avg. Training loss: 5.6641 0.4818 sec/batch
Epoch 1/10 Iteration: 200 Avg. Training loss: 5.6106 0.4315 sec/batch
Epoch 1/10 Iteration: 300 Avg. Training loss: 5.4656 0.4631 sec/batch
Epoch 1/10 Iteration: 400 Avg. Training loss: 5.6155 0.4086 sec/batch
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-20-b1d3869cd0d6> in <module>()
     18             feed = {inputs: x,
     19                     labels: np.array(y)[:, None]}
---> 20             train_loss, _ = sess.run([cost, optimizer], feed_dict=feed)
     21 
     22             loss += train_loss

/Users/zhuanxu/anaconda/envs/linear_regression_demo/lib/python3.6/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
    765     try:
    766       result = self._run(None, fetches, feed_dict, options_ptr,
--> 767                          run_metadata_ptr)
    768       if run_metadata:
    769         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

/Users/zhuanxu/anaconda/envs/linear_regression_demo/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
    963     if final_fetches or final_targets:
    964       results = self._do_run(handle, final_targets, final_fetches,
--> 965                              feed_dict_string, options, run_metadata)
    966     else:
    967       results = []

/Users/zhuanxu/anaconda/envs/linear_regression_demo/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
   1013     if handle is None:
   1014       return self._do_call(_run_fn, self._session, feed_dict, fetch_list,
-> 1015                            target_list, options, run_metadata)
   1016     else:
   1017       return self._do_call(_prun_fn, self._session, handle, feed_dict,

/Users/zhuanxu/anaconda/envs/linear_regression_demo/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
   1020   def _do_call(self, fn, *args):
   1021     try:
-> 1022       return fn(*args)
   1023     except errors.OpError as e:
   1024       message = compat.as_text(e.message)

/Users/zhuanxu/anaconda/envs/linear_regression_demo/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run_fn(session, feed_dict, fetch_list, target_list, options, run_metadata)
   1002         return tf_session.TF_Run(session, options,
   1003                                  feed_dict, fetch_list, target_list,
-> 1004                                  status, run_metadata)
   1005 
   1006     def _prun_fn(session, handle, feed_dict, fetch_list):

KeyboardInterrupt: 

Restore the trained network if you need to:


In [ ]:
with train_graph.as_default():
    saver = tf.train.Saver()

with tf.Session(graph=train_graph) as sess:
    saver.restore(sess, tf.train.latest_checkpoint('checkpoints'))
    embed_mat = sess.run(embedding)

Visualizing the word vectors

Below we'll use T-SNE to visualize how our high-dimensional word vectors cluster together. T-SNE is used to project these vectors into two dimensions while preserving local stucture. Check out this post from Christopher Olah to learn more about T-SNE and other ways to visualize high-dimensional data.


In [ ]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

In [ ]:
viz_words = 500
tsne = TSNE()
embed_tsne = tsne.fit_transform(embed_mat[:viz_words, :])

In [ ]:
fig, ax = plt.subplots(figsize=(14, 14))
for idx in range(viz_words):
    plt.scatter(*embed_tsne[idx, :], color='steelblue')
    plt.annotate(int_to_vocab[idx], (embed_tsne[idx, 0], embed_tsne[idx, 1]), alpha=0.7)