強化学習を「つかう」という視点で考えると,
参考資料: inoory さんの qiita 記事
以下のmethodを自作の環境ではオーバーライドしなければいけない:
また,以下を埋める必要がある
最低限,以下を実装すればokらしい
In [1]:
import copy
import gym
import gym.spaces
import numpy as np
import matplotlib.pyplot as plt
from IPython import display
# seaborn を入れてない人は以下をコメントアウト
import seaborn as sns
sns.set_style('darkgrid')
In [3]:
class SimpleVision(gym.core.Env):
metadata = {'render.modes': ['human']}
def __init__(self):
self._n_fruit = 2
self._pos = np.random.randint(size=2, low=0, high=5)
self._fruit_pos = np.random.randint(size=(self._n_fruit, 2), low=0, high=5)
# 行動は上下左右
self.action_space = gym.spaces.Discrete(4)
# G G G G G
# G R G G G
# G G G R G
# G G B G G
# G G G G G
# R: fruit, G: background, B: agent
self.observation_space = gym.spaces.Box(low=0, high=1, shape=(1, 5*5*3))
def _reset(self):
self._pos = np.random.randint(size=2, low=0, high=5)
self._fruit_pos = np.random.randint(size=(self._n_fruit, 2), low=0, high=5)
return self._get_observation()
def _step(self, action):
# 移動
if action == 0: # UP
self._pos[1] += 1
if self._pos[1] >= 5:
self._pos[1] = 4
elif action == 1: # DOWN
self._pos[1] -= 1
if self._pos[1] < 0:
self._pos[1] = 0
elif action == 2: # LEFT
self._pos[0] += 1
if self._pos[0] >= 5:
self._pos[0] = 4
elif action == 3: # RIGHT
self._pos[0] -= 1
if self._pos[0] < 0:
self._pos[0] = 0
# 報酬設定
reward = -0.1
new_fruit_pos = copy.deepcopy(self._fruit_pos)
for i, f_pos in enumerate(self._fruit_pos):
if np.prod(np.isclose(self._pos, f_pos)) == True:
# フルーツの場所に着いたら,報酬
reward = 1.0
# 新しいフルーツを置く
new_fruit_pos[i] = np.random.randint(size=2, low=0, high=5)
while np.prod(np.isclose(self._pos, new_fruit_pos[i])) == True:
new_fruit_pos[i] = np.random.randint(size=2, low=0, high=5)
self._fruit_pos = copy.deepcopy(new_fruit_pos)
# 完了条件は特にない
done = False
# 追加情報はないため,最後は空dict
return self._get_observation(), reward, done, {}
def _get_observation(self):
observation = np.zeros((5, 5, 3))
observation[:,:,1] = 0.3
# Fruit Position
for f_pos in self._fruit_pos:
observation[f_pos[0], f_pos[1]] = [1, 0, 0]
# Agent Position
observation[self._pos[0], self._pos[1]] = [0, 0, 1]
return observation.reshape(1, 5*5*3)
def _render(self, mode='human', close=False):
plt.figure(1)
plt.clf()
plt.imshow(self._get_observation().reshape(5,5,3), interpolation='nearest')
plt.pause(0.0001)
env = SimpleVision()
obs = env.reset()
plt.imshow(obs)
plt.show()
nb_actions = env.action_space.n
input_shape = (1,) + env.observation_space.shape
#input_shape = env.observation_space.shape
print("# of Actions : {}".format(nb_actions))
print("Shape of Observation : {}".format(env.observation_space.shape))
In [4]:
from keras.models import Sequential
from keras.layers import Dense, Activation, Flatten, Permute
from keras.layers.convolutional import Convolution2D
from keras.optimizers import Adam
from rl.agents.dqn import DQNAgent
from rl.policy import EpsGreedyQPolicy
from rl.memory import SequentialMemory
# DQNのネットワーク定義
model = Sequential()
model.add(Flatten(input_shape=input_shape))
#model.add(Convolution2D(32, 3, 3, input_shape=input_shape[1:]))
model.add(Activation('relu'))
#model.add(Flatten())
model.add(Dense(400))
model.add(Activation('relu'))
model.add(Dense(400))
model.add(Activation('relu'))
model.add(Dense(nb_actions))
model.add(Activation('linear'))
print(model.summary())
# experience replay用のmemory
memory = SequentialMemory(limit=20000, window_length=1)
policy = EpsGreedyQPolicy(eps=0.3)
dqn = DQNAgent(model=model,
gamma = 0.99,
nb_actions=nb_actions,
memory=memory,
nb_steps_warmup=1000,
target_model_update=1e-3,
policy=policy)
# Optimizer を設定
optimizer = Adam(lr=1e-3, epsilon=0.01)
dqn.compile(optimizer=optimizer, metrics=['mae'])
In [5]:
from rl.callbacks import Callback
class PlotReward(Callback):
def on_train_begin(self, episode, logs={}):
self.episode_reward = []
self.fig = plt.figure(0)
def on_episode_end(self, episode, logs={}):
self.episode_reward.append(logs['episode_reward'])
self.show_result()
def show_result(self):
display.clear_output(wait=True)
display.display(plt.gcf())
plt.clf()
plt.plot(self.episode_reward, 'r')
plt.xlabel('Episode')
plt.ylabel('Total Reward')
plt.pause(0.001)
callbacks = [PlotReward()]
In [6]:
dqn.fit(env,
nb_steps=30000,
visualize=False,
callbacks = callbacks,
nb_max_episode_steps=100)
Out[6]:
In [7]:
dqn.test(env, nb_episodes=5, visualize=True, nb_max_episode_steps=100)
Out[7]:
In [ ]: