In [1]:
import numpy as np
import json
from keras.models import Model
from keras.layers import Input
from keras.layers import Conv2D, MaxPooling2D, AveragePooling2D, BatchNormalization, Maximum
from keras import backend as K
from collections import OrderedDict


Using TensorFlow backend.

In [2]:
def format_decimal(arr, places=6):
    return [round(x * 10**places) / 10**places for x in arr]

In [3]:
DATA = OrderedDict()

graph 4


In [4]:
random_seed = 10004
data_in_shape = (8, 8, 2)

input_layer_0 = Input(shape=data_in_shape)
branch_0 = Conv2D(4, (3,3), activation='relu', padding='valid', strides=(1,1), data_format='channels_last', use_bias=True)(input_layer_0)

input_layer_1 = Input(shape=data_in_shape)
branch_1 = Conv2D(4, (3,3), activation='relu', padding='valid', strides=(1,1), data_format='channels_last', use_bias=True)(input_layer_1)

output_layer = Maximum()([branch_0, branch_1])
model = Model(inputs=[input_layer_0, input_layer_1], outputs=output_layer)

data_in = []
for i in range(2):
    np.random.seed(random_seed + i)
    data_in.append(np.expand_dims(2 * np.random.random(data_in_shape) - 1, axis=0))

# set weights to random (use seed for reproducibility)
weights = []
for i, w in enumerate(model.get_weights()):
    np.random.seed(random_seed + i)
    weights.append(2 * np.random.random(w.shape) - 1)
model.set_weights(weights)

result = model.predict(data_in)
data_out_shape = result[0].shape
data_in_formatted = [format_decimal(data_in[i].ravel().tolist()) for i in range(2)]
data_out_formatted = format_decimal(result[0].ravel().tolist())

DATA['graph_04'] = {
    'inputs': [{'data': data_in_formatted[i], 'shape': data_in_shape} for i in range(2)],
    'weights': [{'data': format_decimal(w.ravel().tolist()), 'shape': w.shape} for w in weights],
    'expected': {'data': data_out_formatted, 'shape': data_out_shape}
}

export for Keras.js tests


In [5]:
import os

filename = '../../test/data/graph/04.json'
if not os.path.exists(os.path.dirname(filename)):
    os.makedirs(os.path.dirname(filename))
with open(filename, 'w') as f:
    json.dump(DATA, f)

In [6]:
print(json.dumps(DATA))


{"graph_04": {"inputs": [{"data": [-0.592446, -0.52773, 0.08531, -0.949905, -0.127145, 0.411354, 0.419711, -0.622227, 0.872841, -0.88662, 0.701828, 0.153318, 0.327738, 0.178798, 0.813647, -0.366809, 0.712692, 0.986871, 0.668377, 0.736488, -0.264375, 0.19263, -0.308472, 0.407634, -0.951991, 0.764595, -0.73637, 0.313222, 0.629502, 0.444145, -0.198403, 0.174327, -0.411546, 0.35668, -0.479344, -0.451058, 0.094248, -0.172221, 0.8534, 0.999072, 0.705663, 0.149181, -0.316913, 0.880756, 0.17336, 0.883104, -0.88683, -0.455842, -0.982796, -0.645087, -0.728562, -0.492119, -0.941125, -0.696325, 0.703916, 0.751858, -0.828058, 0.145984, 0.967902, 0.566607, 0.620443, 0.060608, 0.960336, 0.077866, -0.260331, -0.995759, 0.872716, 0.516793, -0.53123, 0.709423, -0.436639, 0.143448, -0.351875, -0.464221, 0.994688, -0.157409, -0.233078, 0.572034, -0.951472, 0.079706, -0.750931, -0.230296, -0.489441, -0.219935, -0.686203, -0.24279, 0.567421, -0.961362, -0.966846, 0.618548, 0.987486, 0.167981, -0.901755, 0.067314, 0.689805, 0.720812, -0.288329, 0.793488, -0.940847, 0.267783, 0.778441, -0.929597, -0.553821, -0.702763, 0.778919, 0.428746, -0.098998, 0.345026, -0.220246, -0.720604, 0.355463, -0.364463, 0.039299, 0.426763, -0.114002, -0.388025, 0.016499, -0.40977, 0.068633, -0.835767, -0.263438, 0.161918, -0.050737, -0.40455, 0.048353, -0.375367, -0.112814, 0.742089], "shape": [8, 8, 2]}, {"data": [-0.905266, -0.940244, 0.975308, -0.726779, 0.318662, -0.747347, -0.333653, -0.875921, 0.959097, -0.046026, 0.500018, -0.914037, 0.884544, 0.190366, -0.635631, -0.160835, -0.946491, 0.638743, -0.403733, 0.431154, 0.771673, -0.000895, -0.667179, -0.734761, 0.366933, -0.524729, 0.171283, 0.611351, 0.477013, 0.586021, 0.75193, -0.305236, -0.374709, 0.756282, 0.61714, -0.823425, -0.135917, -0.45961, -0.35282, 0.165981, -0.875342, -0.910735, 0.54216, 0.791704, -0.363715, -0.379062, -0.778289, -0.503017, -0.498858, 0.73821, -0.560404, 0.383332, -0.162873, 0.462676, -0.888228, 0.603298, -0.211376, -0.410015, 0.969717, -0.800772, -0.326232, -0.903871, -0.472227, 0.527646, 0.845871, -0.555119, -0.242145, -0.720775, 0.230819, 0.098654, -0.136847, -0.75402, -0.526179, -0.035179, -0.592459, -0.656289, -0.379236, -0.727621, 0.320888, -0.769713, -0.606966, 0.368143, 0.40487, -0.269795, -0.724819, -0.941842, 0.852848, -0.641574, 0.309148, 0.993258, -0.642038, 0.330661, 0.749424, -0.088987, 0.480838, 0.820957, -0.81524, 0.597648, 0.665083, 0.971523, -0.716728, -0.592397, -0.09027, -0.321656, 0.650358, -0.517554, 0.110004, 0.153027, -0.582732, 0.693639, 0.176338, -0.061709, -0.20594, -0.404174, 0.614825, 0.744034, -0.611636, -0.090203, -0.063809, 0.111312, -0.41878, -0.268608, 0.490234, -0.907873, 0.49094, 0.962297, 0.209586, -0.363235], "shape": [8, 8, 2]}], "weights": [{"data": [-0.592446, -0.52773, 0.08531, -0.949905, -0.127145, 0.411354, 0.419711, -0.622227, 0.872841, -0.88662, 0.701828, 0.153318, 0.327738, 0.178798, 0.813647, -0.366809, 0.712692, 0.986871, 0.668377, 0.736488, -0.264375, 0.19263, -0.308472, 0.407634, -0.951991, 0.764595, -0.73637, 0.313222, 0.629502, 0.444145, -0.198403, 0.174327, -0.411546, 0.35668, -0.479344, -0.451058, 0.094248, -0.172221, 0.8534, 0.999072, 0.705663, 0.149181, -0.316913, 0.880756, 0.17336, 0.883104, -0.88683, -0.455842, -0.982796, -0.645087, -0.728562, -0.492119, -0.941125, -0.696325, 0.703916, 0.751858, -0.828058, 0.145984, 0.967902, 0.566607, 0.620443, 0.060608, 0.960336, 0.077866, -0.260331, -0.995759, 0.872716, 0.516793, -0.53123, 0.709423, -0.436639, 0.143448], "shape": [3, 3, 2, 4]}, {"data": [-0.905266, -0.940244, 0.975308, -0.726779], "shape": [4]}, {"data": [0.87859, 0.511819, -0.499979, 0.899103, -0.273074, -0.448988, 0.778226, -0.725762, 0.533488, -0.254761, 0.563277, 0.149018, -0.800132, -0.323823, -0.687353, -0.853655, 0.344338, -0.804382, -0.339254, -0.814494, -0.931578, -0.554643, 0.526105, 0.197187, -0.423878, -0.86257, -0.805159, 0.875224, -0.45214, -0.480211, -0.152902, 0.569441, -0.211393, -0.317172, -0.578325, 0.08373, -0.290923, -0.917902, -0.79171, -0.507596, -0.688969, 0.13645, -0.227688, -0.984034, -0.12649, 0.788219, 0.613676, 0.486748, 0.810045, 0.79828, 0.258775, -0.365513, 0.780025, -0.883853, 0.036758, -0.106986, 0.678987, -0.117617, 0.719819, -0.904588, -0.003723, 0.744139, 0.344811, -0.832108, -0.69998, 0.185014, 0.256453, -0.047385, -0.870467, -0.492493, 0.739994, 0.674833], "shape": [3, 3, 2, 4]}, {"data": [-0.719503, 0.396686, 0.710021, -0.058892], "shape": [4]}], "expected": {"data": [1.39413, 0.0, 1.075018, 1.296832, 2.040395, 2.111687, 0.957007, 2.882732, 1.19891, 1.124668, 4.20438, 0.760041, 0.0, 1.000653, 2.520276, 1.464747, 1.241573, 0.857682, 1.930741, 0.063745, 0.913974, 0.0, 3.889487, 0.92037, 2.426044, 0.0, 1.051703, 0.0, 0.0, 0.270952, 1.906174, 0.875992, 1.944361, 1.366018, 1.141477, 0.72947, 0.0, 1.072188, 2.305962, 0.5909, 0.888642, 2.472507, 3.45503, 2.577835, 0.0, 0.0, 4.231712, 1.314825, 0.0, 1.235292, 2.742102, 0.622673, 0.716516, 2.57412, 0.417932, 1.805181, 0.919551, 0.545131, 4.514028, 0.897809, 0.0, 1.011779, 1.214847, 0.063795, 0.0, 0.677629, 2.073788, 0.196829, 0.0, 2.320562, 0.598812, 2.904462, 0.0, 0.007377, 0.547975, 0.0, 0.0, 0.544882, 0.151824, 1.892519, 0.549325, 0.167612, 1.650692, 0.816664, 1.200346, 1.272109, 2.168049, 0.0, 2.14383, 1.757349, 3.249456, 0.776992, 0.158002, 1.49249, 2.390257, 1.487363, 2.583734, 0.383135, 3.27299, 0.077432, 0.990541, 0.976744, 2.891695, 0.207515, 0.0, 3.511224, 1.954723, 0.822155, 0.332409, 1.710274, 2.226701, 2.458072, 0.0, 1.031972, 1.565825, 0.051234, 0.0, 2.441141, 0.619311, 0.0, 0.366473, 0.364582, 1.202601, 1.53687, 1.284916, 0.120911, 0.967073, 2.086205, 0.0, 0.283486, 1.712032, 0.0, 0.0, 1.928514, 0.712862, 0.958929, 0.0, 0.0, 1.825192, 0.729307, 0.0, 1.030668, 2.201941, 0.0], "shape": [6, 6, 4]}}}

In [ ]: