In [1]:
import numpy as np
import json
from keras.models import Model
from keras.layers import Input, Permute, Concatenate
from keras import backend as K
from collections import OrderedDict


Using TensorFlow backend.

In [2]:
def format_decimal(arr, places=6):
    return [round(x * 10**places) / 10**places for x in arr]

In [3]:
DATA = OrderedDict()

graph 0


In [4]:
random_seed = 10000
data_in_shape = (3, 3, 5)

input_layer_0 = Input(shape=data_in_shape)
branch_0 = Permute((1, 3, 2))(input_layer_0)

input_layer_1 = Input(shape=data_in_shape)
branch_1 = Permute((2, 3, 1))(input_layer_1)

output_layer = Concatenate()([branch_0, branch_1])
model = Model(inputs=[input_layer_0, input_layer_1], outputs=output_layer)

data_in = []
for i in range(2):
    np.random.seed(random_seed + i)
    data_in.append(np.expand_dims(2 * np.random.random(data_in_shape) - 1, axis=0))

# set weights to random (use seed for reproducibility)
weights = []
for i, w in enumerate(model.get_weights()):
    np.random.seed(random_seed + i)
    weights.append(2 * np.random.random(w.shape) - 1)
model.set_weights(weights)

result = model.predict(data_in)
data_out_shape = result[0].shape
data_in_formatted = [format_decimal(data_in[i].ravel().tolist()) for i in range(2)]
data_out_formatted = format_decimal(result[0].ravel().tolist())

DATA['graph_00'] = {
    'inputs': [{'data': data_in_formatted[i], 'shape': data_in_shape} for i in range(2)],
    'weights': [{'data': format_decimal(w.ravel().tolist()), 'shape': w.shape} for w in weights],
    'expected': {'data': data_out_formatted, 'shape': data_out_shape}
}

export for Keras.js tests


In [5]:
import os

filename = '../../test/data/graph/00.json'
if not os.path.exists(os.path.dirname(filename)):
    os.makedirs(os.path.dirname(filename))
with open(filename, 'w') as f:
    json.dump(DATA, f)

In [6]:
print(json.dumps(DATA))


{"graph_00": {"inputs": [{"data": [0.09094, -0.656269, 0.705864, -0.639553, 0.1774, -0.75845, 0.109398, -0.835948, 0.98541, -0.494561, 0.06592, 0.784278, -0.515813, -0.164759, 0.735891, -0.42734, -0.724601, -0.898413, 0.856582, 0.5599, 0.388949, 0.906877, 0.863885, -0.585201, -0.553286, 0.995037, -0.015873, -0.54845, 0.975432, 0.513181, 0.165166, 0.132054, -0.256313, -0.094073, 0.634537, -0.795987, 0.834313, -0.912258, 0.149097, -0.639002, 0.271403, 0.788287, 0.31112, -0.49155, 0.193158], "shape": [3, 3, 5]}, {"data": [0.614539, -0.766986, 0.408571, 0.047078, -0.6846, 0.784464, 0.388013, 0.800908, 0.168812, -0.389806, -0.317231, -0.455732, 0.005666, -0.753868, 0.656084, 0.433935, -0.129529, 0.017374, -0.074927, -0.853786, 0.046388, -0.110576, 0.585628, 0.932795, 0.616368, 0.041046, -0.725788, -0.466342, -0.866621, -0.401956, -0.83338, -0.006886, -0.895257, -0.871749, -0.763091, 0.195421, 0.652286, -0.133466, 0.539261, -0.252054, -0.385686, -0.272843, -0.743031, -0.914618, 0.082555], "shape": [3, 3, 5]}], "weights": [], "expected": {"data": [0.09094, -0.75845, 0.06592, 0.614539, 0.433935, -0.83338, -0.656269, 0.109398, 0.784278, -0.766986, -0.129529, -0.006886, 0.705864, -0.835948, -0.515813, 0.408571, 0.017374, -0.895257, -0.639553, 0.98541, -0.164759, 0.047078, -0.074927, -0.871749, 0.1774, -0.494561, 0.735891, -0.6846, -0.853786, -0.763091, -0.42734, 0.388949, 0.995037, 0.784464, 0.046388, 0.195421, -0.724601, 0.906877, -0.015873, 0.388013, -0.110576, 0.652286, -0.898413, 0.863885, -0.54845, 0.800908, 0.585628, -0.133466, 0.856582, -0.585201, 0.975432, 0.168812, 0.932795, 0.539261, 0.5599, -0.553286, 0.513182, -0.389806, 0.616368, -0.252054, 0.165166, -0.795987, 0.271403, -0.317231, 0.041046, -0.385686, 0.132054, 0.834313, 0.788287, -0.455732, -0.725788, -0.272843, -0.256313, -0.912258, 0.31112, 0.005666, -0.466342, -0.743031, -0.094073, 0.149097, -0.49155, -0.753868, -0.866621, -0.914618, 0.634537, -0.639002, 0.193158, 0.656084, -0.401956, 0.082555], "shape": [3, 5, 6]}}}

In [ ]: