In [1]:
import tensorflow as tf
import numpy as np
from collections import Counter
from urllib.request import urlretrieve
from zipfile import ZipFile
import os
import random
In [2]:
dataset_url = 'http://mattmahoney.net/dc/'
dataset_name = 'text8.zip'
if not os.path.exists(dataset_name):
dataset_name, _ = urlretrieve(dataset_url + dataset_name, dataset_name)
with ZipFile(dataset_name) as file:
data = file.read(file.namelist()[0])
words = data.split()
In [3]:
vocab_size = 10000
counter = Counter(words)
vocab_dict = dict([(b'UNK', 0)])
for word, _ in counter.most_common(vocab_size - 1):
vocab_dict[word] = len(vocab_dict)
reverse_vocab = {v: k for k, v in vocab_dict.items()}
dataset = list(map(lambda x: vocab_dict.get(x, 0), words))
In [4]:
class BatchGenerator(object):
def __init__(self, dataset, batch_size, num_skips, skip_window):
self.dataset = dataset
self.batch_size = batch_size
self.num_skips = num_skips
self.skip_window = skip_window
self.idx = skip_window
assert batch_size % num_skips == 0, 'batch_size must be divisible by num_skips'
assert num_skips <= 2 * skip_window, 'num_skips <= 2 * skip_window'
def next_batch(self):
x = np.empty(self.batch_size, dtype=np.int32)
y = np.empty((self.batch_size, 1), dtype=np.int32)
for i in range(self.batch_size // self.num_skips):
used = set([0])
for s in range(self.num_skips):
x[i * self.num_skips + s] = self.dataset[self.idx]
r = 0
while r in used:
r = random.randint(-self.skip_window, self.skip_window)
used.add(r)
y[i * self.num_skips + s] = self.dataset[self.idx + r]
self.idx = (self.idx + i) % (len(self.dataset) - 2 * self.skip_window) + self.skip_window
return x, y
In [5]:
gen = BatchGenerator(dataset, 4, 2, 1)
bx, by = gen.next_batch()
print([reverse_vocab[i] for i in dataset[:4]])
print([reverse_vocab[i] for i in bx])
print([reverse_vocab[i[0]] for i in by])
gen = BatchGenerator(dataset, 4, 1, 1)
bx, by = gen.next_batch()
print([reverse_vocab[i] for i in dataset[:5]])
print([reverse_vocab[i] for i in bx])
print([reverse_vocab[i[0]] for i in by])
In [ ]: