在强化学习(十四) Actor-Critic中,我们讨论了Actor-Critic的算法流程,但是由于普通的Actor-Critic算法难以收敛,需要一些其他的优化。而Asynchronous Advantage Actor-critic(以下简称A3C)就是其中比较好的优化算法。本文我们讨论A3C的算法原理和算法流程。

    本文主要参考了A3C的论文,以及ICML 2016的deep RL tutorial。

1. A3C的引入

    上一篇Actor-Critic算法的代码,其实很难收敛,无论怎么调参,最后的CartPole都很难稳定在200分,这是Actor-Critic算法的问题。但是我们还是有办法去有优化这个难以收敛的问题的。

    回忆下之前的DQN算法,为了方便收敛使用了经验回放的技巧。那么我们的Actor-Critic是不是也可以使用经验回放的技巧呢?当然可以!不过A3C更进一步,还克服了一些经验回放的问题。经验回放有什么问题呢? 回放池经验数据相关性太强,用于训练的时候效果很可能不佳。举个例子,我们学习下棋,总是和同一个人下,期望能提高棋艺。这当然没有问题,但是到一定程度就再难提高了,此时最好的方法是另寻高手切磋。

    A3C的思路也是如此,它利用多线程的方法,同时在多个线程里面分别和环境进行交互学习,每个线程都把学习的成果汇总起来,整理保存在一个公共的地方。并且,定期从公共的地方把大家的齐心学习的成果拿回来,指导自己和环境后面的学习交互。

    通过这种方法,A3C避免了经验回放相关性过强的问题,同时做到了异步并发的学习模型。

2. A3C的算法优化

    现在我们来看看相比Actor-Critic,A3C到底做了哪些具体的优化。

    相比Actor-Critic,A3C的优化主要有3点,分别是异步训练框架,网络结构优化,Critic评估点的优化。其中异步训练框架是最大的优化。

    我们首先来看这个异步训练框架,如下图所示:

    图中上面的Global Network就是上一节说的共享的公共部分,主要是一个公共的神经网络模型,这个神经网络包括Actor网络和Critic网络两部分的功能。下面有n个worker线程,每个线程里有和公共的神经网络一样的网络结构,每个线程会独立的和环境进行交互得到经验数据,这些线程之间互不干扰,独立运行。

    每个线程和环境交互到一定量的数据后,就计算在自己线程里的神经网络损失函数的梯度,但是这些梯度却并不更新自己线程里的神经网络,而是去更新公共的神经网络。也就是n个线程会独立的使用累积的梯度分别更新公共部分的神经网络模型参数。每隔一段时间,线程会将自己的神经网络的参数更新为公共神经网络的参数,进而指导后面的环境交互。

    可见,公共部分的网络模型就是我们要学习的模型,而线程里的网络模型主要是用于和环境交互使用的,这些线程里的模型可以帮助线程更好的和环境交互,拿到高质量的数据帮助模型更快收敛。

    现在我们来看看第二个优化,网络结构的优化。之前在强化学习(十四) Actor-Critic中,我们使用了两个不同的网络Actor和Critic。在A3C这里,我们把两个网络放到了一起,即输入状态SS,可以输出状态价值VV,和对应的策略ππ, 当然,我们仍然可以把Actor和Critic看做独立的两块,分别处理,如下图所示:

    第三个优化点是Critic评估点的优化,在强化学习(十四) Actor-Critic第2节中,我们讨论了不同的Critic评估点的选择,其中d部分讲到了使用优势函数AA来做Critic评估点,优势函数AA在时刻t不考虑参数的默认表达式为:

A(S,A,t)=Q(S,A)−V(S)A(S,A,t)=Q(S,A)−V(S)

    Q(S,A)Q(S,A)的值一般可以通过单步采样近似估计,即:

Q(S,A)=R+γV(S′)Q(S,A)=R+γV(S′)

    这样优势函数去掉动作可以表达为:

A(S,t)=R+γV(S′)−V(S)A(S,t)=R+γV(S′)−V(S)

    其中V(S)V(S)的值需要通过Critic网络来学习得到。

    在A3C中,采样更进一步,使用了N步采样,以加速收敛。这样A3C中使用的优势函数表达为:

A(S,t)=Rt++γRt+1+...γn−1Rt+n−1+γnV(S′)−V(S)A(S,t)=Rt++γRt+1+...γn−1Rt+n−1+γnV(S′)−V(S)

    对于Actor和Critic的损失函数部分,和Actor-Critic基本相同。有一个小的优化点就是在Actor-Critic策略函数的损失函数中,加入了策略ππ的熵项,系数为c, 即策略参数的梯度更新和Actor-Critic相比变成了这样:

θ=θ+α∇θlogπθ(st,at)A(S,t)+c∇θH(π(St,θ))θ=θ+α∇θlogπθ(st,at)A(S,t)+c∇θH(π(St,θ))

    以上就是A3C和Actor-Critic相比有优化的部分。下面我们来总价下A3C的算法流程。

3. A3C算法流程

    这里我们对A3C算法流程做一个总结,由于A3C是异步多线程的,我们这里给出任意一个线程的算法流程。

    输入:公共部分的A3C神经网络结构,对应参数位θ,wθ,w,本线程的A3C神经网络结构,对应参数θ′,w′θ′,w′, 全局共享的迭代轮数TT,全局最大迭代次数TmaxTmax, 线程内单次迭代时间序列最大长度TlocalTlocal,状态特征维度nn, 动作集AA, 步长α,βα,β,熵系数c, 衰减因子γγ, 探索率ϵϵ

    输入:公共部分的A3C神经网络参数θ,wθ,w

    1. 更新时间序列t=1t=1

    2. 重置Actor和Critic的梯度更新量:dθ←0,dw←0dθ←0,dw←0

    3. 从公共部分的A3C神经网络同步参数到本线程的神经网络:θ′=θ,w′=wθ′=θ,w′=w

    4. tstart=ttstart=t,初始化状态stst

    5. 基于策略π(at|st;θ)π(at|st;θ)选择出动作atat

    6. 执行动作atat得到奖励rtrt和新状态st+1st+1

    7. t←t+1,T←T+1t←t+1,T←T+1

    8. 如果stst是终止状态,或t−tstart==tlocalt−tstart==tlocal,则进入步骤9,否则回到步骤5

    9. 计算最后一个时间序列位置stst的Q(s,t)Q(s,t):

Q(s,t)={0V(st,w′)terminalstatenoneterminalstate,bootstrappingQ(s,t)={0terminalstateV(st,w′)noneterminalstate,bootstrapping

    10. for i∈(t−1,t−2,...tstart)i∈(t−1,t−2,...tstart):

      1) 计算每个时刻的Q(s,i)Q(s,i):Q(s,i)=ri+γQ(s,i+1)Q(s,i)=ri+γQ(s,i+1)

      2) 累计Actor的本地梯度更新:

dθ←dθ+∇θ′logπθ′(si,ai)(Q(s,i)−V(Si,w′))+c∇θ′H(π(si,θ′))dθ←dθ+∇θ′logπθ′(si,ai)(Q(s,i)−V(Si,w′))+c∇θ′H(π(si,θ′))

      3) 累计Critic的本地梯度更新:

dw←dw+∂(Q(s,i)−V(Si,w′))2∂w′dw←dw+∂(Q(s,i)−V(Si,w′))2∂w′

    11. 更新全局神经网络的模型参数:

θ=θ−αdθ,w=w−βdwθ=θ−αdθ,w=w−βdw

    12. 如果T>TmaxT>Tmax,则算法结束,输出公共部分的A3C神经网络参数θ,wθ,w,否则进入步骤3

    以上就是A3C算法单个线程的算法流程。

4. A3C算法实例

    下面我们基于上述算法流程给出A3C算法实例。仍然使用了OpenAI Gym中的CartPole-v0游戏来作为我们算法应用。CartPole-v0游戏的介绍参见这里。它比较简单,基本要求就是控制下面的cart移动使连接在上面的pole保持垂直不倒。这个任务只有两个离散动作,要么向左用力,要么向右用力。而state状态就是这个cart的位置和速度, pole的角度和角速度,4维的特征。坚持到200分的奖励则为过关。

    算法代码大部分参考了莫烦的A3C代码,增加了模型测试部分的代码并调整了部分模型参数。完整的代码参见我的Github:https://github.com/ljpzzz/machinelearning/blob/master/reinforcement-learning/a3c.py

    整个算法的Actor和Critic的网络结构都定义在这里, 所有的线程中的网络结构,公共部分的网络结构都在这里定义。

    def _build_net(self, scope):w_init = tf.random_normal_initializer(0., .1)with tf.variable_scope('actor'):l_a = tf.layers.dense(self.s, 200, tf.nn.relu6, kernel_initializer=w_init, name='la')a_prob = tf.layers.dense(l_a, N_A, tf.nn.softmax, kernel_initializer=w_init, name='ap')with tf.variable_scope('critic'):l_c = tf.layers.dense(self.s, 100, tf.nn.relu6, kernel_initializer=w_init, name='lc')v = tf.layers.dense(l_c, 1, kernel_initializer=w_init, name='v')  # state valuea_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope + '/actor')c_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope + '/critic')return a_prob, v, a_params, c_params

    所有线程初始化部分,以及本线程和公共的网络结构初始化部分如下:

    with tf.device("/cpu:0"):OPT_A = tf.train.RMSPropOptimizer(LR_A, name='RMSPropA')OPT_C = tf.train.RMSPropOptimizer(LR_C, name='RMSPropC')GLOBAL_AC = ACNet(GLOBAL_NET_SCOPE)  # we only need its paramsworkers = []# Create workerfor i in range(N_WORKERS):i_name = 'W_%i' % i   # worker nameworkers.append(Worker(i_name, GLOBAL_AC))

    本线程神经网络将本地的梯度更新量用于更新公共网络参数的逻辑在update_global函数中,而从公共网络把参数拉回到本线程神经网络的逻辑在pull_global中。

    def update_global(self, feed_dict):  # run by a localSESS.run([self.update_a_op, self.update_c_op], feed_dict)  # local grads applies to global netdef pull_global(self):  # run by a localSESS.run([self.pull_a_params_op, self.pull_c_params_op])

    详细的内容大家可以对照代码和算法流程一起看。在主函数里我新加了一个测试模型效果的过程,大家可以试试看看最后的模型效果如何。

5. A3C小结

    A3C解决了Actor-Critic难以收敛的问题,同时更重要的是,提供了一种通用的异步的并发的强化学习框架,也就是说,这个并发框架不光可以用于A3C,还可以用于其他的强化学习算法。这是A3C最大的贡献。目前,已经有基于GPU的A3C框架,这样A3C的框架训练速度就更快了。

    除了A3C, DDPG算法也可以改善Actor-Critic难收敛的问题。它使用了Nature DQN,DDQN类似的思想,用两个Actor网络,两个Critic网络,一共4个神经网络来迭代更新模型参数。

A3C的算法原理和算法流程相关推荐

  1. 【分布式技术专题】「分布式技术架构」一文带你厘清分布式事务协议及分布式一致性协议的算法原理和核心流程机制(上篇)

    背景介绍 最近大家都相比遇到了就业瓶颈了,很多公司要不就是不招人了,要不就是把门槛抬的很高,所以针对于一些分布式角度而言的技术知识点,更是必备条件以及重中之重了.那么今天笔者就针对于分布式协议以及一些 ...

  2. 集成学习(Bagging、Boosting、Stacking)算法原理与算法步骤

    集成学习 概述 严格意义上来说,集成学习算法不能算是一种机器学习算法,而像是一种模型优化手段,是一种能在各种机器学习任务上提高准确率的强有力技术.在很多数据挖掘竞赛中,集成学习算法是比赛大杀器,能很好 ...

  3. 多变异位自适应遗传算法(MMAdapGA)的算法原理、算法步骤和matlab实现

    算法原理 自适应遗传算法是交叉概率和变异概率能够随使用度自动改变,以求得相对某个解的最佳交叉概率和变异概率.本算法是在自适应遗传算法中引进多变异位,以增加种群的多样性. 自适应遗传算法中的交叉概率和变 ...

  4. 连通域最小外接矩形算法原理_算法|图论 2W字知识点整理(超全面)

    作者:SovietPower✨ 链接:https://ac.nowcoder.com/discuss/186584 来源:牛客网 度数序列 对于无向图, 为每个点的度数.有 (每条边被计算两次).有偶 ...

  5. 人工神经网络的算法原理,神经网络算法的原理是

    神经网络算法原理 4.2.1概述人工神经网络的研究与计算机的研究几乎是同步发展的. 1943年心理学家McCulloch和数学家Pitts合作提出了形式神经元的数学模型,20世纪50年代末,Rosen ...

  6. DFS算法原理及其具体流程,包你看一遍就能理解

    目录 写在前面 DFS算法 所解决的问题 所需要的数据结构 代码结构及解释 方法一:递归 解释 递归dfs总结 方法二:栈 解释 栈dfs总结 写在前面 因为楼主也是刚开始刷leetcode,所以下面 ...

  7. 【分布式技术专题】「分布式技术架构」一文带你厘清分布式事务协议及分布式一致性协议的算法原理和核心流程机制(Paxos篇)

    概念简介 Paxos是一种基于消息传递具有高度容错特性的一致性算法,是目前公认的解决分布式一致性问题最有效的算法之一. 发展历史 Paxos算法的发展历史追溯到古希腊,当时有一个名为"Pax ...

  8. 简述dijkstra算法原理_Dijkstra算法之 Java详解

    迪杰斯特拉算法介绍 迪杰斯特拉(Dijkstra)算法是典型最短路径算法,用于计算一个节点到其他节点的最短路径. 它的主要特点是以起始点为中心向外层层扩展(广度优先搜索思想),直到扩展到终点为止. 基 ...

  9. php md5 file算法原理,MD5算法原理与实现

    //MessageDigestAlgorithm5.cpp #include "stdafx.h" #include "MessageDigestAlgorithm5.h ...

  10. 雪花算法原理_算法越来越强,我们的判断力却越来越弱

    我们每天在网络平台上看到的信息,很大程度上是由运营平台的商业公司所编写的算法决定的.但在工具理性的背后,确实存在着网络用户对新闻平台的不满与抱怨:点开全文毫无意义的标题党.缺少"石锤&quo ...

最新文章

  1. Apache httpd设置HTTPS双向认证
  2. oracle sql 执行计划分析_《真正读懂Oracle SQL执行计划》
  3. Educational Codeforces Round 75 (Rated for Div. 2)
  4. mysql 两表管理查询_mysql两表查询
  5. 程序员面试100题之九:求子数组的最大和
  6. hdfs中与file数组类似的数组_如何在 JavaScript 中克隆数组
  7. 损失函数、python实现均方误差、交叉熵误差函数、mini-batch的损失函数
  8. Google research 一行预处理代码,让你的CV模型更强!
  9. 强大的 pdf 编辑器 —— Acrobat
  10. vb 运行错误429 mysql_运行时错误429 ActiveX部件不能创建对象的终极解决方法
  11. Fréchet Inception Distance(FID)
  12. 深度学习调参经验分享(遥感建筑提取)
  13. android 多个sdcard路径,Android中访问sdcard路径的几种方式
  14. 55岁的大妈被儿媳嫌弃,二次创业,靠洗地毯一年就开了一家公司!
  15. 怎么制作出一张证件照?分享几种好用的证件照制作方法
  16. connect 连接超时
  17. 招行193亿港元收购永隆银行53.1%股份
  18. html wmf 不显示,在Word、Excel、PPT中不能显示WMF图片
  19. PPT 如何取消幻灯片自动播放
  20. 【Aegisub相关】_G 简化代码写法的有效范围

热门文章

  1. 搭建无人机仿真环境之PX4安装中出现的一些问题的解决
  2. sop28和so28j封装_sop28封装尺寸
  3. 8086 CPU 寄存器
  4. Vlan的划分;配置trunk中继链路;以太通道配置;DHCP服务配置
  5. 如何在vue中使用Cesium加载shp文件、wms服务、WMTS服务
  6. TNF8SLNO 华为OSN1800全新4路STM-16/8路STM-4/8路STM-1业务板
  7. 【MATLAB生信分析】MATLAB生物信息分析工具箱(二)
  8. css3ps插件,css3ps插件
  9. 深度学习软件安装及环境配置(Win10)
  10. 微信公众号文章排版编辑器推荐