PPN Dynamic Filtering


In [48]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
import matplotlib.pyplot as plt
import numpy as np
np.random.seed(1337)

from keras.datasets import mnist, cifar10
from keras.models import Sequential, Model
from keras.layers import Dense, Dropout, Activation, Flatten, Merge, ThresholdedReLU
from keras.layers import Convolution2D, MaxPooling2D, InputLayer, Input, UpSampling2D, Deconvolution2D
from keras.regularizers import activity_l2
from keras.utils import np_utils
from keras.callbacks import Callback
from keras import backend as K

from ppap.layers import PPDFN
from utils import *


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload

In [64]:
# Modified version of keras/examples/mnist_cnn.py

batch_size = 1024
nb_classes = 10

# input image dimensions
img_rows, img_cols = 28, 28
# number of convolutional filters to use
nb_filters = 32
# size of pooling area for max pooling
pool_size = (2, 2)
# convolution kernel size
kernel_size = (3, 3)

# the data, shuffled and split between train and test sets
(X_train, y_train), (X_test, y_test) = mnist.load_data()

# Explicitly set dim ordering to theano
K.set_image_dim_ordering('th')

if K.image_dim_ordering() == 'th':
    X_train = X_train.reshape(X_train.shape[0], 1, img_rows, img_cols)
    X_test = X_test.reshape(X_test.shape[0], 1, img_rows, img_cols)
    input_shape = (1, img_rows, img_cols)
else:
    X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 1)
    X_test = X_test.reshape(X_test.shape[0], img_rows, img_cols, 1)
    input_shape = (img_rows, img_cols, 1)

X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train /= 255
X_test /= 255
print('X_train shape:', X_train.shape)
print(X_train.shape[0], 'train samples')
print(X_test.shape[0], 'test samples')

# convert class vectors to binary class matrices
Y_train = np_utils.to_categorical(y_train, nb_classes)
Y_test = np_utils.to_categorical(y_test, nb_classes)


X_train shape: (60000, 1, 28, 28)
60000 train samples
10000 test samples

Convolutional Auto-encoder


In [11]:
input_img = Input(shape=(1, 28, 28))

x = Convolution2D(16, 3, 3, activation='relu', border_mode='same')(input_img)
x = MaxPooling2D((2, 2), border_mode='same')(x)
x = Convolution2D(8, 3, 3, activation='relu', border_mode='same')(x)
x = MaxPooling2D((2, 2), border_mode='same')(x)
x = Convolution2D(8, 3, 3, activation='relu', border_mode='same')(x)
encoded = MaxPooling2D((2, 2), border_mode='same')(x)

# at this point the representation is (8, 4, 4) i.e. 128-dimensional

x = Convolution2D(8, 3, 3, activation='relu', border_mode='same')(encoded)
x = UpSampling2D((2, 2))(x)
x = Convolution2D(8, 3, 3, activation='relu', border_mode='same')(x)
x = UpSampling2D((2, 2))(x)
x = Convolution2D(16, 3, 3, activation='relu')(x)
x = UpSampling2D((2, 2))(x)
decoded = Convolution2D(1, 3, 3, activation='sigmoid', border_mode='same')(x)

autoencoder = Model(input_img, decoded)
autoencoder.compile(optimizer='adadelta', loss='binary_crossentropy')

In [35]:
autoencoder.summary()


____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
====================================================================================================
input_3 (InputLayer)             (None, 1, 28, 28)     0                                            
____________________________________________________________________________________________________
convolution2d_12 (Convolution2D) (None, 16, 28, 28)    160         input_3[0][0]                    
____________________________________________________________________________________________________
maxpooling2d_7 (MaxPooling2D)    (None, 16, 14, 14)    0           convolution2d_12[0][0]           
____________________________________________________________________________________________________
convolution2d_13 (Convolution2D) (None, 8, 14, 14)     1160        maxpooling2d_7[0][0]             
____________________________________________________________________________________________________
maxpooling2d_8 (MaxPooling2D)    (None, 8, 7, 7)       0           convolution2d_13[0][0]           
____________________________________________________________________________________________________
convolution2d_14 (Convolution2D) (None, 8, 7, 7)       584         maxpooling2d_8[0][0]             
____________________________________________________________________________________________________
maxpooling2d_9 (MaxPooling2D)    (None, 8, 4, 4)       0           convolution2d_14[0][0]           
____________________________________________________________________________________________________
convolution2d_15 (Convolution2D) (None, 8, 4, 4)       584         maxpooling2d_9[0][0]             
____________________________________________________________________________________________________
upsampling2d_4 (UpSampling2D)    (None, 8, 8, 8)       0           convolution2d_15[0][0]           
____________________________________________________________________________________________________
convolution2d_16 (Convolution2D) (None, 8, 8, 8)       584         upsampling2d_4[0][0]             
____________________________________________________________________________________________________
upsampling2d_5 (UpSampling2D)    (None, 8, 16, 16)     0           convolution2d_16[0][0]           
____________________________________________________________________________________________________
convolution2d_17 (Convolution2D) (None, 16, 14, 14)    1168        upsampling2d_5[0][0]             
____________________________________________________________________________________________________
upsampling2d_6 (UpSampling2D)    (None, 16, 28, 28)    0           convolution2d_17[0][0]           
____________________________________________________________________________________________________
convolution2d_18 (Convolution2D) (None, 1, 28, 28)     145         upsampling2d_6[0][0]             
====================================================================================================
Total params: 4385
____________________________________________________________________________________________________

In [16]:
autoencoder.fit(X_train, X_train,
                nb_epoch=50,
                batch_size=128,
                shuffle=True,
                validation_data=(X_test, X_test))


Train on 60000 samples, validate on 10000 samples
Epoch 1/50
60000/60000 [==============================] - 7s - loss: 0.0993 - val_loss: 0.0972
Epoch 2/50
60000/60000 [==============================] - 7s - loss: 0.0991 - val_loss: 0.0971
Epoch 3/50
60000/60000 [==============================] - 7s - loss: 0.0989 - val_loss: 0.0992
Epoch 4/50
60000/60000 [==============================] - 7s - loss: 0.0988 - val_loss: 0.0973
Epoch 5/50
60000/60000 [==============================] - 7s - loss: 0.0988 - val_loss: 0.0989
Epoch 6/50
60000/60000 [==============================] - 7s - loss: 0.0987 - val_loss: 0.0976
Epoch 7/50
60000/60000 [==============================] - 7s - loss: 0.0985 - val_loss: 0.0989
Epoch 8/50
60000/60000 [==============================] - 7s - loss: 0.0984 - val_loss: 0.0978
Epoch 9/50
60000/60000 [==============================] - 7s - loss: 0.0984 - val_loss: 0.0974
Epoch 10/50
60000/60000 [==============================] - 7s - loss: 0.0982 - val_loss: 0.0976
Epoch 11/50
60000/60000 [==============================] - 7s - loss: 0.0983 - val_loss: 0.0967
Epoch 12/50
60000/60000 [==============================] - 7s - loss: 0.0979 - val_loss: 0.0977
Epoch 13/50
60000/60000 [==============================] - 7s - loss: 0.0977 - val_loss: 0.0972
Epoch 14/50
60000/60000 [==============================] - 7s - loss: 0.0976 - val_loss: 0.0956
Epoch 15/50
60000/60000 [==============================] - 7s - loss: 0.0975 - val_loss: 0.0955
Epoch 16/50
60000/60000 [==============================] - 7s - loss: 0.0973 - val_loss: 0.0978
Epoch 17/50
60000/60000 [==============================] - 7s - loss: 0.0972 - val_loss: 0.0959
Epoch 18/50
60000/60000 [==============================] - 7s - loss: 0.0973 - val_loss: 0.0964
Epoch 19/50
60000/60000 [==============================] - 7s - loss: 0.0970 - val_loss: 0.0952
Epoch 20/50
60000/60000 [==============================] - 7s - loss: 0.0970 - val_loss: 0.0955
Epoch 21/50
60000/60000 [==============================] - 7s - loss: 0.0970 - val_loss: 0.0961
Epoch 22/50
60000/60000 [==============================] - 7s - loss: 0.0968 - val_loss: 0.0958
Epoch 23/50
60000/60000 [==============================] - 7s - loss: 0.0970 - val_loss: 0.0947
Epoch 24/50
60000/60000 [==============================] - 7s - loss: 0.0970 - val_loss: 0.0987
Epoch 25/50
60000/60000 [==============================] - 7s - loss: 0.0965 - val_loss: 0.0950
Epoch 26/50
60000/60000 [==============================] - 7s - loss: 0.0961 - val_loss: 0.0946
Epoch 27/50
60000/60000 [==============================] - 7s - loss: 0.0965 - val_loss: 0.0960
Epoch 28/50
60000/60000 [==============================] - 7s - loss: 0.0964 - val_loss: 0.0953
Epoch 29/50
60000/60000 [==============================] - 7s - loss: 0.0964 - val_loss: 0.0947
Epoch 30/50
60000/60000 [==============================] - 7s - loss: 0.0966 - val_loss: 0.0955
Epoch 31/50
60000/60000 [==============================] - 7s - loss: 0.0964 - val_loss: 0.0951
Epoch 32/50
60000/60000 [==============================] - 7s - loss: 0.0964 - val_loss: 0.0946
Epoch 33/50
60000/60000 [==============================] - 7s - loss: 0.0961 - val_loss: 0.0956
Epoch 34/50
60000/60000 [==============================] - 7s - loss: 0.0959 - val_loss: 0.0940
Epoch 35/50
60000/60000 [==============================] - 7s - loss: 0.0959 - val_loss: 0.0957
Epoch 36/50
60000/60000 [==============================] - 7s - loss: 0.0959 - val_loss: 0.0951
Epoch 37/50
60000/60000 [==============================] - 7s - loss: 0.0960 - val_loss: 0.0943
Epoch 38/50
60000/60000 [==============================] - 7s - loss: 0.0961 - val_loss: 0.0963
Epoch 39/50
60000/60000 [==============================] - 7s - loss: 0.0957 - val_loss: 0.0958
Epoch 40/50
60000/60000 [==============================] - 7s - loss: 0.0957 - val_loss: 0.0957
Epoch 41/50
60000/60000 [==============================] - 7s - loss: 0.0959 - val_loss: 0.0972
Epoch 42/50
60000/60000 [==============================] - 7s - loss: 0.0956 - val_loss: 0.0943
Epoch 43/50
60000/60000 [==============================] - 7s - loss: 0.0956 - val_loss: 0.0947
Epoch 44/50
60000/60000 [==============================] - 7s - loss: 0.0955 - val_loss: 0.0957
Epoch 45/50
60000/60000 [==============================] - 7s - loss: 0.0953 - val_loss: 0.0936
Epoch 46/50
60000/60000 [==============================] - 7s - loss: 0.0953 - val_loss: 0.0960
Epoch 47/50
60000/60000 [==============================] - 7s - loss: 0.0951 - val_loss: 0.0948
Epoch 48/50
60000/60000 [==============================] - 7s - loss: 0.0952 - val_loss: 0.0952
Epoch 49/50
60000/60000 [==============================] - 7s - loss: 0.0953 - val_loss: 0.0955
Epoch 50/50
60000/60000 [==============================] - 7s - loss: 0.0954 - val_loss: 0.0934
Out[16]:
<keras.callbacks.History at 0x7f9cd9b294e0>

In [55]:
decoded_imgs = autoencoder.predict(X_test)
n = 10
plt.figure(figsize=(20, 4))
for i in range(n):
    # display original
    ax = plt.subplot(2, n, i+1)
    plt.imshow(X_test[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    # display reconstruction
    ax = plt.subplot(2, n, i + n +1 )
    plt.imshow(decoded_imgs[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()



In [85]:
decoded_imgs = model.predict(X_test)
n = 10
plt.figure(figsize=(20, 4))
for i in range(n):
    # display original
    ax = plt.subplot(2, n, i+1)
    plt.imshow(X_test[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    # display reconstruction
    ax = plt.subplot(2, n, i + n +1 )
    plt.imshow(decoded_imgs[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()


DFN auto-encoder


In [82]:
# Works on fixed batch size as of now
model = Sequential()
model.add(InputLayer(batch_input_shape=(batch_size, 1, img_rows, img_cols)))
model.add(PPDFN(7))
model.add(Activation('relu'))
model.add(Convolution2D(16, 3, 3, activation='relu', border_mode='same'))
model.add(MaxPooling2D((2, 2), border_mode='same'))
model.add(Convolution2D(8, 3, 3, activation='relu', border_mode='same'))
model.add(MaxPooling2D((2, 2), border_mode='same'))
model.add(Convolution2D(8, 3, 3, activation='relu', border_mode='same'))
model.add(MaxPooling2D((2, 2), border_mode='same'))
model.add(Convolution2D(8, 3, 3, activation='relu', border_mode='same'))
model.add(UpSampling2D((2, 2)))
model.add(Convolution2D(8, 3, 3, activation='relu', border_mode='same'))
model.add(UpSampling2D((2, 2)))
model.add(Convolution2D(16, 3, 3, activation='relu'))
model.add(UpSampling2D((2, 2)))
model.add(Convolution2D(1, 3, 3, activation='sigmoid', border_mode='same'))
model.compile(optimizer='adam', loss='binary_crossentropy')

In [83]:
model.summary()


____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
====================================================================================================
input_23 (InputLayer)            (1024, 1, 28, 28)     0                                            
____________________________________________________________________________________________________
ppdfn_27 (PPDFN)                 (1024, 1, 28, 28)     1499        input_23[0][0]                   
____________________________________________________________________________________________________
activation_18 (Activation)       (1024, 1, 28, 28)     0           ppdfn_27[0][0]                   
____________________________________________________________________________________________________
convolution2d_100 (Convolution2D)(1024, 16, 28, 28)    160         activation_18[0][0]              
____________________________________________________________________________________________________
maxpooling2d_48 (MaxPooling2D)   (1024, 16, 14, 14)    0           convolution2d_100[0][0]          
____________________________________________________________________________________________________
convolution2d_101 (Convolution2D)(1024, 8, 14, 14)     1160        maxpooling2d_48[0][0]            
____________________________________________________________________________________________________
maxpooling2d_49 (MaxPooling2D)   (1024, 8, 7, 7)       0           convolution2d_101[0][0]          
____________________________________________________________________________________________________
convolution2d_102 (Convolution2D)(1024, 8, 7, 7)       584         maxpooling2d_49[0][0]            
____________________________________________________________________________________________________
maxpooling2d_50 (MaxPooling2D)   (1024, 8, 4, 4)       0           convolution2d_102[0][0]          
____________________________________________________________________________________________________
convolution2d_103 (Convolution2D)(1024, 8, 4, 4)       584         maxpooling2d_50[0][0]            
____________________________________________________________________________________________________
upsampling2d_40 (UpSampling2D)   (1024, 8, 8, 8)       0           convolution2d_103[0][0]          
____________________________________________________________________________________________________
convolution2d_104 (Convolution2D)(1024, 8, 8, 8)       584         upsampling2d_40[0][0]            
____________________________________________________________________________________________________
upsampling2d_41 (UpSampling2D)   (1024, 8, 16, 16)     0           convolution2d_104[0][0]          
____________________________________________________________________________________________________
convolution2d_105 (Convolution2D)(1024, 16, 14, 14)    1168        upsampling2d_41[0][0]            
____________________________________________________________________________________________________
upsampling2d_42 (UpSampling2D)   (1024, 16, 28, 28)    0           convolution2d_105[0][0]          
____________________________________________________________________________________________________
convolution2d_106 (Convolution2D)(1024, 1, 28, 28)     145         upsampling2d_42[0][0]            
====================================================================================================
Total params: 5884
____________________________________________________________________________________________________

In [84]:
model.fit(X_train, X_train,
                nb_epoch=20,
                batch_size=2048,
                shuffle=True,
                validation_data=(X_test, X_test))


Train on 60000 samples, validate on 10000 samples
Epoch 1/20
60000/60000 [==============================] - 21s - loss: 0.5194 - val_loss: 0.4117
Epoch 2/20
60000/60000 [==============================] - 21s - loss: 0.3220 - val_loss: 0.2497
Epoch 3/20
60000/60000 [==============================] - 21s - loss: 0.2299 - val_loss: 0.2149
Epoch 4/20
60000/60000 [==============================] - 21s - loss: 0.2063 - val_loss: 0.1981
Epoch 5/20
60000/60000 [==============================] - 22s - loss: 0.1913 - val_loss: 0.1844
Epoch 6/20
60000/60000 [==============================] - 21s - loss: 0.1798 - val_loss: 0.1741
Epoch 7/20
60000/60000 [==============================] - 21s - loss: 0.1710 - val_loss: 0.1662
Epoch 8/20
60000/60000 [==============================] - 21s - loss: 0.1642 - val_loss: 0.1601
Epoch 9/20
60000/60000 [==============================] - 21s - loss: 0.1588 - val_loss: 0.1550
Epoch 10/20
60000/60000 [==============================] - 22s - loss: 0.1543 - val_loss: 0.1517
Epoch 11/20
60000/60000 [==============================] - 22s - loss: 0.1507 - val_loss: 0.1475
Epoch 12/20
60000/60000 [==============================] - 22s - loss: 0.1475 - val_loss: 0.1448
Epoch 13/20
60000/60000 [==============================] - 21s - loss: 0.1448 - val_loss: 0.1422
Epoch 14/20
60000/60000 [==============================] - 22s - loss: 0.1426 - val_loss: 0.1405
Epoch 15/20
60000/60000 [==============================] - 21s - loss: 0.1404 - val_loss: 0.1381
Epoch 16/20
60000/60000 [==============================] - 21s - loss: 0.1385 - val_loss: 0.1363
Epoch 17/20
60000/60000 [==============================] - 22s - loss: 0.1368 - val_loss: 0.1346
Epoch 18/20
60000/60000 [==============================] - 22s - loss: 0.1351 - val_loss: 0.1332
Epoch 19/20
60000/60000 [==============================] - 21s - loss: 0.1336 - val_loss: 0.1314
Epoch 20/20
60000/60000 [==============================] - 22s - loss: 0.1321 - val_loss: 0.1301
Out[84]:
<keras.callbacks.History at 0x7f9cad501a90>

In [86]:
plt.imshow(decoded_imgs[0, 0], cmap="Greys")


Out[86]:
<matplotlib.image.AxesImage at 0x7f9c9cba1390>

In [87]:
get_filter = get_filter_getter(model)
fs = get_filter(X_test[:20])
plt.imshow(filters_image(fs[0]))


Out[87]:
<matplotlib.image.AxesImage at 0x7f9c9c9c2e10>

In [51]:
gen = model.layers[1].gen
a = gen.coordinates_weights.get_value()

In [52]:
gen.coordinates.eval()[0, 0]


Out[52]:
array([[-1.46687257, -1.99878156, -1.47023892, -0.00562654, -0.26125073,
        -1.46154678, -1.93342435, -1.46938038,  0.03133357, -0.27354863,
        -1.45782709, -1.86868143, -1.47022164,  0.06816409, -0.28500938,
        -1.45589375, -1.80462193, -1.47295356,  0.10485051, -0.29553899,
        -1.45594442, -1.74132156, -1.47778535,  0.14137688, -0.30503446,
        -1.45819366, -1.67886233, -1.48494411],
       [ 0.17772591, -0.31338388, -1.46287096, -1.61733186, -1.49467289,
         0.21387909, -0.32046771, -1.47021806, -1.55682278, -1.5072273 ,
         0.24981686, -0.32615986, -1.48048306, -1.49742997, -1.52287006,
         0.28551927, -0.33033112, -1.49391377, -1.43924832, -1.54186344,
         0.32096627, -0.33285215, -1.51074874, -1.38236904, -1.56445992,
         0.35613871, -0.33359873, -1.53120625],
       [-1.32687545, -1.59089041,  0.3910189 , -0.33245715, -1.55547237,
        -1.27283895, -1.62135208,  0.42559183, -0.32933024, -1.58369029,
        -1.22031438, -1.65599632,  0.45984599, -0.32414347, -1.61595011,
        -1.16933608, -1.69491863,  0.49377409, -0.31684983, -1.6522826 ,
        -1.11991572, -1.73815131,  0.52737361, -0.30743331, -1.69265699,
        -1.07204175, -1.78566217,  0.56064719],
       [-0.29590997, -1.7369833 , -1.02567959, -1.83735561,  0.59360188,
        -0.28232673, -1.7851181 , -0.98077452, -1.89308023,  0.62624937,
        -0.2667582 , -1.83687556, -0.93725526, -1.95263875,  0.65860462,
        -0.24930145, -1.89203703, -0.89503837, -2.01580024,  0.69068527,
        -0.23007026, -1.95036447, -0.85403252, -2.08231258,  0.72251052,
        -0.20918892, -2.01160979, -0.814143  ],
       [-2.1519134 ,  0.75410038, -0.18678659, -2.07552481, -0.77527481,
        -2.22433972,  0.78547496, -0.16299266, -2.14186788, -0.73733556,
        -2.29933596,  0.81665361, -0.13793308, -2.2104094 , -0.70023739,
        -2.37665915,  0.84765488, -0.11172745, -2.28093505, -0.66389823,
        -2.45608234,  0.87849611, -0.08448774, -2.35324717, -0.62824243,
        -2.53739595,  0.90919322, -0.05631684],
       [-1.36891639, -1.92194057, -1.36766243, -0.04230534, -0.22951286,
        -1.36209595, -1.85601151, -1.36522198, -0.00522466, -0.24258989,
        -1.35682988, -1.79067695, -1.36442649,  0.03173061, -0.25485671,
        -1.35330796, -1.72600985, -1.36547709,  0.06854519, -0.26621434,
        -1.35174167, -1.6620909 , -1.36859751,  0.105202  , -0.27655265,
        -1.35236382, -1.59900904, -1.37403417],
       [ 0.14168227, -0.28575018, -1.35542846, -1.53686166, -1.382056  ,
         0.17796554, -0.29367465, -1.36120701, -1.47575247, -1.3929503 ,
         0.21402985, -0.3001844 , -1.36998355, -1.41579032, -1.40701759,
         0.24985231, -0.30513152, -1.3820467 , -1.35708535, -1.42456341,
         0.28540966, -0.30836543, -1.39767873, -1.2997458 , -1.44588685,
         0.32067913, -0.30973902, -1.41714239],
       [-1.24387217, -1.47126544,  0.35563946, -0.30911541, -1.44066453,
        -1.18955112, -1.50093973,  0.39027244, -0.30637637, -1.46842146,
        -1.13685012, -1.53509605,  0.4245638 , -0.3014299 , -1.500525  ,
        -1.08581197, -1.5738529 ,  0.4585045 , -0.29421771, -1.53701377,
        -1.03645146, -1.61725116,  0.49209145, -0.28471974, -1.57784915,
        -0.98875386, -1.66524982,  0.52532774],
       [-0.27295604, -1.62291944, -0.94267637, -1.71773064,  0.55822247,
        -0.25898504, -1.67204809, -0.89815146, -1.77450705,  0.59078974,
        -0.24289848, -1.72500837, -0.85509241, -1.83533871,  0.62304801,
        -0.22481474, -1.78153753, -0.8133986 , -1.89994788,  0.65501833,
        -0.20487067, -1.84135342, -0.77296227, -1.96803558,  0.68672353,
        -0.1832134 , -1.90416718, -0.7336728 ],
       [-2.03929639,  0.71818686, -0.15999353, -1.96969485, -0.69542158,
        -2.11342978,  0.74943125, -0.13535896, -2.03766513, -0.65810478,
        -2.19014788,  0.78047866, -0.10945126, -2.10782385, -0.62162519,
        -2.26918268,  0.81134951, -0.0824028 , -2.17993808, -0.58589381,
        -2.3502872 ,  0.84206259, -0.05433499, -2.25379634, -0.55082959,
        -2.43323755,  0.87263495, -0.02535808],
       [-1.27256632, -1.84571397, -1.26678586, -0.07911366, -0.19693793,
        -1.2641995 , -1.77919316, -1.26270866, -0.04190824, -0.21082097,
        -1.25731933, -1.71324134, -1.26020491, -0.00482277, -0.22392911,
        -1.25212336, -1.64793372, -1.25948358,  0.03212685, -0.23615935,
        -1.24883544, -1.58335602, -1.26078188,  0.06892253, -0.24739507,
        -1.24770689, -1.51960444, -1.26436555],
       [ 0.10554405, -0.25750524, -1.24901628, -1.45678556, -1.27052963,
         0.14196891, -0.26634455, -1.25306797, -1.39501584, -1.27959621,
         0.17817251, -0.27375442, -1.26018679, -1.33441937, -1.29190898,
         0.21412873, -0.27956557, -1.27070999, -1.27512538, -1.30782497,
         0.24981031, -0.2836023 , -1.28497362, -1.21726239, -1.3276999 ,
         0.28519014, -0.2856892 , -1.30329621],
       [-1.16095209, -1.35187078,  0.32024252, -0.28566039, -1.32595766,
        -1.10630178, -1.38063419,  0.35494494, -0.2833699 , -1.3531785 ,
        -1.05339575, -1.41422307,  0.38927951, -0.27870283, -1.38510001,
        -1.00228798, -1.45278728,  0.42323491, -0.27158558, -1.42177081,
        -0.95299715, -1.49637818,  0.45680717, -0.26199272, -1.4631424 ,
        -0.90550458, -1.54494429,  0.49000025],
       [-0.24994963, -1.50907314, -0.85975629, -1.59833598,  0.52282554,
        -0.23552996, -1.55934298, -0.81566799, -1.6563201 ,  0.55530083,
        -0.21884862, -1.61367166, -0.77313244, -1.71860027,  0.58744866,
        -0.20005153, -1.67174077, -0.73202783, -1.78483939,  0.61929476,
        -0.17930478, -1.73321438, -0.69222564, -1.85468161,  0.65086621,
        -0.15678348, -1.797755  , -0.65359676],
       [-1.92777002,  0.68219024, -0.13266341, -1.86503792, -0.61601704,
        -2.00376105,  0.71329314, -0.10711396, -1.9347589 , -0.57937008,
        -2.08233237,  0.74419928, -0.08029363, -2.006639  , -0.54354906,
        -2.16318917,  0.77493125, -0.05234787, -2.08042741, -0.50845814,
        -2.24606538,  0.80550927, -0.02340746, -2.1559    , -0.47401139,
        -2.33072424,  0.83595145,  0.00641084],
       [-1.1780026 , -1.77017069, -1.16780007, -0.1160661 , -0.16343179,
        -1.16804719, -1.70304227, -1.16204143, -0.07873252, -0.17814294,
        -1.15949297, -1.63644993, -1.15776575, -0.041512  , -0.19212368,
        -1.15254271, -1.57047129, -1.15518773, -0.00442087, -0.20526837,
        -1.1474303 , -1.5051955 , -1.15455496,  0.032522  , -0.2174551 ,
        -1.14442348, -1.44072545, -1.1561507 ],
       [ 0.06929503, -0.22854427, -1.14382565, -1.37717688, -1.1602962 ,
         0.10587374, -0.23837776, -1.14597595, -1.31467974, -1.16735029,
         0.14223073, -0.24677874, -1.15124488, -1.25337565, -1.17770505,
         0.17833617, -0.2535542 , -1.16002607, -1.19341516, -1.19177747,
         0.21415825, -0.25849879, -1.17272222, -1.13495255, -1.20999336,
         0.24966455, -0.26140279, -1.18972301],
       [-1.07813656, -1.23276532,  0.2848236 , -0.26206294, -1.21137846,
        -1.02310145, -1.26046371,  0.31960714, -0.26029691, -1.23796856,
        -0.96995401, -1.29338515,  0.35399261, -0.25595862, -1.26967502,
        -0.918764  , -1.33172166,  0.38796532, -0.24895343, -1.30656087,
        -0.86955541, -1.37554014,  0.42152029, -0.23924847, -1.34856308,
        -0.82230419, -1.42477381,  0.45466244],
       [-0.22687659, -1.39550006, -0.77694076, -1.4792304 ,  0.48740661,
        -0.21193257, -1.44709158, -0.73335809, -1.53861356,  0.51977521,
        -0.19456224, -1.50298786, -0.69142222, -1.60255277,  0.55179662,
        -0.1749481 , -1.56279886, -0.65098405, -1.67063546,  0.58350217,
        -0.15329334, -1.62612236, -0.61188954, -1.74243569,  0.61492437,
        -0.12980773, -1.69256437, -0.57398808],
       [-1.81753659,  0.64609504, -0.10469665, -1.76175439, -0.53713799,
        -1.8955462 ,  0.67704409, -0.07815305, -1.83335364, -0.50120956,
        -1.97610545,  0.70779872, -0.05035372, -1.90705848, -0.46608657,
        -2.05889344,  0.73838347, -0.02145683, -1.98260105, -0.43166673,
        -2.14362645,  0.76882005,  0.00839799, -2.0597477 , -0.39786047,
        -2.23005724,  0.79912716,  0.0390889 ],
       [-1.08542275, -1.69538653, -1.07091403, -0.15317857, -0.12889159,
        -1.07385051, -1.62763953, -1.06344402, -0.11571456, -0.14444551,
        -1.06357479, -1.56038857, -1.05734634, -0.07835516, -0.1593236 ,
        -1.05479991, -1.49371183, -1.05283713, -0.04111684, -0.17341937,
        -1.04776609, -1.427701  , -1.05017066, -0.00401898, -0.18660764,
        -1.04275465, -1.36246407, -1.04964471],
       [ 0.03291575, -0.19874172, -1.04009163, -1.29812551, -1.05160439,
         0.06966106, -0.20965166, -1.04015088, -1.23482835, -1.05644524,
         0.10618674, -0.21914265, -1.04335248, -1.17273331, -1.06461203,
         0.14245895, -0.22699568, -1.05015576, -1.1120162 , -1.0765909 ,
         0.1784406 , -0.23297134, -1.0610441 , -1.05286193, -1.09289336,
         0.21409273, -0.23681761, -1.07649899],
       [-0.99545461, -1.11402917,  0.24937646, -0.23828359, -1.09696388,
        -0.939964  , -1.14046764,  0.28425604, -0.23713806, -1.1228013 ,
        -0.88652861, -1.17259228,  0.31870222, -0.23319212, -1.15425003,
        -0.83524001, -1.21065617,  0.35269579, -0.22632134, -1.19139361,
        -0.78613001, -1.25474739,  0.3862299 , -0.21648201, -1.23414862,
        -0.7391668 , -1.30477774,  0.41931137],
       [-0.2037178 , -1.28227592, -0.69425881, -1.36049438,  0.45195949,
        -0.18815316, -1.33541346, -0.65126753, -1.42151368,  0.4842034 ,
        -0.16997705, -1.39311743, -0.6100232 , -1.48736608,  0.51607895,
        -0.14942062, -1.45490646, -0.57034171, -1.55754232,  0.54762495,
        -0.12673488, -1.52029729, -0.53203809, -1.63153064,  0.57888043,
        -0.10217169, -1.58883035, -0.49493665],
       [-1.70884478,  0.60988235, -0.07597055, -1.66008568, -0.45887664,
        -1.78904021,  0.64066482, -0.04835046, -1.73368943, -0.42371508,
        -1.87172127,  0.67125773, -0.01950623, -1.80931556, -0.38932714,
        -1.95654273,  0.70168751,  0.01039214, -1.88668275, -0.35560533,
        -2.04320693,  0.73197687,  0.04119804, -1.9655509 , -0.32245767,
        -2.13145971,  0.7621451 ,  0.0727863 ],
       [-0.99504167, -1.62144351, -0.97635502, -0.1904684 , -0.09320527,
        -0.98184228, -1.55307388, -0.96716297, -0.15287314, -0.10960737,
        -0.96981573, -1.4851532 , -0.95921206, -0.1153725 , -0.1253981 ,
        -0.95916265, -1.41775799, -0.95271504, -0.07798267, -0.14047286,
        -0.95012426, -1.35098028, -0.94792694, -0.0407231 , -0.15470599,
        -0.94298953, -1.28493094, -0.94515365],
       [-0.00361709, -0.16794683, -0.93810272, -1.2197417 , -0.94475961,
         0.03330762, -0.180016  , -0.93586993, -1.15556777, -0.94717461,
         0.07001819, -0.19070169, -0.93676227, -1.09258926, -0.95289725,
         0.10647669, -0.19975843, -0.94131273, -1.03101027, -0.96249175,
         0.14264007, -0.2069083 , -0.95010233, -0.97105306, -0.97657293,
         0.17846149, -0.21184857, -0.96373022],
       [-0.91294676, -0.99577498,  0.21389261, -0.2142669 , -0.98276663,
        -0.85690969, -1.02070141,  0.24888742, -0.21386597, -1.00769091,
        -0.80312496, -1.05185962,  0.2834073 , -0.21039607, -1.03882504,
        -0.75171602, -1.08959055,  0.3174262 , -0.20368917, -1.07628322,
        -0.70272636, -1.13401473,  0.35093495, -0.19368593, -1.11995137,
        -0.65611243, -1.18501163,  0.38394275]], dtype=float32)