In [1]:
import torch
from seq2seq.tools.inference import Translator
import warnings
warnings.filterwarnings('ignore')

In [2]:
cuda = False
checkpoint = torch.load('../results/en_he_dual_resumed/checkpoint.pth.tar')
model = checkpoint['model']
src_tok, target_tok = checkpoint['tokenizers'].values()

translation_model = Translator(model,
                               src_tok=src_tok,
                               target_tok=target_tok,
                               beam_size=5,
                               length_normalization_factor=0.45,
                               cuda=cuda)

def translate(s, src_lang, target_lang, target_priming=None):
    translation_model.set_src_language(src_lang)
    translation_model.set_target_language(target_lang)
    pred = translation_model.translate(s, target_priming=target_priming)
    print(pred)

def en2he(s, target_priming=None):
    translate(s, 'en', 'he', target_priming=target_priming)
    
def he2en(s, target_priming=None):
    translate(s, 'he', 'en', target_priming=target_priming)

In [3]:
he2en('שלום עולם')


Hello, world.

In [4]:
he2en('מה שלומך הבוקר?')


How are you this morning?

In [5]:
en2he('hello world')


שלום עולם העולם

In [6]:
en2he('This seem to be working good')


נראה שזה עובד טוב טוב.

In [7]:
en2he('may the force be with you')


ייתכן שהכוח יהיה איתך.

In [8]:
en2he('three dogs?',target_priming='מה אתה עושה')


מה אתה עושה שם?