In [1]:
%load_ext autoreload
%autoreload 2

Main Objectives

  1. [ ] write a language translation model for English to Chinese
  2. [ ] Congusion Matrix
  3. [ ] Add Attention

Todo

  1. [ ] add evaluate method

  2. [ ] use torch.utils.dataloader and stick with a standard api for data.

  3. [ ] make input and output language configurable
  4. [ ] train multiple models in parallel, with different language, batch_size etc
  5. [ ] how does batch size affect loss?
  6. [ ] add evaluation print. What is the metric to use for evaluating translation? Take a look at Roberson's notes.

Usage of this Training script

Output Figures are saved by hyper parameter id.

for example, batch_size etc.

loss vs time vs sample.

%%bash source activate deep-learning python -u train.py --batch-size=100 --n-epoch=10


In [2]:
from utils import ledger, Struct
import data
from language import get_language_pairs
from visdom_helper import visdom_helper
from model import VanillaSequenceToSequence

In [9]:
args = Struct(**{'BATCH_SIZE': 10,
                 'BI_DIRECTIONAL': False,
                 'DASH_ID': 'seq-to-seq-experiment',
                 'DEBUG': True,
                 'EVAL_INTERVAL': 10,
                 'TEACHER_FORCING_R': 0.5,
                 'INPUT_LANG': 'eng',
                 'LEARNING_RATE': 0.001,
                 'MAX_DATA_LEN': 10,
                 'MAX_OUTPUT_LEN': 100,
                 'N_EPOCH': 5,
                 'N_LAYERS': 1,
                 'OUTPUT_LANG': 'cmn',
                 'SAVE_INTERVAL': 100})


********<loading train module>********

In [4]:
from train import Session


********<loading train module>********

In [5]:
sess = Session(args)


{'d': {'BATCH_SIZE': 10,
       'BI_DIRECTIONAL': False,
       'DASH_ID': 'seq-to-seq-experiment',
       'DEBUG': True,
       'EVAL_INTERVAL': 10,
       'INPUT_LANG': 'eng',
       'LEARNING_RATE': 0.001,
       'MAX_DATA_LEN': 10,
       'N_EPOCH': 5,
       'N_LAYERS': 1,
       'OUTPUT_LANG': 'cmn',
       'SAVE_INTERVAL': 100}}
Reading Lines... 
Number of sentence pairs after filtering: 15811
eng has  5982 words.
cmn has  3206 words.
Sequence to sequence model graph:
 VanillaSequenceToSequence (
  (encoder): EncoderRNN (
    (embedding): Embedding(5982, 200)
    (gru): GRU(200, 200, batch_first=True)
  )
  (decoder): DecoderRNN (
    (embedding): Embedding(3206, 200)
    (gru): GRU(200, 200, batch_first=True)
    (output_embedding): Linear (200 -> 3206)
    (softmax): Softmax ()
  )
)

In [6]:
for i in range(args.N_EPOCH):
    sess.train()
    sess.ledger.green('epoch {} is complete'.format(i))


1582it [01:47, 14.77it/s]
epoch 0 is complete
1582it [01:51, 10.45it/s]
epoch 1 is complete
1582it [01:50, 10.85it/s]
epoch 2 is complete
1582it [01:49, 14.51it/s]
epoch 3 is complete
1582it [01:49, 14.39it/s]
epoch 4 is complete


In [7]:
# TODO: somehow the evaluation always stops.

In [8]:
sentence = sess.evaluate('This is a job.')
print(sentence)


attention-networks/language.py L68: ledger.debug(indexes[0])
----------------------
1
<SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS><SOS>

In [ ]: