이번 튜토리얼에서는 Sequence to Sequence 모델의 핵심인 RNN Encoder Decoder과 Attention 모델을 이해하고, 이를 활용하여 Machine Translator를 구현해보겠습니다.
Machine Traslator에 핵심인 Sequence to Sequence 모델은 아래의 그림과 같이 구성되어 있습니다.
In [ ]:
MAX_LENGTH = 10
In [ ]:
from data_util import prepare_data
input_lang, output_lang, train_pairs, test_pairs = prepare_data('lang1', 'lang2', MAX_LENGTH, 2, True)
Encoder와 Decoder를 구현해봅니다.
구현하고자 하는 Encoder는 다음과 같은 구조로 구성되어 있습니다.
다음, 구현하고자 하는 Decoder의 구조는 아래와 같습니다.
모델의 전체과정 중 Attention 부분은 다음과 같습니다.
Input에 들어온 데이터는 embedding layer을 통해 이전 스텝의 hidden_vector와 결합을 합니다. 이후 softmax function을 거쳐 attn linear function을 두어 encoder_outputs와 matrix multiplication을 할 수 있도록 해줍니다.
Attention이 적용된 context vector는 input vector와 결합이 되어 hidden vector와 같이 GRU function에 들어갑니다. GRU에서 나온 output은 softmax를 처리하여 return 처리를 합니다.
이제 위 내용을 바탕으로 model을 구현해 보겠습니다.
models.py에 NotImplementedError라 표시된 영역에 구현해보겠습니다.
각 구현에 대한 순서는 다음과 같습니다.
In [ ]:
from models import EncoderRNN, AttnDecoderRNN
hidden_size = 256
encoder1 = EncoderRNN(input_lang.n_words, hidden_size)
attn_decoder1 = AttnDecoderRNN(hidden_size, output_lang.n_words,
MAX_LENGTH, dropout_p=0.1)
Training Module 중 Teacher forcing 부분에 대해서 구현을 하고 criterion과 optimizer에 대해서 설정을 해봅니다. train.py에 NotImplementedError라 표시된 영역에 구현해보겠습니다.
Teacher forcing 부분을 구현합니다.
Without Teacher forcing 부분을 구현합니다.
In [ ]:
from train import train_iters
plot_losses = train_iters(encoder1, attn_decoder1, input_lang,
output_lang, train_pairs[:70], 1000, MAX_LENGTH)
In [ ]:
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
%matplotlib inline
def showPlot(points):
plt.figure()
fig, ax = plt.subplots()
# this locator puts ticks at regular intervals
loc = ticker.MultipleLocator(base=0.2)
ax.yaxis.set_major_locator(loc)
plt.plot(points)
In [ ]:
showPlot(plot_losses)
In [ ]:
from predict import ModelPredictor
predictor = ModelPredictor(encoder1, attn_decoder1, input_lang, output_lang, MAX_LENGTH)
predictor.evaluate_randomly(train_pairs[:10])
predictor.predict_sentence("je comprends il est essentiel .")