In [10]:
import random
import numpy as np
from collections import deque
from keras.models import Sequential, Model
from keras.layers import Dense, Input, Conv2D, Flatten, Activation, MaxPooling2D
from keras.optimizers import Adam
import keras

import logging
import pickle
import os.path

In [11]:
import nnutils

name = "data/CattleG1"

guylaine_input_size = 100

state_width = nnutils.tileWidth
state_height = nnutils.tileHeight
state_channels = 14
ship_input_size = 4
output_size = 6
memory = deque(maxlen=2000)
gamma = 0.95    # discount r
epsilon = 1.0  # exploration
epsilon_min = 0.01
epsilon_decay = 0.995
learning_rate = 0.001

In [12]:
guylaine_input = Input(shape=(guylaine_input_size,), name='ship_guylaine_input')

ship_input = Input(shape=(ship_input_size,), name='ship_input')

x = keras.layers.concatenate([guylaine_input, ship_input])

x = Dense(64, activation='relu')(x)
x = Dense(64, activation='relu')(x)
x = Dense(64, activation='relu')(x)
ship_output = Dense(output_size, activation='sigmoid', name='cattle_output')(x)

model = Model(inputs=[guylaine_input, ship_input], outputs=ship_output)

model.compile(loss='mse',
              optimizer=Adam(lr=learning_rate))

In [13]:
model.load_weights(name)
memory = pickle.load(open(name + '_memory', 'rb'))
epsilon = pickle.load(open(name + '_epsilon', 'rb'))

In [14]:
from matplotlib import pyplot as plt
from IPython.display import clear_output
# updatable plot
# a minimal example (sort of)

class PlotLosses(keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.i = 0
        self.x = []
        self.losses = []
        self.val_losses = []
        
        self.fig = plt.figure()
        
        self.logs = []

    def on_epoch_end(self, epoch, logs={}):
        
        self.logs.append(logs)
        self.x.append(self.i)
        self.losses.append(logs.get('loss'))
        self.val_losses.append(logs.get('val_loss'))
        self.i += 1
        
        clear_output(wait=True)
        plt.plot(self.x, self.losses, label="loss")
        plt.plot(self.x, self.val_losses, label="val_loss")
        plt.legend()
        plt.show();
        
plot_losses = PlotLosses()

In [18]:
minBatchSize = batch_size
if (len(memory) < batch_size):
    minBatchSize = len(memory)

minibatch = random.sample(memory, minBatchSize)
for guylaine_output, ship_state, action, reward, next_guylaine_output, next_ship_state, done in minibatch:
    target = reward

    if not done:
        target = (reward + gamma * model.predict({'ship_guylaine_input': next_guylaine_output, 'ship_input': next_ship_state}))

    target_f = model.predict({'ship_guylaine_input': guylaine_output, 'ship_input': ship_state})

    action_index = np.argmax(action)

    target_f[0][action_index] = target
    model.fit(state, target_f, epochs=1, verbose=0)
if epsilon > epsilon_min:
    epsilon *= epsilon_decay


---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-18-8d6792d757f9> in <module>()
      8 
      9     if not done:
---> 10         target = (reward + gamma * model.predict({'ship_guylaine_input': next_guylaine_output, 'ship_input': next_ship_state}))
     11 
     12     target_f = model.predict({'ship_guylaine_input': guylaine_output, 'ship_input': ship_state})

C:\ProgramData\Anaconda3\lib\site-packages\keras\engine\training.py in predict(self, x, batch_size, verbose, steps)
   1693         x = _standardize_input_data(x, self._feed_input_names,
   1694                                     self._feed_input_shapes,
-> 1695                                     check_batch_axis=False)
   1696         if self.stateful:
   1697             if x[0].shape[0] > batch_size and x[0].shape[0] % batch_size != 0:

C:\ProgramData\Anaconda3\lib\site-packages\keras\engine\training.py in _standardize_input_data(data, names, shapes, check_batch_axis, exception_prefix)
    142                             ' to have shape ' + str(shapes[i]) +
    143                             ' but got array with shape ' +
--> 144                             str(array.shape))
    145     return arrays
    146 

ValueError: Error when checking : expected ship_guylaine_input to have shape (None, 100) but got array with shape (100, 1)

In [ ]: