In [26]:
import gym
import tensorflow as tf
import numpy as np

In [27]:
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import Adam

In [28]:
env = gym.make('LunarLander-v2')


[2017-09-01 15:00:36,734] Making new env: LunarLander-v2

In [4]:
n_actions = env.action_space.n
n_states = env.observation_space.shape
print(n_actions)
print(n_states)


4
(8,)

In [48]:
s = env.reset()
for i in range(10000):
    new_s, reward, done, _ = env.step(env.action_space.sample())
    env.render()
    if done:
        break
#print(i)
env.close()

In [120]:
from sklearn.neural_network import MLPClassifier
agent = MLPClassifier(hidden_layer_sizes=(20,20),
                      activation='tanh',
                      warm_start=True, #keep progress between .fit(...) calls
                      max_iter=1 #make only 1 iteration on each .fit(...)
                     )
#initialize agent to the dimension of state an amount of actions
agent.fit([env.reset()]*n_actions,range(n_actions));


C:\Users\Abdul\Anaconda3\envs\dlnd-tf-lab\lib\site-packages\sklearn\neural_network\multilayer_perceptron.py:563: ConvergenceWarning: Stochastic Optimizer: Maximum iterations reached and the optimization hasn't converged yet.
  % (), ConvergenceWarning)

In [121]:
t_max = 10000
def generate_sample():
    s = env.reset()
    batch_s = []
    batch_a = []
    total_reward = 0
    
    for i in range(t_max):
        # probs = agent.predict(s.reshape(1, 8))
        probs = agent.predict_proba(s.reshape(1, 8))
        a = int(np.random.choice(n_actions, 1, p = probs[0]))
        new_s, r, done, _ = env.step(a)
        batch_s.append(s)
        batch_a.append(a)
        s = new_s
        total_reward = total_reward + r
        if done:
            break
    env.close()
    return batch_s, batch_a, total_reward

In [125]:
iterations = 100
percentile = 70
samples = 250

for i in range(iterations):
    population = [generate_sample() for i in range(samples)]
    batch_states,batch_actions,batch_rewards = map(np.array,zip(*population))
    threshold = np.percentile(batch_rewards, percentile)
    elite_states = batch_states[batch_rewards > threshold]
    elite_actions = batch_actions[batch_rewards > threshold]
    elite_states, elite_actions = map(np.concatenate, [elite_states, elite_actions])
    agent.fit(X=elite_states, y=elite_actions)
    print('Iteration: {0}, Mean Reward: {1:.2f}, Threshold: {2:.2f}'.format(i + 1, np.mean(batch_rewards), threshold))


Iteration: 1, Mean Reward: -269.72, Threshold: -203.86
Iteration: 2, Mean Reward: -214.88, Threshold: -174.57
Iteration: 3, Mean Reward: -193.35, Threshold: -163.07
Iteration: 4, Mean Reward: -181.44, Threshold: -150.71
Iteration: 5, Mean Reward: -179.10, Threshold: -150.28
Iteration: 6, Mean Reward: -164.94, Threshold: -146.30
Iteration: 7, Mean Reward: -166.23, Threshold: -143.38
Iteration: 8, Mean Reward: -164.63, Threshold: -142.32
Iteration: 9, Mean Reward: -158.53, Threshold: -137.96
Iteration: 10, Mean Reward: -151.76, Threshold: -135.14
Iteration: 11, Mean Reward: -146.64, Threshold: -131.94
Iteration: 12, Mean Reward: -143.42, Threshold: -128.11
Iteration: 13, Mean Reward: -141.43, Threshold: -125.82
Iteration: 14, Mean Reward: -137.67, Threshold: -121.26
Iteration: 15, Mean Reward: -134.23, Threshold: -119.86
Iteration: 16, Mean Reward: -136.16, Threshold: -120.03
Iteration: 17, Mean Reward: -136.39, Threshold: -120.20
Iteration: 18, Mean Reward: -133.29, Threshold: -117.34
Iteration: 19, Mean Reward: -129.45, Threshold: -111.65
Iteration: 20, Mean Reward: -116.96, Threshold: -102.58
Iteration: 21, Mean Reward: -111.81, Threshold: -97.48
Iteration: 22, Mean Reward: -106.67, Threshold: -91.14
Iteration: 23, Mean Reward: -103.95, Threshold: -87.00
Iteration: 24, Mean Reward: -99.29, Threshold: -85.47
Iteration: 25, Mean Reward: -97.34, Threshold: -80.83
Iteration: 26, Mean Reward: -92.53, Threshold: -72.83
Iteration: 27, Mean Reward: -86.11, Threshold: -69.51
Iteration: 28, Mean Reward: -80.60, Threshold: -63.33
Iteration: 29, Mean Reward: -75.42, Threshold: -58.15
Iteration: 30, Mean Reward: -70.59, Threshold: -50.97
Iteration: 31, Mean Reward: -68.76, Threshold: -49.85
Iteration: 32, Mean Reward: -61.13, Threshold: -42.77
Iteration: 33, Mean Reward: -51.18, Threshold: -33.70
Iteration: 34, Mean Reward: -43.07, Threshold: -28.81
Iteration: 35, Mean Reward: -39.56, Threshold: -23.69
Iteration: 36, Mean Reward: -36.65, Threshold: -23.10
Iteration: 37, Mean Reward: -33.14, Threshold: -21.38
Iteration: 38, Mean Reward: -29.72, Threshold: -17.45
Iteration: 39, Mean Reward: -33.77, Threshold: -18.18
Iteration: 40, Mean Reward: -32.35, Threshold: -17.26
Iteration: 41, Mean Reward: -31.75, Threshold: -13.29
Iteration: 42, Mean Reward: -24.68, Threshold: -10.19
Iteration: 43, Mean Reward: -25.64, Threshold: -8.92
Iteration: 44, Mean Reward: -23.88, Threshold: -8.31
Iteration: 45, Mean Reward: -23.91, Threshold: -8.06
Iteration: 46, Mean Reward: -20.34, Threshold: -1.55
Iteration: 47, Mean Reward: -17.85, Threshold: -5.08
Iteration: 48, Mean Reward: -14.60, Threshold: -2.69
Iteration: 49, Mean Reward: -15.16, Threshold: -1.01
Iteration: 50, Mean Reward: -13.45, Threshold: -0.35
Iteration: 51, Mean Reward: -11.53, Threshold: 1.86
Iteration: 52, Mean Reward: -7.97, Threshold: 12.26
Iteration: 53, Mean Reward: -6.37, Threshold: 6.34
Iteration: 54, Mean Reward: -7.71, Threshold: 8.11
Iteration: 55, Mean Reward: -6.05, Threshold: 14.80
Iteration: 56, Mean Reward: -10.54, Threshold: 18.15
Iteration: 57, Mean Reward: -12.18, Threshold: 11.31
Iteration: 58, Mean Reward: -9.79, Threshold: 14.14
Iteration: 59, Mean Reward: -5.03, Threshold: 16.74
Iteration: 60, Mean Reward: -1.74, Threshold: 27.77
Iteration: 61, Mean Reward: -0.54, Threshold: 29.65
Iteration: 62, Mean Reward: 2.07, Threshold: 34.06
Iteration: 63, Mean Reward: 3.84, Threshold: 36.15
Iteration: 64, Mean Reward: 14.03, Threshold: 44.41
Iteration: 65, Mean Reward: 15.56, Threshold: 50.69
Iteration: 66, Mean Reward: 16.35, Threshold: 50.72
Iteration: 67, Mean Reward: 26.99, Threshold: 56.30
Iteration: 68, Mean Reward: 17.75, Threshold: 47.04
Iteration: 69, Mean Reward: 8.56, Threshold: 46.65
Iteration: 70, Mean Reward: 13.78, Threshold: 49.92
Iteration: 71, Mean Reward: 17.87, Threshold: 50.34
Iteration: 72, Mean Reward: 26.04, Threshold: 56.95
Iteration: 73, Mean Reward: 20.07, Threshold: 58.79
Iteration: 74, Mean Reward: 24.12, Threshold: 60.84
Iteration: 75, Mean Reward: 32.86, Threshold: 62.61
Iteration: 76, Mean Reward: 20.36, Threshold: 59.52
Iteration: 77, Mean Reward: 25.65, Threshold: 58.96
Iteration: 78, Mean Reward: 19.73, Threshold: 53.70
Iteration: 79, Mean Reward: 30.65, Threshold: 65.19
Iteration: 80, Mean Reward: 30.73, Threshold: 64.52
Iteration: 81, Mean Reward: 26.07, Threshold: 62.50
Iteration: 82, Mean Reward: 27.76, Threshold: 71.62
Iteration: 83, Mean Reward: 29.18, Threshold: 69.42
Iteration: 84, Mean Reward: 34.19, Threshold: 67.37
Iteration: 85, Mean Reward: 31.08, Threshold: 67.35
Iteration: 86, Mean Reward: 32.46, Threshold: 65.26
Iteration: 87, Mean Reward: 25.25, Threshold: 67.23
Iteration: 88, Mean Reward: 32.50, Threshold: 65.15
Iteration: 89, Mean Reward: 33.05, Threshold: 68.15
Iteration: 90, Mean Reward: 33.94, Threshold: 62.61
Iteration: 91, Mean Reward: 34.96, Threshold: 62.65
Iteration: 92, Mean Reward: 35.79, Threshold: 67.68
Iteration: 93, Mean Reward: 35.53, Threshold: 62.79
Iteration: 94, Mean Reward: 34.33, Threshold: 61.85
Iteration: 95, Mean Reward: 26.08, Threshold: 54.09
Iteration: 96, Mean Reward: 28.57, Threshold: 55.93
Iteration: 97, Mean Reward: 29.68, Threshold: 56.43
Iteration: 98, Mean Reward: 28.07, Threshold: 54.67
Iteration: 99, Mean Reward: 40.48, Threshold: 68.97
Iteration: 100, Mean Reward: 37.26, Threshold: 63.06

In [155]:
from PIL import Image
s = env.reset()
for i in range(1000):
    render = env.render('rgb_array')
    if i%5==0:
        img = Image.fromarray(render, 'RGB')
        img.save(''.join(['./renders/',str(i),'.jpg']))
    probs = agent.predict_proba(s.reshape(1, 8))
    a = int(np.random.choice(n_actions, 1, p = probs[0]))
    new_s, reward, done, _ = env.step(a)
    s = new_s
    if done:
        break
env.close()

In [8]:
agent = Sequential()
agent.add(Dense(20, input_shape=n_states, activation='relu'))
agent.add(Dense(20, activation='relu'))
agent.add(Dense(n_actions, activation='softmax'))
agent.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

In [9]:
action_lookup = np.eye(n_actions)

t_max = 10000
def generate_sample():
    s = env.reset()
    batch_s = []
    batch_a = []
    total_reward = 0
    
    for i in range(t_max):
        probs = agent.predict(s.reshape(1, 8))
        a = int(np.random.choice(n_actions, 1, p = probs[0]))
        new_s, r, done, _ = env.step(a)
        batch_s.append(s)
        batch_a.append(action_lookup[a])
        s = new_s
        total_reward = total_reward + r
        if done:
            break
    env.close()
    return batch_s, batch_a, total_reward

In [10]:
iterations = 100
percentile = 70
samples = 250

for i in range(iterations):
    population = [generate_sample() for i in range(samples)]
    batch_states,batch_actions,batch_rewards = map(np.array,zip(*population))
    threshold = np.percentile(batch_rewards, percentile)
    elite_states = batch_states[batch_rewards > threshold]
    elite_actions = batch_actions[batch_rewards > threshold]
    elite_states, elite_actions = map(np.concatenate, [elite_states, elite_actions])
    agent.fit(epochs=1, x=elite_states, y=elite_actions)
    print('Iteration: {0}, Mean Reward: {1:.2f}, Threshold: {2:.2f}'.format(i + 1, np.mean(batch_rewards), threshold))


Epoch 1/1
5798/5798 [==============================] - 0s - loss: 1.3708 - acc: 0.2913      
Iteration: 1, Mean Reward: -246.71, Threshold: -169.54
Epoch 1/1
5812/5812 [==============================] - 1s - loss: 1.3746 - acc: 0.3004     
Iteration: 2, Mean Reward: -207.32, Threshold: -155.99
Epoch 1/1
5976/5976 [==============================] - 0s - loss: 1.3704 - acc: 0.3030     
Iteration: 3, Mean Reward: -192.20, Threshold: -149.63
Epoch 1/1
5968/5968 [==============================] - 0s - loss: 1.3675 - acc: 0.3170     
Iteration: 4, Mean Reward: -183.17, Threshold: -150.88
Epoch 1/1
6770/6770 [==============================] - 0s - loss: 1.3674 - acc: 0.3161     
Iteration: 5, Mean Reward: -167.66, Threshold: -146.26
Epoch 1/1
5600/5600 [==============================] - 0s - loss: 1.3710 - acc: 0.3046     
Iteration: 6, Mean Reward: -157.78, Threshold: -140.23
Epoch 1/1
6254/6254 [==============================] - 0s - loss: 1.3648 - acc: 0.3184     
Iteration: 7, Mean Reward: -157.42, Threshold: -138.23
Epoch 1/1
6117/6117 [==============================] - 0s - loss: 1.3651 - acc: 0.3199     
Iteration: 8, Mean Reward: -152.80, Threshold: -134.03
Epoch 1/1
6199/6199 [==============================] - 0s - loss: 1.3587 - acc: 0.3251     
Iteration: 9, Mean Reward: -147.33, Threshold: -132.69
Epoch 1/1
6494/6494 [==============================] - 0s - loss: 1.3514 - acc: 0.3372     
Iteration: 10, Mean Reward: -140.57, Threshold: -126.92
Epoch 1/1
6132/6132 [==============================] - 0s - loss: 1.3465 - acc: 0.3403     
Iteration: 11, Mean Reward: -133.28, Threshold: -120.11
Epoch 1/1
6597/6597 [==============================] - 0s - loss: 1.3357 - acc: 0.3596     
Iteration: 12, Mean Reward: -126.61, Threshold: -113.13
Epoch 1/1
6998/6998 [==============================] - 0s - loss: 1.3178 - acc: 0.3753     
Iteration: 13, Mean Reward: -120.42, Threshold: -106.94
Epoch 1/1
7233/7233 [==============================] - 0s - loss: 1.3095 - acc: 0.3838     
Iteration: 14, Mean Reward: -115.18, Threshold: -102.58
Epoch 1/1
7970/7970 [==============================] - 0s - loss: 1.3009 - acc: 0.3945     
Iteration: 15, Mean Reward: -106.40, Threshold: -93.97
Epoch 1/1
7450/7450 [==============================] - 0s - loss: 1.2738 - acc: 0.4285     
Iteration: 16, Mean Reward: -97.64, Threshold: -86.01
Epoch 1/1
7738/7738 [==============================] - 0s - loss: 1.2478 - acc: 0.4499     
Iteration: 17, Mean Reward: -88.04, Threshold: -76.59
Epoch 1/1
7865/7865 [==============================] - 0s - loss: 1.2215 - acc: 0.4641     
Iteration: 18, Mean Reward: -81.78, Threshold: -70.13
Epoch 1/1
8025/8025 [==============================] - 0s - loss: 1.2050 - acc: 0.4824     
Iteration: 19, Mean Reward: -74.36, Threshold: -60.83
Epoch 1/1
8131/8131 [==============================] - 0s - loss: 1.1638 - acc: 0.5089     
Iteration: 20, Mean Reward: -64.26, Threshold: -49.42
Epoch 1/1
7981/7981 [==============================] - 0s - loss: 1.1609 - acc: 0.5107     
Iteration: 21, Mean Reward: -66.96, Threshold: -55.82
Epoch 1/1
8767/8767 [==============================] - 1s - loss: 1.1553 - acc: 0.5039     
Iteration: 22, Mean Reward: -65.04, Threshold: -49.57
Epoch 1/1
8747/8747 [==============================] - 0s - loss: 1.1477 - acc: 0.5141     
Iteration: 23, Mean Reward: -55.89, Threshold: -40.91
Epoch 1/1
10661/10661 [==============================] - 1s - loss: 1.1637 - acc: 0.4974     
Iteration: 24, Mean Reward: -51.65, Threshold: -38.73
Epoch 1/1
11102/11102 [==============================] - 1s - loss: 1.1463 - acc: 0.5127     
Iteration: 25, Mean Reward: -43.39, Threshold: -29.46
Epoch 1/1
16863/16863 [==============================] - 2s - loss: 1.1944 - acc: 0.4762     
Iteration: 26, Mean Reward: -37.23, Threshold: -23.08
Epoch 1/1
24933/24933 [==============================] - 3s - loss: 1.2113 - acc: 0.4637     
Iteration: 27, Mean Reward: -33.38, Threshold: -18.70
Epoch 1/1
27749/27749 [==============================] - 4s - loss: 1.2242 - acc: 0.4479     
Iteration: 28, Mean Reward: -35.30, Threshold: -10.81
Epoch 1/1
35526/35526 [==============================] - 5s - loss: 1.2402 - acc: 0.4333     
Iteration: 29, Mean Reward: -30.18, Threshold: -11.61
Epoch 1/1
35035/35035 [==============================] - 4s - loss: 1.2205 - acc: 0.4395     
Iteration: 30, Mean Reward: -36.27, Threshold: -7.00
Epoch 1/1
39270/39270 [==============================] - 5s - loss: 1.2204 - acc: 0.4377     
Iteration: 31, Mean Reward: -37.48, Threshold: -9.01
Epoch 1/1
47809/47809 [==============================] - 6s - loss: 1.2356 - acc: 0.4242     
Iteration: 32, Mean Reward: -26.80, Threshold: 0.23
Epoch 1/1
48907/48907 [==============================] - 5s - loss: 1.2207 - acc: 0.4322     
Iteration: 33, Mean Reward: -47.82, Threshold: 0.67
Epoch 1/1
48623/48623 [==============================] - 5s - loss: 1.2409 - acc: 0.4216      
Iteration: 34, Mean Reward: -21.85, Threshold: -0.13
Epoch 1/1
47020/47020 [==============================] - 5s - loss: 1.2413 - acc: 0.4188     
Iteration: 35, Mean Reward: -21.77, Threshold: 0.26
Epoch 1/1
42239/42239 [==============================] - 5s - loss: 1.2475 - acc: 0.4200     
Iteration: 36, Mean Reward: -15.84, Threshold: 0.72
Epoch 1/1
49557/49557 [==============================] - 5s - loss: 1.2594 - acc: 0.4100     
Iteration: 37, Mean Reward: -19.57, Threshold: 3.72
Epoch 1/1
46020/46020 [==============================] - 4s - loss: 1.2592 - acc: 0.4057     
Iteration: 38, Mean Reward: -15.04, Threshold: 4.79
Epoch 1/1
54108/54108 [==============================] - 5s - loss: 1.2697 - acc: 0.4014      
Iteration: 39, Mean Reward: -12.72, Threshold: 7.08
Epoch 1/1
35476/35476 [==============================] - 3s - loss: 1.2335 - acc: 0.4269     
Iteration: 40, Mean Reward: -23.13, Threshold: -1.08
Epoch 1/1
39006/39006 [==============================] - 4s - loss: 1.2396 - acc: 0.4193     
Iteration: 41, Mean Reward: -18.85, Threshold: 3.34
Epoch 1/1
51349/51349 [==============================] - 5s - loss: 1.2628 - acc: 0.4042     
Iteration: 42, Mean Reward: -10.84, Threshold: 8.46
Epoch 1/1
52131/52131 [==============================] - 5s - loss: 1.2666 - acc: 0.4031     
Iteration: 43, Mean Reward: -4.90, Threshold: 9.57
Epoch 1/1
50383/50383 [==============================] - 5s - loss: 1.2572 - acc: 0.4106     
Iteration: 44, Mean Reward: -13.89, Threshold: 6.27
Epoch 1/1
59239/59239 [==============================] - 6s - loss: 1.2716 - acc: 0.3977      
Iteration: 45, Mean Reward: -7.34, Threshold: 14.10
Epoch 1/1
66333/66333 [==============================] - 7s - loss: 1.2750 - acc: 0.3974      
Iteration: 46, Mean Reward: -2.99, Threshold: 21.84
Epoch 1/1
62733/62733 [==============================] - 7s - loss: 1.2621 - acc: 0.4071      
Iteration: 47, Mean Reward: -3.61, Threshold: 15.95
Epoch 1/1
69785/69785 [==============================] - 8s - loss: 1.2657 - acc: 0.4008      
Iteration: 48, Mean Reward: 2.29, Threshold: 26.73
Epoch 1/1
61035/61035 [==============================] - 6s - loss: 1.2524 - acc: 0.4069      
Iteration: 49, Mean Reward: -5.58, Threshold: 14.47
Epoch 1/1
75000/75000 [==============================] - 9s - loss: 1.2604 - acc: 0.4030      
Iteration: 50, Mean Reward: 2.01, Threshold: 35.44
Epoch 1/1
75000/75000 [==============================] - 8s - loss: 1.2613 - acc: 0.4072      
Iteration: 51, Mean Reward: 3.71, Threshold: 54.39
Epoch 1/1
75000/75000 [==============================] - 8s - loss: 1.2560 - acc: 0.4129      
Iteration: 52, Mean Reward: 2.36, Threshold: 37.29
Epoch 1/1
68905/68905 [==============================] - 7s - loss: 1.2512 - acc: 0.4199      
Iteration: 53, Mean Reward: -6.79, Threshold: 16.69
Epoch 1/1
68820/68820 [==============================] - 8s - loss: 1.2487 - acc: 0.4209      
Iteration: 54, Mean Reward: 7.55, Threshold: 27.06
Epoch 1/1
75000/75000 [==============================] - 10s - loss: 1.2558 - acc: 0.4191     
Iteration: 55, Mean Reward: 13.18, Threshold: 50.57
Epoch 1/1
75000/75000 [==============================] - 9s - loss: 1.2468 - acc: 0.4285      
Iteration: 56, Mean Reward: 7.59, Threshold: 62.85
Epoch 1/1
75000/75000 [==============================] - 10s - loss: 1.2530 - acc: 0.4210     
Iteration: 57, Mean Reward: 16.63, Threshold: 65.11
Epoch 1/1
75000/75000 [==============================] - 8s - loss: 1.2408 - acc: 0.4285      
Iteration: 58, Mean Reward: 15.37, Threshold: 63.02
Epoch 1/1
75000/75000 [==============================] - 8s - loss: 1.2323 - acc: 0.4449      
Iteration: 59, Mean Reward: 12.34, Threshold: 59.09
Epoch 1/1
75000/75000 [==============================] - 9s - loss: 1.2350 - acc: 0.4449      
Iteration: 60, Mean Reward: 16.17, Threshold: 60.25
Epoch 1/1
75000/75000 [==============================] - 9s - loss: 1.2352 - acc: 0.4456      
Iteration: 61, Mean Reward: 23.71, Threshold: 71.98
Epoch 1/1
75000/75000 [==============================] - 8s - loss: 1.2315 - acc: 0.4459      
Iteration: 62, Mean Reward: 32.03, Threshold: 76.89
Epoch 1/1
75000/75000 [==============================] - 11s - loss: 1.2260 - acc: 0.4545     
Iteration: 63, Mean Reward: 33.70, Threshold: 80.04
Epoch 1/1
74375/74375 [==============================] - 9s - loss: 1.2199 - acc: 0.4598      
Iteration: 64, Mean Reward: 34.01, Threshold: 79.55
Epoch 1/1
75000/75000 [==============================] - 8s - loss: 1.2172 - acc: 0.4597      
Iteration: 65, Mean Reward: 26.59, Threshold: 75.58
Epoch 1/1
75000/75000 [==============================] - 10s - loss: 1.2204 - acc: 0.4556     
Iteration: 66, Mean Reward: 27.01, Threshold: 76.43
Epoch 1/1
74981/74981 [==============================] - 9s - loss: 1.2328 - acc: 0.4460      
Iteration: 67, Mean Reward: 32.06, Threshold: 79.41
Epoch 1/1
74850/74850 [==============================] - 11s - loss: 1.2303 - acc: 0.4440     
Iteration: 68, Mean Reward: 32.40, Threshold: 79.01
Epoch 1/1
75000/75000 [==============================] - 10s - loss: 1.2405 - acc: 0.4377     
Iteration: 69, Mean Reward: 39.68, Threshold: 80.17
Epoch 1/1
75000/75000 [==============================] - 8s - loss: 1.2296 - acc: 0.4459      
Iteration: 70, Mean Reward: 31.33, Threshold: 77.70
Epoch 1/1
75000/75000 [==============================] - 12s - loss: 1.2294 - acc: 0.4468      
Iteration: 71, Mean Reward: 12.63, Threshold: 63.89
Epoch 1/1
75000/75000 [==============================] - 8s - loss: 1.2249 - acc: 0.4512      
Iteration: 72, Mean Reward: 43.28, Threshold: 90.36
Epoch 1/1
75000/75000 [==============================] - 8s - loss: 1.2275 - acc: 0.4529      
Iteration: 73, Mean Reward: 29.03, Threshold: 80.73
Epoch 1/1
74160/74160 [==============================] - 9s - loss: 1.2098 - acc: 0.4661      
Iteration: 74, Mean Reward: 48.58, Threshold: 93.25
Epoch 1/1
75000/75000 [==============================] - 8s - loss: 1.2158 - acc: 0.4573      
Iteration: 75, Mean Reward: 49.29, Threshold: 94.42
Epoch 1/1
75000/75000 [==============================] - 9s - loss: 1.2138 - acc: 0.4608      
Iteration: 76, Mean Reward: 45.40, Threshold: 88.01
Epoch 1/1
74484/74484 [==============================] - 8s - loss: 1.1995 - acc: 0.4697      
Iteration: 77, Mean Reward: 60.99, Threshold: 93.92
Epoch 1/1
75000/75000 [==============================] - 9s - loss: 1.1922 - acc: 0.4709      
Iteration: 78, Mean Reward: 65.24, Threshold: 101.97
Epoch 1/1
74908/74908 [==============================] - 9s - loss: 1.1960 - acc: 0.4726      
Iteration: 79, Mean Reward: 66.52, Threshold: 100.93
Epoch 1/1
75000/75000 [==============================] - 8s - loss: 1.2005 - acc: 0.4641      
Iteration: 80, Mean Reward: 56.13, Threshold: 98.51
Epoch 1/1
75000/75000 [==============================] - 8s - loss: 1.1919 - acc: 0.4667      
Iteration: 81, Mean Reward: 49.21, Threshold: 97.87
Epoch 1/1
75000/75000 [==============================] - 8s - loss: 1.2161 - acc: 0.4520      
Iteration: 82, Mean Reward: 40.45, Threshold: 91.59
Epoch 1/1
75000/75000 [==============================] - 8s - loss: 1.2072 - acc: 0.4646      
Iteration: 83, Mean Reward: 43.01, Threshold: 94.91
Epoch 1/1
75000/75000 [==============================] - 10s - loss: 1.2063 - acc: 0.4665     
Iteration: 84, Mean Reward: 68.31, Threshold: 101.13
Epoch 1/1
74577/74577 [==============================] - 10s - loss: 1.2136 - acc: 0.4599     
Iteration: 85, Mean Reward: 65.18, Threshold: 101.01
Epoch 1/1
75000/75000 [==============================] - 9s - loss: 1.2221 - acc: 0.4567      
Iteration: 86, Mean Reward: 49.69, Threshold: 92.36
Epoch 1/1
75000/75000 [==============================] - 11s - loss: 1.2137 - acc: 0.4640     
Iteration: 87, Mean Reward: 47.63, Threshold: 91.59
Epoch 1/1
75000/75000 [==============================] - 9s - loss: 1.2108 - acc: 0.4704      
Iteration: 88, Mean Reward: 51.39, Threshold: 93.24
Epoch 1/1
75000/75000 [==============================] - 9s - loss: 1.2101 - acc: 0.4667      
Iteration: 89, Mean Reward: 50.54, Threshold: 93.95
Epoch 1/1
75000/75000 [==============================] - 9s - loss: 1.2131 - acc: 0.4636      
Iteration: 90, Mean Reward: 47.11, Threshold: 91.77
Epoch 1/1
75000/75000 [==============================] - 8s - loss: 1.2073 - acc: 0.4657      
Iteration: 91, Mean Reward: 52.18, Threshold: 100.29
Epoch 1/1
75000/75000 [==============================] - 11s - loss: 1.2035 - acc: 0.4684     
Iteration: 92, Mean Reward: 59.83, Threshold: 101.71
Epoch 1/1
74688/74688 [==============================] - 11s - loss: 1.2003 - acc: 0.4688     
Iteration: 93, Mean Reward: 73.85, Threshold: 105.06
Epoch 1/1
74699/74699 [==============================] - 8s - loss: 1.1955 - acc: 0.4715      
Iteration: 94, Mean Reward: 69.55, Threshold: 100.98
Epoch 1/1
75000/75000 [==============================] - 9s - loss: 1.1876 - acc: 0.4767      
Iteration: 95, Mean Reward: 68.21, Threshold: 104.20
Epoch 1/1
75000/75000 [==============================] - 8s - loss: 1.1945 - acc: 0.4761      
Iteration: 96, Mean Reward: 66.72, Threshold: 101.45
Epoch 1/1
75000/75000 [==============================] - 8s - loss: 1.1722 - acc: 0.4922      
Iteration: 97, Mean Reward: 74.13, Threshold: 106.18
Epoch 1/1
75000/75000 [==============================] - 8s - loss: 1.1757 - acc: 0.4914      
Iteration: 98, Mean Reward: 80.46, Threshold: 105.36
Epoch 1/1
75000/75000 [==============================] - 8s - loss: 1.1619 - acc: 0.5002      
Iteration: 99, Mean Reward: 78.15, Threshold: 108.92
Epoch 1/1
75000/75000 [==============================] - 8s - loss: 1.1643 - acc: 0.4963      
Iteration: 100, Mean Reward: 64.93, Threshold: 104.45

In [54]:
s = env.reset()
t_max = 1000
for i in range(t_max):
    env.render()
    probs = agent.predict(s.reshape(1, 8))
    a = int(np.random.choice(n_actions, 1, p = probs[0]))
    new_s, r, done, _ = env.step(a)
    s = new_s
    if done:
        break
env.close()