In [ ]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
In [ ]:
#export
from exp.nb_12a import *
One time download
In [ ]:
#path = datasets.Config().data_path()
#version = '103' #2
In [ ]:
#! wget https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-{version}-v1.zip -P {path}
#! unzip -q -n {path}/wikitext-{version}-v1.zip -d {path}
#! mv {path}/wikitext-{version}/wiki.train.tokens {path}/wikitext-{version}/train.txt
#! mv {path}/wikitext-{version}/wiki.valid.tokens {path}/wikitext-{version}/valid.txt
#! mv {path}/wikitext-{version}/wiki.test.tokens {path}/wikitext-{version}/test.txt
Split the articles: WT103 is given as one big text file and we need to chunk it in different articles if we want to be able to shuffle them at the beginning of each epoch.
In [ ]:
path = datasets.Config().data_path()/'wikitext-103'
In [ ]:
def istitle(line):
return len(re.findall(r'^ = [^=]* = $', line)) != 0
In [ ]:
def read_wiki(filename):
articles = []
with open(filename, encoding='utf8') as f:
lines = f.readlines()
current_article = ''
for i,line in enumerate(lines):
current_article += line
if i < len(lines)-2 and lines[i+1] == ' \n' and istitle(lines[i+2]):
current_article = current_article.replace('<unk>', UNK)
articles.append(current_article)
current_article = ''
current_article = current_article.replace('<unk>', UNK)
articles.append(current_article)
return articles
In [ ]:
train = TextList(read_wiki(path/'train.txt'), path=path) #+read_file(path/'test.txt')
valid = TextList(read_wiki(path/'valid.txt'), path=path)
In [ ]:
len(train), len(valid)
In [ ]:
sd = SplitData(train, valid)
In [ ]:
proc_tok,proc_num = TokenizeProcessor(),NumericalizeProcessor()
In [ ]:
ll = label_by_func(sd, lambda x: 0, proc_x = [proc_tok,proc_num])
In [ ]:
pickle.dump(ll, open(path/'ld.pkl', 'wb'))
In [ ]:
ll = pickle.load( open(path/'ld.pkl', 'rb'))
In [ ]:
bs,bptt = 128,70
data = lm_databunchify(ll, bs, bptt)
In [ ]:
vocab = ll.train.proc_x[-1].vocab
len(vocab)
In [ ]:
dps = np.array([0.1, 0.15, 0.25, 0.02, 0.2]) * 0.2
tok_pad = vocab.index(PAD)
In [ ]:
emb_sz, nh, nl = 300, 300, 2
model = get_language_model(len(vocab), emb_sz, nh, nl, tok_pad, *dps)
In [ ]:
cbs = [partial(AvgStatsCallback,accuracy_flat),
CudaCallback, Recorder,
partial(GradientClipping, clip=0.1),
partial(RNNTrainer, α=2., β=1.),
ProgressCallback]
In [ ]:
learn = Learner(model, data, cross_entropy_flat, lr=5e-3, cb_funcs=cbs, opt_func=adam_opt())
In [ ]:
lr = 5e-3
sched_lr = combine_scheds([0.3,0.7], cos_1cycle_anneal(lr/10., lr, lr/1e5))
sched_mom = combine_scheds([0.3,0.7], cos_1cycle_anneal(0.8, 0.7, 0.8))
cbsched = [ParamScheduler('lr', sched_lr), ParamScheduler('mom', sched_mom)]
In [ ]:
learn.fit(10, cbs=cbsched)
In [ ]:
torch.save(learn.model.state_dict(), path/'pretrained.pth')
pickle.dump(vocab, open(path/'vocab.pkl', 'wb'))
In [ ]: