前言:其实很早之前就想开始写写深度强化学习(Deep reinforcement learning)了,但是一年前DQN没调出来,没好意思写哈哈,最近呢无意中把打砖块游戏Breakout训练到平均分接近40分,最高分随便上50(虽说也不算太好,但好歹也体现了DRL的优势),于是就写写吧~

提到深度强化学习的成名作,很多人可能会觉得是2016年轰动一时的AlphaGo,从大众来看是这样的,但真正让深度强化学习火起来并获得学术界蹭蹭往上涨关注度的,当属Deep Q-learning Network(DQN),最早见于2013年的论文《Playing Atari with Deep Reinforcement Learning》。

2012年,深度学习刚在ImageNet比赛大获全胜,紧接着DeepMind团队就想到把深度网络与强化学习结合起来,思想是基于强化学习领域很早就出现的值函数逼近(function approximation),但是通过深度神经网络这一神奇的工具,巧妙地解决了状态维数爆炸的问题!

怎么解决的呢?让我们走进DQN,一探究竟。

CNN实现Q(s, a)

如果我们以纯数学的角度来看动作值函数

,不过就是建立一个从状态空间
到动作空间
的映射,而映射的具体形式是什么,完全可以自己定,只要能够接近真实的最优

就是胜利。于是用CNN完成这种映射的做法应运而生,先上一幅架构图:

DQN中的CNN架构

DQN

通过gym模块输出Atari环境的游戏,状态空间都是(210, 160, 3),即210*160的图片大小,3个通道,在输入CNN之前需要通过图像处理二值化并缩小成84*84。由于如果将一张图片作为状态输入信息,很多隐藏信息就会忽略(比如球往哪边飞),于是论文中把连续的4帧图片作为状态输入。所以在pytorch中,CNN的输入就是

,在原论文中与上图所示稍有出入,卷积结构如下:
  • 第一层卷积核8*8,stride=4,输出通道为32,ReLU
  • 第二层卷积核4*4,stride=2,输出通道为64,ReLU
  • 第三层卷积核3*3,stride=1,输出通道为64,ReLU
  • 第三层输出经过flat之后维度为3136,然后第四层连接一个512大小的全连接层
  • 第五层为动作空间大小的输出层,Breakout游戏中为4,表示每种动作的概率
二值化之后的Breakout

搭建CNN的py文件在q_model.py中,并手动随机生成torch.randn(32, 4, 84, 84)向量用于测试网络架构的正确性:

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import cv2
import gym
import matplotlib.pyplot as pltclass QNetwork(nn.Module):"""Actor (Policy) Model."""def __init__(self, state_size, action_size, seed):"""Initialize parameters and build model.Params======state_size (int): Dimension of each stateaction_size (int): Dimension of each actionseed (int): Random seed"""super(QNetwork, self).__init__()self.seed = torch.manual_seed(seed)"*** YOUR CODE HERE ***"self.conv = nn.Sequential(nn.Conv2d(state_size[1], 32, kernel_size=8, stride=4),nn.ReLU(),nn.Conv2d(32, 64, kernel_size=4, stride=2),nn.ReLU(),nn.Conv2d(64, 64, kernel_size=3, stride=1),nn.ReLU())self.fc = nn.Sequential(nn.Linear(64*7*7, 512),nn.ReLU(),nn.Linear(512, action_size))def forward(self, state):"""Build a network that maps state -> action values."""conv_out = self.conv(state).view(state.size()[0], -1)return self.fc(conv_out)def pre_process(observation):"""Process (210, 160, 3) picture into (1, 84, 84)"""x_t = cv2.cvtColor(cv2.resize(observation, (84, 84)), cv2.COLOR_BGR2GRAY)ret, x_t = cv2.threshold(x_t, 1, 255, cv2.THRESH_BINARY)return np.reshape(x_t, (1, 84, 84)), x_tdef stack_state(processed_obs):"""Four frames as a state"""return np.stack((processed_obs, processed_obs, processed_obs, processed_obs), axis=0)if __name__ == '__main__':env = gym.make('Breakout-v0')print('State shape: ', env.observation_space.shape)print('Number of actions: ', env.action_space.n)obs = env.reset()x_t, img = pre_process(obs)state = stack_state(img)print(np.shape(state[0]))# plt.imshow(img, cmap='gray')# 用cv2模块显示# cv2.imshow('Breakout', img)# cv2.waitKey(0)state = torch.randn(32, 4, 84, 84)  # (batch_size, color_channel, img_height,img_width)state_size = state.size()cnn_model = QNetwork(state_size, action_size=4, seed=1)outputs = cnn_model(state)print(outputs)

完成了解决维数灾难的一步,接下来我们就应该考虑训练的稳定性和效率了,这也是深度学习领域常考虑的问题。而在DQN算法中,作者提出了两大技巧来解决,就是著名的replay buffer和target network,我们一一讨论。

Replay Buffer

replay这个词很形象,在英语中用于影视剧之类的回放;buffer这个词则是计算机里的术语;两个词合起来,形象地体会一下,就是把过去的数据从一个缓存中又拿出来用,这样一用,就比较好地解决了困扰Q-learning算法的样本效率以及相关性问题。从Q-learning的原始公式和算法流程来看,每一次更新Q值的样本都只能用一次,而且在连续获取游戏画面的情景下,状态样本存在极高的相关性。针对这两个问题,如果我们使用一个较大的buffer来储存这些样本,每次随机均匀采样,既能多次使用样本,还能打破样本之间的相关性。

Replay Buffer 示意图

在dqn_agent.py中以ReplayBuffer来实现:

class ReplayBuffer:"""Fixed-size buffer to store experience tuples."""def __init__(self, action_size, buffer_size, batch_size, seed):"""Initialize a ReplayBuffer object.Params======action_size (int): dimension of each actionbuffer_size (int): maximum size of bufferbatch_size (int): size of each training batchseed (int): random seed"""self.action_size = action_sizeself.memory = deque(maxlen=buffer_size)  self.batch_size = batch_sizeself.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])self.seed = random.seed(seed)def add(self, state, action, reward, next_state, done):"""Add a new experience to memory."""e = self.experience(state, action, reward, next_state, done)self.memory.append(e)def sample(self):"""Randomly sample a batch of experiences from memory."""experiences = random.sample(self.memory, k=self.batch_size)states = torch.from_numpy(np.stack([e.state for e in experiences if e is not None])).float().to(device)actions = torch.from_numpy(np.vstack([e.action for e in experiences if e is not None])).long().to(device)rewards = torch.from_numpy(np.vstack([e.reward for e in experiences if e is not None])).float().to(device)next_states = torch.from_numpy(np.stack([e.next_state for e in experiences if e is not None])).float().to(device)dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to(device)return (states, actions, rewards, next_states, dones)def __len__(self):"""Return the current size of internal memory."""return len(self.memory)

Target Network

假设真实的动作值函数为

,我们的训练目标是训练一个CNN使
能够逼近它,经过对梯度的求导,感觉上我们希望TD目标值
应该是一回事,但实际上由于
是不依赖于参数w的,而TD目标值依赖于参数w,这在数学上是不合理的。也就是说TD目标值本来就是估计值,与旧的估计值作差再去更新
,就感觉像追着一个移动的目标却一直够不着。论文作者提出设置一个target network,而之前与环境交互产生动作的网络称为behavior network,训练开始时二者使用一样的架构和参数,训练过程中每完成一定数目的迭代,behavior network的参数就同步给target network。这是原始论文中提出更新方法,而在笔者的代码中,借鉴了DDPG中soft update法。
import numpy as np
import random
from collections import namedtuple, dequefrom q_model import QNetworkimport torch
import torch.nn.functional as F
import torch.optim as optimBUFFER_SIZE = int(1e6)  # replay buffer size
BATCH_SIZE = 32         # minibatch size
GAMMA = 0.99            # discount factor
TAU = 1e-3              # for soft update of target parameters
LR = 1e-5               # learning rate
UPDATE_EVERY = 4        # how often to update the networkdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)class Agent():"""Interacts with and learns from the environment."""def __init__(self, state_size, action_size, seed):"""Initialize an Agent object.Params======state_size (int): dimension of each stateaction_size (int): dimension of each actionseed (int): random seed"""self.state_size = state_sizeself.action_size = action_sizeself.seed = random.seed(seed)# Q-Networkself.qnetwork_local = QNetwork(state_size, action_size, seed).to(device)  # behavior networkself.qnetwork_target = QNetwork(state_size, action_size, seed).to(device)  # target networkself.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=LR)# Replay memoryself.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed)# Initialize time step (for updating every UPDATE_EVERY steps)self.t_step = 0def step(self, state, action, reward, next_state, done):# Save experience in replay memoryself.memory.add(state, action, reward, next_state, done)# Learn every UPDATE_EVERY time steps.self.t_step = (self.t_step + 1) % UPDATE_EVERYif self.t_step == 0:# If enough samples are available in memory, get random subset and learnif len(self.memory) > BATCH_SIZE:experiences = self.memory.sample()self.learn(experiences, GAMMA)def act(self, state, eps=0.):"""Returns actions for given state as per current policy.Params======state (array_like): current stateeps (float): epsilon, for epsilon-greedy action selection"""state = torch.from_numpy(state).float().unsqueeze(0).to(device)self.qnetwork_local.eval()with torch.no_grad():action_values = self.qnetwork_local(state)self.qnetwork_local.train()# Epsilon-greedy action selectionif random.random() > eps:return np.argmax(action_values.cpu().data.numpy())else:return random.choice(np.arange(self.action_size))def learn(self, experiences, gamma):"""Update value parameters using given batch of experience tuples.Params======experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples gamma (float): discount factor"""states, actions, rewards, next_states, dones = experiences## TODO: compute and minimize the loss# Get max predicted Q values (for next states) from target modelQ_targets_next = self.qnetwork_target(next_states).detach().max(1)[0].unsqueeze(1)# Compute Q targets for current statesQ_targets = rewards + (gamma * Q_targets_next * (1 - dones))Q_expected = self.qnetwork_local(states).gather(1, actions)  # 固定行号,确认列号# Compute lossloss = F.mse_loss(Q_expected, Q_targets)# Minimize the lossself.optimizer.zero_grad()loss.backward()self.optimizer.step()# ------------------- update target network ------------------- #self.soft_update(self.qnetwork_local, self.qnetwork_target, TAU)                     def soft_update(self, local_model, target_model, tau):"""Soft update model parameters.θ_target = τ*θ_local + (1 - τ)*θ_targetParams======local_model (PyTorch model): weights will be copied fromtarget_model (PyTorch model): weights will be copied totau (float): interpolation parameter """for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):target_param.data.copy_(tau*local_param.data + (1.0-tau)*target_param.data)

DQN算法流程

原论文的算法流程

我们以打砖块游戏Breakout来测试整套算法流程,整个流程的代码在Deep_Q_network.py中,超参数设置在dqn_agent.py中,其中最为重要的超参数设置是replay buffer的大小迭代次数、学习率和

衰减率。replay buffer大小至少设置为100万,迭代次数也是越多越好。
import gym
import random
import torch
import numpy as np
from collections import deque
from dqn_agent import Agent
import matplotlib.pyplot as plt
import cv2
import timeenv = gym.make('Breakout-v0')
state_size = env.observation_space.shape
action_size = env.action_space.n
print('Original state shape: ', state_size)
print('Number of actions: ', env.action_space.n)agent = Agent((32, 4, 84, 84), action_size, seed=1)  # state size (batch_size, 4 frames, img_height, img_width)
TRAIN = False  # train or test flag def pre_process(observation):"""Process (210, 160, 3) picture into (1, 84, 84)"""x_t = cv2.cvtColor(cv2.resize(observation, (84, 84)), cv2.COLOR_BGR2GRAY)ret, x_t = cv2.threshold(x_t, 1, 255, cv2.THRESH_BINARY)return x_tdef init_state(processed_obs):return np.stack((processed_obs, processed_obs, processed_obs, processed_obs), axis=0)def dqn(n_episodes=30000, max_t=40000, eps_start=1.0, eps_end=0.01, eps_decay=0.9995):"""Deep Q-Learning.Params======n_episodes (int): maximum number of training episodesmax_t (int): maximum number of timesteps per episode, maximum frameseps_start (float): starting value of epsilon, for epsilon-greedy action selectioneps_end (float): minimum value of epsiloneps_decay (float): multiplicative factor (per episode) for decreasing epsilon"""scores = []  # list containing scores from each episodescores_window = deque(maxlen=100)  # last 100 scoreseps = eps_start  # initialize epsilonfor i_episode in range(1, n_episodes + 1):obs = env.reset()obs = pre_process(obs)state = init_state(obs)score = 0for t in range(max_t):action = agent.act(state, eps)next_state, reward, done, _ = env.step(action)# last three frames and current frame as the next statenext_state = np.stack((state[1], state[2], state[3], pre_process(next_state)), axis=0)agent.step(state, action, reward, next_state, done)state = next_statescore += rewardif done:breakscores_window.append(score)  # save most recent scorescores.append(score)  # save most recent scoreeps = max(eps_end, eps_decay * eps)  # decrease epsilonprint('tEpsilon now : {:.2f}'.format(eps))print('rEpisode {}tAverage Score: {:.2f}'.format(i_episode, np.mean(scores_window)), end="")if i_episode % 1000 == 0:print('rEpisode {}tAverage Score: {:.2f}'.format(i_episode, np.mean(scores_window)))print('rEpisode {}tThe length of replay buffer now: {}'.format(i_episode, len(agent.memory)))if np.mean(scores_window) >= 50.0:print('nEnvironment solved in {:d} episodes!tAverage Score: {:.2f}'.format(i_episode - 100,np.mean(scores_window)))torch.save(agent.qnetwork_local.state_dict(), 'checkpoint/dqn_checkpoint_solved.pth')breaktorch.save(agent.qnetwork_local.state_dict(), 'checkpoint/dqn_checkpoint_8.pth')return scoresif __name__ == '__main__':if TRAIN:start_time = time.time()scores = dqn()print('COST: {} min'.format((time.time() - start_time)/60))print("Max score:", np.max(scores))# plot the scoresfig = plt.figure()ax = fig.add_subplot(111)plt.plot(np.arange(len(scores)), scores)plt.ylabel('Score')plt.xlabel('Episode #')plt.show()else:# load the weights from fileagent.qnetwork_local.load_state_dict(torch.load('checkpoint/dqn_checkpoint_8.pth'))rewards = []for i in range(10):  # episodes, play ten timestotal_reward = 0obs = env.reset()obs = pre_process(obs)state = init_state(obs)for j in range(10000):  # frames, in case stuck in one frameaction = agent.act(state)env.render()next_state, reward, done, _ = env.step(action)state = np.stack((state[1], state[2], state[3], pre_process(next_state)), axis=0)total_reward += reward# time.sleep(0.01)if done:rewards.append(total_reward)breakprint("Test rewards are:", *rewards)print("Average reward:", np.mean(rewards))env.close()

在Titan XP显卡运行,分布经过1万次和3万次迭代,统计回报曲线如下两幅图。据图中分析可知,回报曲线波动性较大,但整体趋势是在上升;在三万次游戏中,最高的一次接近350,但这一次结果十分良好的训练并不会对后面造成过多影响,这也是DQN饱受困扰的“灾难性遗忘”问题(catastrophic forgetting)。

3万次迭代
1万次迭代

最后的最后,上效果视频和GitHub仓库(代码是基于Udacity课程代码改的)!

自动打砖块https://www.zhihu.com/video/1178626355478061056zhengsizuo/DRL_udacity​github.com

深度学习stride_深度强化学习成名作——DQN相关推荐

  1. 叶梓老师人工智能培训之强化学习与深度强化学习提纲(强化学习讲师培训)

    强化学习与深度强化学习提纲(强化学习讲师培训) 第一天 强化学习   第一课 强化学习综述 1.强化学习要解决的问题 2.强化学习方法的分类 3.强化学习方法的发展趋势 4.环境搭建实验(Gym,Te ...

  2. AI内训讲师叶梓-强化学习与深度强化学习提纲(强化学习讲师培训)

    叶梓老师更多教程资料可点击个人主业​​​​​​​查看 第一天 强化学习   第一课 强化学习综述 1.强化学习要解决的问题 2.强化学习方法的分类 3.强化学习方法的发展趋势 4.环境搭建实验(Gym ...

  3. 机器学习深度学习加强学习_加强强化学习背后的科学

    机器学习深度学习加强学习 机器学习 ,强化学习 (Machine Learning, Reinforcement Learning) You're getting bore stuck in lock ...

  4. 强化学习应用简述---强化学习方向优秀科学家李玉喜博士创作

    强化学习 (reinforcement learning) 经过了几十年的研发,在一直稳定发展,最近取得了很多傲人的成果,后面会有越来越好的进展.强化学习广泛应用于科学.工程.艺术等领域. 下面简单列 ...

  5. 【强化学习知识】强化学习简介

    文章目录 前言 1. Q learning 2. Sarsa 3. Deep Q Network(DQN) 4. 总结 前言 强化学习是机器学习中的一大类,它可以让机器学着如何在环境中拿到高分, 表现 ...

  6. 强化学习q学习求最值_Q学习简介:强化学习

    强化学习q学习求最值 by ADL 通过ADL Q学习简介:强化学习 (An introduction to Q-Learning: reinforcement learning) This arti ...

  7. 《强化学习周刊》第44期:RL-CoSeg、图强化学习、安全强化学习

    No.44 智源社区 强化学习组 强 化 学  习 研究 观点 资源 活动 周刊订阅 告诉大家一个好消息,<强化学习周刊>已经开启"订阅功能",以后我们会向您自动推送最 ...

  8. 学习笔记:强化学习与最优控制(Chapter 2)

    Approximation in Value Space 学习笔记:强化学习与最优控制(Chapter 2) Approximation in Value Space 1. 综述 2. 基于Value ...

  9. 强化学习笔记-01强化学习介绍

    本文是博主对<Reinforcement Learning- An introduction>的阅读笔记,不涉及内容的翻译,主要为个人的理解和思考. 1. 强化学习是什么?解决什么样的问题 ...

  10. 强化学习-什么是强化学习?白话文告诉你!

    目录 1.强化学习简介 2.强化学习的概念: 3.马尔可夫决策过程 4.Bellman方程 5.Q-Learning基本原理实例讲解 1.强化学习简介 世石与AlphaGo的这场人机世纪巅峰对决,不但 ...

最新文章

  1. 各大型邮箱smtp服务器及端口收集:
  2. 数据链路层差错检测:CRC(循环冗余检验)
  3. 威纶触摸屏与电脑连接_PLC与这7种设备的连接方式,一看就懂!
  4. 计算机网络-信道复用技术
  5. 【Python】文件夹的常用操作
  6. linux如何查看磁盘剩余空间
  7. C++中序列化对象并存储到mysql
  8. Quartz.Net 学习随手记之03 配置文件
  9. .NET平台开源项目速览(6)FluentValidation验证组件介绍与入门(一)
  10. Ubuntu下远程访问MySQL数据库
  11. Java RMI(远程方法调用)入门
  12. 一条语句引发的思考:装箱和拆箱,空指针的类型转换
  13. Okhttp之CallServerInterceptor简单分析
  14. 阶段1 语言基础+高级_1-3-Java语言高级_06-File类与IO流_08 转换流_3_转换流的原理...
  15. linux 磁盘序列号修改,linux 获取硬盘序列号解决思路
  16. 【星门跳跃】解题报告
  17. 为知笔记数据备份方法
  18. 报错:SLF4J: Failed to load class “org.slf4j.impl.StaticLoggerBinder“.
  19. HTML练习—东风破
  20. 查询一列不同值的数据 mysql_怎样查询两个表中同一字段的不同数据值

热门文章

  1. 如何在React中使用功能组件
  2. 计算机编程课程顺序_九月份可以开始提供650多种免费的在线编程和计算机科学课程
  3. 测试开发面试技巧_面试技巧将给您带来信心并帮助您获得开发工作
  4. fastify 后台_如何使用Fastify启动和运行
  5. Django复习:视图和模版
  6. PyTorch入门-词向量
  7. Python错误,pip安装包或更新时因超时而报错误
  8. Python网络爬虫开发实战,关于过程中的异常处理
  9. 漫步数理统计二十二——二项及相关分布
  10. 计算 期望与方差(mean and Variance)在 Tensorflow 与 Numpy 对比