深度Q学习原理及相关实例

  • 8. 深度Q学习
    • 8.1 经验回放
    • 8.2 目标网络
    • 8.3 相关算法
    • 8.4 训练算法
    • 8.5 深度Q学习实例
      • 8.5.1 主程序
        • 程序注释
      • 8.5.2 DQN模型构建程序
        • 程序注释
      • 8.5.3 程序测试
    • 8.6 双重深度Q网络
    • 8.7 对偶深度Q网络

8. 深度Q学习

深度Q学习将深度学习和强化学习相结合,是第一个深度强化学习算法。深度Q学习的核心就是用一个人工神经网络q(s,a;θ),s∈S,a∈Aq(s,a;\theta),s∈\mathcal{S},a∈\mathcal{A}q(s,a;θ),s∈S,a∈A来代替动作价值函数。其中θ\thetaθ为神经网络权重,在前面文章中,也使用过w\text{w}w。由于神经网络具有强大的表达能力,能够自动寻找特征,所以采用神经网络有潜力比传统人工特征强大得多。最近基于深度Q网络的深度强化学习算法有了重大的进展,在目前学术界有非常大的影响力。当同时出现异策、自益和函数近似时,无法保证收敛性,会出现训练不稳定或训练困难等问题。针对出现的各种问题,研究人员主要从以下两方面进行了改进。

  • 经验回放(experience replay):将经验(即历史的状态、动作、奖励等)存储起来,再在存储的经验中按一定的规则采样。
  • 目标网络(target network):修改网络的更新方式,例如不把刚学习到的网络权重马上用于后续的自益过程。本节后续内容将从这两条主线出发,介绍基于深度Q网络的强化学习算法。

8.1 经验回放

V. Mnih 等在 2013 年 发 表 文 章 《Playing Atari with deep reinforcement leaming》,提出了基于经验回放的深度Q网络,标志着深度Q网络的诞生,也标志着深度强化学习的诞生1

采用批处理的模式能够提供稳定性。经验回放就是一种让经验的概率分布变得稳定的技术,它能提高训练的稳定性。
经验回放主要有“存储”和 “采样回放”两大关键步骤。其相关算法在之后会介绍, 现在主要来看其特征。

  • 存储:将轨迹以(St,At,Rt+1,St+1)(S_t,A_t, R_{t+1}, S_{t+1})(St​,At​,Rt+1​,St+1​) 等形式存储起来;
  • 采样回放:使用某种规则从存储的(St,At,Rt+1,St+1)(S_t,A_t, R_{t+1}, S_{t+1})(St​,At​,Rt+1​,St+1​) 中随机取出一条或多条经验。

经验回放有以下好处。

  • 在训练Q网络时,可以消除数据的关联,使得数据更像是独立同分布的(独立同分布是很多有监督学习的证明条件)。这样可以减小参数更新的方差,加快收敛。
  • 能够重复使用经验,对于数据获取困难的情况尤其有用。从存储的角度,经验回放可以分为集中式回放和分布式回放。

回放可以分为以下几种,

  • 集中式回放:智能体在一个环境中运行,把经验统一存储在经验池中。
  • 分布式回放:智能体的多份拷贝(worker)同时在多个环境中运行,并将经验统一存 储于经验池中。由于多个智能体拷贝同时生成经验,所以能够在使用更多资源的同
    时更快地收集经验。从采样的角度,经验回放可以分为均匀回放和优先回放。
  • 均匀回放:等概率从经验集中取经验,并且用取得的经验来更新最优价值函数。
  • 优先回放(PrioritizedExperienceReplay, PER): 为经验池里的每个经验指定一个优
    先级,在选取经验时更倾向于选择优先级高的经验。

T. Schaul等 于 2016年发表文章《Prioritized experience replay》,提出了优先回放。优先回放的基本思想是为经验池里的经验指定一个优先级,在选取经验时更倾向于选择优先级高的经验。一般的做法是,如果某个经验(例如经验iii)的优先级为pip_ipi​,那么选取该经验的概率为
pi=pi∑kpkp_i = \frac{p_i}{\sum_{k} p_k}pi​=∑k​pk​pi​​

经验值有许多不同的选取方法,最常见的选取方法有成比例优先基于排序优先

  • 成比例优先(proportional priority):第iii个经验的优先级为
    pi=(δi+ε)αp_i = {(\delta_i + \varepsilon)^{\alpha}}pi​=(δi​+ε)α
    其中δi\delta_iδi​是时序差分误差,ε\varepsilonε是预先选择的一个小正数,α\alphaα是正参数。
  • 基于排序优先(rank-basedpriority):第iii个经验的优先级为
    pi=(1ranki)αp_i = (\frac{1}{\text{rank}_{i}})^{\alpha}pi​=(ranki​1​)α
    其中ranki\text{rank}_{i}ranki​是第iii个经验从大到小排序的排名, 排名从1开始。

经验回放也不是完全没有缺点。例如,它也会导致回合更新和多步学习算法无法使用。一般情况下,如果我们将经验回放用于Q学习,就规避了这个缺点。

8.2 目标网络

对于基于自益的Q学习,其回报的估计和动作价值的估计都和权重θ\thetaθ有关。当权重值变化时,回报的估计和动作价值的估计都会变化。在学习的过程中,动作价值试图追逐一个变化的回报,也容易出现不稳定的情况。可以使用之前介绍的半梯度下降的算法来解决这个问题。在半梯度下降中,在更新价值参数θ\thetaθ时,不对基于自益得到的回报估计UtU_{t}Ut​求梯度。其中一种阻止对UtU_tUt​求梯度的方法就是将价值参数复制一份得到θtarget\theta_{\text{target}}θtarget​, 在计算UtU_tUt​时用θtarget\theta_{\text{target}}θtarget​目标计算。

基于这一方法,V. Mnih等 在 2015年发表了论文《Human-level control through deep reinforcement learning》提出了目标网络(target network) 这一概念。 目标网络是在原有的神经网络之外再搭建一份结构完全相同的网络。原先就有的神经网络称为评估网络( evaluation network)。在学习的过程中,使用目标网络来进行自益得到回报的评估值,作 为学习的目标。在权重更新的过程中,只更新评估网络的权重,而不更新目标网络的权重。这样,更新权重时针对的目标不会在每次迭代都变化,是一个固定的目标。在完成一定次数的更新后,再将评估网络的权重值赋给目标网络,进而进行下一批更新。这样,目标网络也能得到更新。由于在目标网络没有变化的一段时间内回报的估计是相对固定的,目标网络的引入增加了学习的稳定性。所以,目标网络目前已经成为深度Q学习的主流做法。

8.3 相关算法

现在我们考虑使用深度Q学习算法来训练智能体玩游戏2

在每一个时间步骤中,智能体从游戏动作集A=1,...K\mathcal{A} = {1, ... K}A=1,...K中选择一个动作。该动作被传递给模拟器并修改其内部状态和游戏分数。在一般情况下,环境可能是随机的。仿真器的内部状态不被智能体观察到,相反,智能体观察到一个来自仿真器的图像xt∈Rdx_t\in \mathbb{R}^dxt​∈Rd,这是一个代表当前屏幕的像素值的向量。此外,它还会收到代表游戏分数变化的奖励 rtr_trt​。需要注意的是,一般情况下,游戏得分可能取决于之前的整个动作和观察序列;关于一个动作的反馈可能只有在经过数千次的时间步长之后才会收到。

由于智能体只能观察当前屏幕,任务是部分观察,许多模拟器状态在感知上是异构的(即不可能只从当前屏幕xtx_txt​中完全了解当前情况)。因此,动作和观察的序列st=x1,a1,x2,...,at−1,xts_t = x_1,a_1,x_2,...,a_{t-1},x_tst​=x1​,a1​,x2​,...,at−1​,xt​ 被输入到算法中,然后算法根据这些序列学习游戏策略。仿真器中的所有序列都被假定为在有限的时间步长内终止。这个形式化的过程产生了一个大而有限的马尔科夫决策过程(MDP),在这个过程中,每个序列都是一个独立的状态。因此,我们可以将标准的强化学习方法应用于MDP,只需将完整序列sts_tst​作为时间ttt的状态表示即可。

智能体的任务是在模拟器中选择最佳的动作最大化未来的损失.我们做一个标准的假设,对未来的每一步回报采用一个折扣因子γ\gammaγ(γ\gammaγ从始至终设置为0.99),然后定义了在时间ttt上经过折扣后的回报Rt=∑t′=tTγt′−trt′R_t = \sum_{t'=t}^{T}\gamma^{t'-t}r_{t'}Rt​=∑t′=tT​γt′−trt′​,其中TTT为最终停止的时间步。我们定义最佳动作价值函数Q∗(s,a)Q^*(s, a)Q∗(s,a)作为遵循任何策略所能获得的最大预期收益。在经过一些状态sss和采取一些动作aaa后,Q∗(s,a)=max⁡πE[Rt∣st=s,at=a,π]Q^*(s, a) = \max_{\pi}\mathbb{E}[R_t|s_t = s, a_t =a, \pi]Q∗(s,a)=maxπ​E[Rt​∣st​=s,at​=a,π],其中π\piπ作为在状态sss采取的动作aaa的映射,即策略。

最优行为价值函数遵循一个重要的恒等式,这个恒等式被称为贝尔曼方程(Bellman equation)。这基于以下直觉:如果状态s′s's′在下一个时间步的最优值Q∗(s′,a′)Q^*(s', a')Q∗(s′,a′)对于所有可能的行动a′a'a′都已知,那么最优策略就是选择使期望值r+γQ∗(s′,a′)r + \gamma Q^*(s', a')r+γQ∗(s′,a′)最大化的行动a′a'a′:
Q∗(s,a)=Es′[r+γmax⁡a′Q∗(s′,a′)∣s,a]Q^*(s, a) = \mathbb{E}_{s'}[r + \gamma \max_{a'}Q^*(s', a')|s, a]Q∗(s,a)=Es′​[r+γa′max​Q∗(s′,a′)∣s,a]

许多强化学习算法背后的基本思想是通过使用贝尔曼方程作为迭代更新来估计动作价值函数,Qi+1(s,a)=Es′[r+γmax⁡a′Qi(s′,a′)∣s,a]Q_{i+1}(s, a) = \mathbb{E}_{s'}[r + \gamma \max_{a'}Q_{i}(s', a')|s, a]Qi+1​(s,a)=Es′​[r+γmaxa′​Qi​(s′,a′)∣s,a]。这些价值迭代算法都收敛于最优动作价值函数,当i→∞i\to \infini→∞时Qi→Q∗Q_i \to Q^*Qi​→Q∗。在实践中,这种基本的方法是不切实际的,因为动作-价值函数是对每个状态分别估计的,没有任何泛化。相反,通常使用函数逼近器来估计动作价值函数Q(s,a;θ)≈Q∗(s,a)Q(s, a;\theta) \approx Q^*(s, a)Q(s,a;θ)≈Q∗(s,a)。在强化学习中这是典型的线性函数逼近器,但是有时用非线性函数逼近器代替,如神经网络。我们把带有权值θ\thetaθ的神经网络函数逼近器称为Q网络。Q网络可以通过在迭代iii中调整参数θi\theta_iθi​来训练减少贝尔曼方程中的均方误差, 其中最佳目标值r+γmax⁡a′Q∗(s′,a′)r+\gamma \max_{a'}Q^*(s', a')r+γmaxa′​Q∗(s′,a′)被替代为近似目标值y=r+γmax⁡a′Q(s′,a′;θi−)y =r+\gamma \max_{a'}Q(s', a';\theta_i^{-})y=r+γmaxa′​Q(s′,a′;θi−​),其使用先前的一些迭代中的参数θi−\theta_{i}^{-}θi−​。这就产生了一个损失函数Li(θi)L_i(\theta_i)Li​(θi​)的序列,它在每次迭代iii时发生变化,

Li(θi)=Es,a,r[(Es′[y∣s,a]−Q(s,a;θi))2]=Es,a,r,s′[(y−Q(s,a;θi))2]+Es,a,r[Vs′[y]]\begin{aligned}L_i(\theta_i) & = \mathbb{E}_{s, a,r} [(\mathrm{E}_{s'}[y|s,a] - Q(s,a;\theta_i))^2] \\ &= \mathbb{E}_{s, a,r, s'}[(y - Q(s, a;\theta_i))^2]+ \mathrm{E}_{s, a, r}[\mathrm{V}_{s'}[y]]\end{aligned}Li​(θi​)​=Es,a,r​[(Es′​[y∣s,a]−Q(s,a;θi​))2]=Es,a,r,s′​[(y−Q(s,a;θi​))2]+Es,a,r​[Vs′​[y]]​

请注意,目标取决于网络权重;这与用于监督学习的目标不同,后者在学习开始前是固定的。在优化的每一个阶段,我们在优化第iii个损失函数Li(θi)L_i(\theta_i)Li​(θi​)时,保持上一次迭代的参数θi−\theta_{i}^-θi−​固定,从而产生一系列定义明确的优化问题。最后一项是目标的方差,它不依赖于我们当前优化的参数θi\theta_iθi​,因此可以忽略。将损失函数相对于权重进行微分,我们得出以下梯度:

∇θiL(θi)=Es,a,r,s′[(r+γmax⁡a′Q(s′,a′;θi−)−Q(s,a;θi))∇θiQ(s,a;θi))]\nabla_{\theta_i}L(\theta_i) = \mathbb{E}_{s,a,r,s'}[(r+\gamma\max_{a'}Q(s',a';\theta_{i}^-)-Q(s, a;\theta_i))\nabla_{\theta_i}Q(s, a;\theta_i))]∇θi​​L(θi​)=Es,a,r,s′​[(r+γa′max​Q(s′,a′;θi−​)−Q(s,a;θi​))∇θi​​Q(s,a;θi​))]

与其计算上述梯度中的全部期望值,不如通过随机梯度下降来优化损失函数,这通常是计算上的便利。在这个框架中,通过在每一个时间步长后更新权重,使用单样本替换期望值,并设置θi−=θi−1\theta_{i}^- = \theta_{i-1}θi−​=θi−1​,可以恢复熟悉的Q-learning算法。

需要注意的是,这个算法是无模型的:它直接使用仿真器的样本来解决强化学习任务,而不需要明确地估计奖赏和过渡动态P(r,s′∣s,a)P(r, s'|s, a)P(r,s′∣s,a).它也是off-policy:它学习贪婪的策略a=arg max⁡a′Q(s,a′;θ)a = \argmax_{a'}Q(s,a';\theta)a=a′argmax​Q(s,a′;θ),以确保充分探索状态空间。在实际工作中,行为分布往往由ε\varepsilonε-greedy策略选择,遵循概率1−ε1-\varepsilon1−ε贪婪策略,选择概率ε\varepsilonε的随机行动。

8.4 训练算法

训练深度Q-网络的完整算法在下图所示的算法1中提出。智能体根据基于Q表的ε\varepsilonε-贪婪策略选择和执行动作。 由于使用任意长度的历史作为神经网络的输入可能是困难的,Q函数因此工作在由上述函数ϕ\phiϕ产生的固定长度的历史表征上。该算法以两种方式修改了标准的在线Q-learning,使其适用于训练大型神经网络而不产生分歧。

首先,这里使用了经验回放,我们将智能体在每个时间步的经验et=(st,at,rt,st+1)e_t=(s_t,a_t,r_t,s_{t+1})et​=(st​,at​,rt​,st+1​)存储在一个数据集Dt=e1,...,etD_t={e_1,...,e_t}Dt​=e1​,...,et​中,将许多情节(其中一个情节的结束发生在达到终端状态时)汇集到重放存储器中。在算法的内循环过程中,我们对从存储样本池中随机抽取的经验样本(s,a,r,s′)∼U(D)(s,a,r,s') \thicksim U(D)(s,a,r,s′)∼U(D)进行Q-learning更新,或称minibatch更新。这种方法比标准的在线Q-learning有几个优势。

  • 第一,每一步的经验都有可能被用于许多权重更新,这使得数据效率更高。
  • 第二,直接从连续的样本中学习是低效的,因为样本之间有很强的相关性;随机化样本可以打破这些相关性,从而降低更新的方差。
  • 第三,在对策略进行学习时,当前的参数决定了参数训练的下一个数据样本。例如,如果最大化动作是向左移动,那么训练样本将以左手边的样本为主;如果最大化动作随后切换到右边,那么训练分布也将切换。

很容易看出,不需要的反馈循环可能会出现,参数可能会被卡在一个糟糕的局部最小值中,甚至是灾难性的偏离。通过使用经验重放,行为分布是对其以前的许多状态进行平均,平滑学习,避免参数的振荡或发散。需要注意的是,通过经验重放学习时,需要进行off-policy学习(因为我们当前的参数与用于生成样本的参数不同),这也是选择Q-learning的动机。

在实践中,算法只在重放存储器中存储最后的NNN个经验元组,并在执行更新时从DDD中随机均匀取样。这种方法在某些方面是有局限性的,因为内存缓冲区并不能区分重要的转折,而且由于内存大小NNN是有限的,所以总是用最近的转折来覆盖。同样,均匀采样对重放内存中的所有转折给予同等的重要性。

对在线Q-learning的第二个修改旨在进一步提高方法与神经网络的稳定性,就是在Q-learning更新中使用一个单独的网络来生成目标yjy_jyj​,即,建立目标网络。更准确的说,每一次C更新,我们都会克隆网络Q,得到一个目标网络Q^\hat{Q}Q^​,并使用Q^\hat{Q}Q^​来生成Q-learning目标yjy_jyj​,用于后续C更新Q。与标准的在线Q-learning相比,这种修改使得算法更加稳定,在标准的在线Q-learning中,增加Q(st,at)Q(s_t,a_t)Q(st​,at​)的更新往往也会增加所有aaa的Q(st+1,a)Q(s_{t+1},a)Q(st+1​,a),因此也会增加目标yjy_jyj​,可能会导致策略的振荡或分歧。使用较旧的参数集生成目标,在对Q进行更新和更新影响目标yjy_jyj​之间增加了一个延迟,使得分歧或振荡的可能性大大降低。

将更新r+γmax⁡a′Q(s′,a′;θi−)−Q(s,a;,θi)r+\gamma \max_{a'}Q(s',a';\theta_{i}^{-})-Q(s, a;,\theta_i)r+γmaxa′​Q(s′,a′;θi−​)−Q(s,a;,θi​)中的误差项约束为-1和1之间是很有帮助的.因为绝对值损失函数∣x∣|x|∣x∣对x的所有负值都有-1的导数,对x的所有正值都有1的导数,所以将平方误差剪裁为−1-1−1和111之间相当于对(−1,1)(-1,1)(−1,1)区间外的误差使用绝对值损失函数.这种形式的误差剪裁进一步提高了算法的稳定性.

8.5 深度Q学习实例

在这个实例里我们采用"LunarLander-v2"环境。

LunarLander-v2 着陆台总是在坐标(0,0)处。坐标是状态向量的前两个数字。从屏幕顶部移动到着陆台并以零速度降落的奖励大约是100…140点。如果着陆器远离着陆台,就会失去奖励。如果着陆器坠毁或静止,则事件结束,获得额外的-100或+100分。每条腿的地面接触是+10。发射主引擎每格为-0.3分。解决了就是+200分。可以在起落架外降落。燃料是无限的,所以智能体可以学习飞行,然后在第一次尝试降落。有四个离散动作可供选择:什么都不做、发射左方位引擎、发射主引擎、发射右方位引擎。

8.5.1 主程序

这是我们的主程序, 在其中我们建立相关环境并调用了子函数来建立深度强化网络模型进行训练。神经网络框架采用pytorch, 神经网络部分简单的采用3层全连接神经网络。

import gym
import random
import torch
import numpy as np
from collections import deque
import matplotlib.pyplot as plt
from dqn_agent import Agent
import osdef dqn(n_episode=2000, max_t=1000, eps_start=1.0, eps_end=0.01, eps_decay=0.995, mode='train'):"""Deep Q-Learning:param n_episode:maximum number of training episodes:param max_t:maximum number of timesteps per episode:param eps_start:starting value of epsilon, for epsilon-greedy action selection:param eps_end:minimum value of epsilon:param eps_decay:multiplicative factor (per episode) for decreasing epsilon:return: final score"""scores = []scores_window = deque(maxlen=100)eps = eps_startif mode == 'train':for i_episode in range(1, n_episode+1):# 初始化状态state = env.reset()score = 0for t in range(max_t):action = agent.act(state, eps)next_state, reward, done, _ = env.step(action)agent.step(state, action, reward, next_state, done)state = next_statescore += rewardif done:breakscores_window.append(score)scores.append(score)eps = max(eps_end, eps_decay*eps)print('\rEpisode {}\t Average Score:{:.2f}'.format(i_episode, np.mean(scores_window)), end="")if i_episode % 100 == 0:print('\rEpisode {}\rAverage Score :{:.2f}'.format(i_episode, np.mean(scores_window)))if np.mean(scores_window) >= 200.0:print('\nEnvironment solved in {:d} episode! \t Average Score: {:.2f}'.format(i_episode, np.mean(scores_window)))torch.save(agent.qnetwork_local.state_dict(), 'checkpoint.pth')breakelse:# 训练一次state = env.reset()for j in range(200):action = agent.act(state, eps)print('state :{} action :{}'. format(state, action))env.render()next_state, reward, done, _ = env.step(action)print('next_state={}, reward={}, done={}'.format(next_state, reward, done))agent.step(state, action, reward, next_state, done)if done:breakreturn scoresif __name__ == '__main__':os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'env = gym.make('LunarLander-v2')env.seed(0)print('State shape: ', env.observation_space.shape)print('Number of actions: ', env.action_space.n)MODE = 'train'if MODE == 'debug':# 调试模式agent = Agent(state_size=8, action_size=4, seed=1,debug_mode=True)scores = dqn(mode='test')elif MODE == 'run':agent = Agent(state_size=8, action_size=4, seed=1)agent.qnetwork_local.load_state_dict(torch.load('checkpoint.pth'))# 以当前策略运行for i in range(3):state = env.reset()for j in range(200):action = agent.act(state)env.render()state, reward, done, _ = env.step(action)if done:breakenv.close()else:# 训练模式agent = Agent(state_size=8, action_size=4, seed=1)scores = dqn()# plot the scoresfig = plt.figure()ax = fig.add_subplot(111)plt.plot(np.arange(len(scores)), scores)plt.ylabel('Score')plt.xlabel('Episode #')plt.show()
程序注释
  MODE = 'train'if MODE == 'debug':# 调试模式agent = Agent(state_size=8, action_size=4, seed=1,debug_mode=True)scores = dqn(mode='test')elif MODE == 'run':agent = Agent(state_size=8, action_size=4, seed=1)agent.qnetwork_local.load_state_dict(torch.load('checkpoint.pth'))...else:# 训练模式agent = Agent(state_size=8, action_size=4, seed=1)scores = dqn()

在这里提供了程序运行的三种模式,“debug”, “run”, "train"模式。debug模式是为了方便查看在程序运行过程中的各种参数,方便程序调试和后期更改而设置的。run模式是在模型训练完成后可以使用训练完成的神经网络来查看最终效果。 train模式即训练模式,没有太多相关数据输出。

8.5.2 DQN模型构建程序

这部分为模型构建子程序,包含了DQN最重要的算法。程序包含了3个类, 分别是class QNetwork, class Agentclass ReplayBuffer

import numpy as np
import random
from collections import namedtuple, dequeimport torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optimBUFFER_SIZE = int(1e4)  # 经验回放的缓冲区的大小
BATCH_SIZE = 64  # 最小训练批数量
GAMMA = 0.99  # 折扣率
TAU = 1e-3  # 用于目标函数的柔性策略更新
LR = 5e-4  # 学习率
UPDATE_EVERY = 4  # 更新网络的频率device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")class QNetwork(nn.Module):"""Actor (Policy) Model."""def __init__(self, state_size, action_size, seed, fc1_units=64, fc2_units=64):"""Initialize parameters and build model.Params======state_size (int): Dimension of each stateaction_size (int): Dimension of each actionseed (int): Random seed"""super(QNetwork, self).__init__()self.seed = torch.manual_seed(seed)self.fc1 = nn.Linear(state_size, fc1_units)self.fc2 = nn.Linear(fc1_units, fc2_units)self.fc3 = nn.Linear(fc2_units, action_size)def forward(self, state):"""Build a network that maps state -> action values."""x = F.relu(self.fc1(state))x = F.relu(self.fc2(x))return self.fc3(x)class Agent():"""与环境相互作用,从环境中学习。"""def __init__(self, state_size, action_size, seed, debug_mode=False):"""初始化智能体对象。Params======state_size (int): dimension of each stateaction_size (int): dimension of each actionseed (int): random seed"""self.state_size = state_sizeself.action_size = action_sizeself.seed = random.seed(seed)self.debug_mode = debug_modeprint('Program running in {}'.format(device))# Q-Networkself.qnetwork_local = QNetwork(state_size, action_size, seed).to(device)self.qnetwork_target = QNetwork(state_size, action_size, seed).to(device)self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=LR)  # 自适应梯度算法# print('Q-Network_local:{}\nQ-Network_target:{}'.format(self.qnetwork_local, self.qnetwork_target))# 经验回放if self.debug_mode is True:self.memory = ReplayBuffer(action_size, BUFFER_SIZE, 1, seed)else:self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed)# 初始化时间步 (for updating every UPDATE_EVERY steps)self.t_step = 0def step(self, state, action, reward, next_state, done):# 在经验回放中保存经验self.memory.add(state, action, reward, next_state, done)# 在每个时间步UPDATE_EVERY中学习self.t_step = (self.t_step + 1) % UPDATE_EVERYif self.t_step == 0:# 如果内存中有足够的样本,取随机子集进行学习if len(self.memory) > BATCH_SIZE:experiences = self.memory.sample()self.learn(experiences, GAMMA)if self.debug_mode is True:experiences = self.memory.sample()self.learn(experiences, GAMMA)def act(self, state, eps=0.):"""根据当前策略返回给定状态的操作.Params======state (array_like): 当前的状态eps (float): epsilon, 用于 epsilon-greedy action selection"""state = torch.from_numpy(state).float().unsqueeze(0).to(device)# 将qn更改成评估形式self.qnetwork_local.eval()# 禁用梯度with torch.no_grad():# 获得动作价值action_values = self.qnetwork_local(state)# 将qn更改成训练模式self.qnetwork_local.train()# Epsilon-greedy action selectionif random.random() > eps:return np.argmax(action_values.cpu().data.numpy())else:return random.choice(np.arange(self.action_size))def learn(self, experiences, gamma):"""使用给定的一批经验元组更新值参数。Params======experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples gamma (float): discount factor"""states, actions, rewards, next_states, dones = experiencesif self.debug_mode is True:print('\nstates={}, actions={}, rewards={}, next_states={}, dones={}'.format(states, actions, rewards, next_states, dones))# compute and minimize the loss# 从目标网络得到最大的预测Q值(下一个状态)Q_targets_next = self.qnetwork_target(next_states).detach().max(1)[0].unsqueeze(1)# 计算当前状态的Q目标Q_targets = rewards + (gamma * Q_targets_next * (1 - dones))# 从评估网络中获得期望的Q值Q_expected = self.qnetwork_local(states).gather(1, actions)if self.debug_mode is True:print('Q_target_next={}, \nQ_target ={}, \nQ_expected={},'.format(Q_targets_next, Q_targets, Q_expected))# Compute lossloss = F.mse_loss(Q_expected, Q_targets)# Minimize the lossself.optimizer.zero_grad()loss.backward()# 执行单个优化步骤self.optimizer.step()# ------------------- update target network ------------------- #self.soft_update(self.qnetwork_local, self.qnetwork_target, TAU)def soft_update(self, local_model, target_model, tau):""":柔性更新模型参数。θ_target = τ*θ_local + (1 - τ)*θ_targetParams======local_model (PyTorch model): weights will be copied fromtarget_model (PyTorch model): weights will be copied totau (float): 插值参数"""for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):# 柔性更新, 将src中数据复制到self中target_param.data.copy_(tau * local_param.data + (1.0 - tau) * target_param.data)class ReplayBuffer:"""Fixed-size buffer to store experience tuples."""def __init__(self, action_size, buffer_size, batch_size, seed):"""Initialize a ReplayBuffer object.Params======action_size (int): dimension of each actionbuffer_size (int): maximum size of bufferbatch_size (int): size of each training batchseed (int): random seed"""self.action_size = action_sizeself.memory = deque(maxlen=buffer_size)self.batch_size = batch_sizeself.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])self.seed = random.seed(seed)def add(self, state, action, reward, next_state, done):"""在memory中添加一段新的经验."""e = self.experience(state, action, reward, next_state, done)self.memory.append(e)def sample(self):"""从memory中随机抽取一批经验."""experiences = random.sample(self.memory, k=self.batch_size)states = torch.from_numpy(np.vstack([e.state for e in experiences if e is not None])).float().to(device)actions = torch.from_numpy(np.vstack([e.action for e in experiences if e is not None])).long().to(device)rewards = torch.from_numpy(np.vstack([e.reward for e in experiences if e is not None])).float().to(device)next_states = torch.from_numpy(np.vstack([e.next_state for e in experiences if e is not None])).float().to(device)dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to(device)return (states, actions, rewards, next_states, dones)def __len__(self):"""Return the current size of internal memory."""return len(self.memory)
程序注释

class QNetwork类构建了三层的神经网络模型,class ReplayBuffer类定义了关于经验回访的相关功能。class Agent是最重要的类,它调用了class QNetworkclass ReplayBuffer来创建DQN模型。所以我们主要看一下class Agent的相关函数和功能。


        self.state_size = state_sizeself.action_size = action_sizeself.seed = random.seed(seed)self.debug_mode = debug_mode# Q-Networkself.qnetwork_local = QNetwork(state_size, action_size, seed).to(device)self.qnetwork_target = QNetwork(state_size, action_size, seed).to(device)self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=LR)  # 自适应梯度算法

初始化智能体对象,并构建神经网络。在这里我们需要建立两个神经网络,其中“qnetwork_local”作为训练使用的神经网络,在此之外,我们还要建立qnetwork_target目标网络,来优化我们的训练过程。在这里使用了“自适应梯度算法”来作为神经网络的优化器。


        # 经验回放if self.debug_mode is True:self.memory = ReplayBuffer(action_size, BUFFER_SIZE, 1, seed)else:self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed)

根据相关的模式来建立经验回放功能的类。


 def step(self, state, action, reward, next_state, done):...def act(self, state, eps=0.):...def learn(self, experiences, gamma):...def soft_update(self, local_model, target_model, tau):

这些是在训练过程中使用到的函数,它们的功能如下所示。其作用是方便与理解,其关系并不是完全如图所示的流线型关系。例如,soft_update函数是在learn函数中调用的的一个函数,其关系并不算是线性的。

#mermaid-svg-0SBzETJSKcSPGAql .label{font-family:'trebuchet ms', verdana, arial;font-family:var(--mermaid-font-family);fill:#333;color:#333}#mermaid-svg-0SBzETJSKcSPGAql .label text{fill:#333}#mermaid-svg-0SBzETJSKcSPGAql .node rect,#mermaid-svg-0SBzETJSKcSPGAql .node circle,#mermaid-svg-0SBzETJSKcSPGAql .node ellipse,#mermaid-svg-0SBzETJSKcSPGAql .node polygon,#mermaid-svg-0SBzETJSKcSPGAql .node path{fill:#ECECFF;stroke:#9370db;stroke-width:1px}#mermaid-svg-0SBzETJSKcSPGAql .node .label{text-align:center;fill:#333}#mermaid-svg-0SBzETJSKcSPGAql .node.clickable{cursor:pointer}#mermaid-svg-0SBzETJSKcSPGAql .arrowheadPath{fill:#333}#mermaid-svg-0SBzETJSKcSPGAql .edgePath .path{stroke:#333;stroke-width:1.5px}#mermaid-svg-0SBzETJSKcSPGAql .flowchart-link{stroke:#333;fill:none}#mermaid-svg-0SBzETJSKcSPGAql .edgeLabel{background-color:#e8e8e8;text-align:center}#mermaid-svg-0SBzETJSKcSPGAql .edgeLabel rect{opacity:0.9}#mermaid-svg-0SBzETJSKcSPGAql .edgeLabel span{color:#333}#mermaid-svg-0SBzETJSKcSPGAql .cluster rect{fill:#ffffde;stroke:#aa3;stroke-width:1px}#mermaid-svg-0SBzETJSKcSPGAql .cluster text{fill:#333}#mermaid-svg-0SBzETJSKcSPGAql div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:'trebuchet ms', verdana, arial;font-family:var(--mermaid-font-family);font-size:12px;background:#ffffde;border:1px solid #aa3;border-radius:2px;pointer-events:none;z-index:100}#mermaid-svg-0SBzETJSKcSPGAql .actor{stroke:#ccf;fill:#ECECFF}#mermaid-svg-0SBzETJSKcSPGAql text.actor>tspan{fill:#000;stroke:none}#mermaid-svg-0SBzETJSKcSPGAql .actor-line{stroke:grey}#mermaid-svg-0SBzETJSKcSPGAql .messageLine0{stroke-width:1.5;stroke-dasharray:none;stroke:#333}#mermaid-svg-0SBzETJSKcSPGAql .messageLine1{stroke-width:1.5;stroke-dasharray:2, 2;stroke:#333}#mermaid-svg-0SBzETJSKcSPGAql #arrowhead path{fill:#333;stroke:#333}#mermaid-svg-0SBzETJSKcSPGAql .sequenceNumber{fill:#fff}#mermaid-svg-0SBzETJSKcSPGAql #sequencenumber{fill:#333}#mermaid-svg-0SBzETJSKcSPGAql #crosshead path{fill:#333;stroke:#333}#mermaid-svg-0SBzETJSKcSPGAql .messageText{fill:#333;stroke:#333}#mermaid-svg-0SBzETJSKcSPGAql .labelBox{stroke:#ccf;fill:#ECECFF}#mermaid-svg-0SBzETJSKcSPGAql .labelText,#mermaid-svg-0SBzETJSKcSPGAql .labelText>tspan{fill:#000;stroke:none}#mermaid-svg-0SBzETJSKcSPGAql .loopText,#mermaid-svg-0SBzETJSKcSPGAql .loopText>tspan{fill:#000;stroke:none}#mermaid-svg-0SBzETJSKcSPGAql .loopLine{stroke-width:2px;stroke-dasharray:2, 2;stroke:#ccf;fill:#ccf}#mermaid-svg-0SBzETJSKcSPGAql .note{stroke:#aa3;fill:#fff5ad}#mermaid-svg-0SBzETJSKcSPGAql .noteText,#mermaid-svg-0SBzETJSKcSPGAql .noteText>tspan{fill:#000;stroke:none}#mermaid-svg-0SBzETJSKcSPGAql .activation0{fill:#f4f4f4;stroke:#666}#mermaid-svg-0SBzETJSKcSPGAql .activation1{fill:#f4f4f4;stroke:#666}#mermaid-svg-0SBzETJSKcSPGAql .activation2{fill:#f4f4f4;stroke:#666}#mermaid-svg-0SBzETJSKcSPGAql .mermaid-main-font{font-family:"trebuchet ms", verdana, arial;font-family:var(--mermaid-font-family)}#mermaid-svg-0SBzETJSKcSPGAql .section{stroke:none;opacity:0.2}#mermaid-svg-0SBzETJSKcSPGAql .section0{fill:rgba(102,102,255,0.49)}#mermaid-svg-0SBzETJSKcSPGAql .section2{fill:#fff400}#mermaid-svg-0SBzETJSKcSPGAql .section1,#mermaid-svg-0SBzETJSKcSPGAql .section3{fill:#fff;opacity:0.2}#mermaid-svg-0SBzETJSKcSPGAql .sectionTitle0{fill:#333}#mermaid-svg-0SBzETJSKcSPGAql .sectionTitle1{fill:#333}#mermaid-svg-0SBzETJSKcSPGAql .sectionTitle2{fill:#333}#mermaid-svg-0SBzETJSKcSPGAql .sectionTitle3{fill:#333}#mermaid-svg-0SBzETJSKcSPGAql .sectionTitle{text-anchor:start;font-size:11px;text-height:14px;font-family:'trebuchet ms', verdana, arial;font-family:var(--mermaid-font-family)}#mermaid-svg-0SBzETJSKcSPGAql .grid .tick{stroke:#d3d3d3;opacity:0.8;shape-rendering:crispEdges}#mermaid-svg-0SBzETJSKcSPGAql .grid .tick text{font-family:'trebuchet ms', verdana, arial;font-family:var(--mermaid-font-family)}#mermaid-svg-0SBzETJSKcSPGAql .grid path{stroke-width:0}#mermaid-svg-0SBzETJSKcSPGAql .today{fill:none;stroke:red;stroke-width:2px}#mermaid-svg-0SBzETJSKcSPGAql .task{stroke-width:2}#mermaid-svg-0SBzETJSKcSPGAql .taskText{text-anchor:middle;font-family:'trebuchet ms', verdana, arial;font-family:var(--mermaid-font-family)}#mermaid-svg-0SBzETJSKcSPGAql .taskText:not([font-size]){font-size:11px}#mermaid-svg-0SBzETJSKcSPGAql .taskTextOutsideRight{fill:#000;text-anchor:start;font-size:11px;font-family:'trebuchet ms', verdana, arial;font-family:var(--mermaid-font-family)}#mermaid-svg-0SBzETJSKcSPGAql .taskTextOutsideLeft{fill:#000;text-anchor:end;font-size:11px}#mermaid-svg-0SBzETJSKcSPGAql .task.clickable{cursor:pointer}#mermaid-svg-0SBzETJSKcSPGAql .taskText.clickable{cursor:pointer;fill:#003163 !important;font-weight:bold}#mermaid-svg-0SBzETJSKcSPGAql .taskTextOutsideLeft.clickable{cursor:pointer;fill:#003163 !important;font-weight:bold}#mermaid-svg-0SBzETJSKcSPGAql .taskTextOutsideRight.clickable{cursor:pointer;fill:#003163 !important;font-weight:bold}#mermaid-svg-0SBzETJSKcSPGAql .taskText0,#mermaid-svg-0SBzETJSKcSPGAql .taskText1,#mermaid-svg-0SBzETJSKcSPGAql .taskText2,#mermaid-svg-0SBzETJSKcSPGAql .taskText3{fill:#fff}#mermaid-svg-0SBzETJSKcSPGAql .task0,#mermaid-svg-0SBzETJSKcSPGAql .task1,#mermaid-svg-0SBzETJSKcSPGAql .task2,#mermaid-svg-0SBzETJSKcSPGAql .task3{fill:#8a90dd;stroke:#534fbc}#mermaid-svg-0SBzETJSKcSPGAql .taskTextOutside0,#mermaid-svg-0SBzETJSKcSPGAql .taskTextOutside2{fill:#000}#mermaid-svg-0SBzETJSKcSPGAql .taskTextOutside1,#mermaid-svg-0SBzETJSKcSPGAql .taskTextOutside3{fill:#000}#mermaid-svg-0SBzETJSKcSPGAql .active0,#mermaid-svg-0SBzETJSKcSPGAql .active1,#mermaid-svg-0SBzETJSKcSPGAql .active2,#mermaid-svg-0SBzETJSKcSPGAql .active3{fill:#bfc7ff;stroke:#534fbc}#mermaid-svg-0SBzETJSKcSPGAql .activeText0,#mermaid-svg-0SBzETJSKcSPGAql .activeText1,#mermaid-svg-0SBzETJSKcSPGAql .activeText2,#mermaid-svg-0SBzETJSKcSPGAql .activeText3{fill:#000 !important}#mermaid-svg-0SBzETJSKcSPGAql .done0,#mermaid-svg-0SBzETJSKcSPGAql .done1,#mermaid-svg-0SBzETJSKcSPGAql .done2,#mermaid-svg-0SBzETJSKcSPGAql .done3{stroke:grey;fill:#d3d3d3;stroke-width:2}#mermaid-svg-0SBzETJSKcSPGAql .doneText0,#mermaid-svg-0SBzETJSKcSPGAql .doneText1,#mermaid-svg-0SBzETJSKcSPGAql .doneText2,#mermaid-svg-0SBzETJSKcSPGAql .doneText3{fill:#000 !important}#mermaid-svg-0SBzETJSKcSPGAql .crit0,#mermaid-svg-0SBzETJSKcSPGAql .crit1,#mermaid-svg-0SBzETJSKcSPGAql .crit2,#mermaid-svg-0SBzETJSKcSPGAql .crit3{stroke:#f88;fill:red;stroke-width:2}#mermaid-svg-0SBzETJSKcSPGAql .activeCrit0,#mermaid-svg-0SBzETJSKcSPGAql .activeCrit1,#mermaid-svg-0SBzETJSKcSPGAql .activeCrit2,#mermaid-svg-0SBzETJSKcSPGAql .activeCrit3{stroke:#f88;fill:#bfc7ff;stroke-width:2}#mermaid-svg-0SBzETJSKcSPGAql .doneCrit0,#mermaid-svg-0SBzETJSKcSPGAql .doneCrit1,#mermaid-svg-0SBzETJSKcSPGAql .doneCrit2,#mermaid-svg-0SBzETJSKcSPGAql .doneCrit3{stroke:#f88;fill:#d3d3d3;stroke-width:2;cursor:pointer;shape-rendering:crispEdges}#mermaid-svg-0SBzETJSKcSPGAql .milestone{transform:rotate(45deg) scale(0.8, 0.8)}#mermaid-svg-0SBzETJSKcSPGAql .milestoneText{font-style:italic}#mermaid-svg-0SBzETJSKcSPGAql .doneCritText0,#mermaid-svg-0SBzETJSKcSPGAql .doneCritText1,#mermaid-svg-0SBzETJSKcSPGAql .doneCritText2,#mermaid-svg-0SBzETJSKcSPGAql .doneCritText3{fill:#000 !important}#mermaid-svg-0SBzETJSKcSPGAql .activeCritText0,#mermaid-svg-0SBzETJSKcSPGAql .activeCritText1,#mermaid-svg-0SBzETJSKcSPGAql .activeCritText2,#mermaid-svg-0SBzETJSKcSPGAql .activeCritText3{fill:#000 !important}#mermaid-svg-0SBzETJSKcSPGAql .titleText{text-anchor:middle;font-size:18px;fill:#000;font-family:'trebuchet ms', verdana, arial;font-family:var(--mermaid-font-family)}#mermaid-svg-0SBzETJSKcSPGAql g.classGroup text{fill:#9370db;stroke:none;font-family:'trebuchet ms', verdana, arial;font-family:var(--mermaid-font-family);font-size:10px}#mermaid-svg-0SBzETJSKcSPGAql g.classGroup text .title{font-weight:bolder}#mermaid-svg-0SBzETJSKcSPGAql g.clickable{cursor:pointer}#mermaid-svg-0SBzETJSKcSPGAql g.classGroup rect{fill:#ECECFF;stroke:#9370db}#mermaid-svg-0SBzETJSKcSPGAql g.classGroup line{stroke:#9370db;stroke-width:1}#mermaid-svg-0SBzETJSKcSPGAql .classLabel .box{stroke:none;stroke-width:0;fill:#ECECFF;opacity:0.5}#mermaid-svg-0SBzETJSKcSPGAql .classLabel .label{fill:#9370db;font-size:10px}#mermaid-svg-0SBzETJSKcSPGAql .relation{stroke:#9370db;stroke-width:1;fill:none}#mermaid-svg-0SBzETJSKcSPGAql .dashed-line{stroke-dasharray:3}#mermaid-svg-0SBzETJSKcSPGAql #compositionStart{fill:#9370db;stroke:#9370db;stroke-width:1}#mermaid-svg-0SBzETJSKcSPGAql #compositionEnd{fill:#9370db;stroke:#9370db;stroke-width:1}#mermaid-svg-0SBzETJSKcSPGAql #aggregationStart{fill:#ECECFF;stroke:#9370db;stroke-width:1}#mermaid-svg-0SBzETJSKcSPGAql #aggregationEnd{fill:#ECECFF;stroke:#9370db;stroke-width:1}#mermaid-svg-0SBzETJSKcSPGAql #dependencyStart{fill:#9370db;stroke:#9370db;stroke-width:1}#mermaid-svg-0SBzETJSKcSPGAql #dependencyEnd{fill:#9370db;stroke:#9370db;stroke-width:1}#mermaid-svg-0SBzETJSKcSPGAql #extensionStart{fill:#9370db;stroke:#9370db;stroke-width:1}#mermaid-svg-0SBzETJSKcSPGAql #extensionEnd{fill:#9370db;stroke:#9370db;stroke-width:1}#mermaid-svg-0SBzETJSKcSPGAql .commit-id,#mermaid-svg-0SBzETJSKcSPGAql .commit-msg,#mermaid-svg-0SBzETJSKcSPGAql .branch-label{fill:lightgrey;color:lightgrey;font-family:'trebuchet ms', verdana, arial;font-family:var(--mermaid-font-family)}#mermaid-svg-0SBzETJSKcSPGAql .pieTitleText{text-anchor:middle;font-size:25px;fill:#000;font-family:'trebuchet ms', verdana, arial;font-family:var(--mermaid-font-family)}#mermaid-svg-0SBzETJSKcSPGAql .slice{font-family:'trebuchet ms', verdana, arial;font-family:var(--mermaid-font-family)}#mermaid-svg-0SBzETJSKcSPGAql g.stateGroup text{fill:#9370db;stroke:none;font-size:10px;font-family:'trebuchet ms', verdana, arial;font-family:var(--mermaid-font-family)}#mermaid-svg-0SBzETJSKcSPGAql g.stateGroup text{fill:#9370db;fill:#333;stroke:none;font-size:10px}#mermaid-svg-0SBzETJSKcSPGAql g.statediagram-cluster .cluster-label text{fill:#333}#mermaid-svg-0SBzETJSKcSPGAql g.stateGroup .state-title{font-weight:bolder;fill:#000}#mermaid-svg-0SBzETJSKcSPGAql g.stateGroup rect{fill:#ECECFF;stroke:#9370db}#mermaid-svg-0SBzETJSKcSPGAql g.stateGroup line{stroke:#9370db;stroke-width:1}#mermaid-svg-0SBzETJSKcSPGAql .transition{stroke:#9370db;stroke-width:1;fill:none}#mermaid-svg-0SBzETJSKcSPGAql .stateGroup .composit{fill:white;border-bottom:1px}#mermaid-svg-0SBzETJSKcSPGAql .stateGroup .alt-composit{fill:#e0e0e0;border-bottom:1px}#mermaid-svg-0SBzETJSKcSPGAql .state-note{stroke:#aa3;fill:#fff5ad}#mermaid-svg-0SBzETJSKcSPGAql .state-note text{fill:black;stroke:none;font-size:10px}#mermaid-svg-0SBzETJSKcSPGAql .stateLabel .box{stroke:none;stroke-width:0;fill:#ECECFF;opacity:0.7}#mermaid-svg-0SBzETJSKcSPGAql .edgeLabel text{fill:#333}#mermaid-svg-0SBzETJSKcSPGAql .stateLabel text{fill:#000;font-size:10px;font-weight:bold;font-family:'trebuchet ms', verdana, arial;font-family:var(--mermaid-font-family)}#mermaid-svg-0SBzETJSKcSPGAql .node circle.state-start{fill:black;stroke:black}#mermaid-svg-0SBzETJSKcSPGAql .node circle.state-end{fill:black;stroke:white;stroke-width:1.5}#mermaid-svg-0SBzETJSKcSPGAql #statediagram-barbEnd{fill:#9370db}#mermaid-svg-0SBzETJSKcSPGAql .statediagram-cluster rect{fill:#ECECFF;stroke:#9370db;stroke-width:1px}#mermaid-svg-0SBzETJSKcSPGAql .statediagram-cluster rect.outer{rx:5px;ry:5px}#mermaid-svg-0SBzETJSKcSPGAql .statediagram-state .divider{stroke:#9370db}#mermaid-svg-0SBzETJSKcSPGAql .statediagram-state .title-state{rx:5px;ry:5px}#mermaid-svg-0SBzETJSKcSPGAql .statediagram-cluster.statediagram-cluster .inner{fill:white}#mermaid-svg-0SBzETJSKcSPGAql .statediagram-cluster.statediagram-cluster-alt .inner{fill:#e0e0e0}#mermaid-svg-0SBzETJSKcSPGAql .statediagram-cluster .inner{rx:0;ry:0}#mermaid-svg-0SBzETJSKcSPGAql .statediagram-state rect.basic{rx:5px;ry:5px}#mermaid-svg-0SBzETJSKcSPGAql .statediagram-state rect.divider{stroke-dasharray:10,10;fill:#efefef}#mermaid-svg-0SBzETJSKcSPGAql .note-edge{stroke-dasharray:5}#mermaid-svg-0SBzETJSKcSPGAql .statediagram-note rect{fill:#fff5ad;stroke:#aa3;stroke-width:1px;rx:0;ry:0}:root{--mermaid-font-family: '"trebuchet ms", verdana, arial';--mermaid-font-family: "Comic Sans MS", "Comic Sans", cursive}#mermaid-svg-0SBzETJSKcSPGAql .error-icon{fill:#522}#mermaid-svg-0SBzETJSKcSPGAql .error-text{fill:#522;stroke:#522}#mermaid-svg-0SBzETJSKcSPGAql .edge-thickness-normal{stroke-width:2px}#mermaid-svg-0SBzETJSKcSPGAql .edge-thickness-thick{stroke-width:3.5px}#mermaid-svg-0SBzETJSKcSPGAql .edge-pattern-solid{stroke-dasharray:0}#mermaid-svg-0SBzETJSKcSPGAql .edge-pattern-dashed{stroke-dasharray:3}#mermaid-svg-0SBzETJSKcSPGAql .edge-pattern-dotted{stroke-dasharray:2}#mermaid-svg-0SBzETJSKcSPGAql .marker{fill:#333}#mermaid-svg-0SBzETJSKcSPGAql .marker.cross{stroke:#333}:root { --mermaid-font-family: "trebuchet ms", verdana, arial;}#mermaid-svg-0SBzETJSKcSPGAql {color: rgba(0, 0, 0, 0.75);font: ;}

当前状态
若经验回放样本数足够
act 根据策略选择动作
step 执行单步操作
learn 更新神经网络参数
soft_update 更新目标网络

8.5.3 程序测试

接下来将模式设置为Mode = train运行程序进行训练,要实现平均分数大于200分的目标,我的电脑需要跑40分钟左右。使用run模式运行模型如下,

8.6 双重深度Q网络

之前曾提到Q学习会带来最大化偏差,而双重Q学习却可以消除最大化偏差。基于查找表的双重Q学习引入了两个动作价值的估计Q(0)Q(0)Q(0)和Q(1)Q(1)Q(1),每次更新动作价值时用其中的一个网络确定动作,用确定的动作和另外一个网络来估计回报。对于深度Q学习也有同样的结论。Deepmind于2015年发表论文《Deepreinforcement learning with double Q-learning》,将双重Q学习用于深度Q网络,得到了双重深度Q网络(Double Deep Q Network,Double DQN)。考虑到深度Q网络已经有了评估网络和目标网络两个网络,所以双重深度Q学习在估计回报时只需要用评估网络确定动作,用目标网络确定回报的估计即可。所以,只需要将
y=r+γmax⁡aQ(s′,a;θtarget)y =r+\gamma \max_{a}Q(s', a;\theta_{target})y=r+γamax​Q(s′,a;θtarget​)
更改为
y=r+γQ(s′,arg max⁡aQ(s′,a;θi);θ−)y =r+\gamma Q(s', \argmax_{a}Q(s', a;\theta_i);\theta^-)y=r+γQ(s′,aargmax​Q(s′,a;θi​);θ−)
就得到了带经验回放的双重深度Q网络算法。

8.7 对偶深度Q网络

Z.Wang等在2015年发表论文《Dueling network architectures for deepreinforcement learning》,提出了一种神经网络的结构——对偶网络(duelnetwork)。对偶网络理论利用动作价值函数和状态价值函数之差定义了一个新的函数——优势函数(advantage function):
A(s,a)=Q(s,a)−V(s,a)A(s,a) = Q(s, a) - V(s, a)A(s,a)=Q(s,a)−V(s,a)
对偶Q网络仍然用Q(θ)Q(\theta)Q(θ)来估计动作价值,只不过这时候Q(θ)Q(\theta)Q(θ)是状态价值估计V(s;θ)V(s;\theta)V(s;θ)和优势函数估计A(s,a;θ)A(s,a;\theta)A(s,a;θ)的叠加,即
Q(s,a;θ)=V(s;θ)+A(s,a;θ)Q(s,a;\theta)=V(s;\theta)+A(s,a;\theta)Q(s,a;θ)=V(s;θ)+A(s,a;θ)

其中V(θ)V(\theta)V(θ)和A(θ)A(\theta)A(θ)可能都只用到了θ\thetaθ中的部分参数。在训练的过程中,V(θ)V(\theta)V(θ)和A(θ)A(\theta)A(θ)是共同训练的,训练过程和单独训练普通深度QQQ网络并无不同之处。


  1. 来自于《强化学习:原理与python实现》 ↩︎

  2. 参考于《Human-level control through deep reinforcement learning》Volodymyr Mnih等 ↩︎

强化学习(八) - 深度Q学习(Deep Q-learning, DQL,DQN)原理及相关实例相关推荐

  1. [Java并发包学习八]深度剖析ConcurrentHashMap

    转载----http://qifuguang.me/2015/09/10/[Java并发包学习八]深度剖析ConcurrentHashMap/ HashMap是非线程安全的,并发情况下使用,可能会导致 ...

  2. 【强化学习】Playing Atari with Deep Reinforcement Learning (2013)

    Playing Atari with Deep Reinforcement Learning (2013) 这篇文章提出了第一个可以直接用强化学习成功学习控制policies的深度学习模型. 输入是r ...

  3. 【机器学习网络】神经网络与深度学习-6 深度神经网络(deep neural Networks DNN)

    目录 深度神经网络(deep neural Networks DNN) DNN的底层原理 深度学习网络的问题: 案例1:书写数字识别(梯度下降法详解) 男女头发长短区分案例(为什么隐藏层追求深度): ...

  4. 强化学习论文分析4---异构网络_强化学习_功率控制《Deep Reinforcement Learning for Multi-Agent....》

    目录 一.文章概述 二.系统目标 三.应用场景 四.算法架构 1.微基站处----DQN 2.宏基站处---Actor-Critic 五.伪代码 六.算法流程图 七.性能表征 1.收敛时间 2.信道总 ...

  5. 整理学习之深度迁移学习

    迁移学习(Transfer Learning)通俗来讲就是学会举一反三的能力,通过运用已有的知识来学习新的知识,其核心是找到已有知识和新知识之间的相似性,通过这种相似性的迁移达到迁移学习的目的.世间万 ...

  6. 迁移学习之深度迁移学习

    深度迁移学习即采用深度学习的方法进行迁移学习,这是当前深度学习的一个比较热门的研究方向. 深度学习方法对非深度方法两个优势: 一.自动化地提取更具表现力的特征: 二.满⾜了实际应用中的端到端 (End ...

  7. 深度残差网络 - Deep Residual Learning for Image Recognition

    CVPR2016 code: https://github.com/KaimingHe/deep-residual-networks 针对CNN网络深度问题,本文提出了一个叫深度残差学习网络,可以使得 ...

  8. 聊天机器人(chatbot)终极指南:自然语言处理(NLP)和深度机器学习(Deep Machine Learning)

    为了这份爱 在过去的几个月中,我一直在收集自然语言处理(NLP)以及如何将NLP和深度学习(Deep Learning)应用到聊天机器人(Chatbots)方面的最好的资料. 时不时地我会发现一个出色 ...

  9. 聊天机器人(chatbot)终极指南:自然语言处理(NLP)和深度机器学习(Deep Machine Learning)...

    在过去的几个月中,我一直在收集自然语言处理(NLP)以及如何将NLP和深度学习(Deep Learning)应用到聊天机器人(Chatbots)方面的最好的资料. 时不时地我会发现一个出色的资源,因此 ...

最新文章

  1. Nginx 虚拟主机配置及负载均衡
  2. linkin大话面向对象--多态
  3. maven创建webapp项目
  4. 常考数据结构与算法:将字符串转为整数
  5. AttributeError: ‘Model‘ object has no attribute ‘_get_distribution_strategy
  6. Codeforces Round #587 (Div. 3)
  7. python高阶函数和匿名函数
  8. 玩转oracle 11g(15):命令学习3
  9. 搭建spring MVC项目
  10. YOLO系列:YOLOv1,YOLOv2,YOLOv3,YOLOv4,YOLOv5简介
  11. 如何使用Puppeteer从任何网站创建自定义API
  12. python import MySQLdb 解决报错 Error:Reason: image not found
  13. sae php api,api.php · silenceper/saeApi - Gitee.com
  14. 文件上传_文件下载_后端获取登录用户---SpringCloud Alibaba_若依微服务框架改造---工作笔记003
  15. System.setOut 重定向 memcached 的输出
  16. Linux查看某个进程的启动时间
  17. (10)机器学习_K邻近算法
  18. 视频教程-CCNA自学视频课程专题四:CCNA认证重点难点解析3(扩展篇)-思科认证
  19. 从设计心理学理解交互设计的原则
  20. Typecho独立下载插件安装与使用

热门文章

  1. 2019.03.01 bzoj2555: SubString(sam+lct)
  2. Shiro01 功能点框图、架构图、身份认证逻辑、身份认证代码实现
  3. ubuntu kylin 18.04 安装 Qt Creator 5.11
  4. day13 paramiko、数据库表操作
  5. 深度学习 vs 机器学习 vs 模式识别
  6. 测试与封装5.1.5.2
  7. Spring学习笔记_IOC
  8. 智能会议白板系统每日开发记录
  9. 引导修复_怎么使用bcdrepair引导修复系统【详细步骤】
  10. eclipse关闭mysql数据库,有关于用eclipse连接mysql数据库出现的问题以及解决办法