In [1]:
MAX_LENGTH = 10
In [2]:
from data_util import prepare_data
input_lang, output_lang, train_pairs, test_pairs = prepare_data('lang1', 'lang2', MAX_LENGTH, 2, True)
In [3]:
from models_complete import EncoderRNN, AttnDecoderRNN
hidden_size = 256
encoder1 = EncoderRNN(input_lang.n_words, hidden_size)
attn_decoder1 = AttnDecoderRNN(hidden_size, output_lang.n_words,
MAX_LENGTH, dropout_p=0.1)
In [4]:
from train_complete import train_iters
plot_losses = train_iters(encoder1, attn_decoder1, input_lang,
output_lang, train_pairs[:50], 1000, MAX_LENGTH)
In [5]:
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
%matplotlib inline
def showPlot(points):
plt.figure()
fig, ax = plt.subplots()
# this locator puts ticks at regular intervals
loc = ticker.MultipleLocator(base=0.2)
ax.yaxis.set_major_locator(loc)
plt.plot(points)
In [6]:
showPlot(plot_losses)
In [7]:
from predict import ModelPredictor
predictor = ModelPredictor(encoder1, attn_decoder1, input_lang, output_lang, MAX_LENGTH)
predictor.evaluate_randomly(train_pairs[:10])
predictor.predict_sentence("je comprends il est essentiel .")
Out[7]: