In [1]:
import tensorflow as tf
import numpy as np

from collections import Counter
from urllib.request import urlretrieve
from zipfile import ZipFile
import os
import random

Wczytywanie datasetu


In [2]:
dataset_url = 'http://mattmahoney.net/dc/'
dataset_name = 'text8.zip'

if not os.path.exists(dataset_name):
    dataset_name, _ = urlretrieve(dataset_url + dataset_name, dataset_name)

with ZipFile(dataset_name) as file:
    data = file.read(file.namelist()[0])
    words = data.split()

In [3]:
vocab_size = 10000

counter = Counter(words)
vocab_dict = dict([(b'UNK', 0)])
for word, _ in counter.most_common(vocab_size - 1):
    vocab_dict[word] = len(vocab_dict)
reverse_vocab = {v: k for k, v in vocab_dict.items()}
dataset = list(map(lambda x: vocab_dict.get(x, 0), words))

In [4]:
class BatchGenerator(object):
    def __init__(self, dataset, batch_size, num_skips, skip_window):
        self.dataset = dataset
        self.batch_size = batch_size
        self.num_skips = num_skips
        self.skip_window = skip_window
        self.idx = skip_window
                
        assert batch_size % num_skips == 0, 'batch_size must be divisible by num_skips'
        assert num_skips <= 2 * skip_window, 'num_skips <= 2 * skip_window'
           
    def next_batch(self):
        x = np.empty(self.batch_size, dtype=np.int32)
        y = np.empty((self.batch_size, 1), dtype=np.int32)
        
        for i in range(self.batch_size // self.num_skips):
            used = set([0])
            for s in range(self.num_skips):
                x[i * self.num_skips + s] = self.dataset[self.idx]
                r = 0
                while r in used:
                    r = random.randint(-self.skip_window, self.skip_window)
                used.add(r)
                y[i * self.num_skips + s] = self.dataset[self.idx + r]
            self.idx = (self.idx + i) % (len(self.dataset) - 2 * self.skip_window) + self.skip_window
        return x, y

In [5]:
gen = BatchGenerator(dataset, 4, 2, 1)
bx, by = gen.next_batch()

print([reverse_vocab[i] for i in dataset[:4]])
print([reverse_vocab[i] for i in bx])
print([reverse_vocab[i[0]] for i in by])

gen = BatchGenerator(dataset, 4, 1, 1)
bx, by = gen.next_batch()

print([reverse_vocab[i] for i in dataset[:5]])
print([reverse_vocab[i] for i in bx])
print([reverse_vocab[i[0]] for i in by])


[b'anarchism', b'originated', b'as', b'a']
[b'originated', b'originated', b'as', b'as']
[b'as', b'anarchism', b'a', b'originated']
[b'anarchism', b'originated', b'as', b'a', b'term']
[b'originated', b'as', b'term', b'first']
[b'as', b'a', b'a', b'used']

In [ ]: