In [1]:
import numpy as np
import json
from keras.models import Model
from keras.layers import Input
from keras.layers.convolutional import Conv2D
from keras.layers.pooling import MaxPooling2D, AveragePooling2D
from keras.layers.normalization import BatchNormalization
from keras import backend as K
from collections import OrderedDict


Using TensorFlow backend.

In [2]:
def format_decimal(arr, places=6):
    return [round(x * 10**places) / 10**places for x in arr]

In [3]:
DATA = OrderedDict()

pipeline 9


In [4]:
random_seed = 1009
data_in_shape = (8, 8, 2)

layers = [
    Conv2D(4, (3,3), activation='relu', padding='valid', strides=(1,1), data_format='channels_last', use_bias=True),
    AveragePooling2D(pool_size=(2,2), strides=None, padding='valid', data_format='channels_last')
]

input_layer = Input(shape=data_in_shape)
x = layers[0](input_layer)
for layer in layers[1:-1]:
    x = layer(x)
output_layer = layers[-1](x)
model = Model(inputs=input_layer, outputs=output_layer)

np.random.seed(random_seed)
data_in = 2 * np.random.random(data_in_shape) - 1

# set weights to random (use seed for reproducibility)
weights = []
for i, w in enumerate(model.get_weights()):
    np.random.seed(random_seed + i)
    weights.append(2 * np.random.random(w.shape) - 1)
model.set_weights(weights)

result = model.predict(np.array([data_in]))
data_out_shape = result[0].shape
data_in_formatted = format_decimal(data_in.ravel().tolist())
data_out_formatted = format_decimal(result[0].ravel().tolist())

DATA['pipeline_09'] = {
    'input': {'data': data_in_formatted, 'shape': data_in_shape},
    'weights': [{'data': format_decimal(w.ravel().tolist()), 'shape': w.shape} for w in weights],
    'expected': {'data': data_out_formatted, 'shape': data_out_shape}
}

export for Keras.js tests


In [5]:
import os

filename = '../../test/data/pipeline/09.json'
if not os.path.exists(os.path.dirname(filename)):
    os.makedirs(os.path.dirname(filename))
with open(filename, 'w') as f:
    json.dump(DATA, f)

In [6]:
print(json.dumps(DATA))


{"pipeline_09": {"input": {"data": [0.229761, -0.670851, -0.398017, -0.453102, 0.590253, 0.096484, -0.035974, 0.69389, -0.260438, 0.250677, -0.193674, 0.385879, -0.991763, -0.558061, -0.750891, 0.265556, 0.657619, -0.801723, -0.910484, 0.975926, 0.288351, -0.364498, -0.126447, 0.616925, 0.303193, -0.820024, 0.429665, 0.070106, -0.34556, -0.459221, -0.798301, 0.998903, 0.359268, -0.957685, -0.915632, 0.691774, 0.25984, 0.879641, 0.179776, -0.423543, -0.867718, -0.560873, -0.382501, -0.365361, -0.723601, -0.033603, 0.789498, 0.455418, 0.827951, 0.872766, -0.595781, 0.655746, -0.100867, -0.407415, -0.5028, -0.07739, -0.432834, 0.371304, -0.410596, -0.92421, 0.688391, -0.862805, 0.206217, -0.602252, -0.132926, -0.676393, -0.280987, 0.247491, -0.027147, -0.016659, -0.23147, -0.653909, 0.876245, -0.853047, 0.6974, 0.939353, 0.068803, -0.786326, -0.10847, -0.038066, 0.594608, -0.631675, -0.052451, 0.362388, 0.573359, -0.461688, -0.799513, 0.883652, -0.681072, 0.708611, -0.103832, 0.161348, -0.938899, -0.604966, 0.588076, -0.410424, 0.215717, -0.211664, 0.012579, -0.756258, 0.346215, -0.858999, 0.517538, -0.867138, -0.49847, -0.114061, 0.906122, -0.446741, -0.632029, 0.913293, -0.420459, 0.031947, -0.34244, 0.231598, 0.806704, 0.479821, 0.133654, -0.915619, 0.113133, -0.995506, -0.979647, 0.571933, 0.161284, 0.035078, 0.210916, 0.504783, 0.667269, 0.356783], "shape": [8, 8, 2]}, "weights": [{"data": [0.229761, -0.670851, -0.398017, -0.453102, 0.590253, 0.096484, -0.035974, 0.69389, -0.260438, 0.250677, -0.193674, 0.385879, -0.991763, -0.558061, -0.750891, 0.265556, 0.657619, -0.801723, -0.910484, 0.975926, 0.288351, -0.364498, -0.126447, 0.616925, 0.303193, -0.820024, 0.429665, 0.070106, -0.34556, -0.459221, -0.798301, 0.998903, 0.359268, -0.957685, -0.915632, 0.691774, 0.25984, 0.879641, 0.179776, -0.423543, -0.867718, -0.560873, -0.382501, -0.365361, -0.723601, -0.033603, 0.789498, 0.455418, 0.827951, 0.872766, -0.595781, 0.655746, -0.100867, -0.407415, -0.5028, -0.07739, -0.432834, 0.371304, -0.410596, -0.92421, 0.688391, -0.862805, 0.206217, -0.602252, -0.132926, -0.676393, -0.280987, 0.247491, -0.027147, -0.016659, -0.23147, -0.653909], "shape": [3, 3, 2, 4]}, {"data": [-0.211487, -0.648815, -0.854588, -0.616238], "shape": [4]}], "expected": {"data": [0.509739, 0.237036, 0.293895, 0.282399, 0.423908, 0.008688, 0.304865, 0.666234, 0.022924, 0.541168, 1.17086, 0.0, 0.203349, 0.417828, 0.0, 0.398297, 0.175827, 0.763362, 0.328252, 0.0, 0.780266, 0.128615, 0.007518, 0.537463, 0.079002, 0.150453, 0.0, 0.507228, 0.436817, 0.946573, 0.496548, 0.50641, 1.045505, 0.082524, 0.213519, 0.0], "shape": [3, 3, 4]}}}

In [ ]: