In [1]:
import numpy as np
from keras.models import Model
from keras.layers import Input
from keras.layers.core import Dense
from keras.layers.convolutional import Conv2D
from keras.layers.wrappers import TimeDistributed
from keras import backend as K
import json
from collections import OrderedDict


Using TensorFlow backend.

In [2]:
def format_decimal(arr, places=6):
    return [round(x * 10**places) / 10**places for x in arr]

In [3]:
DATA = OrderedDict()

TimeDistributed

[wrappers.TimeDistributed.0] wrap a Dense layer with units 4 (input: 3 x 6)


In [4]:
data_in_shape = (3, 6)

layer_0 = Input(shape=data_in_shape)
layer_1 = TimeDistributed(Dense(4))(layer_0)
model = Model(inputs=layer_0, outputs=layer_1)

# set weights to random (use seed for reproducibility)
weights = []
for i, w in enumerate(model.get_weights()):
    np.random.seed(4000 + i)
    weights.append(2 * np.random.random(w.shape) - 1)
model.set_weights(weights)
weight_names = ['W', 'b']
for w_i, w_name in enumerate(weight_names):
    print('{} shape:'.format(w_name), weights[w_i].shape)
    print('{}:'.format(w_name), format_decimal(weights[w_i].ravel().tolist()))

data_in = 2 * np.random.random(data_in_shape) - 1
result = model.predict(np.array([data_in]))
data_out_shape = result[0].shape
data_in_formatted = format_decimal(data_in.ravel().tolist())
data_out_formatted = format_decimal(result[0].ravel().tolist())
print('')
print('in shape:', data_in_shape)
print('in:', data_in_formatted)
print('out shape:', data_out_shape)
print('out:', data_out_formatted)

DATA['wrappers.TimeDistributed.0'] = {
    'input': {'data': data_in_formatted, 'shape': data_in_shape},
    'weights': [{'data': format_decimal(w.ravel().tolist()), 'shape': w.shape} for w in weights],
    'expected': {'data': data_out_formatted, 'shape': data_out_shape}
}


W shape: (6, 4)
W: [0.317596, 0.688515, -0.688309, -0.48247, 0.387223, -0.718263, 0.281673, -0.106311, 0.576861, -0.083926, 0.631691, 0.92647, 0.579655, -0.024215, -0.805793, -0.842947, -0.955415, 0.656415, 0.44667, 0.633739, 0.701525, 0.917507, -0.185671, -0.105247]
b shape: (4,)
b: [-0.332867, 0.650317, 0.995501, -0.458367]

in shape: (3, 6)
in: [-0.30351, 0.37881, -0.248093, 0.372204, -0.698964, -0.408058, -0.103801, 0.376217, -0.724015, 0.708616, -0.513219, -0.46074, -0.125163, -0.76111, -0.153798, 0.729255, 0.556458, -0.671966]
out shape: (3, 4)
out: [0.171595, -0.652137, 0.618031, -1.295817, -0.05994, -0.407387, 0.000875, -1.993142, -1.33639, 0.854801, 0.555804, -0.650907]

[wrappers.TimeDistributed.1] wrap a Conv2D layer with 6 3x3 filters (input: 5x4x4x2)


In [5]:
data_in_shape = (5, 4, 4, 2)

layer_0 = Input(shape=data_in_shape)
layer_1 = TimeDistributed(Conv2D(6, (3,3), data_format='channels_last'))(layer_0)
model = Model(inputs=layer_0, outputs=layer_1)

# set weights to random (use seed for reproducibility)
weights = []
for i, w in enumerate(model.get_weights()):
    np.random.seed(4010 + i)
    weights.append(2 * np.random.random(w.shape) - 1)
model.set_weights(weights)
weight_names = ['W', 'b']
for w_i, w_name in enumerate(weight_names):
    print('{} shape:'.format(w_name), weights[w_i].shape)
    print('{}:'.format(w_name), format_decimal(weights[w_i].ravel().tolist()))

data_in = 2 * np.random.random(data_in_shape) - 1
result = model.predict(np.array([data_in]))
data_out_shape = result[0].shape
data_in_formatted = format_decimal(data_in.ravel().tolist())
data_out_formatted = format_decimal(result[0].ravel().tolist())
print('')
print('in shape:', data_in_shape)
print('in:', data_in_formatted)
print('out shape:', data_out_shape)
print('out:', data_out_formatted)

DATA['wrappers.TimeDistributed.1'] = {
    'input': {'data': data_in_formatted, 'shape': data_in_shape},
    'weights': [{'data': format_decimal(w.ravel().tolist()), 'shape': w.shape} for w in weights],
    'expected': {'data': data_out_formatted, 'shape': data_out_shape}
}


W shape: (3, 3, 2, 6)
W: [0.971827, -0.898904, -0.987921, 0.529589, 0.043586, -0.541366, 0.316759, 0.351387, -0.292323, 0.445466, -0.922655, 0.437413, -0.483267, -0.478014, 0.7408, -0.595028, -0.718381, 0.349594, -0.091293, 0.14291, 0.633818, -0.686841, -0.925272, -0.740397, 0.070594, 0.67408, 0.455314, -0.402251, 0.288807, 0.001378, 0.42892, -0.251869, 0.06113, -0.703784, 0.002676, 0.965023, 0.758788, 0.1193, 0.749321, -0.017408, -0.004115, 0.18981, -0.91507, 0.132792, -0.219057, 0.19682, -0.512841, 0.954544, 0.794403, -0.663179, -0.05377, -0.855038, -0.486641, 0.625844, -0.945869, -0.474979, 0.922345, -0.334843, -0.469456, -0.394364, 0.543681, -0.817676, 0.6093, -0.77635, -0.508683, 0.22456, 0.696262, 0.079806, -0.182646, -0.718939, 0.962504, -0.386231, 0.860488, -0.918945, -0.800484, -0.590285, 0.409804, -0.822098, 0.3489, -0.4508, 0.913208, -0.414455, 0.97663, 0.956314, -0.55547, 0.594094, -0.552044, -0.137467, 0.539049, -0.320055, -0.335577, 0.974746, -0.634747, 0.085161, -0.127183, -0.061717, -0.411844, 0.774181, 0.223395, 0.163937, -0.606967, 0.178549, -0.005153, 0.452476, 0.373127, -0.726827, -0.395458, -0.769671]
b shape: (6,)
b: [0.180389, 0.629217, -0.656262, -0.476575, -0.36398, 0.987756]

in shape: (5, 4, 4, 2)
in: [-0.579677, 0.883193, 0.651172, -0.820251, -0.64795, 0.857328, -0.4689, 0.356044, -0.641528, -0.531973, -0.33586, -0.438823, 0.682186, 0.215781, -0.401735, 0.169171, 0.869358, -0.204078, -0.661876, -0.616139, -0.453943, -0.569439, -0.25218, 0.156473, 0.194797, -0.923921, 0.652204, -0.11765, 0.86293, 0.314218, -0.878496, -0.364761, -0.647821, 0.296841, 0.280105, 0.2753, -0.959741, -0.148037, -0.489424, -0.88939, 0.704443, 0.08354, 0.930112, -0.87023, -0.212285, 0.750133, 0.343506, -0.82568, 0.391491, 0.149626, 0.003594, -0.181464, -0.499632, 0.20694, 0.1007, 0.39826, 0.609736, -0.765775, -0.728474, -0.011711, 0.543543, 0.174309, 0.105794, -0.009876, -0.694421, -0.157031, 0.670853, -0.581331, 0.739486, -0.886014, -0.637039, 0.725753, 0.61919, 0.447635, 0.167298, 0.164242, -0.615436, -0.503061, 0.981698, -0.392795, 0.532215, 0.761817, 0.735562, -0.236234, -0.856381, 0.22419, -0.221125, 0.133757, -0.011162, -0.88018, -0.433047, -0.825617, 0.693626, -0.185243, -0.824829, 0.07932, 0.336478, 0.370138, -0.685905, -0.462037, 0.563862, 0.490274, 0.934239, -0.129323, 0.717792, -0.73658, -0.939587, 0.796637, -0.131382, -0.79957, -0.271279, 0.816961, -0.082096, 0.64553, -0.106661, 0.651369, -0.843208, -0.221077, 0.758074, 0.156006, -0.429501, 0.191698, 0.988067, -0.277344, 0.757645, -0.877824, 0.053841, 0.394075, 0.786359, 0.735302, 0.247852, -0.310899, 0.703408, -0.848404, 0.455067, 0.295289, -0.629316, 0.626332, -0.075289, -0.442735, -0.219408, -0.766048, 0.303257, 0.142211, 0.910002, -0.780858, 0.333242, -0.533434, 0.572575, 0.355883, -0.671924, 0.22028, -0.505951, -0.317892, 0.609641, -0.360548, 0.490007, 0.441024, 0.660294, 0.850007]
out shape: (5, 2, 2, 6)
out: [2.089554, -2.186939, -1.436176, -0.951733, -0.212962, 2.449681, 1.053569, -0.592297, -0.875753, -0.803289, -0.834779, -0.568349, -0.842922, 3.976765, -1.054281, 0.581773, 0.235047, 0.103039, -0.079684, 0.225164, -2.408352, -1.116154, 1.561833, -0.491674, 2.43274, -0.158393, -0.874487, -1.96851, -0.106465, 1.602375, 0.941225, 0.480547, 0.002478, 1.246195, -1.388929, -1.133004, 1.476556, -0.459852, -2.130519, -0.126113, -1.162246, 1.398016, -0.61384, 1.539333, -0.466156, 0.0395, 0.506595, -1.590958, -1.044266, 0.736233, 0.61792, -0.923799, 1.275832, 1.491487, 1.903216, -2.385963, -1.553725, -0.554848, -0.456638, 1.645426, 0.690056, 0.190637, -2.015925, 1.143469, -2.530136, 1.025159, -0.150503, 2.627801, -1.352068, 1.245647, 1.235627, -0.915363, 0.682646, 0.854592, -0.030856, 0.949627, 1.204568, 1.052329, -0.942961, 2.039314, 0.892454, -1.925232, 0.046332, 2.315713, -2.358422, 1.724373, -1.528506, 1.794933, 0.342617, -0.191888, -0.026605, 0.475714, -1.332559, -1.158213, 0.028725, 1.890396, -0.305622, 0.890336, -3.426138, 1.245994, -2.027975, -0.505022, 1.32001, 0.477823, -2.460816, -0.984189, 1.221664, 0.339475, 1.26535, 2.228118, 0.207158, -0.455112, -0.64988, 0.688864, 0.574933, 1.911587, -1.642423, -1.385077, 0.744757, -0.567276]

export for Keras.js tests


In [6]:
print(json.dumps(DATA))


{"wrappers.TimeDistributed.0": {"expected": {"data": [0.171595, -0.652137, 0.618031, -1.295817, -0.05994, -0.407387, 0.000875, -1.993142, -1.33639, 0.854801, 0.555804, -0.650907], "shape": [3, 4]}, "input": {"data": [-0.30351, 0.37881, -0.248093, 0.372204, -0.698964, -0.408058, -0.103801, 0.376217, -0.724015, 0.708616, -0.513219, -0.46074, -0.125163, -0.76111, -0.153798, 0.729255, 0.556458, -0.671966], "shape": [3, 6]}, "weights": [{"data": [0.317596, 0.688515, -0.688309, -0.48247, 0.387223, -0.718263, 0.281673, -0.106311, 0.576861, -0.083926, 0.631691, 0.92647, 0.579655, -0.024215, -0.805793, -0.842947, -0.955415, 0.656415, 0.44667, 0.633739, 0.701525, 0.917507, -0.185671, -0.105247], "shape": [6, 4]}, {"data": [-0.332867, 0.650317, 0.995501, -0.458367], "shape": [4]}]}, "wrappers.TimeDistributed.1": {"expected": {"data": [2.089554, -2.186939, -1.436176, -0.951733, -0.212962, 2.449681, 1.053569, -0.592297, -0.875753, -0.803289, -0.834779, -0.568349, -0.842922, 3.976765, -1.054281, 0.581773, 0.235047, 0.103039, -0.079684, 0.225164, -2.408352, -1.116154, 1.561833, -0.491674, 2.43274, -0.158393, -0.874487, -1.96851, -0.106465, 1.602375, 0.941225, 0.480547, 0.002478, 1.246195, -1.388929, -1.133004, 1.476556, -0.459852, -2.130519, -0.126113, -1.162246, 1.398016, -0.61384, 1.539333, -0.466156, 0.0395, 0.506595, -1.590958, -1.044266, 0.736233, 0.61792, -0.923799, 1.275832, 1.491487, 1.903216, -2.385963, -1.553725, -0.554848, -0.456638, 1.645426, 0.690056, 0.190637, -2.015925, 1.143469, -2.530136, 1.025159, -0.150503, 2.627801, -1.352068, 1.245647, 1.235627, -0.915363, 0.682646, 0.854592, -0.030856, 0.949627, 1.204568, 1.052329, -0.942961, 2.039314, 0.892454, -1.925232, 0.046332, 2.315713, -2.358422, 1.724373, -1.528506, 1.794933, 0.342617, -0.191888, -0.026605, 0.475714, -1.332559, -1.158213, 0.028725, 1.890396, -0.305622, 0.890336, -3.426138, 1.245994, -2.027975, -0.505022, 1.32001, 0.477823, -2.460816, -0.984189, 1.221664, 0.339475, 1.26535, 2.228118, 0.207158, -0.455112, -0.64988, 0.688864, 0.574933, 1.911587, -1.642423, -1.385077, 0.744757, -0.567276], "shape": [5, 2, 2, 6]}, "input": {"data": [-0.579677, 0.883193, 0.651172, -0.820251, -0.64795, 0.857328, -0.4689, 0.356044, -0.641528, -0.531973, -0.33586, -0.438823, 0.682186, 0.215781, -0.401735, 0.169171, 0.869358, -0.204078, -0.661876, -0.616139, -0.453943, -0.569439, -0.25218, 0.156473, 0.194797, -0.923921, 0.652204, -0.11765, 0.86293, 0.314218, -0.878496, -0.364761, -0.647821, 0.296841, 0.280105, 0.2753, -0.959741, -0.148037, -0.489424, -0.88939, 0.704443, 0.08354, 0.930112, -0.87023, -0.212285, 0.750133, 0.343506, -0.82568, 0.391491, 0.149626, 0.003594, -0.181464, -0.499632, 0.20694, 0.1007, 0.39826, 0.609736, -0.765775, -0.728474, -0.011711, 0.543543, 0.174309, 0.105794, -0.009876, -0.694421, -0.157031, 0.670853, -0.581331, 0.739486, -0.886014, -0.637039, 0.725753, 0.61919, 0.447635, 0.167298, 0.164242, -0.615436, -0.503061, 0.981698, -0.392795, 0.532215, 0.761817, 0.735562, -0.236234, -0.856381, 0.22419, -0.221125, 0.133757, -0.011162, -0.88018, -0.433047, -0.825617, 0.693626, -0.185243, -0.824829, 0.07932, 0.336478, 0.370138, -0.685905, -0.462037, 0.563862, 0.490274, 0.934239, -0.129323, 0.717792, -0.73658, -0.939587, 0.796637, -0.131382, -0.79957, -0.271279, 0.816961, -0.082096, 0.64553, -0.106661, 0.651369, -0.843208, -0.221077, 0.758074, 0.156006, -0.429501, 0.191698, 0.988067, -0.277344, 0.757645, -0.877824, 0.053841, 0.394075, 0.786359, 0.735302, 0.247852, -0.310899, 0.703408, -0.848404, 0.455067, 0.295289, -0.629316, 0.626332, -0.075289, -0.442735, -0.219408, -0.766048, 0.303257, 0.142211, 0.910002, -0.780858, 0.333242, -0.533434, 0.572575, 0.355883, -0.671924, 0.22028, -0.505951, -0.317892, 0.609641, -0.360548, 0.490007, 0.441024, 0.660294, 0.850007], "shape": [5, 4, 4, 2]}, "weights": [{"data": [0.971827, -0.898904, -0.987921, 0.529589, 0.043586, -0.541366, 0.316759, 0.351387, -0.292323, 0.445466, -0.922655, 0.437413, -0.483267, -0.478014, 0.7408, -0.595028, -0.718381, 0.349594, -0.091293, 0.14291, 0.633818, -0.686841, -0.925272, -0.740397, 0.070594, 0.67408, 0.455314, -0.402251, 0.288807, 0.001378, 0.42892, -0.251869, 0.06113, -0.703784, 0.002676, 0.965023, 0.758788, 0.1193, 0.749321, -0.017408, -0.004115, 0.18981, -0.91507, 0.132792, -0.219057, 0.19682, -0.512841, 0.954544, 0.794403, -0.663179, -0.05377, -0.855038, -0.486641, 0.625844, -0.945869, -0.474979, 0.922345, -0.334843, -0.469456, -0.394364, 0.543681, -0.817676, 0.6093, -0.77635, -0.508683, 0.22456, 0.696262, 0.079806, -0.182646, -0.718939, 0.962504, -0.386231, 0.860488, -0.918945, -0.800484, -0.590285, 0.409804, -0.822098, 0.3489, -0.4508, 0.913208, -0.414455, 0.97663, 0.956314, -0.55547, 0.594094, -0.552044, -0.137467, 0.539049, -0.320055, -0.335577, 0.974746, -0.634747, 0.085161, -0.127183, -0.061717, -0.411844, 0.774181, 0.223395, 0.163937, -0.606967, 0.178549, -0.005153, 0.452476, 0.373127, -0.726827, -0.395458, -0.769671], "shape": [3, 3, 2, 6]}, {"data": [0.180389, 0.629217, -0.656262, -0.476575, -0.36398, 0.987756], "shape": [6]}]}}

In [ ]: