In [1]:
import numpy as np
import keras as ks
from keras.models import Model
from keras.layers import Input, Dense, Lambda, Reshape, Permute
from keras.layers.convolutional import Conv2D
from keras.layers.pooling import MaxPooling3D, MaxPooling2D
from keras.layers.merge import concatenate
import keras.backend as K
from dataset import *
In [2]:
train_data = Dataset("data/gridworld_8x8.npz", mode='train', imsize=8)
test_data = Dataset("data/gridworld_8x8.npz", mode='test', imsize=8)
In [3]:
def VIN_Block(r, k, ch_q):
conv3 = Conv2D(filters=l_q,
kernel_size=(3, 3),
padding='same',
bias=False)
conv3b = Conv2D(filters=l_q,
kernel_size=(3, 3),
padding='same',
bias=False)
q = conv3(r)
for _ in range(k):
#v = Lambda(lambda x: K.max(x, axis=CHANNEL_AXIS, keepdims=True)),
# output_shape=(sz,sz,1))(q)
v = MaxPooling3D(pool_size=(1,1,ch_q))(q)
rv = concatenate([r, v], axis=3)
q = conv3b(rv)
return q
In [7]:
def VIN(sz, ch_i, k, ch_h, ch_q, ch_a):
map_in = Input(shape=(sz,sz,ch_i))
s = Input(shape=(1,), dtype='int32')
#print(s)
h = Conv2D(filters=ch_h,
kernel_size=(3,3),
padding='same',
activation='relu')(map_in)
r = Conv2D(filters=1,
kernel_size=(3,3),
padding='same',
use_bias=False,
activation=None,
)(h)
conv3 = Conv2D(filters=ch_q,
kernel_size=(3, 3),
padding='same',
use_bias=False)
conv3b = Conv2D(filters=ch_q,
kernel_size=(3, 3),
padding='same',
use_bias=False)
q = conv3(r)
for _ in range(k):
v = Lambda(lambda x: K.max(x, axis=3, keepdims=True), output_shape=(sz,sz,1))(q)
rv = concatenate([r, v], axis=3)
q = conv3b(rv)
#print(q)
q = Reshape(target_shape=(sz * sz, ch_q))(q)
#print(q)
def attention(x):
#x = K.permute_dimensions(x, (1,0,2))
N = K.shape(x)[0]
q_out = K.map_fn(lambda i: K.gather(x[i], s[i,0]), K.arange(0,N), dtype='float32')
return q_out
print(q)
q_out = Lambda(attention, output_shape=(ch_q,))(q)
print(q_out)
out = Dense(units=ch_a, input_shape=(10,), activation='softmax', use_bias=False)(q_out)
print(out)
return Model(inputs=[map_in,s], outputs=out)
model = VIN(8, 2, 10, 150, 10, 8)
In [5]:
model.compile(loss='categorical_crossentropy',
optimizer='sgd',
metrics=['accuracy'])
Xtrain = [train_data.images,train_data.s1*8+train_data.s2]
Ytrain = K.get_value(K.one_hot(train_data.labels, 8))
print(np.shape(Xtrain[0]))
print(np.shape(Ytrain))
model.fit(Xtrain, Ytrain, epochs=5, batch_size=32)
Out[5]:
In [6]:
Xtest = [test_data.images,test_data.s1*8+test_data.s2]
Ytest = K.get_value(K.one_hot(test_data.labels, 8))
model.evaluate(Xtest, Ytest)
Out[6]: