In [1]:
import numpy as np
import json
from keras.models import Model
from keras.layers import Input
from keras.layers.convolutional import Conv2D
from keras import backend as K
from collections import OrderedDict


Using TensorFlow backend.

In [2]:
def format_decimal(arr, places=6):
    return [round(x * 10**places) / 10**places for x in arr]

In [3]:
DATA = OrderedDict()

pipeline 0


In [4]:
random_seed = 1000
data_in_shape = (8, 8, 2)

layers = [
    Conv2D(4, (3,3), strides=(1,1), padding='valid', data_format='channels_last', activation='relu', use_bias=True),
    Conv2D(4, (3,3), strides=(1,1), padding='valid', data_format='channels_last', activation='relu', use_bias=True),
    Conv2D(4, (3,3), strides=(1,1), padding='valid', data_format='channels_last', activation='relu', use_bias=True)
]

input_layer = Input(shape=data_in_shape)
x = layers[0](input_layer)
for layer in layers[1:-1]:
    x = layer(x)
output_layer = layers[-1](x)
model = Model(inputs=input_layer, outputs=output_layer)

np.random.seed(random_seed)
data_in = 2 * np.random.random(data_in_shape) - 1

# set weights to random (use seed for reproducibility)
weights = []
for i, w in enumerate(model.get_weights()):
    np.random.seed(random_seed + i)
    weights.append(2 * np.random.random(w.shape) - 1)
model.set_weights(weights)

result = model.predict(np.array([data_in]))
data_out_shape = result[0].shape
data_in_formatted = format_decimal(data_in.ravel().tolist())
data_out_formatted = format_decimal(result[0].ravel().tolist())

DATA['pipeline_00'] = {
    'input': {'data': data_in_formatted, 'shape': data_in_shape},
    'weights': [{'data': format_decimal(w.ravel().tolist()), 'shape': w.shape} for w in weights],
    'expected': {'data': data_out_formatted, 'shape': data_out_shape}
}

export for Keras.js tests


In [5]:
import os

filename = '../../test/data/pipeline/00.json'
if not os.path.exists(os.path.dirname(filename)):
    os.makedirs(os.path.dirname(filename))
with open(filename, 'w') as f:
    json.dump(DATA, f)

In [6]:
print(json.dumps(DATA))


{"pipeline_00": {"input": {"data": [0.307179, -0.769986, 0.900566, -0.035617, 0.744949, -0.575335, -0.918581, -0.205611, -0.533736, 0.683481, -0.585835, 0.484939, -0.215692, -0.635487, 0.487079, -0.860836, 0.770674, 0.905289, 0.862287, -0.169138, -0.942037, 0.964055, -0.320725, 0.413374, -0.276246, -0.929788, 0.710117, 0.314507, 0.531366, 0.108174, 0.770186, 0.808395, -0.979157, -0.850887, -0.510742, -0.73339, 0.39585, -0.20359, 0.766244, -0.637985, -0.135002, -0.963714, 0.382876, -0.060619, -0.743556, 0.782674, 0.836407, -0.853758, -0.909104, -0.122854, 0.203442, -0.379546, 0.363816, -0.581974, 0.039209, 0.131978, -0.117665, -0.724888, -0.572914, -0.733256, -0.355407, -0.532226, 0.054996, 0.131942, -0.123549, -0.356255, 0.119282, 0.730691, 0.694566, -0.784366, -0.367361, -0.181043, 0.374178, 0.404471, -0.107604, -0.159356, 0.605261, 0.077235, 0.847001, -0.876185, -0.264833, 0.940798, 0.398208, 0.781951, -0.492895, 0.451054, -0.593011, 0.075024, -0.526114, -0.127016, 0.596049, -0.383182, 0.243033, -0.120704, 0.826647, 0.317372, 0.307263, -0.283084, 0.045883, -0.909825, -0.811381, 0.843224, -0.853911, 0.858835, 0.43403, 0.244621, 0.509979, -0.71666, 0.587644, 0.450806, 0.082879, -0.034831, -0.047045, -0.107934, 0.63214, -0.451297, -0.876076, 0.920593, -0.083764, 0.515784, 0.362317, 0.083556, -0.734335, -0.493589, -0.112289, 0.1175, 0.627729, -0.364343], "shape": [8, 8, 2]}, "weights": [{"data": [0.307179, -0.769986, 0.900566, -0.035617, 0.744949, -0.575335, -0.918581, -0.205611, -0.533736, 0.683481, -0.585835, 0.484939, -0.215692, -0.635487, 0.487079, -0.860836, 0.770674, 0.905289, 0.862287, -0.169138, -0.942037, 0.964055, -0.320725, 0.413374, -0.276246, -0.929788, 0.710117, 0.314507, 0.531366, 0.108174, 0.770186, 0.808395, -0.979157, -0.850887, -0.510742, -0.73339, 0.39585, -0.20359, 0.766244, -0.637985, -0.135002, -0.963714, 0.382876, -0.060619, -0.743556, 0.782674, 0.836407, -0.853758, -0.909104, -0.122854, 0.203442, -0.379546, 0.363816, -0.581974, 0.039209, 0.131978, -0.117665, -0.724888, -0.572914, -0.733256, -0.355407, -0.532226, 0.054996, 0.131942, -0.123549, -0.356255, 0.119282, 0.730691, 0.694566, -0.784366, -0.367361, -0.181043], "shape": [3, 3, 2, 4]}, {"data": [-0.387536, -0.469873, -0.60788, -0.138957], "shape": [4]}, {"data": [-0.742023, -0.077688, -0.167692, 0.205448, -0.633864, -0.164175, -0.731823, 0.313236, 0.613465, -0.723716, -0.299231, 0.229032, 0.102561, 0.384949, -0.90948, -0.294898, -0.916217, -0.699031, -0.323329, -0.673445, 0.521949, -0.306796, -0.476018, -0.628623, 0.808028, -0.585043, -0.307429, -0.234868, -0.897584, 0.741743, 0.320785, 0.709132, -0.978084, 0.601894, -0.228816, -0.069558, -0.522066, -0.399597, -0.916222, 0.161549, -0.211915, 0.823372, -0.6549, -0.30403, 0.677588, -0.431259, 0.219659, -0.091937, -0.101636, -0.595218, -0.815428, 0.502932, 0.775249, 0.624226, 0.622601, -0.091075, 0.763603, 0.472659, 0.621131, -0.504549, -0.270214, 0.492749, 0.643055, -0.290058, -0.752162, 0.758918, 0.011832, -0.183967, 0.768298, 0.764241, 0.906398, 0.872853, -0.292238, 0.16788, -0.447741, 0.679196, 0.566614, 0.867549, -0.011606, -0.252108, 0.165669, -0.509362, 0.620632, -0.32465, -0.071143, -0.823613, 0.331067, -0.016903, -0.76138, -0.491146, 0.106088, -0.641492, 0.234893, 0.658853, -0.475623, 0.269103, 0.935505, -0.577134, 0.985015, -0.405957, -0.325882, 0.849518, -0.589155, 0.378331, -0.753075, 0.711411, 0.04547, 0.398327, -0.665657, 0.531142, -0.410293, -0.526649, 0.860648, 0.32795, -0.197082, -0.095526, -0.391361, 0.785465, -0.267269, -0.020154, -0.95189, -0.580742, 0.788104, -0.092433, 0.320354, 0.070651, 0.045416, 0.99799, 0.583116, -0.708131, -0.104784, -0.838947, -0.598224, 0.209105, 0.824956, 0.10438, 0.692046, -0.091308, 0.884896, 0.730617, 0.244486, -0.415624, -0.397714, -0.647236], "shape": [3, 3, 4, 4]}, {"data": [0.195612, -0.128132, -0.96626, 0.193375], "shape": [4]}, {"data": [-0.922097, 0.712992, 0.493001, 0.727856, 0.119969, -0.839034, -0.536727, -0.515472, 0.231, 0.214218, -0.791636, -0.148304, 0.309846, 0.742779, -0.123022, 0.427583, -0.882276, 0.818571, 0.043634, 0.454859, -0.007311, -0.744895, -0.368229, 0.324805, -0.388758, -0.556215, -0.542859, 0.685655, 0.350785, -0.312753, 0.591401, 0.95999, 0.136369, -0.58844, -0.506667, -0.208736, 0.548969, 0.653173, 0.128943, 0.180094, -0.16098, 0.208798, 0.666245, 0.347307, -0.384733, -0.88354, -0.328468, -0.515324, 0.479247, -0.360647, 0.09069, -0.221424, 0.091284, 0.202631, 0.208087, 0.582248, -0.164064, -0.925036, -0.678806, -0.212846, 0.960861, 0.536089, -0.038634, -0.473456, -0.409408, 0.620315, -0.873085, -0.695405, -0.024465, 0.762843, -0.928228, 0.557106, -0.65499, -0.918356, 0.815491, 0.996431, 0.115769, -0.751652, 0.075229, 0.969983, -0.80409, -0.080661, -0.644088, 0.160702, -0.486518, -0.09818, -0.191651, -0.961566, -0.238209, 0.260427, 0.085307, -0.664437, 0.458517, -0.824692, 0.312768, -0.253698, 0.761718, 0.551215, 0.566009, -0.85706, 0.687904, -0.283819, 0.5816, 0.820087, -0.028474, 0.588153, -0.221145, 0.049173, 0.529328, -0.359074, -0.463161, 0.493967, -0.852793, -0.552675, -0.695748, -0.178157, 0.477995, 0.858725, 0.120384, -0.515209, 0.204484, -0.025025, -0.654961, 0.239585, -0.654691, -0.651696, -0.699951, -0.054626, -0.232999, 0.464974, 0.285499, -0.311165, 0.18009, -0.100505, 0.303943, 0.265535, -0.960747, -0.542418, 0.195178, -0.848394, 0.0774, 0.250615, -0.690541, -0.106589], "shape": [3, 3, 4, 4]}, {"data": [0.318429, -0.858397, -0.059042, 0.68597], "shape": [4]}], "expected": {"data": [5.009162, 0.0, 0.0, 0.0, 1.770272, 3.243442, 0.0, 3.319521, 0.0, 2.15876, 0.0, 0.0, 4.509293, 0.188208, 0.0, 0.0], "shape": [2, 2, 4]}}}

In [ ]: