In [1]:
import numpy as np
import json
from keras.models import Model
from keras.layers import Input
from keras.layers.convolutional import Conv2D
from keras.layers.pooling import MaxPooling2D, AveragePooling2D
from keras.layers.normalization import BatchNormalization
from keras import backend as K


Using TensorFlow backend.

In [2]:
def format_decimal(arr, places=8):
    return [round(x * 10**places) / 10**places for x in arr]

pipeline 8


In [6]:
data_in_shape = (8, 8, 2)

conv_0 = Conv2D(4, 3, 3, activation='relu', border_mode='same', subsample=(1, 1), dim_ordering='tf', bias=True)
conv_1 = Conv2D(4, 3, 3, activation='relu', border_mode='valid', subsample=(1, 1), dim_ordering='tf', bias=True)
pool_0 = MaxPooling2D(pool_size=(2, 2), strides=(1, 1), border_mode='same', dim_ordering='tf')

input_layer = Input(shape=data_in_shape)
x = conv_0(input_layer)
x = conv_1(x)
output_layer = pool_0(x)
model = Model(input=input_layer, output=output_layer)

np.random.seed(9000)
data_in = 2 * np.random.random(data_in_shape) - 1

# set weights to random (use seed for reproducibility)
weights = []
for i, w in enumerate(model.get_weights()):
    np.random.seed(9000 + i)
    weights.append(2 * np.random.random(w.shape) - 1)
model.set_weights(weights)

result = model.predict(np.array([data_in]))

print({
    'input': {'data': format_decimal(data_in.ravel().tolist()), 'shape': list(data_in_shape)},
    'weights': [{'data': format_decimal(weights[i].ravel().tolist()), 'shape': list(weights[i].shape)} for i in range(len(weights))],
    'expected': {'data': format_decimal(result[0].ravel().tolist()), 'shape': list(result[0].shape)}
})


{'input': {'shape': [8, 8, 2], 'data': [-0.41555038, 0.97461397, 0.75821321, -0.21914657, -0.97631015, -0.73121474, 0.33607099, -0.19573518, -0.90314981, 0.57328105, -0.31008384, -0.04097988, -0.7944409, -0.59867349, 0.11736359, 0.44927613, -0.45658166, 0.53144057, 0.88095177, -0.26818166, -0.78497489, 0.52071551, -0.77909559, 0.27316327, -0.70334372, 0.68283568, -0.42097229, -0.56161334, 0.32549452, -0.28786984, 0.91681983, -0.40536397, -0.20665999, -0.19036561, -0.42646282, 0.79665162, 0.15267526, 0.59278252, 0.73005423, -0.8377447, -0.40203092, 0.49379166, -0.17837407, -0.24601101, -0.4719636, 0.72086455, 0.6045019, -0.04170968, -0.07386621, -0.58061159, 0.7559498, 0.67788764, 0.15911773, -0.33300175, 0.27103546, 0.4685299, 0.7996605, 0.39635589, -0.62500296, 0.13468982, 0.41144838, 0.10678898, 0.13393713, -0.51624453, 0.75973462, -0.15889839, -0.13356371, -0.24345362, 0.54311613, 0.24951841, 0.86556601, 0.93970336, -0.44348215, -0.14442047, -0.7798827, -0.5337385, 0.2702182, -0.8605702, 0.0616022, 0.2277329, 0.78379221, 0.65838309, -0.34385938, -0.43427408, -0.120432, 0.54343261, 0.50475903, 0.02010303, -0.97302154, 0.13699901, -0.38906848, -0.42578161, 0.59754615, -0.01172449, 0.30915267, -0.91441695, 0.34111593, 0.35975558, 0.42907848, 0.07654757, -0.38833753, -0.56917265, 0.80524169, 0.86392902, 0.56319301, -0.26450284, -0.85947563, -0.90760189, -0.31486217, -0.66787224, -0.64252826, -0.53166519, 0.79790673, 0.71681791, 0.93662571, -0.17843694, -0.34482929, 0.79757309, 0.01266515, -0.05330285, -0.49386733, -0.27143997, -0.19199762, 0.30088794, 0.53047715, -0.8311494, -0.81386653, -0.04151039]}, 'expected': {'shape': [6, 6, 4], 'data': [11.20991135, 2.4998405, 0.0, 2.75583458, 7.7911706, 2.4998405, 1.25839007, 1.43201602, 6.23500109, 0.0, 1.25839007, 1.43201602, 6.23500109, 0.0, 0.0, 3.07163763, 6.42382383, 0.96967769, 0.14339605, 3.07163763, 6.42382383, 0.96967769, 0.14339605, 1.8516854, 11.20991135, 2.4998405, 0.0, 2.00711751, 8.63814354, 2.4998405, 1.25839007, 4.27138615, 8.63814354, 0.0, 1.25839007, 4.27138615, 6.23500109, 0.0, 2.00964189, 3.07163763, 5.667027, 0.96967769, 2.00964189, 3.07163763, 5.667027, 0.96967769, 0.0, 1.8516854, 8.80222321, 0.0, 1.48333907, 0.0, 8.80222321, 0.0, 0.65547383, 4.27138615, 8.63814354, 0.0, 0.25925896, 4.27138615, 4.48585844, 0.0, 2.00964189, 0.00787196, 5.667027, 0.13031936, 2.00964189, 0.00787196, 5.667027, 0.13031936, 0.0, 0.0, 8.80222321, 0.0, 4.75704288, 0.0, 8.80222321, 0.0, 1.40301585, 0.0, 6.61095762, 0.0, 1.40301585, 1.52760398, 3.8409071, 0.0, 0.25925896, 3.58441162, 3.53277445, 0.0, 0.0, 3.58441162, 3.379951, 0.0, 0.0, 0.0, 8.47192097, 0.0, 4.75704288, 0.0, 7.86143923, 0.0, 2.58718634, 3.06576085, 5.90569162, 0.0, 1.40301585, 3.06576085, 5.17238426, 0.0, 0.0, 3.58441162, 5.17238426, 0.0, 0.0, 3.58441162, 4.71007776, 0.0, 0.0, 0.0, 8.47192097, 0.0, 2.58718634, 0.0, 7.67782354, 0.0, 2.58718634, 3.06576085, 5.90569162, 0.0, 0.0, 3.06576085, 5.17238426, 0.0, 0.0, 2.03451777, 5.17238426, 0.0, 0.0, 2.03451777, 4.71007776, 0.0, 0.0, 0.0]}, 'weights': [{'shape': [3, 3, 2, 4], 'data': [-0.41555038, 0.97461397, 0.75821321, -0.21914657, -0.97631015, -0.73121474, 0.33607099, -0.19573518, -0.90314981, 0.57328105, -0.31008384, -0.04097988, -0.7944409, -0.59867349, 0.11736359, 0.44927613, -0.45658166, 0.53144057, 0.88095177, -0.26818166, -0.78497489, 0.52071551, -0.77909559, 0.27316327, -0.70334372, 0.68283568, -0.42097229, -0.56161334, 0.32549452, -0.28786984, 0.91681983, -0.40536397, -0.20665999, -0.19036561, -0.42646282, 0.79665162, 0.15267526, 0.59278252, 0.73005423, -0.8377447, -0.40203092, 0.49379166, -0.17837407, -0.24601101, -0.4719636, 0.72086455, 0.6045019, -0.04170968, -0.07386621, -0.58061159, 0.7559498, 0.67788764, 0.15911773, -0.33300175, 0.27103546, 0.4685299, 0.7996605, 0.39635589, -0.62500296, 0.13468982, 0.41144838, 0.10678898, 0.13393713, -0.51624453, 0.75973462, -0.15889839, -0.13356371, -0.24345362, 0.54311613, 0.24951841, 0.86556601, 0.93970336]}, {'shape': [4], 'data': [-0.81559274, 0.70130393, 0.80150024, 0.18722638]}, {'shape': [3, 3, 4, 4], 'data': [-0.49129679, 0.99335292, -0.09384886, 0.99927519, 0.1555095, -0.14215545, 0.59698772, 0.06297022, 0.80461406, -0.03825503, -0.72438552, 0.55677282, -0.21721763, 0.08387028, -0.88896861, -0.83734603, -0.96649633, -0.17975805, 0.1351767, -0.56443661, 0.49017687, 0.99862875, 0.66386746, -0.77783428, 0.35908482, -0.68846291, 0.9712287, -0.60346077, -0.32886263, -0.98975221, -0.57730462, 0.09302021, 0.87707981, 0.64105047, -0.68734847, 0.47072703, 0.6643682, -0.35243655, 0.58872234, -0.96003558, -0.42869414, -0.58263149, 0.87760874, -0.68192418, 0.31116538, 0.78138193, -0.8448471, -0.40173266, 0.19118648, -0.71754657, 0.36678324, 0.26604816, -0.54068262, 0.62284268, -0.69058066, -0.50971755, 0.13375695, -0.5099981, -0.51845976, 0.76798738, 0.20009001, -0.46146058, -0.65743523, -0.62921159, -0.00295129, -0.4576503, -0.27900797, 0.31531608, 0.80159974, -0.61869678, 0.83833313, -0.50569751, -0.00669321, -0.0733309, -0.59472012, 0.31042903, 0.56855354, -0.6554806, -0.84646972, 0.22104567, 0.49975684, 0.62930184, -0.26526444, -0.43224213, -0.31080217, 0.97250306, 0.00184188, 0.24078872, 0.87760641, -0.91302742, -0.14860063, 0.47567594, 0.44041543, 0.87551868, 0.58701064, 0.31502402, -0.61876824, -0.74926491, 0.91711402, 0.35428865, 0.50698658, 0.12919918, 0.14279217, 0.848144, 0.06334585, -0.74902194, -0.43910205, -0.14009199, 0.37387057, -0.14434464, -0.53373154, 0.70043858, 0.28036532, -0.5061185, -0.40290755, -0.42522438, 0.45044635, 0.61366226, 0.88598663, -0.67757282, 0.31285953, 0.01199989, -0.62344537, 0.32380209, 0.11713391, 0.61614678, -0.84702926, 0.35170671, -0.42284564, 0.33130915, -0.51675575, 0.61663542, 0.34359166, -0.853249, 0.0515108, -0.61426739, 0.36041111, -0.27616768, 0.61014143, 0.09021638, 0.83476087, -0.25739965, -0.22953653, 0.75112283]}, {'shape': [4], 'data': [-0.53630771, 0.71085857, -0.10986421, -0.17026386]}]}

In [ ]: