Important: This notebook will only work with fastai-0.7.x. Do not try to run any fastai-1.x code from this path in the repository because it will load fastai-0.7.x


In [ ]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

from fastai.model import fit
from fastai.dataset import *

import torchtext
from torchtext import vocab, data
from torchtext.datasets import language_modeling

from fastai.rnn_reg import *
from fastai.rnn_train import *
from fastai.nlp import *
from fastai.lm_rnn import *

import dill as pickle
import random

In [ ]:
bs,bptt = 64,70

Language modeling

Data


In [ ]:
import os, requests, time
# feedparser isn't a fastai dependency so you may need to install it.
import feedparser
import pandas as pd


class GetArXiv(object):
    def __init__(self, pickle_path, categories=list()):
        """
        :param pickle_path (str): path to pickle data file to save/load
        :param pickle_name (str): file name to save pickle to path
        :param categories (list): arXiv categories to query
        """
        if os.path.isdir(pickle_path):
            pickle_path = f"{pickle_path}{'' if pickle_path[-1] == '/' else '/'}all_arxiv.pkl"
        if len(categories) < 1:
            categories = ['cs*', 'cond-mat.dis-nn', 'q-bio.NC', 'stat.CO', 'stat.ML']
        # categories += ['cs.CV', 'cs.AI', 'cs.LG', 'cs.CL']

        self.categories = categories
        self.pickle_path = pickle_path
        self.base_url = 'http://export.arxiv.org/api/query'

    @staticmethod
    def build_qs(categories):
        """Build query string from categories"""
        return '+OR+'.join(['cat:'+c for c in categories])

    @staticmethod
    def get_entry_dict(entry):
        """Return a dictionary with the items we want from a feedparser entry"""
        try:
            return dict(title=entry['title'], authors=[a['name'] for a in entry['authors']],
                        published=pd.Timestamp(entry['published']), summary=entry['summary'],
                        link=entry['link'], category=entry['category'])
        except KeyError:
            print('Missing keys in row: {}'.format(entry))
            return None

    @staticmethod
    def strip_version(link):
        """Strip version number from arXiv paper link"""
        return link[:-2]

    def fetch_updated_data(self, max_retry=5, pg_offset=0, pg_size=1000, wait_time=15):
        """
        Get new papers from arXiv server
        :param max_retry: max number of time to retry request
        :param pg_offset: number of pages to offset
        :param pg_size: num abstracts to fetch per request
        :param wait_time: num seconds to wait between requests
        """
        i, retry = pg_offset, 0
        df = pd.DataFrame()
        past_links = []
        if os.path.isfile(self.pickle_path):
            df = pd.read_pickle(self.pickle_path)
            df.reset_index()
        if len(df) > 0: past_links = df.link.apply(self.strip_version)

        while True:
            params = dict(search_query=self.build_qs(self.categories),
                          sortBy='submittedDate', start=pg_size*i, max_results=pg_size)
            response = requests.get(self.base_url, params='&'.join([f'{k}={v}' for k, v in params.items()]))
            entries = feedparser.parse(response.text).entries
            if len(entries) < 1:
                if retry < max_retry:
                    retry += 1
                    time.sleep(wait_time)
                    continue
                break

            results_df = pd.DataFrame([self.get_entry_dict(e) for e in entries])
            max_date = results_df.published.max().date()
            new_links = ~results_df.link.apply(self.strip_version).isin(past_links)
            print(f'{i}. Fetched {len(results_df)} abstracts published {max_date} and earlier')
            if not new_links.any():
                break

            df = pd.concat((df, results_df.loc[new_links]), ignore_index=True)
            i += 1
            retry = 0
            time.sleep(wait_time)

        print(f'Downloaded {len(df)-len(past_links)} new abstracts')
        df.sort_values('published', ascending=False).groupby('link').first().reset_index()
        df.to_pickle(self.pickle_path)
        return df

    @classmethod
    def load(cls, pickle_path):
        """Load data from pickle and remove duplicates"""
        return pd.read_pickle(cls(pickle_path).pickle_path)

    @classmethod
    def update(cls, pickle_path, categories=list(), **kwargs):
        """
        Update arXiv data pickle with the latest abstracts
        """
        cls(pickle_path, categories).fetch_updated_data(**kwargs)
        return True

In [ ]:
PATH='data/arxiv/'

ALL_ARXIV = f'{PATH}all_arxiv.pkl'

# all_arxiv.pkl: if arxiv hasn't been downloaded yet, it'll take some time to get it - go get some coffee
if not os.path.exists(ALL_ARXIV): GetArXiv.update(ALL_ARXIV)

# arxiv.csv: see dl1/nlp-arxiv.ipynb to get this one
df_mb = pd.read_csv(f'{PATH}arxiv.csv')
df_all = pd.read_pickle(ALL_ARXIV)

In [ ]:
def get_txt(df):
    return '<CAT> ' + df.category.str.replace(r'[\.\-]','') + ' <SUMM> ' + df.summary + ' <TITLE> ' + df.title
df_mb['txt'] = get_txt(df_mb)
df_all['txt'] = get_txt(df_all)
n=len(df_all); n


Out[ ]:
49600

In [ ]:
os.makedirs(f'{PATH}trn/yes', exist_ok=True)
os.makedirs(f'{PATH}val/yes', exist_ok=True)
os.makedirs(f'{PATH}trn/no', exist_ok=True)
os.makedirs(f'{PATH}val/no', exist_ok=True)
os.makedirs(f'{PATH}all/trn', exist_ok=True)
os.makedirs(f'{PATH}all/val', exist_ok=True)
os.makedirs(f'{PATH}models', exist_ok=True)

In [ ]:
for (i,(_,r)) in enumerate(df_all.iterrows()):
    dset = 'trn' if random.random()>0.1 else 'val'
    open(f'{PATH}all/{dset}/{i}.txt', 'w').write(r['txt'])

In [ ]:
for (i,(_,r)) in enumerate(df_mb.iterrows()):
    lbl = 'yes' if r.tweeted else 'no'
    dset = 'trn' if random.random()>0.1 else 'val'
    open(f'{PATH}{dset}/{lbl}/{i}.txt', 'w').write(r['txt'])

In [ ]:
from spacy.symbols import ORTH

# install the 'en' model if the next line of code fails by running:
#python -m spacy download en              # default English model (~50MB)
#python -m spacy download en_core_web_md  # larger English model (~1GB)
my_tok = spacy.load('en')

my_tok.tokenizer.add_special_case('<SUMM>', [{ORTH: '<SUMM>'}])
my_tok.tokenizer.add_special_case('<CAT>', [{ORTH: '<CAT>'}])
my_tok.tokenizer.add_special_case('<TITLE>', [{ORTH: '<TITLE>'}])
my_tok.tokenizer.add_special_case('<BR />', [{ORTH: '<BR />'}])
my_tok.tokenizer.add_special_case('<BR>', [{ORTH: '<BR>'}])

def my_spacy_tok(x): return [tok.text for tok in my_tok.tokenizer(x)]

In [ ]:
TEXT = data.Field(lower=True, tokenize=my_spacy_tok)
FILES = dict(train='trn', validation='val', test='val')
md = LanguageModelData.from_text_files(f'{PATH}all/', TEXT, **FILES, bs=bs, bptt=bptt, min_freq=10)
pickle.dump(TEXT, open(f'{PATH}models/TEXT.pkl','wb'))

In [ ]:
len(md.trn_dl), md.nt, len(md.trn_ds), len(md.trn_ds[0].text)


Out[ ]:
(2129, 17951, 1, 9543290)

In [ ]:
TEXT.vocab.itos[:12]


Out[ ]:
['<unk>', '<pad>', '\n', 'the', ',', '.', 'of', '-', 'and', 'a', 'to', 'in']

In [ ]:
' '.join(md.trn_ds[0].text[:150])


Out[ ]:
'<cat> csni <summ> the exploitation of mm - wave bands is one of the key - enabler for 5 g mobile \n radio networks . however , the introduction of mm - wave technologies in cellular \n networks is not straightforward due to harsh propagation conditions that limit \n the mm - wave access availability . mm - wave technologies require high - gain antenna \n systems to compensate for high path loss and limited power . as a consequence , \n directional transmissions must be used for cell discovery and synchronization \n processes : this can lead to a non - negligible access delay caused by the \n exploration of the cell area with multiple transmissions along different \n directions . \n    the integration of mm - wave technologies and conventional wireless access \n networks with the objective of speeding up the cell search process requires new \n'

Train


In [ ]:
em_sz = 200
nh = 500
nl = 3
opt_fn = partial(optim.Adam, betas=(0.7, 0.99))

In [ ]:
learner = md.get_model(opt_fn, em_sz, nh, nl,
    dropout=0.05, dropouth=0.1, dropouti=0.05, dropoute=0.02, wdrop=0.2)
# dropout=0.4, dropouth=0.3, dropouti=0.65, dropoute=0.1, wdrop=0.5
#                dropouti=0.05, dropout=0.05, wdrop=0.1, dropoute=0.02, dropouth=0.05)
learner.reg_fn = partial(seq2seq_reg, alpha=2, beta=1)
learner.clip=0.3

In [ ]:
learner.fit(3e-3, 1, wds=1e-6)


[ 0.       4.36189  4.2185 ]                                  


In [ ]:
learner.fit(3e-3, 3, wds=1e-6, cycle_len=1, cycle_mult=2)


[ 0.       4.11236  3.99207]                                  
[ 1.       4.03207  3.89298]                                  
[ 2.       3.91653  3.81915]                                  
[ 3.       3.97808  3.8428 ]                                  
[ 4.       3.88482  3.76226]                                  
[ 5.       3.79955  3.70472]                                  
[ 6.       3.75721  3.69048]                                  


In [ ]:
learner.save_encoder('adam2_enc')

In [ ]:
learner.fit(3e-3, 10, wds=1e-6, cycle_len=5, cycle_save_name='adam3_10')


[ 0.       3.89388  3.76575]                                  
[ 1.       3.82548  3.71875]                                  
[ 2.       3.76471  3.66974]                                  
[ 3.       3.71713  3.63861]                                  
[ 4.       3.67534  3.62983]                                  
[ 5.       3.83938  3.71551]                                  
[ 6.       3.78093  3.68056]                                  
[ 7.       3.72828  3.63638]                                  
[ 8.       3.66743  3.60355]                                  
[ 9.       3.65793  3.59448]                                  
[ 10.        3.80545   3.68213]                               
[ 11.        3.75299   3.65219]                               
[ 12.        3.7057    3.61324]                               
[ 13.        3.63348   3.58048]                               
[ 14.        3.62304   3.57257]                               
[ 15.        3.78656   3.66324]                               
[ 16.        3.73522   3.63348]                               
[ 17.        3.67258   3.59369]                               
[ 18.        3.6242    3.56674]                               
[ 19.        3.61123   3.55783]                               
[ 20.        3.77443   3.65472]                               
[ 21.        3.7374    3.62169]                               
[ 22.        3.68367   3.58247]                               
[ 23.        3.61606   3.55567]                               
[ 24.        3.58527   3.54725]                               
[ 25.        3.75456   3.63861]                               
[ 26.        3.72061   3.61084]                               
[ 27.        3.65141   3.57073]                               
[ 28.        3.59711   3.54414]                               
[ 29.        3.57052   3.53622]                               
[ 30.        3.75229   3.62935]                               
[ 31.        3.70693   3.60198]                               
[ 32.        3.65193   3.56444]                               
[ 33.        3.59173   3.53742]                               
[ 34.        3.58699   3.53152]                               
[ 35.        3.74211   3.62154]                               
[ 36.        3.70016   3.59831]                               
[ 37.        3.64095   3.55689]                               
[ 38.        3.60686   3.53296]                               
[ 39.       3.5627   3.523 ]                                  
[ 40.        3.72944   3.61956]                               
[ 41.        3.68161   3.58779]                               
[ 42.        3.62305   3.55187]                               
[ 43.        3.58559   3.52524]                               
[ 44.        3.56087   3.51682]                               
[ 45.        3.72533   3.61458]                               
[ 46.        3.68025   3.58452]                               
[ 47.        3.64447   3.55002]                               
[ 48.        3.575     3.52066]                               
[ 49.        3.54424   3.5133 ]                               


In [ ]:
learner.save_encoder('adam3_10_enc')

In [ ]:
learner.fit(3e-3, 8, wds=1e-6, cycle_len=10, cycle_save_name='adam3_5')


[ 0.       3.70587  3.61666]                                  
[ 1.       3.71738  3.61174]                                  
[ 2.       3.68606  3.59661]                                  
[ 3.       3.65407  3.5742 ]                                  
[ 4.       3.62901  3.55795]                                  
[ 5.       3.59921  3.53632]                                  
[ 6.       3.58401  3.52149]                                  
[ 7.       3.55126  3.50797]                                  
[ 8.       3.52965  3.50178]                                  
[ 9.       3.52336  3.49997]                                  
[ 10.        3.7109    3.60817]                               
[ 11.        3.69879   3.60047]                               
[ 12.        3.6735    3.58623]                               
[ 13.        3.64365   3.56568]                               
[ 14.        3.6099    3.54776]                               
[ 15.        3.58244   3.52829]                               
[ 16.        3.54894   3.51071]                               
[ 17.        3.52702   3.50173]                               
[ 18.        3.51357   3.49522]                               
[ 19.        3.50302   3.49272]                               
[ 20.        3.72505   3.60198]                               
[ 21.        3.70037   3.59914]                               
[ 22.        3.68386   3.58279]                               
[ 23.        3.64176   3.56435]                               
[ 24.        3.60259   3.54304]                               
[ 25.        3.58669   3.52432]                               
[ 26.        3.54075   3.50703]                               
[ 27.        3.50951   3.49534]                               
[ 28.        3.51915   3.4896 ]                               
[ 29.        3.48695   3.48968]                               
[ 30.        3.70563   3.59631]                               
[ 31.        3.68822   3.58723]                               
[ 32.        3.67549   3.58141]                               
[ 33.        3.63267   3.55537]                               
[ 34.        3.60638   3.5386 ]                               
[ 35.        3.58803   3.52154]                               
[ 36.        3.53987   3.50394]                               
[ 37.        3.51036   3.49244]                               
[ 38.        3.48651   3.48652]                               
[ 39.        3.49061   3.48673]                               
[ 40.        3.70093   3.59211]                               
[ 41.        3.67371   3.58516]                               
[ 42.        3.66558   3.57032]                               
[ 43.        3.65089   3.55939]                               
[ 44.        3.59885   3.53445]                               
[ 45.        3.56369   3.51585]                               
[ 46.        3.55304   3.50237]                               
[ 47.        3.50469   3.48919]                               
[ 48.        3.49559   3.48289]                               
[ 49.        3.50912   3.48136]                               
[ 50.        3.70603   3.59182]                               
[ 51.        3.669     3.58069]                               
[ 52.        3.64965   3.56896]                               
[ 53.        3.62839   3.55251]                               
[ 54.        3.59578   3.53297]                               
[ 55.        3.55814   3.51205]                               
[ 56.        3.53653   3.49682]                               
[ 57.        3.50043   3.48502]                               
[ 58.        3.49535   3.4797 ]                               
[ 59.        3.48039   3.47882]                               
[ 60.        3.68319   3.58874]                               
[ 61.        3.68893   3.58173]                               
[ 62.        3.6516    3.56403]                               
[ 63.        3.63432   3.55047]                               
[ 64.        3.59697   3.52815]                               
[ 65.        3.55784   3.50832]                               
[ 66.        3.52815   3.49319]                               
[ 67.        3.50618   3.48222]                               
[ 68.        3.48319   3.47645]                               
[ 69.        3.49879   3.47596]                               
[ 70.        3.68466   3.58318]                               
[ 71.        3.67045   3.57351]                               
[ 72.        3.64409   3.5606 ]                               
[ 73.        3.61991   3.54552]                               
[ 74.        3.60503   3.52782]                               
[ 75.        3.56681   3.50743]                               
[ 76.        3.52401   3.49046]                               
[ 77.        3.50519   3.47875]                               
[ 78.        3.49343   3.47452]                               
[ 79.        3.49275   3.47175]                               


In [ ]:
learner.fit(3e-3, 1, wds=1e-6, cycle_len=20, cycle_save_name='adam3_20')


[ 0.       3.47841  3.4751 ]                                  
[ 1.       3.69717  3.57883]                                  
[ 2.       3.68267  3.57793]                                  
[ 3.       3.66797  3.57299]                                  
[ 4.       3.66805  3.56847]                                  
[ 5.       3.63489  3.56238]                                  
[ 6.       3.62479  3.54928]                                  
[ 7.       3.60663  3.53879]                                  
[ 8.       3.59124  3.53175]                                  
[ 9.       3.58617  3.52009]                                  
[ 10.        3.56924   3.51174]                               
[ 11.        3.5509    3.49974]                               
[ 12.        3.51595   3.49008]                               
[ 13.        3.50939   3.48222]                               
[ 14.        3.48886   3.47952]                               
[ 15.        3.4676    3.47311]                               
[ 16.        3.4856    3.46577]                               
[ 17.        3.44909   3.46499]                               
[ 18.        3.46791   3.46314]                               
[ 19.        3.44028   3.46231]                               


In [ ]:
learner.save_encoder('adam3_20_enc')

In [ ]:
learner.save('adam3_20')

Test


In [ ]:
def proc_str(s): return TEXT.preprocess(TEXT.tokenize(s))
def num_str(s): return TEXT.numericalize([proc_str(s)])

In [ ]:
m=learner.model

In [ ]:
s="""<CAT> cscv <SUMM> algorithms that"""

In [ ]:
def sample_model(m, s, l=50):
    t = num_str(s)
    m[0].bs=1
    m.eval()
    m.reset()
    res,*_ = m(t)
    print('...', end='')

    for i in range(l):
        n=res[-1].topk(2)[1]
        n = n[1] if n.data[0]==0 else n[0]
        word = TEXT.vocab.itos[n.data[0]]
        print(word, end=' ')
        if word=='<eos>': break
        res,*_ = m(n[0].unsqueeze(0))

    m[0].bs=bs

In [ ]:
sample_model(m,"<CAT> csni <SUMM> algorithms that")


...use the same network as a single node are not able to 
 achieve the same performance as the traditional network - based routing 
 algorithms . in this paper , we propose a novel routing scheme for routing 
 protocols in wireless networks . the proposed scheme is based ...

In [ ]:
sample_model(m,"<CAT> cscv <SUMM> algorithms that")


...use the same data to perform image classification are 
 increasingly being used to improve the performance of image classification 
 algorithms . in this paper , we propose a novel method for image classification 
 using a deep convolutional neural network ( cnn ) . the proposed method is ...

In [ ]:
sample_model(m,"<CAT> cscv <SUMM> algorithms. <TITLE> on ")


...the performance of deep learning for image classification <eos> 

In [ ]:
sample_model(m,"<CAT> csni <SUMM> algorithms. <TITLE> on ")


...the performance of wireless networks <eos> 

In [ ]:
sample_model(m,"<CAT> cscv <SUMM> algorithms. <TITLE> towards ")


...a new approach to image classification <eos> 

In [ ]:
sample_model(m,"<CAT> csni <SUMM> algorithms. <TITLE> towards ")


...a new approach to the analysis of wireless networks <eos> 

Sentiment


In [ ]:
TEXT = pickle.load(open(f'{PATH}models/TEXT.pkl','rb'))

In [ ]:
class ArxivDataset(torchtext.data.Dataset):
    def __init__(self, path, text_field, label_field, **kwargs):
        fields = [('text', text_field), ('label', label_field)]
        examples = []
        for label in ['yes', 'no']:
            fnames = glob(os.path.join(path, label, '*.txt'));
            assert fnames, f"can't find 'yes.txt' or 'no.txt' under {path}/{label}"
            for fname in fnames:
                with open(fname, 'r') as f: text = f.readline()
                examples.append(data.Example.fromlist([text, label], fields))
        super().__init__(examples, fields, **kwargs)

    @staticmethod
    def sort_key(ex): return len(ex.text)
    
    @classmethod
    def splits(cls, text_field, label_field, root='.data',
               train='train', test='test', **kwargs):
        return super().splits(
            root, text_field=text_field, label_field=label_field,
            train=train, validation=None, test=test, **kwargs)

In [ ]:
ARX_LABEL = data.Field(sequential=False)
splits = ArxivDataset.splits(TEXT, ARX_LABEL, PATH, train='trn', test='val')

In [ ]:
md2 = TextData.from_splits(PATH, splits, bs)

In [ ]:
#            dropout=0.3, dropouti=0.4, wdrop=0.3, dropoute=0.05, dropouth=0.2)

In [ ]:
from sklearn.metrics import precision_recall_curve
import matplotlib.pyplot as plt

def prec_at_6(preds,targs):
    precision, recall, _ = precision_recall_curve(targs==2, preds[:,2])
    print(recall[precision>=0.6][0])
    return recall[precision>=0.6][0]

In [ ]:
# dropout=0.4, dropouth=0.3, dropouti=0.65, dropoute=0.1, wdrop=0.5
m3 = md2.get_model(opt_fn, 1500, bptt, emb_sz=em_sz, n_hid=nh, n_layers=nl, 
           dropout=0.1, dropouti=0.65, wdrop=0.5, dropoute=0.1, dropouth=0.3)
m3.reg_fn = partial(seq2seq_reg, alpha=2, beta=1)
m3.clip=25.

In [ ]:
# this notebook has a mess of some things going under 'all/' others not, so a little hack here
!ln -sf ../all/models/adam3_20_enc.h5 {PATH}models/adam3_20_enc.h5
m3.load_encoder(f'adam3_20_enc')
lrs=np.array([1e-4,1e-3,1e-3,1e-2,3e-2])

In [ ]:
m3.freeze_to(-1)
m3.fit(lrs/2, 1, metrics=[accuracy])
m3.unfreeze()
m3.fit(lrs, 1, metrics=[accuracy], cycle_len=1)


[ 0.       0.47654  0.44322  0.78525]                         

[ 0.       0.43033  0.40192  0.80087]                        


In [ ]:
m3.fit(lrs, 2, metrics=[accuracy], cycle_len=4, cycle_save_name='imdb2')


[ 0.       0.42236  0.39006  0.8194 ]                        
[ 1.       0.39477  0.37063  0.82086]                        
[ 2.       0.39389  0.37082  0.82449]                        
[ 3.       0.40728  0.36999  0.82195]                        
[ 4.       0.39308  0.3675   0.81977]                        
[ 5.       0.38662  0.36737  0.8234 ]                        
[ 6.       0.39259  0.36512  0.82486]                        
[ 7.       0.38047  0.36538  0.82522]                        


In [ ]:
prec_at_6(*m3.predict_with_targs())


0.659305993691
Out[ ]:
0.65930599369085174

In [ ]:
m3.fit(lrs, 4, metrics=[accuracy], cycle_len=2, cycle_save_name='imdb2')


[ 0.       0.38752  0.36351  0.82486]                        
[ 1.       0.38664  0.36123  0.82558]                        
[ 2.       0.3904   0.36098  0.82486]                        
[ 3.       0.37319  0.36144  0.82486]                        
[ 4.       0.38074  0.36334  0.82595]                        
[ 5.       0.36405  0.3594   0.82413]                        
[ 6.       0.38781  0.35914  0.82522]                        
[ 7.       0.37722  0.357    0.82631]                        


In [ ]:
prec_at_6(*m3.predict_with_targs())


0.695583596215
Out[ ]:
0.69558359621451105

In [ ]: