1 引言

在深度强化学习-策略梯度算法推导博文中,采用了两种方法推导策略梯度算法,并给出了Reinforce算法的伪代码。可能会有小伙伴对策略梯度算法的形式比较疑惑,本文就带领大家剖析其中的原理,深入理解策略梯度算法的公式。本文主要参考了百度飞桨的视频Policy Gradient算法有兴趣的小伙伴可以看看,我觉得讲的非常透彻。

2 手写数字识别

我们先来看一下手写数字识别案列,采用LeNet网络,其输入为一张手写数字照片,输出为0-9每个数字对应的概率。LeNet网络结构不是本文介绍的重点,我们主要看损失函数部分。

假设网络的输入为数字5,标签为one-hot编码形式,即数字5对应概率值为1,其余为0,网络的输出如上图所示。对于分类问题,通常采用交叉熵(Cross Entropy) 损失函数

交叉熵:

分别表示两个不同的分布,交叉熵可以衡量两个分布的差距,通过最小化交叉熵损失,就可以缩小两个分布之间的距离。将标签看作分布,预测概率看作分布,根据交叉熵公式,计算上图中的交叉熵

将其作为损失进行梯度反传,更新网络参数,从而让预测概率分布更加接近标签。

3 策略梯度算法

看完手写数字识别案列后,回到策略梯度算法,单步损失和策略梯度的形式为

单步损失:

策略梯度:

假设智能体的动作空间为离散形式,包括“左、停、右”三个动作,策略网络的输入为状态,输出为每个动作对应的概率。如下图所示

其中预测概率为网络输出的概率分布,真实动作为智能体真正执行的动作,但是它并一定是一个正确的动作,无法作为标签。计算预测概率与真实动作之间的交叉熵,得到

发现它与单步损失中的形式一致。由于真实动作不一定是正确的标签,所以加上累积奖励作为权重。越大,对应的损失越需要重视,反之越小,对应的损失就不那么重要。可以认为是一个缩放因子,始终为正数,并不影响梯度的方向,因此可以忽略。综上,单步损失具体可以表示为

其中表示真实动作。对单步损失求梯度即为策略梯度的蒙特卡洛近似,通过梯度反传不断优化策略网络参数,让网络输出的概率分布接近累积回报较大的动作。

4 总结

本文利用离散动作模型剖析了策略梯度公式,发现它与分类模型类似。对于连续动作模型也是同样的道理,利用交叉熵衡量网络预测的概率分布与真实动作的概率分布,并采用累积奖励加权作为单步损失。对损失求梯度,然后沿着梯度的反方向不断更新策略网络参数,从而不断提升策略。

深度强化学习-策略梯度算法深入理解相关推荐

  1. 深度强化学习-Double DQN算法原理与代码

    深度强化学习-Double DQN算法原理与代码 引言 1 DDQN算法简介 2 DDQN算法原理 3 DDQN算法伪代码 4 仿真验证 引言 Double Deep Q Network(DDQN)是 ...

  2. 【深度强化学习】DRL算法实现pytorch

    DRL Algorithms DQN (deep Q network) Policiy_Gradient 策略梯度是强化学习的一类方法,大致的原理是使用神经网络构造一个策略网络,输入是状态,输出为动作 ...

  3. 【深度强化学习】DDPG算法

    1 DDPG简介 确定性策略梯度(Deterministic Policy Gradient,DPG):确定性策略是和随机策略相对而言的.作为随机策略,在同一个状态处,采用的动作是基于一个概率分布,即 ...

  4. 强化学习 | 策略梯度 | Natural PG | TRPO | PPO

    学习情况:

  5. 深度强化学习Soft-Actor Critic算法高性能Pytorch代码(改写自spinningup,低环境依赖,低阅读障碍)

    写在前面 DRL各种算法在github上各处都是,例如莫凡的DRL代码.ElegantDRL(易读性NO.1) 很多代码不是原算法的最佳实现,在具体实现细节上也存在差异,不建议直接用在科研上. 这篇博 ...

  6. 重温强化学习之深度强化学习

    1.简介                输入特征和真实特征相距比较远,加一个深度学习提取源的特征 2.基于值函数的深度强化学习 意义:不用函数近似无法解决大规模的问题,用函数近似训练不稳定,首次证明了 ...

  7. 赠票 | 深度强化学习的理论、算法与应用专题探索班

    文末有数据派赠票福利呦! 深度强化学习是人工智能领域的一个新的研究热点.它以一种通用的形式将深度学习的感知能力与强化学习的决策能力相结合,并能够通过端对端的学习方式实现从原始输入到输出的直接控制.自提 ...

  8. 线下报名 | YOCSEF TDS:深度强化学习的理论、算法与应用

    时间:7月29日9:00-17:20 地点:北京中科院计算所,一层/四层报告厅(暂定) 报名方式:1.报名链接:http://conf2.ccf.org.cn/TDS  2.点击文末阅读原文报名  3 ...

  9. 深度学习学习笔记-论文研读4-基于深度强化学习的多用户边缘计算任务卸载调度与资源分配算法

    本人学识浅薄,如有理解不到位的地方还请大佬们指出,相互学习,共同进步 概念引入 强化学习 DQN算法 边缘计算 边缘计算,是指在靠近物或数据源头的一侧,采用网络.计算.存储.应用核心能力为一体的开放平 ...

  10. 深度强化学习-D3QN算法原理与代码

    Dueling Double Deep Q Network(D3QN)算法结合了Double DQN和Dueling DQN算法的思想,进一步提升了算法的性能.如果对Doubel DQN和Duelin ...

最新文章

  1. kettle将文件路径定义为_kettle_步骤解释
  2. ar9344 9382 8035 编程器固件_沈阳熔铜炉设计,紧固件加热炉_宏祥电炉
  3. Java虚拟机运行时的数据区域
  4. Oracle 存储过程的导出导入序列的导出
  5. html奇淫技巧 2 教你如何进行图文环绕布局 原创
  6. 菜鸟——首个页面——奇葩问题
  7. 【英语学习】【医学】Unit 03 Blood
  8. java ajax传递到action_ajax传值到action,后台取不到值。
  9. Python 格式化输出 —— %r 与 %s 的区别(__repr__ 与 __str__)
  10. osea/ 5.0-6.0
  11. 10.docker build
  12. 在中国从事什么职业最赚钱_中国最好的十大职业2(转)
  13. 袋鼠云数智之旅·上海站|探索“十四五”智慧校园新图景
  14. 纯js浏览器h5调用摄像头扫描识别解析 条形码+二维码
  15. 完美解决IE9浏览器出现的对象未定义问题
  16. 2018.06~7 阅读随笔
  17. 腹腰部肌肉锻炼(腰会变粗)
  18. 社团部部长工作计划计算机学院,社团部长的工作计划(共9篇).doc
  19. DCOS到底是啥?看完这篇你就懂了
  20. 在Window10子系统Ubantu创建conda环境

热门文章

  1. 财报识别OCR,披露虚假财务报表
  2. python中ix用法_pandas中ix的使用详细讲解
  3. c语言坐标三角形判断,C语言输入三角形边长判断其类型并输出面积实例代码
  4. 软件测试必学的16个高频数据库操作及命令
  5. linux 主机管理平台,Linux虚拟主机管理系统directadmin使用中文教程
  6. DFS判断回路及回路个数
  7. 【转载】数据中心网络架构浅谈
  8. luogu P4234 最小差值生成树
  9. 实测Maven上传jar包到私服的方法归纳
  10. 困扰了很久的ubuntu下智能拼音输入法