深層強化学習で画像入力のエージェントを学習させる

  • gym を pip で入れた人は: pip install gym[atari]
  • gym をgithubのソースから入れた人は gym ディレクトリで : pip install -e .[atari]

この他, 環境はpip install gym[ナントカ] で環境を追加できます 環境


In [1]:
# 数値計算に必須のもろもろ
import numpy as np

# 可視化パッケージ 
import matplotlib.pyplot as plt
from IPython import display
%matplotlib inline

# seaborn を入れてない人は以下をコメントアウト
import seaborn as sns
sns.set_style('darkgrid')

In [2]:
# OpenAI Gym
import gym

# 利用できる全ての環境を表示
from gym import envs
print(envs.registry.all())


[EnvSpec(PredictActionsCartpole-v0), EnvSpec(Asteroids-ramDeterministic-v0), EnvSpec(Asteroids-ramDeterministic-v3), EnvSpec(Gopher-ramDeterministic-v3), EnvSpec(Gopher-ramDeterministic-v0), EnvSpec(DoubleDunk-ramDeterministic-v3), EnvSpec(DoubleDunk-ramDeterministic-v0), EnvSpec(Tennis-ramNoFrameskip-v3), EnvSpec(RoadRunner-ramDeterministic-v0), EnvSpec(Robotank-ram-v3), EnvSpec(CartPole-v0), EnvSpec(CartPole-v1), EnvSpec(Gopher-ram-v3), EnvSpec(Gopher-ram-v0), EnvSpec(Pooyan-ram-v0), EnvSpec(Pooyan-ram-v3), EnvSpec(SpaceInvaders-ram-v3), EnvSpec(CarRacing-v0), EnvSpec(SpaceInvaders-ram-v0), EnvSpec(YarsRevenge-ramDeterministic-v0), EnvSpec(SpaceInvadersDeterministic-v0), EnvSpec(DoubleDunk-ram-v3), EnvSpec(DoubleDunk-ram-v0), EnvSpec(SpaceInvadersDeterministic-v3), EnvSpec(Centipede-v3), EnvSpec(Centipede-v0), EnvSpec(Pitfall-ramNoFrameskip-v3), EnvSpec(Pitfall-ramNoFrameskip-v0), EnvSpec(Frostbite-ramNoFrameskip-v0), EnvSpec(Phoenix-ram-v3), EnvSpec(AmidarNoFrameskip-v3), EnvSpec(SkiingNoFrameskip-v0), EnvSpec(SkiingNoFrameskip-v3), EnvSpec(HotterColder-v0), EnvSpec(RoadRunner-ramDeterministic-v3), EnvSpec(Phoenix-ram-v0), EnvSpec(Tennis-ramNoFrameskip-v0), EnvSpec(Berzerk-ramNoFrameskip-v3), EnvSpec(Berzerk-ramNoFrameskip-v0), EnvSpec(AirRaidDeterministic-v3), EnvSpec(AirRaidDeterministic-v0), EnvSpec(ChopperCommandDeterministic-v3), EnvSpec(AirRaidNoFrameskip-v0), EnvSpec(AirRaidNoFrameskip-v3), EnvSpec(ChopperCommandDeterministic-v0), EnvSpec(Asteroids-ram-v0), EnvSpec(Asteroids-ram-v3), EnvSpec(KrullDeterministic-v0), EnvSpec(Atlantis-ramDeterministic-v3), EnvSpec(Atlantis-ramDeterministic-v0), EnvSpec(KrullDeterministic-v3), EnvSpec(OffSwitchCartpoleProb-v0), EnvSpec(TimePilot-v3), EnvSpec(Go19x19-v0), EnvSpec(TimePilot-v0), EnvSpec(Solaris-ram-v0), EnvSpec(Solaris-ram-v3), EnvSpec(VentureDeterministic-v3), EnvSpec(FishingDerbyNoFrameskip-v3), EnvSpec(FishingDerbyNoFrameskip-v0), EnvSpec(Robotank-ram-v0), EnvSpec(Qbert-v3), EnvSpec(ReversedAddition-v0), EnvSpec(Qbert-v0), EnvSpec(Pitfall-v0), EnvSpec(Pitfall-v3), EnvSpec(RiverraidNoFrameskip-v0), EnvSpec(RiverraidNoFrameskip-v3), EnvSpec(BipedalWalkerHardcore-v2), EnvSpec(Venture-ram-v3), EnvSpec(Venture-ram-v0), EnvSpec(Tennis-v0), EnvSpec(Tennis-v3), EnvSpec(MontezumaRevenge-ramNoFrameskip-v0), EnvSpec(MontezumaRevenge-ramNoFrameskip-v3), EnvSpec(Go9x9-v0), EnvSpec(MountainCarContinuous-v0), EnvSpec(SemisuperPendulumNoise-v0), EnvSpec(Reacher-v1), EnvSpec(ChopperCommand-ramNoFrameskip-v0), EnvSpec(Taxi-v2), EnvSpec(Pong-v3), EnvSpec(Pong-v0), EnvSpec(UpNDownDeterministic-v0), EnvSpec(UpNDownDeterministic-v3), EnvSpec(Enduro-v0), EnvSpec(Enduro-v3), EnvSpec(Zaxxon-ramDeterministic-v3), EnvSpec(Krull-ramNoFrameskip-v0), EnvSpec(Krull-ramNoFrameskip-v3), EnvSpec(ElevatorAction-ramNoFrameskip-v3), EnvSpec(ElevatorAction-ramNoFrameskip-v0), EnvSpec(Venture-ramNoFrameskip-v3), EnvSpec(QbertNoFrameskip-v3), EnvSpec(Venture-ramNoFrameskip-v0), EnvSpec(StarGunner-ramNoFrameskip-v3), EnvSpec(StarGunner-ramNoFrameskip-v0), EnvSpec(NameThisGame-ram-v3), EnvSpec(YarsRevenge-ramDeterministic-v3), EnvSpec(Breakout-ram-v0), EnvSpec(Breakout-ram-v3), EnvSpec(PrivateEye-ramNoFrameskip-v3), EnvSpec(Bowling-v0), EnvSpec(Bowling-v3), EnvSpec(PrivateEye-ramNoFrameskip-v0), EnvSpec(BattleZoneDeterministic-v3), EnvSpec(BattleZoneDeterministic-v0), EnvSpec(PitfallNoFrameskip-v0), EnvSpec(PitfallNoFrameskip-v3), EnvSpec(AirRaid-ramDeterministic-v0), EnvSpec(AirRaid-ramDeterministic-v3), EnvSpec(CentipedeNoFrameskip-v0), EnvSpec(Skiing-ram-v0), EnvSpec(CentipedeNoFrameskip-v3), EnvSpec(EnduroDeterministic-v3), EnvSpec(VentureNoFrameskip-v0), EnvSpec(SpaceInvaders-ramNoFrameskip-v3), EnvSpec(Freeway-ram-v3), EnvSpec(Skiing-ram-v3), EnvSpec(ConvergenceControl-v0), EnvSpec(Riverraid-ramNoFrameskip-v3), EnvSpec(Riverraid-ramNoFrameskip-v0), EnvSpec(ChopperCommand-v3), EnvSpec(ChopperCommand-v0), EnvSpec(Pooyan-v0), EnvSpec(Pooyan-v3), EnvSpec(BattleZoneNoFrameskip-v0), EnvSpec(PrivateEye-v3), EnvSpec(PrivateEye-v0), EnvSpec(BattleZoneNoFrameskip-v3), EnvSpec(FrozenLake8x8-v0), EnvSpec(Alien-ramNoFrameskip-v0), EnvSpec(Alien-ramNoFrameskip-v3), EnvSpec(WizardOfWor-ramDeterministic-v0), EnvSpec(TutankhamDeterministic-v0), EnvSpec(TutankhamDeterministic-v3), EnvSpec(LunarLanderContinuous-v2), EnvSpec(UpNDown-ramDeterministic-v3), EnvSpec(UpNDown-ramDeterministic-v0), EnvSpec(Phoenix-ramNoFrameskip-v3), EnvSpec(Phoenix-ramNoFrameskip-v0), EnvSpec(Asterix-ram-v3), EnvSpec(Asterix-ram-v0), EnvSpec(Jamesbond-ramNoFrameskip-v3), EnvSpec(Jamesbond-ramNoFrameskip-v0), EnvSpec(JourneyEscape-v0), EnvSpec(JourneyEscape-v3), EnvSpec(BipedalWalker-v2), EnvSpec(CrazyClimberDeterministic-v3), EnvSpec(CrazyClimberDeterministic-v0), EnvSpec(FishingDerby-ramDeterministic-v3), EnvSpec(QbertDeterministic-v0), EnvSpec(SpaceInvaders-ramDeterministic-v3), EnvSpec(QbertDeterministic-v3), EnvSpec(SolarisDeterministic-v0), EnvSpec(SolarisDeterministic-v3), EnvSpec(YarsRevengeDeterministic-v0), EnvSpec(YarsRevengeDeterministic-v3), EnvSpec(SpaceInvaders-ramDeterministic-v0), EnvSpec(TwoRoundDeterministicReward-v0), EnvSpec(Bowling-ramNoFrameskip-v3), EnvSpec(Bowling-ramNoFrameskip-v0), EnvSpec(JourneyEscapeDeterministic-v3), EnvSpec(NameThisGame-ramNoFrameskip-v0), EnvSpec(TwoRoundNondeterministicReward-v0), EnvSpec(AmidarNoFrameskip-v0), EnvSpec(TimePilotNoFrameskip-v0), EnvSpec(MsPacmanDeterministic-v0), EnvSpec(MsPacmanDeterministic-v3), EnvSpec(Pooyan-ramDeterministic-v0), EnvSpec(Frostbite-ramNoFrameskip-v3), EnvSpec(PhoenixDeterministic-v3), EnvSpec(PhoenixDeterministic-v0), EnvSpec(CrazyClimber-ram-v0), EnvSpec(MontezumaRevenge-ram-v0), EnvSpec(MontezumaRevenge-ram-v3), EnvSpec(CrazyClimber-ram-v3), EnvSpec(StarGunner-ramDeterministic-v0), EnvSpec(StarGunner-ramDeterministic-v3), EnvSpec(Centipede-ramNoFrameskip-v3), EnvSpec(Centipede-ramNoFrameskip-v0), EnvSpec(BeamRider-ramDeterministic-v0), EnvSpec(BeamRider-ramDeterministic-v3), EnvSpec(KungFuMaster-v0), EnvSpec(KungFuMaster-v3), EnvSpec(Jamesbond-ramDeterministic-v0), EnvSpec(BreakoutDeterministic-v3), EnvSpec(BreakoutDeterministic-v0), EnvSpec(Jamesbond-ramDeterministic-v3), EnvSpec(IceHockey-v0), EnvSpec(IceHockey-v3), EnvSpec(Venture-ramDeterministic-v0), EnvSpec(Carnival-ram-v0), EnvSpec(Venture-ramDeterministic-v3), EnvSpec(PongDeterministic-v3), EnvSpec(RobotankDeterministic-v0), EnvSpec(RobotankDeterministic-v3), EnvSpec(PongDeterministic-v0), EnvSpec(Pong-ramDeterministic-v0), EnvSpec(Pong-ramDeterministic-v3), EnvSpec(NameThisGame-ramNoFrameskip-v3), EnvSpec(Berzerk-v3), EnvSpec(Berzerk-v0), EnvSpec(SemisuperPendulumDecay-v0), EnvSpec(MontezumaRevenge-v0), EnvSpec(JourneyEscape-ramDeterministic-v3), EnvSpec(JourneyEscape-ramDeterministic-v0), EnvSpec(MontezumaRevenge-v3), EnvSpec(AirRaid-ram-v0), EnvSpec(AirRaid-ram-v3), EnvSpec(Zaxxon-v0), EnvSpec(BeamRiderDeterministic-v3), EnvSpec(YarsRevenge-ramNoFrameskip-v0), EnvSpec(YarsRevenge-ramNoFrameskip-v3), EnvSpec(BeamRiderDeterministic-v0), EnvSpec(RoadRunner-ramNoFrameskip-v0), EnvSpec(TimePilotNoFrameskip-v3), EnvSpec(Phoenix-ramDeterministic-v0), EnvSpec(StarGunner-ram-v0), EnvSpec(Phoenix-ramDeterministic-v3), EnvSpec(Hex9x9-v0), EnvSpec(Skiing-v0), EnvSpec(Skiing-v3), EnvSpec(StarGunner-ram-v3), EnvSpec(Boxing-ramDeterministic-v0), EnvSpec(Boxing-ramDeterministic-v3), EnvSpec(AsteroidsDeterministic-v3), EnvSpec(PrivateEye-ram-v3), EnvSpec(Pooyan-ramDeterministic-v3), EnvSpec(Centipede-ramDeterministic-v0), EnvSpec(Centipede-ramDeterministic-v3), EnvSpec(JourneyEscapeNoFrameskip-v3), EnvSpec(JourneyEscapeNoFrameskip-v0), EnvSpec(BattleZone-ramDeterministic-v0), EnvSpec(BattleZone-ramDeterministic-v3), EnvSpec(NameThisGameNoFrameskip-v0), EnvSpec(NameThisGameNoFrameskip-v3), EnvSpec(Seaquest-ram-v3), EnvSpec(Seaquest-ram-v0), EnvSpec(AsteroidsDeterministic-v0), EnvSpec(ElevatorAction-ram-v3), EnvSpec(ElevatorAction-ram-v0), EnvSpec(ChopperCommand-ramDeterministic-v0), EnvSpec(Zaxxon-ramNoFrameskip-v3), EnvSpec(ChopperCommand-ramDeterministic-v3), EnvSpec(Krull-ramDeterministic-v3), EnvSpec(Krull-ramDeterministic-v0), EnvSpec(BankHeistDeterministic-v0), EnvSpec(BankHeistDeterministic-v3), EnvSpec(VideoPinballDeterministic-v3), EnvSpec(Reverse-v0), EnvSpec(Zaxxon-ramNoFrameskip-v0), EnvSpec(SeaquestDeterministic-v0), EnvSpec(SeaquestDeterministic-v3), EnvSpec(JourneyEscape-ram-v0), EnvSpec(JourneyEscape-ram-v3), EnvSpec(BerzerkDeterministic-v3), EnvSpec(BerzerkDeterministic-v0), EnvSpec(AssaultNoFrameskip-v0), EnvSpec(Enduro-ramDeterministic-v0), EnvSpec(Enduro-ramDeterministic-v3), EnvSpec(AssaultNoFrameskip-v3), EnvSpec(QbertNoFrameskip-v0), EnvSpec(Gopher-ramNoFrameskip-v0), EnvSpec(IceHockey-ramDeterministic-v0), EnvSpec(IceHockey-ramDeterministic-v3), EnvSpec(Gopher-ramNoFrameskip-v3), EnvSpec(PhoenixNoFrameskip-v0), EnvSpec(PhoenixNoFrameskip-v3), EnvSpec(Humanoid-v1), EnvSpec(NameThisGame-ram-v0), EnvSpec(Tutankham-ramNoFrameskip-v0), EnvSpec(Tutankham-ramNoFrameskip-v3), EnvSpec(MsPacman-ramNoFrameskip-v0), EnvSpec(ReversedAddition3-v0), EnvSpec(Assault-ramDeterministic-v0), EnvSpec(Atlantis-ramNoFrameskip-v0), EnvSpec(Atlantis-ramNoFrameskip-v3), EnvSpec(Assault-ramDeterministic-v3), EnvSpec(Skiing-ramNoFrameskip-v0), EnvSpec(BreakoutNoFrameskip-v0), EnvSpec(BreakoutNoFrameskip-v3), EnvSpec(KungFuMaster-ram-v0), EnvSpec(KungFuMaster-ram-v3), EnvSpec(OneRoundNondeterministicReward-v0), EnvSpec(NameThisGame-ramDeterministic-v0), EnvSpec(NameThisGame-ramDeterministic-v3), EnvSpec(RoadRunner-ramNoFrameskip-v3), EnvSpec(Frostbite-ramDeterministic-v3), EnvSpec(Frostbite-ramDeterministic-v0), EnvSpec(BankHeist-ramNoFrameskip-v0), EnvSpec(BankHeist-ramNoFrameskip-v3), EnvSpec(Qbert-ramNoFrameskip-v0), EnvSpec(Ant-v1), EnvSpec(Qbert-ramNoFrameskip-v3), EnvSpec(Skiing-ramNoFrameskip-v3), EnvSpec(YarsRevenge-ram-v0), EnvSpec(YarsRevenge-ram-v3), EnvSpec(FrostbiteNoFrameskip-v3), EnvSpec(FishingDerby-ram-v0), EnvSpec(FishingDerby-ram-v3), EnvSpec(FrostbiteNoFrameskip-v0), EnvSpec(BeamRiderNoFrameskip-v0), EnvSpec(Enduro-ramNoFrameskip-v3), EnvSpec(Enduro-ramNoFrameskip-v0), EnvSpec(BeamRiderNoFrameskip-v3), EnvSpec(CentipedeDeterministic-v3), EnvSpec(Gravitar-ramNoFrameskip-v0), EnvSpec(Gravitar-ramNoFrameskip-v3), EnvSpec(CentipedeDeterministic-v0), EnvSpec(Kangaroo-ram-v3), EnvSpec(Alien-ram-v3), EnvSpec(Kangaroo-ram-v0), EnvSpec(VideoPinball-ramNoFrameskip-v0), EnvSpec(VideoPinball-ramNoFrameskip-v3), EnvSpec(StarGunnerDeterministic-v3), EnvSpec(StarGunnerDeterministic-v0), EnvSpec(PongNoFrameskip-v0), EnvSpec(PongNoFrameskip-v3), EnvSpec(TimePilotDeterministic-v3), EnvSpec(TimePilotDeterministic-v0), EnvSpec(CNNClassifierTraining-v0), EnvSpec(Boxing-ram-v0), EnvSpec(Boxing-ram-v3), EnvSpec(Tennis-ramDeterministic-v0), EnvSpec(StarGunner-v0), EnvSpec(StarGunner-v3), EnvSpec(Tennis-ramDeterministic-v3), EnvSpec(DemonAttackNoFrameskip-v0), EnvSpec(DemonAttackNoFrameskip-v3), EnvSpec(PitfallDeterministic-v3), EnvSpec(Assault-ram-v3), EnvSpec(PooyanDeterministic-v0), EnvSpec(PooyanDeterministic-v3), EnvSpec(Assault-ram-v0), EnvSpec(Amidar-ram-v3), EnvSpec(PitfallDeterministic-v0), EnvSpec(Amidar-ram-v0), EnvSpec(ChopperCommandNoFrameskip-v0), EnvSpec(ChopperCommandNoFrameskip-v3), EnvSpec(Tutankham-ramDeterministic-v0), EnvSpec(VentureDeterministic-v0), EnvSpec(ElevatorActionDeterministic-v3), EnvSpec(Solaris-ramDeterministic-v3), EnvSpec(Solaris-ramDeterministic-v0), EnvSpec(ElevatorActionDeterministic-v0), EnvSpec(Riverraid-ram-v0), EnvSpec(Riverraid-ram-v3), EnvSpec(Solaris-v0), EnvSpec(KungFuMasterNoFrameskip-v3), EnvSpec(BattleZone-v3), EnvSpec(BattleZone-v0), EnvSpec(KungFuMasterNoFrameskip-v0), EnvSpec(MsPacmanNoFrameskip-v3), EnvSpec(MsPacmanNoFrameskip-v0), EnvSpec(VideoPinballNoFrameskip-v3), EnvSpec(Breakout-ramDeterministic-v0), EnvSpec(Breakout-ramDeterministic-v3), EnvSpec(VideoPinballNoFrameskip-v0), EnvSpec(PrivateEye-ramDeterministic-v0), EnvSpec(WizardOfWor-ram-v3), EnvSpec(WizardOfWor-ram-v0), EnvSpec(PrivateEye-ramDeterministic-v3), EnvSpec(Gravitar-v0), EnvSpec(RoadRunner-v3), EnvSpec(RoadRunner-v0), EnvSpec(Gravitar-v3), EnvSpec(RoadRunner-ram-v3), EnvSpec(Jamesbond-ram-v3), EnvSpec(RoadRunner-ram-v0), EnvSpec(MsPacman-ram-v0), EnvSpec(MsPacman-ram-v3), EnvSpec(Riverraid-ramDeterministic-v0), EnvSpec(Riverraid-ramDeterministic-v3), EnvSpec(Jamesbond-ram-v0), EnvSpec(UpNDownNoFrameskip-v3), EnvSpec(VideoPinball-ram-v3), EnvSpec(VideoPinball-ram-v0), EnvSpec(UpNDownNoFrameskip-v0), EnvSpec(OffSwitchCartpole-v0), EnvSpec(WizardOfWorNoFrameskip-v3), EnvSpec(WizardOfWorNoFrameskip-v0), EnvSpec(FreewayNoFrameskip-v3), EnvSpec(FreewayNoFrameskip-v0), EnvSpec(WizardOfWor-ramDeterministic-v3), EnvSpec(Asterix-ramNoFrameskip-v0), EnvSpec(Asterix-ramNoFrameskip-v3), EnvSpec(AlienNoFrameskip-v3), EnvSpec(AlienNoFrameskip-v0), EnvSpec(BankHeist-ramDeterministic-v3), EnvSpec(BankHeist-ramDeterministic-v0), EnvSpec(InvertedDoublePendulum-v1), EnvSpec(Asterix-v3), EnvSpec(WizardOfWor-ramNoFrameskip-v0), EnvSpec(Asterix-v0), EnvSpec(AsteroidsNoFrameskip-v0), EnvSpec(AsteroidsNoFrameskip-v3), EnvSpec(Pong-ramNoFrameskip-v3), EnvSpec(JamesbondDeterministic-v3), EnvSpec(WizardOfWor-ramNoFrameskip-v3), EnvSpec(ZaxxonDeterministic-v3), EnvSpec(ZaxxonDeterministic-v0), EnvSpec(Pong-ramNoFrameskip-v0), EnvSpec(ChopperCommand-ram-v3), EnvSpec(ChopperCommand-ram-v0), EnvSpec(SpaceInvaders-ramNoFrameskip-v0), EnvSpec(SeaquestNoFrameskip-v3), EnvSpec(JamesbondDeterministic-v0), EnvSpec(BowlingDeterministic-v3), EnvSpec(BowlingDeterministic-v0), EnvSpec(SemisuperPendulumRandom-v0), EnvSpec(BankHeist-v0), EnvSpec(BankHeist-v3), EnvSpec(TimePilot-ramDeterministic-v0), EnvSpec(TimePilot-ramDeterministic-v3), EnvSpec(NChain-v0), EnvSpec(FishingDerby-ramDeterministic-v0), EnvSpec(SeaquestNoFrameskip-v0), EnvSpec(StarGunnerNoFrameskip-v0), EnvSpec(Seaquest-v3), EnvSpec(CrazyClimber-ramNoFrameskip-v3), EnvSpec(CrazyClimber-ramNoFrameskip-v0), EnvSpec(Seaquest-v0), EnvSpec(CrazyClimber-v0), EnvSpec(CrazyClimber-v3), EnvSpec(MsPacman-ramDeterministic-v3), EnvSpec(Pitfall-ramDeterministic-v0), EnvSpec(Pitfall-ramDeterministic-v3), EnvSpec(Enduro-ram-v0), EnvSpec(MsPacman-ramDeterministic-v0), EnvSpec(Enduro-ram-v3), EnvSpec(GravitarDeterministic-v0), EnvSpec(GravitarDeterministic-v3), EnvSpec(Breakout-ramNoFrameskip-v3), EnvSpec(Swimmer-v1), EnvSpec(Alien-ram-v0), EnvSpec(Breakout-ramNoFrameskip-v0), EnvSpec(GravitarNoFrameskip-v0), EnvSpec(VideoPinballDeterministic-v0), EnvSpec(AsterixDeterministic-v0), EnvSpec(AsterixDeterministic-v3), EnvSpec(AlienDeterministic-v0), EnvSpec(AlienDeterministic-v3), EnvSpec(RoadRunnerDeterministic-v3), EnvSpec(RoadRunnerDeterministic-v0), EnvSpec(RepeatCopy-v0), EnvSpec(FrostbiteDeterministic-v0), EnvSpec(Bowling-ramDeterministic-v0), EnvSpec(Bowling-ramDeterministic-v3), EnvSpec(Carnival-ramDeterministic-v0), EnvSpec(EnduroNoFrameskip-v0), EnvSpec(EnduroNoFrameskip-v3), EnvSpec(Carnival-ramDeterministic-v3), EnvSpec(FrostbiteDeterministic-v3), EnvSpec(Asteroids-ramNoFrameskip-v3), EnvSpec(Asteroids-ramNoFrameskip-v0), EnvSpec(TennisNoFrameskip-v0), EnvSpec(DemonAttackDeterministic-v3), EnvSpec(Pitfall-ram-v0), EnvSpec(DemonAttackDeterministic-v0), EnvSpec(TennisNoFrameskip-v3), EnvSpec(DemonAttack-ram-v3), EnvSpec(DemonAttack-ram-v0), EnvSpec(UpNDown-v0), EnvSpec(BankHeistNoFrameskip-v3), EnvSpec(BankHeistNoFrameskip-v0), EnvSpec(UpNDown-v3), EnvSpec(Pitfall-ram-v3), EnvSpec(Kangaroo-ramDeterministic-v0), EnvSpec(Kangaroo-ramDeterministic-v3), EnvSpec(RobotankNoFrameskip-v3), EnvSpec(RobotankNoFrameskip-v0), EnvSpec(WizardOfWor-v3), EnvSpec(WizardOfWor-v0), EnvSpec(Hopper-v1), EnvSpec(Asterix-ramDeterministic-v3), EnvSpec(Robotank-v0), EnvSpec(BattleZone-ramNoFrameskip-v0), EnvSpec(PrivateEyeDeterministic-v3), EnvSpec(Pooyan-ramNoFrameskip-v0), EnvSpec(Pooyan-ramNoFrameskip-v3), EnvSpec(PrivateEyeDeterministic-v0), EnvSpec(ElevatorActionNoFrameskip-v0), EnvSpec(ElevatorActionNoFrameskip-v3), EnvSpec(TutankhamNoFrameskip-v0), EnvSpec(Zaxxon-ramDeterministic-v0), EnvSpec(Robotank-v3), EnvSpec(JamesbondNoFrameskip-v0), EnvSpec(JamesbondNoFrameskip-v3), EnvSpec(HumanoidStandup-v1), EnvSpec(KungFuMaster-ramDeterministic-v3), EnvSpec(KungFuMaster-ramDeterministic-v0), EnvSpec(Amidar-v3), EnvSpec(Amidar-v0), EnvSpec(BattleZone-ramNoFrameskip-v3), EnvSpec(BerzerkNoFrameskip-v0), EnvSpec(BerzerkNoFrameskip-v3), EnvSpec(Amidar-ramNoFrameskip-v3), EnvSpec(Amidar-ramNoFrameskip-v0), EnvSpec(Gravitar-ramDeterministic-v3), EnvSpec(Gravitar-ramDeterministic-v0), EnvSpec(Asterix-ramDeterministic-v0), EnvSpec(BattleZone-ram-v3), EnvSpec(BattleZone-ram-v0), EnvSpec(IceHockey-ram-v0), EnvSpec(IceHockey-ram-v3), EnvSpec(ChopperCommand-ramNoFrameskip-v3), EnvSpec(MountainCar-v0), EnvSpec(Qbert-ramDeterministic-v3), EnvSpec(Qbert-ramDeterministic-v0), EnvSpec(BeamRider-ramNoFrameskip-v3), EnvSpec(Carnival-ram-v3), EnvSpec(Carnival-v0), EnvSpec(FrozenLake-v0), EnvSpec(IceHockeyNoFrameskip-v0), EnvSpec(IceHockeyNoFrameskip-v3), EnvSpec(NameThisGameDeterministic-v3), EnvSpec(NameThisGameDeterministic-v0), EnvSpec(BeamRider-ramNoFrameskip-v0), EnvSpec(DoubleDunk-ramNoFrameskip-v0), EnvSpec(Tutankham-ram-v3), EnvSpec(DoubleDunk-ramNoFrameskip-v3), EnvSpec(YarsRevenge-v0), EnvSpec(IceHockey-ramNoFrameskip-v3), EnvSpec(IceHockey-ramNoFrameskip-v0), EnvSpec(YarsRevenge-v3), EnvSpec(MsPacman-v0), EnvSpec(Solaris-ramNoFrameskip-v0), EnvSpec(Solaris-ramNoFrameskip-v3), EnvSpec(MsPacman-v3), EnvSpec(Gopher-v3), EnvSpec(Walker2d-v1), EnvSpec(Gopher-v0), EnvSpec(Zaxxon-ram-v3), EnvSpec(Zaxxon-ram-v0), EnvSpec(DoubleDunkDeterministic-v0), EnvSpec(DoubleDunkDeterministic-v3), EnvSpec(PooyanNoFrameskip-v3), EnvSpec(PooyanNoFrameskip-v0), EnvSpec(Seaquest-ramNoFrameskip-v0), EnvSpec(Seaquest-ramNoFrameskip-v3), EnvSpec(FreewayDeterministic-v0), EnvSpec(FreewayDeterministic-v3), EnvSpec(Blackjack-v0), EnvSpec(TennisDeterministic-v3), EnvSpec(TennisDeterministic-v0), EnvSpec(Atlantis-v0), EnvSpec(Atlantis-v3), EnvSpec(EnduroDeterministic-v0), EnvSpec(GuessingGame-v0), EnvSpec(Copy-v0), EnvSpec(CrazyClimber-ramDeterministic-v0), EnvSpec(CrazyClimber-ramDeterministic-v3), EnvSpec(Phoenix-v3), EnvSpec(Phoenix-v0), EnvSpec(Alien-ramDeterministic-v3), EnvSpec(FishingDerbyDeterministic-v3), EnvSpec(CarnivalDeterministic-v3), EnvSpec(Asteroids-v0), EnvSpec(Asteroids-v3), EnvSpec(CarnivalDeterministic-v0), EnvSpec(Tutankham-ramDeterministic-v3), EnvSpec(Robotank-ramDeterministic-v3), EnvSpec(Robotank-ramDeterministic-v0), EnvSpec(IceHockeyDeterministic-v3), EnvSpec(IceHockeyDeterministic-v0), EnvSpec(Centipede-ram-v3), EnvSpec(FishingDerby-ramNoFrameskip-v0), EnvSpec(FishingDerby-ramNoFrameskip-v3), EnvSpec(Centipede-ram-v0), EnvSpec(Solaris-v3), EnvSpec(Tennis-ram-v0), EnvSpec(Assault-v3), EnvSpec(Assault-v0), EnvSpec(Tennis-ram-v3), EnvSpec(HalfCheetah-v1), EnvSpec(GopherNoFrameskip-v3), EnvSpec(GopherNoFrameskip-v0), EnvSpec(WizardOfWorDeterministic-v0), EnvSpec(WizardOfWorDeterministic-v3), EnvSpec(TimePilot-ram-v3), EnvSpec(DoubleDunk-v3), EnvSpec(DoubleDunk-v0), EnvSpec(Tutankham-v0), EnvSpec(LunarLander-v2), EnvSpec(Tutankham-v3), EnvSpec(BeamRider-v3), EnvSpec(BeamRider-v0), EnvSpec(CarnivalNoFrameskip-v0), EnvSpec(BoxingDeterministic-v3), EnvSpec(BoxingDeterministic-v0), EnvSpec(CarnivalNoFrameskip-v3), EnvSpec(Alien-v0), EnvSpec(Alien-v3), EnvSpec(Berzerk-ram-v3), EnvSpec(Berzerk-ram-v0), EnvSpec(PredictObsCartpole-v0), EnvSpec(AmidarDeterministic-v3), EnvSpec(AmidarDeterministic-v0), EnvSpec(SolarisNoFrameskip-v0), EnvSpec(GravitarNoFrameskip-v3), EnvSpec(AssaultDeterministic-v3), EnvSpec(Gravitar-ram-v0), EnvSpec(Gravitar-ram-v3), EnvSpec(AssaultDeterministic-v0), EnvSpec(Frostbite-v3), EnvSpec(Venture-v3), EnvSpec(Venture-v0), EnvSpec(Frostbite-v0), EnvSpec(Acrobot-v1), EnvSpec(Boxing-v0), EnvSpec(Boxing-v3), EnvSpec(ZaxxonNoFrameskip-v0), EnvSpec(DemonAttack-v3), EnvSpec(DemonAttack-v0), EnvSpec(ZaxxonNoFrameskip-v3), EnvSpec(Freeway-v0), EnvSpec(NameThisGame-v0), EnvSpec(NameThisGame-v3), EnvSpec(Freeway-v3), EnvSpec(KungFuMasterDeterministic-v0), EnvSpec(KungFuMasterDeterministic-v3), EnvSpec(RoadRunnerNoFrameskip-v0), EnvSpec(RoadRunnerNoFrameskip-v3), EnvSpec(Bowling-ram-v0), EnvSpec(Bowling-ram-v3), EnvSpec(Seaquest-ramDeterministic-v3), EnvSpec(Krull-ram-v3), EnvSpec(Krull-ram-v0), EnvSpec(Seaquest-ramDeterministic-v0), EnvSpec(TimePilot-ramNoFrameskip-v3), EnvSpec(TimePilot-ramNoFrameskip-v0), EnvSpec(BowlingNoFrameskip-v0), EnvSpec(BowlingNoFrameskip-v3), EnvSpec(UpNDown-ramNoFrameskip-v0), EnvSpec(UpNDown-ramNoFrameskip-v3), EnvSpec(Assault-ramNoFrameskip-v3), EnvSpec(Assault-ramNoFrameskip-v0), EnvSpec(KungFuMaster-ramNoFrameskip-v0), EnvSpec(KungFuMaster-ramNoFrameskip-v3), EnvSpec(PrivateEye-ram-v0), EnvSpec(BankHeist-ram-v0), EnvSpec(BankHeist-ram-v3), EnvSpec(YarsRevengeNoFrameskip-v3), EnvSpec(YarsRevengeNoFrameskip-v0), EnvSpec(ElevatorAction-v3), EnvSpec(ElevatorAction-v0), EnvSpec(DemonAttack-ramDeterministic-v0), EnvSpec(DemonAttack-ramDeterministic-v3), EnvSpec(Carnival-ramNoFrameskip-v3), EnvSpec(Carnival-ramNoFrameskip-v0), EnvSpec(FishingDerby-v0), EnvSpec(FishingDerby-v3), EnvSpec(MontezumaRevengeDeterministic-v0), EnvSpec(MontezumaRevengeDeterministic-v3), EnvSpec(UpNDown-ram-v3), EnvSpec(DoubleDunkNoFrameskip-v3), EnvSpec(DoubleDunkNoFrameskip-v0), EnvSpec(AsterixNoFrameskip-v3), EnvSpec(AsterixNoFrameskip-v0), EnvSpec(SolarisNoFrameskip-v3), EnvSpec(Amidar-ramDeterministic-v0), EnvSpec(Amidar-ramDeterministic-v3), EnvSpec(Pong-ram-v3), EnvSpec(Pong-ram-v0), EnvSpec(ElevatorAction-ramDeterministic-v0), EnvSpec(JourneyEscapeDeterministic-v0), EnvSpec(Pendulum-v0), EnvSpec(ElevatorAction-ramDeterministic-v3), EnvSpec(Freeway-ramDeterministic-v3), EnvSpec(Freeway-ramDeterministic-v0), EnvSpec(Breakout-v0), EnvSpec(Breakout-v3), EnvSpec(BeamRider-ram-v3), EnvSpec(Zaxxon-v3), EnvSpec(InvertedPendulum-v1), EnvSpec(BeamRider-ram-v0), EnvSpec(VideoPinball-v3), EnvSpec(VideoPinball-v0), EnvSpec(MsPacman-ramNoFrameskip-v3), EnvSpec(PrivateEyeNoFrameskip-v0), EnvSpec(PrivateEyeNoFrameskip-v3), EnvSpec(AtlantisDeterministic-v0), EnvSpec(AtlantisDeterministic-v3), EnvSpec(Berzerk-ramDeterministic-v0), EnvSpec(Berzerk-ramDeterministic-v3), EnvSpec(AirRaid-ramNoFrameskip-v3), EnvSpec(AirRaid-ramNoFrameskip-v0), EnvSpec(Roulette-v0), EnvSpec(Atlantis-ram-v0), EnvSpec(Atlantis-ram-v3), EnvSpec(Freeway-ramNoFrameskip-v0), EnvSpec(Freeway-ramNoFrameskip-v3), EnvSpec(Boxing-ramNoFrameskip-v3), EnvSpec(Boxing-ramNoFrameskip-v0), EnvSpec(Jamesbond-v3), EnvSpec(Jamesbond-v0), EnvSpec(Skiing-ramDeterministic-v0), EnvSpec(SpaceInvaders-v3), EnvSpec(Skiing-ramDeterministic-v3), EnvSpec(CrazyClimberNoFrameskip-v0), EnvSpec(CrazyClimberNoFrameskip-v3), EnvSpec(KangarooNoFrameskip-v0), EnvSpec(AtlantisNoFrameskip-v3), EnvSpec(AtlantisNoFrameskip-v0), EnvSpec(KangarooNoFrameskip-v3), EnvSpec(SpaceInvaders-v0), EnvSpec(SpaceInvadersNoFrameskip-v3), EnvSpec(StarGunnerNoFrameskip-v3), EnvSpec(Kangaroo-ramNoFrameskip-v3), EnvSpec(Kangaroo-ramNoFrameskip-v0), EnvSpec(FishingDerbyDeterministic-v0), EnvSpec(SpaceInvadersNoFrameskip-v0), EnvSpec(DuplicatedInput-v0), EnvSpec(Robotank-ramNoFrameskip-v0), EnvSpec(Robotank-ramNoFrameskip-v3), EnvSpec(Qbert-ram-v3), EnvSpec(DemonAttack-ramNoFrameskip-v3), EnvSpec(DemonAttack-ramNoFrameskip-v0), EnvSpec(Frostbite-ram-v3), EnvSpec(GopherDeterministic-v0), EnvSpec(GopherDeterministic-v3), EnvSpec(Frostbite-ram-v0), EnvSpec(Alien-ramDeterministic-v0), EnvSpec(VideoPinball-ramDeterministic-v3), EnvSpec(OneRoundDeterministicReward-v0), EnvSpec(VideoPinball-ramDeterministic-v0), EnvSpec(Qbert-ram-v0), EnvSpec(Tutankham-ram-v0), EnvSpec(TutankhamNoFrameskip-v3), EnvSpec(SkiingDeterministic-v3), EnvSpec(Freeway-ram-v0), EnvSpec(KangarooDeterministic-v3), EnvSpec(AirRaid-v0), EnvSpec(AirRaid-v3), EnvSpec(KangarooDeterministic-v0), EnvSpec(VentureNoFrameskip-v3), EnvSpec(Krull-v3), EnvSpec(JourneyEscape-ramNoFrameskip-v0), EnvSpec(JourneyEscape-ramNoFrameskip-v3), EnvSpec(Krull-v0), EnvSpec(KrullNoFrameskip-v3), EnvSpec(Riverraid-v0), EnvSpec(Riverraid-v3), EnvSpec(KrullNoFrameskip-v0), EnvSpec(MontezumaRevenge-ramDeterministic-v3), EnvSpec(MontezumaRevenge-ramDeterministic-v0), EnvSpec(RiverraidDeterministic-v3), EnvSpec(RiverraidDeterministic-v0), EnvSpec(Carnival-v3), EnvSpec(TimePilot-ram-v0), EnvSpec(MontezumaRevengeNoFrameskip-v3), EnvSpec(BoxingNoFrameskip-v0), EnvSpec(BoxingNoFrameskip-v3), EnvSpec(MontezumaRevengeNoFrameskip-v0), EnvSpec(SkiingDeterministic-v0), EnvSpec(UpNDown-ram-v0), EnvSpec(Kangaroo-v3), EnvSpec(Kangaroo-v0)]

In [3]:
# 画像入力の環境を指定 (後でDQNをつかうため)
ENV_NAME ='Pong-v0'

# 環境の初期化
env = gym.make(ENV_NAME)
np.random.seed(123)
env.seed(123)

# インターフェイスの情報をもらっておく
nb_actions = env.action_space.n

print("# of Actions : {}".format(nb_actions))
print("Shape of Observation : {}".format(env.observation_space.shape))


[2017-02-25 17:33:58,059] Making new env: Pong-v0
# of Actions : 6
Shape of Observation : (210, 160, 3)

In [4]:
from rl.core import Processor

# Pillow が必要
from PIL import Image

INPUT_SHAPE = (84, 84)

class AtariProcessor(Processor):
    def process_observation(self, observation):
        assert observation.ndim == 3  # (height, width, channel)
        img = Image.fromarray(observation)
        img = img.resize(INPUT_SHAPE).convert('L')  # resize and convert to grayscale
        processed_observation = np.array(img)
        assert processed_observation.shape == INPUT_SHAPE
        return processed_observation.astype('uint8')  # saves storage in experience memory

    def process_state_batch(self, batch):
        processed_batch = batch.astype('float32') / 255.
        return processed_batch

    def process_reward(self, reward):
        return np.clip(reward, -1., 1.)
    
processor = AtariProcessor()

obs = env.reset()
plt.imshow(obs)
print(obs.shape)


Using TensorFlow backend.
(210, 160, 3)

In [5]:
import keras.backend as K
from keras.models import Sequential
from keras.layers import Convolution2D, Dense, Activation, Flatten, Permute
from rl.memory import SequentialMemory

# 深層パーセプトロン
WINDOW_LENGTH = 1
input_shape = (WINDOW_LENGTH,) + INPUT_SHAPE
model = Sequential()
if K.image_dim_ordering() == 'tf':
    # (width, height, channels)
    model.add(Permute((2, 3, 1), input_shape=input_shape))
elif K.image_dim_ordering() == 'th':
    # (channels, width, height)
    model.add(Permute((1, 2, 3), input_shape=input_shape))
else:
    raise RuntimeError('Unknown image_dim_ordering.')
model.add(Convolution2D(32, 8, 8, subsample=(4, 4)))
model.add(Activation('relu'))
model.add(Convolution2D(64, 4, 4, subsample=(2, 2)))
model.add(Activation('relu'))
model.add(Convolution2D(64, 3, 3, subsample=(1, 1)))
model.add(Activation('relu'))
model.add(Flatten())
model.add(Dense(512))
model.add(Activation('relu'))
model.add(Dense(nb_actions))
model.add(Activation('linear'))
print(model.summary())

# Memory Buffer
memory = SequentialMemory(limit=1000000, window_length=WINDOW_LENGTH)


____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
====================================================================================================
permute_1 (Permute)              (None, 84, 84, 1)     0           permute_input_1[0][0]            
____________________________________________________________________________________________________
convolution2d_1 (Convolution2D)  (None, 20, 20, 32)    2080        permute_1[0][0]                  
____________________________________________________________________________________________________
activation_1 (Activation)        (None, 20, 20, 32)    0           convolution2d_1[0][0]            
____________________________________________________________________________________________________
convolution2d_2 (Convolution2D)  (None, 9, 9, 64)      32832       activation_1[0][0]               
____________________________________________________________________________________________________
activation_2 (Activation)        (None, 9, 9, 64)      0           convolution2d_2[0][0]            
____________________________________________________________________________________________________
convolution2d_3 (Convolution2D)  (None, 7, 7, 64)      36928       activation_2[0][0]               
____________________________________________________________________________________________________
activation_3 (Activation)        (None, 7, 7, 64)      0           convolution2d_3[0][0]            
____________________________________________________________________________________________________
flatten_1 (Flatten)              (None, 3136)          0           activation_3[0][0]               
____________________________________________________________________________________________________
dense_1 (Dense)                  (None, 512)           1606144     flatten_1[0][0]                  
____________________________________________________________________________________________________
activation_4 (Activation)        (None, 512)           0           dense_1[0][0]                    
____________________________________________________________________________________________________
dense_2 (Dense)                  (None, 6)             3078        activation_4[0][0]               
____________________________________________________________________________________________________
activation_5 (Activation)        (None, 6)             0           dense_2[0][0]                    
====================================================================================================
Total params: 1,681,062
Trainable params: 1,681,062
Non-trainable params: 0
____________________________________________________________________________________________________
None

In [6]:
from rl.policy import LinearAnnealedPolicy, EpsGreedyQPolicy
from rl.agents import DQNAgent
from keras.optimizers import Adam

# 線形にランダム選択確率を減少させる方策
policy = LinearAnnealedPolicy(EpsGreedyQPolicy(),
                              attr='eps',
                              value_max=1.,
                              value_min=.1,
                              value_test=.05,
                              nb_steps=1000000)

# DQN エージェントを作成
dqn = DQNAgent(model=model,
               nb_actions=nb_actions,
               policy=policy,
               memory=memory,
               processor=processor,
               nb_steps_warmup=1000,
               gamma=.99,
               target_model_update=1000,
               train_interval=4,
               delta_clip=1.)

optimizer = Adam(lr=.0001, epsilon=0.0001)
dqn.compile(optimizer=optimizer, metrics=['mae'])

In [7]:
from rl.callbacks import Callback, TestLogger, ModelIntervalCheckpoint

# コールバックを作成
class PlotReward(Callback):
    def on_train_begin(self, episode, logs={}):
        self.episode_reward = []
        self.fig = plt.figure(0)

    def on_episode_end(self, episode, logs={}):
        self.episode_reward.append(logs['episode_reward'])
        self.show_result()

    def show_result(self):
        display.clear_output(wait=True)
        display.display(plt.gcf())
        plt.clf()
        plt.plot(self.episode_reward, 'r')
        plt.xlabel('Episode')
        plt.ylabel('Total Reward')
        plt.pause(0.001)

callbacks = [PlotReward(), ModelIntervalCheckpoint(filepath='./weight_now.h5f', interval=10000)]

In [ ]:
# 学習
dqn.fit(env, verbose=2, visualize=True, callbacks=callbacks, nb_steps=1750000)

# モデルをそのまま保存
dqn.model.save('mymdel.h5f', overwrite=True)


<matplotlib.figure.Figure at 0x1185882d0>
    1271/1750000: episode: 1, duration: 38.105s, episode steps: 1271, steps per second: 33, episode reward: -21.000, mean reward: -0.017 [-1.000, 0.000], mean action: 2.445 [0.000, 5.000], mean observation: 105.698 [0.000, 236.000], loss: 0.006802, mean_absolute_error: 0.017334, mean_q: 0.008899, mean_eps: 0.998978

In [ ]:
# 最後に10 Episode評価
dqn.test(env, nb_episodes=10, visualize=False)