Model definition in NMT-Keras

In this module, we are going to create an encoder-decoder model with:

  • A bidirectional GRU encoder and a GRU decoder
  • An attention model
  • The previously generated word feeds back de decoder
  • MLPs for initializing the initial RNN state
  • Skip connections from inputs to outputs
  • Beam search.

We first setup the machine.


In [1]:
!pip install update pip
!pip uninstall -y keras  # Avoid crashes with pre-installed packages
!git clone https://github.com/lvapeab/nmt-keras
import os
os.chdir('nmt-keras')
!pip install -e .


Collecting update
  Downloading https://files.pythonhosted.org/packages/9f/c4/dfe8a392edd35cc635c35cd3b20df6a746aacdeb39b685d1668b56bf819b/update-0.0.1-py2.py3-none-any.whl
Requirement already satisfied: pip in /usr/local/lib/python3.6/dist-packages (19.3.1)
Collecting style==1.1.0
  Downloading https://files.pythonhosted.org/packages/4c/0b/6be2071e20c621e7beb01b86e8474c2ec344a9750ba5315886f24d6e7386/style-1.1.0-py2.py3-none-any.whl
Installing collected packages: style, update
Successfully installed style-1.1.0 update-0.0.1
Uninstalling Keras-2.2.5:
  Successfully uninstalled Keras-2.2.5
Cloning into 'nmt-keras'...
remote: Enumerating objects: 4, done.
remote: Counting objects: 100% (4/4), done.
remote: Compressing objects: 100% (4/4), done.
remote: Total 4486 (delta 0), reused 0 (delta 0), pack-reused 4482
Receiving objects: 100% (4486/4486), 5.61 MiB | 26.11 MiB/s, done.
Resolving deltas: 100% (3030/3030), done.
Obtaining file:///content/nmt-keras
Collecting coco-caption@ https://github.com/lvapeab/coco-caption/archive/master.zip
  Downloading https://github.com/lvapeab/coco-caption/archive/master.zip
     / 328.0MB 163kB/s
Requirement already satisfied: cloudpickle in /usr/local/lib/python3.6/dist-packages (from nmt-keras==0.6) (1.3.0)
Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from nmt-keras==0.6) (0.16.0)
Collecting keras@ https://github.com/MarcBS/keras/archive/master.zip
  Downloading https://github.com/MarcBS/keras/archive/master.zip
     | 130.2MB 38kB/s
Requirement already satisfied: keras_applications in /usr/local/lib/python3.6/dist-packages (from nmt-keras==0.6) (1.0.8)
Requirement already satisfied: keras_preprocessing in /usr/local/lib/python3.6/dist-packages (from nmt-keras==0.6) (1.1.0)
Requirement already satisfied: h5py in /usr/local/lib/python3.6/dist-packages (from nmt-keras==0.6) (2.8.0)
Requirement already satisfied: matplotlib in /usr/local/lib/python3.6/dist-packages (from nmt-keras==0.6) (3.2.0)
Collecting multimodal-keras-wrapper
  Downloading https://files.pythonhosted.org/packages/7f/b0/cb8d01fc340ea54f67fa4abd94a7be27fab7ad645156e193527eef004257/multimodal_keras_wrapper-3.0.2-py3-none-any.whl (125kB)
     |████████████████████████████████| 133kB 2.8MB/s 
Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from nmt-keras==0.6) (1.18.2)
Requirement already satisfied: scikit-image in /usr/local/lib/python3.6/dist-packages (from nmt-keras==0.6) (0.16.2)
Requirement already satisfied: scikit-learn in /usr/local/lib/python3.6/dist-packages (from nmt-keras==0.6) (0.22.2.post1)
Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from nmt-keras==0.6) (1.12.0)
Requirement already satisfied: tables in /usr/local/lib/python3.6/dist-packages (from nmt-keras==0.6) (3.4.4)
Requirement already satisfied: pandas in /usr/local/lib/python3.6/dist-packages (from nmt-keras==0.6) (0.25.3)
Collecting sacremoses
  Downloading https://files.pythonhosted.org/packages/a6/b4/7a41d630547a4afd58143597d5a49e07bfd4c42914d8335b2a5657efc14b/sacremoses-0.0.38.tar.gz (860kB)
     |████████████████████████████████| 870kB 65.7MB/s 
Requirement already satisfied: scipy in /usr/local/lib/python3.6/dist-packages (from nmt-keras==0.6) (1.4.1)
Requirement already satisfied: tensorflow<2 in /tensorflow-1.15.0/python3.6 (from nmt-keras==0.6) (1.15.0)
Requirement already satisfied: pyyaml in /usr/local/lib/python3.6/dist-packages (from keras@ https://github.com/MarcBS/keras/archive/master.zip->nmt-keras==0.6) (3.13)
Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->nmt-keras==0.6) (2.4.6)
Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->nmt-keras==0.6) (2.8.1)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->nmt-keras==0.6) (1.1.0)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib->nmt-keras==0.6) (0.10.0)
Requirement already satisfied: toolz in /usr/local/lib/python3.6/dist-packages (from multimodal-keras-wrapper->nmt-keras==0.6) (0.10.0)
Requirement already satisfied: sklearn in /usr/local/lib/python3.6/dist-packages (from multimodal-keras-wrapper->nmt-keras==0.6) (0.0)
Requirement already satisfied: cython in /usr/local/lib/python3.6/dist-packages (from multimodal-keras-wrapper->nmt-keras==0.6) (0.29.15)
Requirement already satisfied: networkx>=2.0 in /usr/local/lib/python3.6/dist-packages (from scikit-image->nmt-keras==0.6) (2.4)
Requirement already satisfied: imageio>=2.3.0 in /usr/local/lib/python3.6/dist-packages (from scikit-image->nmt-keras==0.6) (2.4.1)
Requirement already satisfied: pillow>=4.3.0 in /usr/local/lib/python3.6/dist-packages (from scikit-image->nmt-keras==0.6) (7.0.0)
Requirement already satisfied: PyWavelets>=0.4.0 in /usr/local/lib/python3.6/dist-packages (from scikit-image->nmt-keras==0.6) (1.1.1)
Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.6/dist-packages (from scikit-learn->nmt-keras==0.6) (0.14.1)
Requirement already satisfied: numexpr>=2.5.2 in /usr/local/lib/python3.6/dist-packages (from tables->nmt-keras==0.6) (2.7.1)
Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.6/dist-packages (from pandas->nmt-keras==0.6) (2018.9)
Requirement already satisfied: regex in /usr/local/lib/python3.6/dist-packages (from sacremoses->nmt-keras==0.6) (2019.12.20)
Requirement already satisfied: click in /usr/local/lib/python3.6/dist-packages (from sacremoses->nmt-keras==0.6) (7.1.1)
Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from sacremoses->nmt-keras==0.6) (4.38.0)
Requirement already satisfied: opt-einsum>=2.3.2 in /usr/local/lib/python3.6/dist-packages (from tensorflow<2->nmt-keras==0.6) (3.2.0)
Requirement already satisfied: tensorboard<1.16.0,>=1.15.0 in /tensorflow-1.15.0/python3.6 (from tensorflow<2->nmt-keras==0.6) (1.15.0)
Requirement already satisfied: grpcio>=1.8.6 in /usr/local/lib/python3.6/dist-packages (from tensorflow<2->nmt-keras==0.6) (1.24.3)
Requirement already satisfied: termcolor>=1.1.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow<2->nmt-keras==0.6) (1.1.0)
Requirement already satisfied: astor>=0.6.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow<2->nmt-keras==0.6) (0.8.1)
Requirement already satisfied: gast==0.2.2 in /usr/local/lib/python3.6/dist-packages (from tensorflow<2->nmt-keras==0.6) (0.2.2)
Requirement already satisfied: tensorflow-estimator==1.15.1 in /tensorflow-1.15.0/python3.6 (from tensorflow<2->nmt-keras==0.6) (1.15.1)
Requirement already satisfied: absl-py>=0.7.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow<2->nmt-keras==0.6) (0.9.0)
Requirement already satisfied: protobuf>=3.6.1 in /usr/local/lib/python3.6/dist-packages (from tensorflow<2->nmt-keras==0.6) (3.10.0)
Requirement already satisfied: wheel>=0.26 in /usr/local/lib/python3.6/dist-packages (from tensorflow<2->nmt-keras==0.6) (0.34.2)
Requirement already satisfied: wrapt>=1.11.1 in /usr/local/lib/python3.6/dist-packages (from tensorflow<2->nmt-keras==0.6) (1.12.1)
Requirement already satisfied: google-pasta>=0.1.6 in /usr/local/lib/python3.6/dist-packages (from tensorflow<2->nmt-keras==0.6) (0.2.0)
Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from kiwisolver>=1.0.1->matplotlib->nmt-keras==0.6) (46.0.0)
Requirement already satisfied: decorator>=4.3.0 in /usr/local/lib/python3.6/dist-packages (from networkx>=2.0->scikit-image->nmt-keras==0.6) (4.4.2)
Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.6/dist-packages (from tensorboard<1.16.0,>=1.15.0->tensorflow<2->nmt-keras==0.6) (1.0.0)
Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.6/dist-packages (from tensorboard<1.16.0,>=1.15.0->tensorflow<2->nmt-keras==0.6) (3.2.1)
Building wheels for collected packages: coco-caption, keras, sacremoses
  Building wheel for coco-caption (setup.py) ... done
  Created wheel for coco-caption: filename=coco_caption-0.0-cp36-none-any.whl size=216386196 sha256=32cdf3dab5b81aaba168e6f8dcf3887a5d62096e062466ece7358aa1395ee676
  Stored in directory: /tmp/pip-ephem-wheel-cache-leola3cm/wheels/41/f2/84/a8c0f865fa15ab8ea706f8dbae1fa3fb4073bb4854a3763ff6
  Building wheel for keras (setup.py) ... done
  Created wheel for keras: filename=Keras-2.2.4-cp36-none-any.whl size=455356 sha256=563bb398fa258d52a1ff6015ead650b82f7a5ab8e5a520404832f6fa5ab8c7d4
  Stored in directory: /tmp/pip-ephem-wheel-cache-leola3cm/wheels/82/f8/db/7c0c999dced9850abb60944d255a31dbdf10f76f645454b715
  Building wheel for sacremoses (setup.py) ... done
  Created wheel for sacremoses: filename=sacremoses-0.0.38-cp36-none-any.whl size=884628 sha256=a41c8de0c7dd4dc35d4ecf4ceeaf2cff2e5c38a5dd05c87fa2259d072d99b59e
  Stored in directory: /root/.cache/pip/wheels/6d/ec/1a/21b8912e35e02741306f35f66c785f3afe94de754a0eaf1422
Successfully built coco-caption keras sacremoses
Installing collected packages: coco-caption, keras, sacremoses, multimodal-keras-wrapper, nmt-keras
  Running setup.py develop for nmt-keras
Successfully installed coco-caption-0.0 keras-2.2.4 multimodal-keras-wrapper-3.0.2 nmt-keras sacremoses-0.0.38

Let's import necessary modules:


In [2]:
from keras.layers import *
from keras.models import model_from_json, Model
from keras.optimizers import Adam, RMSprop, Nadam, Adadelta, SGD, Adagrad, Adamax
from keras.regularizers import l2
from keras_wrapper.cnn_model import Model_Wrapper
from keras_wrapper.extra.regularize import Regularize


Using TensorFlow backend.

The default version of TensorFlow in Colab will soon switch to TensorFlow 2.x.
We recommend you upgrade now or ensure your notebook will continue to use TensorFlow 1.x via the %tensorflow_version 1.x magic: .

[24/03/2020 15:28:20] NumExpr defaulting to 2 threads.
[24/03/2020 15:28:20] <<< Cupy not available. Using numpy. >>>

And let's define the dimesnions of our model. For instance, a word embedding size of 50 and 100 units in RNNs. The inputs/outpus are defined as .


In [0]:
ids_inputs = ['source_text', 'state_below']
ids_outputs = ['target_text']
word_embedding_size = 50
hidden_state_size = 100
input_vocabulary_size=686  # Autoset in the library
output_vocabulary_size=513  # Autoset in the library

Encoder

Let's define our encoder. First, we have to create an Input layer to connect the input text to our model. Next, we'll apply a word embedding to the sequence of input indices. This word embedding will feed a Bidirectional GRU network, which will produce our sequence of annotations:


In [4]:
# 1. Source text input
src_text = Input(name=ids_inputs[0],
                 batch_shape=tuple([None, None]), # Since the input sequences have variable-length, we do not retrict the Input shape
                 dtype='int32')
# 2. Encoder
# 2.1. Source word embedding
src_embedding = Embedding(input_vocabulary_size, word_embedding_size, 
                          name='source_word_embedding', mask_zero=True # Zeroes as mask
                          )(src_text)
# 2.2. BRNN encoder (GRU/LSTM)
annotations = Bidirectional(GRU(hidden_state_size, 
                                return_sequences=True  # Return the full sequence
                                ),
                            name='bidirectional_encoder',
                            merge_mode='concat')(src_embedding)


WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:532: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.

[24/03/2020 15:29:00] From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:532: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.

WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:4719: The name tf.random_uniform is deprecated. Please use tf.random.uniform instead.

[24/03/2020 15:29:00] From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:4719: The name tf.random_uniform is deprecated. Please use tf.random.uniform instead.

WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:74: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.

[24/03/2020 15:29:00] From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:74: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.

WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:3454: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
[24/03/2020 15:29:01] From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:3454: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where

Decoder

Once we have built the encoder, let's build our decoder. First, we have an additional input: The previously generated word (the so-called state_below). We introduce it by means of an Input layer and a (target language) word embedding:


In [0]:
# 3. Decoder
# 3.1.1. Previously generated words as inputs for training -> Teacher forcing
next_words = Input(name=ids_inputs[1], batch_shape=tuple([None, None]), dtype='int32')
# 3.1.2. Target word embedding
state_below = Embedding(output_vocabulary_size, word_embedding_size,
                        name='target_word_embedding', 
                        mask_zero=True)(next_words)

The initial hidden state of the decoder's GRU is initialized by means of a MLP (in this case, single-layered) from the average of the annotations:


In [0]:
ctx_mean = MaskedMean()(annotations)
annotations = MaskLayer()(annotations)  # We may want the padded annotations

initial_state = Dense(hidden_state_size, name='initial_state',
                      activation='tanh')(ctx_mean)

So, we have the input of our decoder:


In [0]:
input_attentional_decoder = [state_below, annotations, initial_state]

Note that, for a sample, the sequence of annotations and initial state is the same, independently of the decoding time-step. In order to avoid computation time, we build two models, one for training and the other one for sampling. They will share weights, but the sampling model will be made up of two different models. One (model_init) will compute the sequence of annotations and initial_state. The other model (model_next) will compute a single recurrent step, given the sequence of annotations, the previous hidden state and the generated words up to this moment.

Therefore, now we slightly change the form of declaring layers. We must share layers between the decoding models.

So, let's start by building the attentional-conditional GRU:


In [0]:
# Define the AttGRUCond function
sharedAttGRUCond = AttGRUCond(hidden_state_size,
                              return_sequences=True,
                              return_extra_variables=True, # Return attended input and attenton weights 
                              return_states=True # Returns the sequence of hidden states (see discussion above)
                              )
[proj_h, x_att, alphas, h_state] = sharedAttGRUCond(input_attentional_decoder) # Apply shared_AttnGRUCond to our input

Now, we set skip connections between input and output layer. Note that, since we have a temporal dimension because of the RNN decoder, we must apply the layers in a TimeDistributed way. Finally, we will merge all skip-connections and apply a 'tanh' no-linearlity:


In [0]:
# Define layer function
shared_FC_mlp = TimeDistributed(Dense(word_embedding_size, activation='linear',),
                                name='logit_lstm')
# Apply layer function
out_layer_mlp = shared_FC_mlp(proj_h)

# Define layer function
shared_FC_ctx = TimeDistributed(Dense(word_embedding_size, activation='linear'),
                                name='logit_ctx')
# Apply layer function
out_layer_ctx = shared_FC_ctx(x_att)
shared_Lambda_Permute = PermuteGeneral((1, 0, 2))
out_layer_ctx = shared_Lambda_Permute(out_layer_ctx)

# Define layer function
shared_FC_emb = TimeDistributed(Dense(word_embedding_size, activation='linear'),
                                name='logit_emb')
# Apply layer function
out_layer_emb = shared_FC_emb(state_below)

shared_additional_output_merge = Add(name='additional_input')
additional_output = shared_additional_output_merge([out_layer_mlp, out_layer_ctx, out_layer_emb])
shared_activation_tanh = Activation('tanh')
out_layer = shared_activation_tanh(additional_output)

Now, we'll' apply a deep output layer, with linear activation:


In [0]:
shared_deep_out = TimeDistributed(Dense(word_embedding_size, activation='linear', name='maxout_layer'))
out_layer = shared_deep_out(out_layer)

Our output layer is a projection to the target vocabulary space plus a softmax function for obtaining a probability distribution over the target vocabulary words at each timestep.


In [0]:
shared_FC_soft = TimeDistributed(Dense(output_vocabulary_size,
                                               activation='softmax',
                                               name='softmax_layer'),
                                         name=ids_outputs[0])
softout = shared_FC_soft(out_layer)

And we build the Keras model:


In [14]:
model = Model(name='NMT Model', inputs=[src_text, next_words], outputs=softout)

model.summary()


Model: "NMT Model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
source_text (InputLayer)        (None, None)         0                                            
__________________________________________________________________________________________________
source_word_embedding (Embeddin (None, None, 50)     34300       source_text[0][0]                
__________________________________________________________________________________________________
bidirectional_encoder (Bidirect (None, None, 200)    90600       source_word_embedding[0][0]      
__________________________________________________________________________________________________
state_below (InputLayer)        (None, None)         0                                            
__________________________________________________________________________________________________
masked_mean_1 (MaskedMean)      (None, 200)          0           bidirectional_encoder[0][0]      
__________________________________________________________________________________________________
target_word_embedding (Embeddin (None, None, 50)     25650       state_below[0][0]                
__________________________________________________________________________________________________
mask_layer_1 (MaskLayer)        (None, None, 200)    0           bidirectional_encoder[0][0]      
__________________________________________________________________________________________________
initial_state (Dense)           (None, 100)          20100       masked_mean_1[0][0]              
__________________________________________________________________________________________________
att_gru_cond_1 (AttGRUCond)     [(None, None, 100),  135501      target_word_embedding[0][0]      
                                                                 mask_layer_1[0][0]               
                                                                 initial_state[0][0]              
__________________________________________________________________________________________________
logit_ctx (TimeDistributed)     (None, None, 50)     10050       att_gru_cond_1[0][1]             
__________________________________________________________________________________________________
logit_lstm (TimeDistributed)    (None, None, 50)     5050        att_gru_cond_1[0][0]             
__________________________________________________________________________________________________
permute_general_1 (PermuteGener (None, None, 50)     0           logit_ctx[0][0]                  
__________________________________________________________________________________________________
logit_emb (TimeDistributed)     (None, None, 50)     2550        target_word_embedding[0][0]      
__________________________________________________________________________________________________
additional_input (Add)          (None, None, 50)     0           logit_lstm[0][0]                 
                                                                 permute_general_1[0][0]          
                                                                 logit_emb[0][0]                  
__________________________________________________________________________________________________
activation_1 (Activation)       (None, None, 50)     0           additional_input[0][0]           
__________________________________________________________________________________________________
time_distributed_1 (TimeDistrib (None, None, 50)     2550        activation_1[0][0]               
__________________________________________________________________________________________________
target_text (TimeDistributed)   (None, None, 513)    26163       time_distributed_1[0][0]         
==================================================================================================
Total params: 352,514
Trainable params: 352,514
Non-trainable params: 0
__________________________________________________________________________________________________

That's all! We built a NMT Model!

Sampling models

Now, let's build the models required for sampling. Recall that we are building two models, one for encoding the inputs and the other one for advancing steps in the decoding stage.

Let's start with model_init. It will take the usual inputs (src_text and state_below) and will output: 1) The vector probabilities (for timestep 1) 2) The sequence of annotations (from encoder) 3) The current decoder's hidden state

The only restriction here is that the first output must be the output layer (probabilities) of the model.


In [0]:
model_init = Model(inputs=[src_text, next_words], outputs=[softout, annotations, h_state])
# Store inputs and outputs names for model_init
ids_inputs_init = ids_inputs

# first output must be the output probs.
ids_outputs_init = ids_outputs + ['preprocessed_input', 'next_state']

Next, we will be the model_next. It will have the following inputs:

  • Preprocessed input
  • Previously generated word
  • Previous hidden state

And the following outputs:

  • Model probabilities
  • Current hidden state

So first, we define the inputs:


In [0]:
preprocessed_size = hidden_state_size*2
preprocessed_annotations = Input(name='preprocessed_input', shape=tuple([None, preprocessed_size]))
prev_h_state = Input(name='prev_state', shape=tuple([hidden_state_size]))
input_attentional_decoder = [state_below, preprocessed_annotations, prev_h_state]

And now, we build the model, using the functions stored in the 'shared*' variables declared before:


In [0]:
# Apply decoder
[proj_h, x_att, alphas, h_state] = sharedAttGRUCond(input_attentional_decoder)
out_layer_mlp = shared_FC_mlp(proj_h)
out_layer_ctx = shared_FC_ctx(x_att)
out_layer_ctx = shared_Lambda_Permute(out_layer_ctx)
out_layer_emb = shared_FC_emb(state_below)
additional_output = shared_additional_output_merge([out_layer_mlp, out_layer_ctx, out_layer_emb])
out_layer = shared_activation_tanh(additional_output)
out_layer = shared_deep_out(out_layer)
softout = shared_FC_soft(out_layer)
model_next = Model(inputs=[next_words, preprocessed_annotations, prev_h_state],
                   outputs=[softout, preprocessed_annotations, h_state])

Finally, we store inputs/outputs for model_next. In addition, we create a couple of dictionaries, matching inputs/outputs from the different models (model_init->model_next, model_nex->model_next):


In [0]:
# Store inputs and outputs names for model_next
# first input must be previous word
ids_inputs_next = [ids_inputs[1]] + ['preprocessed_input', 'prev_state']
# first output must be the output probs.
ids_outputs_next = ids_outputs + ['preprocessed_input', 'next_state']

# Input -> Output matchings from model_init to model_next and from model_next to model_nextxt
matchings_init_to_next = {'preprocessed_input': 'preprocessed_input', 'next_state': 'prev_state'}
matchings_next_to_next = {'preprocessed_input': 'preprocessed_input', 'next_state': 'prev_state'}

And that's all! For using this model together with the facilities provided by the multimodal_model_wrapper library, we should declare the model as a method of a Model_Wrapper class. A complete example of this can be found at model_zoo.py.