In [1]:
import numpy as np
from keras.models import Model
from keras.layers import Input
from keras.layers.core import Dense
from keras.layers.convolutional import Conv2D
from keras.layers.wrappers import TimeDistributed
from keras import backend as K
import json
from collections import OrderedDict


Using TensorFlow backend.
/home/leon/miniconda3/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: compiletime version 3.5 of module 'tensorflow.python.framework.fast_tensor_util' does not match runtime version 3.6
  return f(*args, **kwds)

In [2]:
def format_decimal(arr, places=6):
    return [round(x * 10**places) / 10**places for x in arr]

In [3]:
DATA = OrderedDict()

TimeDistributed

[wrappers.TimeDistributed.0] wrap a Dense layer with units 4 (input: 3 x 6)


In [4]:
random_seed = 2000
data_in_shape = (3, 6)

layer_0 = Input(shape=data_in_shape)
layer_1 = TimeDistributed(Dense(4))(layer_0)
model = Model(inputs=layer_0, outputs=layer_1)

np.random.seed(random_seed)
data_in = 2 * np.random.random(data_in_shape) - 1

# set weights to random (use seed for reproducibility)
weights = []
for i, w in enumerate(model.get_weights()):
    np.random.seed(random_seed + i)
    weights.append(2 * np.random.random(w.shape) - 1)
model.set_weights(weights)

result = model.predict(np.array([data_in]))
data_out_shape = result[0].shape
data_in_formatted = format_decimal(data_in.ravel().tolist())
data_out_formatted = format_decimal(result[0].ravel().tolist())

DATA['wrappers.TimeDistributed.0'] = {
    'input': {'data': data_in_formatted, 'shape': data_in_shape},
    'weights': [{'data': format_decimal(w.ravel().tolist()), 'shape': w.shape} for w in weights],
    'expected': {'data': data_out_formatted, 'shape': data_out_shape}
}

[wrappers.TimeDistributed.1] wrap a Conv2D layer with 6 3x3 filters (input: 5x4x4x2)


In [5]:
random_seed = 2000
data_in_shape = (5, 4, 4, 2)

layer_0 = Input(shape=data_in_shape)
layer_1 = TimeDistributed(Conv2D(6, (3,3), data_format='channels_last'))(layer_0)
model = Model(inputs=layer_0, outputs=layer_1)

np.random.seed(random_seed)
data_in = 2 * np.random.random(data_in_shape) - 1

# set weights to random (use seed for reproducibility)
weights = []
for i, w in enumerate(model.get_weights()):
    np.random.seed(random_seed + i)
    weights.append(2 * np.random.random(w.shape) - 1)
model.set_weights(weights)

result = model.predict(np.array([data_in]))
data_out_shape = result[0].shape
data_in_formatted = format_decimal(data_in.ravel().tolist())
data_out_formatted = format_decimal(result[0].ravel().tolist())

DATA['wrappers.TimeDistributed.1'] = {
    'input': {'data': data_in_formatted, 'shape': data_in_shape},
    'weights': [{'data': format_decimal(w.ravel().tolist()), 'shape': w.shape} for w in weights],
    'expected': {'data': data_out_formatted, 'shape': data_out_shape}
}

export for Keras.js tests


In [6]:
import os

filename = '../../../test/data/layers/wrappers/TimeDistributed.json'
if not os.path.exists(os.path.dirname(filename)):
    os.makedirs(os.path.dirname(filename))
with open(filename, 'w') as f:
    json.dump(DATA, f)

In [7]:
print(json.dumps(DATA))


{"wrappers.TimeDistributed.0": {"input": {"data": [0.141035, 0.129058, -0.023116, -0.327044, -0.248264, 0.064072, -0.863787, 0.169058, -0.524204, -0.678487, -0.695762, -0.745862, -0.345118, 0.388308, -0.282067, 0.782731, -0.59624, -0.778795], "shape": [3, 6]}, "weights": [{"data": [0.141035, 0.129058, -0.023116, -0.327044, -0.248264, 0.064072, -0.863787, 0.169058, -0.524204, -0.678487, -0.695762, -0.745862, -0.345118, 0.388308, -0.282067, 0.782731, -0.59624, -0.778795, 0.055114, 0.735311, -0.476251, -0.00121, -0.142871, 0.060008], "shape": [6, 4]}, {"data": [-0.665749, -0.838004, 0.920451, 0.8769], "shape": [4]}], "expected": {"data": [-0.435401, -0.729574, 0.891208, 0.435142, 0.449463, -0.303688, 1.418705, 0.491531, -0.206694, 0.102946, 0.646889, 1.393312], "shape": [3, 4]}}, "wrappers.TimeDistributed.1": {"input": {"data": [0.141035, 0.129058, -0.023116, -0.327044, -0.248264, 0.064072, -0.863787, 0.169058, -0.524204, -0.678487, -0.695762, -0.745862, -0.345118, 0.388308, -0.282067, 0.782731, -0.59624, -0.778795, 0.055114, 0.735311, -0.476251, -0.00121, -0.142871, 0.060008, 0.147894, -0.216289, -0.840972, 0.734562, -0.670993, 0.606963, -0.424144, -0.462858, 0.434956, 0.762811, 0.98424, -0.0833, 0.570259, 0.477388, -0.052834, -0.030331, 0.86601, 0.505308, -0.681422, -0.730379, -0.178646, 0.513073, -0.574974, -0.371941, -0.597453, 0.87685, 0.00883, 0.207463, 0.675097, 0.220365, 0.471146, -0.180468, -0.02072, 0.017849, 0.012965, 0.236682, 0.66921, 0.173131, -0.957385, 0.471247, 0.841267, 0.511354, -0.430488, 0.899198, 0.679766, 0.6299, 0.487356, 0.829739, 0.792468, -0.759192, -0.25087, -0.47263, -0.357234, 0.453933, 0.475895, -0.071184, 0.528075, -0.516599, 0.807998, 0.143212, -0.373437, -0.952796, 0.040104, 0.160803, -0.743346, 0.534637, 0.92295, -0.646198, 0.692993, 0.776042, -0.474673, 0.98543, 0.184755, -0.30293, -0.022748, 0.135056, 0.643602, 0.502182, 0.218371, -0.033313, 0.643199, 0.824619, -0.750012, 0.913739, 0.493227, -0.223239, 0.962789, -0.051118, -0.649982, 0.028419, 0.749459, 0.525479, -0.045743, -0.979086, -0.669124, -0.35604, -0.444705, 0.763697, 0.780956, 0.99184, 0.253779, -0.425107, 0.913943, 0.032233, 0.04514, 0.302513, 0.234363, -0.133738, 0.062949, -0.420645, 0.063063, 0.515796, -0.982415, 0.2082, 0.685465, 0.948766, 0.572355, 0.526904, -0.070458, 0.363551, 0.61422, 0.201742, 0.831489, 0.011985, -0.426671, 0.597698, -0.259001, 0.138813, 0.482189, -0.30467, 0.396641, 0.487531, 0.455809, 0.803298, 0.753098, 0.123971], "shape": [5, 4, 4, 2]}, "weights": [{"data": [0.141035, 0.129058, -0.023116, -0.327044, -0.248264, 0.064072, -0.863787, 0.169058, -0.524204, -0.678487, -0.695762, -0.745862, -0.345118, 0.388308, -0.282067, 0.782731, -0.59624, -0.778795, 0.055114, 0.735311, -0.476251, -0.00121, -0.142871, 0.060008, 0.147894, -0.216289, -0.840972, 0.734562, -0.670993, 0.606963, -0.424144, -0.462858, 0.434956, 0.762811, 0.98424, -0.0833, 0.570259, 0.477388, -0.052834, -0.030331, 0.86601, 0.505308, -0.681422, -0.730379, -0.178646, 0.513073, -0.574974, -0.371941, -0.597453, 0.87685, 0.00883, 0.207463, 0.675097, 0.220365, 0.471146, -0.180468, -0.02072, 0.017849, 0.012965, 0.236682, 0.66921, 0.173131, -0.957385, 0.471247, 0.841267, 0.511354, -0.430488, 0.899198, 0.679766, 0.6299, 0.487356, 0.829739, 0.792468, -0.759192, -0.25087, -0.47263, -0.357234, 0.453933, 0.475895, -0.071184, 0.528075, -0.516599, 0.807998, 0.143212, -0.373437, -0.952796, 0.040104, 0.160803, -0.743346, 0.534637, 0.92295, -0.646198, 0.692993, 0.776042, -0.474673, 0.98543, 0.184755, -0.30293, -0.022748, 0.135056, 0.643602, 0.502182, 0.218371, -0.033313, 0.643199, 0.824619, -0.750012, 0.913739], "shape": [3, 3, 2, 6]}, {"data": [-0.665749, -0.838004, 0.920451, 0.8769, 0.241718, -0.148796], "shape": [6]}], "expected": {"data": [-1.275021, -0.839414, 2.261197, 1.382413, -1.349136, -0.458723, 0.035198, 0.059575, 3.289305, -0.112961, 1.894473, 0.056991, 1.030516, -1.41503, 3.639513, 1.312105, 0.7178, 1.037027, -0.345448, -1.125389, 2.561248, 1.919182, 2.434602, -0.628632, -1.548767, -0.63035, 1.276973, 2.352928, 0.291083, -0.290137, -0.282871, -1.546028, 1.196641, 0.191184, -1.299763, -0.378339, -1.150905, -2.640953, 1.198463, 1.328465, 0.673755, 0.544141, -0.052356, -0.622702, 1.430426, 1.480166, 0.264735, 0.941609, -0.76932, -0.524748, -0.097613, -0.037621, 0.036755, 0.66339, -1.016433, -0.324838, -1.200336, 0.691055, 0.364328, -1.434147, -0.663102, -1.222518, 2.413622, 0.589465, 2.182118, 0.722914, 1.059836, -2.253524, 1.283919, 2.72919, -1.923285, 3.112732, 0.110403, -2.895395, -0.192991, 2.223259, 0.511342, 0.309826, -1.961463, -0.378234, -1.644435, 0.665466, -0.049178, -2.035013, -0.420975, -2.059799, 0.095355, -0.333549, -1.101064, 0.390433, -1.458754, -2.171488, -0.077547, 0.216551, -0.499852, -0.68166, -0.994649, -1.396949, 0.72207, 1.646976, -1.717063, 1.138718, 0.033565, -1.549827, 1.732043, 2.246031, 0.799031, 1.133765, -1.329821, -0.199264, 1.268272, 3.943292, -0.412988, 1.989369, 0.331454, -1.667789, 1.287873, 1.011695, -0.393542, 0.922799], "shape": [5, 2, 2, 6]}}}

In [ ]: