In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
from torch.autograd import Variable
from tqdm import tnrange, tqdm_notebook, tqdm
In [28]:
class SimpleGRU(nn.Module):
def __init__(self, vocab_size, emb_size, hid_size, batch_size, seq_len, n_layers=1):
super(SimpleGRU, self).__init__()
self.vocab_size = vocab_size
self.emb_size = emb_size
self.hid_size = hid_size
self.n_layers = n_layers
self.batch_size = batch_size
self.seq_len = seq_len
self.emb = nn.Embedding(vocab_size, emb_size)
self.gru = nn.GRU(emb_size, hid_size, batch_first=True)
self.fc1 = nn.Linear(seq_len * hid_size, vocab_size)
self.selu = nn.SELU()
self.logsoftmax = nn.LogSoftmax()
def forward(self, input, hidden):
x = self.emb(input)
x, hidden = self.gru(x, hidden)
x = x.contiguous().view(self.batch_size, -1)
x = self.selu(self.fc1(x))
x = self.logsoftmax(x)
return x, hidden
class CharDataset(data.Dataset):
def __init__(self, data, seq_len):
self.data = data
self.seq_len = seq_len
def __getitem__(self, index):
inp_seq = self.data[index:(index+self.seq_len-1)]
tgt_seq = torch.Tensor([self.data[index+self.seq_len]]).type(self.data.type())
return inp_seq, tgt_seq
def __len__(self):
return len(self.data) - self.seq_len
In [22]:
seq_length = 25
batch_size = 50
emb_size = 25
hid_size = 100
n_layers = 3
In [23]:
with open("/home/david/Programming/data/project_gutenberg/tiny-shakespeare.txt", "r") as f:
text_raw = [c for l in f.readlines() for c in l]
#dangling_length = len(text_raw) % seq_length
#text_raw = text_raw[:-dangling_length]
charset = sorted(list(set(text_raw)))
c2i = {c: i for i, c in enumerate(charset)}
i2c = {i: c for c, i in c2i.items()}
text_idx = [c2i[c] for c in text_raw]
print(len(text_idx), len(text_raw))
#inputs = torch.Tensor([x for x in zip(*[text_idx[i::seq_length] for i in range(seq_length-1)])]).long()
#targets = torch.Tensor(text_idx[(seq_length-1)::seq_length]).long()
inputs = torch.Tensor(text_idx).long()
print(inputs.size())
#ds = data.TensorDataset(inputs, targets)
ds = CharDataset(inputs, seq_length)
dl = data.DataLoader(ds, batch_size=batch_size, drop_last=True)
print(len(dl))
In [24]:
vocab_size = len(charset)
num_batches = len(dl)
epochs = 10
In [31]:
model = SimpleGRU(vocab_size, emb_size, hid_size, batch_size, seq_length-1, n_layers)
criterion = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=0.005)
print(model)
In [32]:
batch_bar = tqdm_notebook(dl, desc="batches", mininterval=0.9)
for epoch in range(epochs):
running_loss = 0
for i, (mb, tgts) in enumerate(batch_bar):
h = Variable(torch.zeros(n_layers,batch_size, hid_size))
tgts.squeeze_()
model.train()
model.zero_grad()
mb, tgts = Variable(mb), Variable(tgts)
out, h = model(mb, h)
loss = criterion(out, tgts)
loss.backward()
optimizer.step()
h.detach_()
running_loss += loss.data[0]
if i % 25 == 0 or i == num_batches - 1:
batch_bar.set_postfix(loss=(running_loss / (i+1)))
torch.save(model.state_dict(), "model_charrnn_{}.pt".format(epoch+1))
In [ ]: