上一期 MyEncyclopedia公众号文章 通过代码学Sutton强化学习:从Q-Learning 演化到 DQN,我们从原理上讲解了DQN算法,这一期,让我们通过代码来实现DQN 在任天堂经典的超级玛丽游戏中的自动通关吧。本系列将延续通过代码学Sutton 强化学习系列,逐步通过代码实现经典深度强化学习应用在各种游戏环境中。本文所有代码在

https://github.com/MyEncyclopedia/reinforcement-learning-2nd/tree/master/super_mario

最终训练第一关结果动画

DQN 算法回顾

上期详细讲解了DQN中的两个重要的技术:Target Network 和 Experience Replay,正是有了它们才使得 Deep Q Network在实战中容易收敛,以下是Deepmind 发表在Nature 的 Human-level control through deep reinforcement learning 的完整算法流程。

超级玛丽 NES OpenAI 环境

安装基于OpenAI gym的超级玛丽环境执行下面的 pip 命令即可。

pip install gym-super-mario-bros

我们先来看一下游戏环境的输入和输出。下面代码采用随机的action来和游戏交互。有了 组合游戏系列3: 井字棋、五子棋的OpenAI Gym GUI环境 关于OpenAI Gym 的介绍,现在对于其基本的交互步骤已经不陌生了。

import gym_super_mario_bros
from random import random, randrange
from gym_super_mario_bros.actions import RIGHT_ONLY
from nes_py.wrappers import JoypadSpace
from gym import wrappersenv = gym_super_mario_bros.make('SuperMarioBros-v0')
env = JoypadSpace(env, RIGHT_ONLY)# Play randomly
done = False
env.reset()step = 0
while not done:action = randrange(len(RIGHT_ONLY))state, reward, done, info = env.step(action)print(done, step, info)env.render()step += 1env.close()

随机策略的效果如下

注意我们在游戏环境初始化的时候用了参数 RIGHT_ONLY,它定义成五种动作的list,表示仅使用右键的一些组合,适用于快速训练来完成Mario第一关。

RIGHT_ONLY = [['NOOP'],['right'],['right', 'A'],['right', 'B'],['right', 'A', 'B'],
]

观察一些 info 输出内容,coins表示金币获得数量,flag_get 表示是否取得最后的旗子,time 剩余时间,以及 Mario 大小状态和所在的 x,y位置。

{"coins":0,"flag_get":False,"life":2,"score":0,"stage":1,"status":"small","time":381,"world":1,"x_pos":594,"y_pos":89
}

游戏图像处理

Deep Reinforcement Learning 一般是 end-to-end learning,意味着将游戏的 screen image,即 observed state 直接视为真实状态 state,喂给神经网络去训练。于此相反的另一种做法是,通过游戏环境拿到内部状态,例如所有相关物品的位置和属性作为模型输入。这两种方式的区别在我看来有两点。第一点,用观察到的屏幕像素代替真正的状态 state,在partially observable 的环境时可能因为 non-stationarity 导致无法很好的工作,而拿内部状态利用了额外的作弊信息,在partially observable环境中也可以工作。第二点,第一种方式屏幕像素维度比较高,输入数据量大,需要神经网络的大量训练拟合,第二种方式,内部真实状态往往维度低得多,训练起来很快,但缺点是因为除了内部状态往往还需要游戏相关规则作为输入,因此generalization能力不如前者强。

这里,我们当然采样屏幕像素的 end-to-end 方式了,自然首要任务是将游戏帧图像有效处理。超级玛丽游戏环境的屏幕输出是 (240, 256, 3) shape的 numpy array,通过下面一系列的转换,尽可能的在不影响训练效果的情况下减小采样到的数据量。

  1. MaxAndSkipFrameWrapper:每4个frame连在一起,采取同样的动作,降低frame数量

  2. FrameDownsampleWrapper:将原始的 (240, 256, 3) down sample 到 (84, 84, 1)

  3. ImageToPyTorchWrapper:转换成适合 pytorch 的 shape (1, 84, 84)

  4. FrameBufferWrapper:保存最后4次屏幕采样

  5. NormalizeFloats:Normalize 成 [0., 1.0] 的浮点值

def wrap_environment(env_name: str, action_space: list) -> Wrapper:env = make(env_name)env = JoypadSpace(env, action_space)env = MaxAndSkipFrameWrapper(env)env = FrameDownsampleWrapper(env)env = ImageToPyTorchWrapper(env)env = FrameBufferWrapper(env, 4)env = NormalizeFloats(env)return env

CNN 模型

模型比较简单,三个卷积层后做 softmax输出,输出维度数为离散动作数。act() 采用了epsilon-greedy 模式,即在epsilon小概率时采取随机动作来 explore,大于epsilon时采取估计的最可能动作来 exploit。

class DQNModel(nn.Module):def __init__(self, input_shape, num_actions):super(DQNModel, self).__init__()self._input_shape = input_shapeself._num_actions = num_actionsself.features = nn.Sequential(nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),nn.ReLU(),nn.Conv2d(32, 64, kernel_size=4, stride=2),nn.ReLU(),nn.Conv2d(64, 64, kernel_size=3, stride=1),nn.ReLU())self.fc = nn.Sequential(nn.Linear(self.feature_size, 512),nn.ReLU(),nn.Linear(512, num_actions))def forward(self, x):x = self.features(x).view(x.size()[0], -1)return self.fc(x)def act(self, state, epsilon, device):if random() > epsilon:state = torch.FloatTensor(np.float32(state)).unsqueeze(0).to(device)q_value = self.forward(state)action = q_value.max(1)[1].item()else:action = randrange(self._num_actions)return action

Experience Replay 缓存

实现采用了 Pytorch CartPole DQN 的官方代码,本质是一个最大为 capacity 的 list 保存了采样到的 (s, a, r, s', is_done)  五元组。

Transition = namedtuple('Transition', ('state', 'action', 'reward', 'next_state', 'done'))class ReplayMemory:def __init__(self, capacity):self.capacity = capacityself.memory = []self.position = 0def push(self, *args):if len(self.memory) < self.capacity:self.memory.append(None)self.memory[self.position] = Transition(*args)self.position = (self.position + 1) % self.capacitydef sample(self, batch_size):return random.sample(self.memory, batch_size)def __len__(self):return len(self.memory)

DQNAgent

我们将 DQN 的逻辑封装在 DQNAgent 类中。DQNAgent 成员变量包括两个 DQNModel,一个ReplayMemory。

train() 方法中会每隔一定时间将 Target Network 的参数同步成现行Network的参数。在td_loss_backprop()方法中采样 ReplayMemory 中的五元组,通过minimize TD error方式来改进现行 Network 参数 。Loss函数为:

class DQNAgent():def act(self, state, episode_idx):self.update_epsilon(episode_idx)action = self.model.act(state, self.epsilon, self.device)return actiondef process(self, episode_idx, state, action, reward, next_state, done):self.replay_mem.push(state, action, reward, next_state, done)self.train(episode_idx)def train(self, episode_idx):if len(self.replay_mem) > self.initial_learning:if episode_idx % self.target_update_frequency == 0:self.target_model.load_state_dict(self.model.state_dict())self.optimizer.zero_grad()self.td_loss_backprop()self.optimizer.step()def td_loss_backprop(self):transitions = self.replay_mem.sample(self.batch_size)batch = Transition(*zip(*transitions))state = Variable(FloatTensor(np.float32(batch.state))).to(self.device)action = Variable(LongTensor(batch.action)).to(self.device)reward = Variable(FloatTensor(batch.reward)).to(self.device)next_state = Variable(FloatTensor(np.float32(batch.next_state))).to(self.device)done = Variable(FloatTensor(batch.done)).to(self.device)q_values = self.model(state)next_q_values = self.target_net(next_state)q_value = q_values.gather(1, action.unsqueeze(-1)).squeeze(-1)next_q_value = next_q_values.max(1)[0]expected_q_value = reward + self.gamma * next_q_value * (1 - done)loss = (q_value - expected_q_value.detach()).pow(2)loss = loss.mean()loss.backward()

外层控制代码

最后是外层调用代码,基本和以前文章一样。

def train(env, args, agent):for episode_idx in range(args.num_episodes):episode_reward = 0.0state = env.reset()while True:action = agent.act(state, episode_idx)if args.render:env.render()next_state, reward, done, stats = env.step(action)agent.process(episode_idx, state, action, reward, next_state, done)state = next_stateepisode_reward += rewardif done:print(f'{episode_idx}: {episode_reward}')break

著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。


往期精彩回顾适合初学者入门人工智能的路线及资料下载机器学习及深度学习笔记等资料打印机器学习在线手册深度学习笔记专辑《统计学习方法》的代码复现专辑
AI基础下载机器学习的数学基础专辑
获取本站知识星球优惠券,复制链接直接打开:
https://t.zsxq.com/qFiUFMV
本站qq群704220115。加入微信群请扫码:

【深度强化学习】DQN训练超级玛丽闯关相关推荐

  1. ROS开发笔记(10)——ROS 深度强化学习dqn应用之tensorflow版本(double dqn/dueling dqn/prioritized replay dqn)

    ROS开发笔记(10)--ROS 深度强化学习dqn应用之tensorflow版本(double dqn/dueling dqn/prioritized replay dqn) 在ROS开发笔记(9) ...

  2. 系统回顾深度强化学习预训练,在线、离线等研究这一篇就够了

    关注公众号,发现CV技术之美 本文转载自机器之心. 本文中,来自上海交通大学和腾讯的研究者系统地回顾了现有深度强化学习预训练研究,并提供了这些方法的分类,同时对每个子领域进行了探讨. 近年来,强化学习 ...

  3. 深度强化学习DRL训练指南和现存问题(D3QN(Dueling Double DQN))

    目录 参数 iteration episode epoch Batch_Size Experimence Replay Buffer经验回放缓存 Reward discount factor或gamm ...

  4. 深度强化学习——DQN

    联系方式:860122112@qq.com DQN(Deep Q-Learning)可谓是深度强化学习(Deep Reinforcement Learning,DRL)的开山之作,是将深度学习与强化学 ...

  5. 深度强化学习DQN网络

    DQN网络 DQN(Deep Q Networks)网络属于深度强化学习中的一种网络,它是深度学习与Q学习的结合,在传统的Q学习中,我们需要维护一张Q(s,a)表,在实际运用中,Q表往往是巨大的,并且 ...

  6. 深度强化学习——DQN算法原理

    DQN算法原理 一.DQN算法是什么 二.DQN训练过程 三.经验回放 (Experience Replay) 四.目标网络(Target Network) 1.自举(Bootstrapping) 2 ...

  7. 深度强化学习-DQN算法原理与代码

    DQN算法是DeepMind团队提出的一种深度强化学习算法,在许多电动游戏中达到人类玩家甚至超越人类玩家的水准,本文就带领大家了解一下这个算法,论文和代码的链接见下方. 论文:Human-level ...

  8. Pytorch 深度强化学习模型训练速度慢

    最近一直在用Pytorch来训练深度强化学习模型,但是速度一直很慢,Gpu利用率也很低. 一.起初开始在训练参数 batch_size = 200, graph_size = 40, epoch_si ...

  9. 深度强化学习DQN详解CartPole

    一. 获取并处理环境图像 本文所刨析的代码是"pytorch官网的DQN示例"(页面),用卷积层配合强化训练去学习小车立杆,所使用的环境是"小车立杆环境"(Ca ...

最新文章

  1. 写程序是最轻松的事情
  2. python批量下载网页文件-python使用selenium实现批量文件下载
  3. python学习笔记 day25 封装
  4. 计算机里的dump是什么意思?(转储、转储文件)
  5. win防火墙禁止访问php文件,windows通过netsh设置防火墙
  6. Android IOS WebRTC 音视频开发总结(十一)-- stunturn部署
  7. smtplib python教程_Python使用poplib模块和smtplib模块收发电子邮件的教程
  8. python查找路径代码_Python搜索路径
  9. 在Ubuntu 14.04上安装 Webmin
  10. Linux中断线程化的优势,记一个实时Linux的中断线程化问题
  11. 2018年手机保值排行榜出炉:华为P20成最大赢家?
  12. 共享usb接口给虚拟机_多网卡虚拟机如何设置?收藏绝对有用
  13. BlackBerry HTML5 WebWorks 平台下,让BB10应用连接上BBM
  14. R语言:企业风险分析(2)【蒙特卡罗模拟,Monte-Carlo Simulation】
  15. Mac OSX x86 10.4.6 安装小记(1)
  16. creo打不开stp文件_为什么stp网站打不开 creo打不开stp文件
  17. Win7: Logoff被用户Lock的屏幕
  18. 如何在Guitar Pro上添加吉他和弦
  19. python key=lambda函数_使用’key’和lambda表达式的python max函数
  20. 免疫组库vdj的数据处理(TCR/BCR)

热门文章

  1. 数据库菜鸟不可不看 简单SQL语句小结
  2. 雷林鹏分享:jQuery EasyUI 数据网格 - 条件设置行背景颜色
  3. 钢琴快案例及手风琴案例
  4. Linux命令在线查询
  5. Map与object的区别
  6. Python学习笔记 setdict
  7. WPF:仿WIN7窗体打开关闭效果
  8. 2021年东南大学附属中大医院公布SCI预警期刊列表的通知
  9. java 微信高级群发_微信高级群发接口demo
  10. css 浮动在最上层_CSS的“层”峦“叠”翠