学习了莫烦讲解的PPO,写了点自己的理解笔记,希望能帮到你们。


代码

代码可以去上面的链接自己下载跑一下,这边也给出我参考莫烦自己学的,基本是一样的:

import gym
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as pltEP_MAX = 1000
EP_LEN = 200
BATCH = 32
GAMMA = 0.9
C_LR = 0.0002
A_LR = 0.0001
A_UPDATE_STEPS = 10
C_UPDATE_STEPS = 10
METHOD = [dict(name='kl_pen', kl_target=0.01, lam=0.5),   # KL penaltydict(name='clip', epsilon=0.2),                 # Clipped surrogate objective, find this is better
][1]class PPO:def __init__(self):self.sess = tf.Session()self.tfs = tf.placeholder(tf.float32, [None, S_DIM], 'state')self._build_anet('Critic')with tf.variable_scope('closs'):self.tfdc_r = tf.placeholder(tf.float32, [None, 1], name='discounted_r')self.adv = self.tfdc_r - self.vcloss = tf.reduce_mean(tf.square(self.adv))self.ctrain = tf.train.AdamOptimizer(C_LR).minimize(closs)pi, pi_params = self._build_anet('pi', trainable=True)oldpi, oldpi_params = self._build_anet('oldpi', trainable=False)with tf.variable_scope('sample_action'):self.sample_op = tf.squeeze(pi.sample(1), axis=0)with tf.variable_scope('update_oldpi'):self.update_oldpi_op = [oldp.assign(p) for p, oldp in zip(pi_params, oldpi_params)]with tf.variable_scope('aloss'):self.tfa = tf.placeholder(tf.float32, [None, A_DIM], 'action')self.tfadv = tf.placeholder(tf.float32, [None, 1], 'advantage')with tf.variable_scope('surrogate'):ratio = pi.prob(self.tfa) / oldpi.prob(self.tfa)surr = ratio * self.tfadvif METHOD['name'] == 'kl_pen':self.tflam = tf.placeholder(tf.float32, None, 'lambda')kl = tf.distributions.kl_divergence(oldpi, pi)self.kl_mean = tf.reduce_mean(kl)self.aloss = -(tf.reduce_mean(surr - self.tflam * kl))else:  # clipping method, find this is betterself.aloss = -tf.reduce_mean(tf.minimum(surr,tf.clip_by_value(ratio, 1. - METHOD['epsilon'], 1. + METHOD['epsilon']) * self.tfadv))self.atrain = tf.train.AdamOptimizer(A_LR).minimize(self.aloss)tf.summary.FileWriter('log/', self.sess.graph)self.sess.run(tf.global_variables_initializer())def _build_anet(self, name, trainable=True):if name == 'Critic':with tf.variable_scope(name):# self.s_Critic = tf.placeholder(tf.float32, [None, S_DIM], 'state')l1_Critic = tf.layers.dense(self.tfs, 100, tf.nn.relu, trainable=trainable, name='l1')self.v = tf.layers.dense(l1_Critic, 1, trainable=trainable, name='value_predict')else:with tf.variable_scope(name):# self.s_Actor = tf.placeholder(tf.float32, [None, S_DIM], 'state')l1_Actor = tf.layers.dense(self.tfs, 100, tf.nn.relu, trainable=trainable, name='l1')mu = 2 * tf.layers.dense(l1_Actor, A_DIM, tf.nn.tanh, trainable=trainable, name='mu')sigma = tf.layers.dense(l1_Actor, A_DIM, tf.nn.softplus, trainable=trainable, name='sigma')norm_list = tf.distributions.Normal(loc=mu, scale=sigma)params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=name)return norm_list, paramsdef update(self, s, a, r):self.sess.run(self.update_oldpi_op)adv = self.sess.run(self.adv, {self.tfdc_r: r, self.tfs: s})if METHOD['name'] == 'kl_pen':for _ in range(A_UPDATE_STEPS):_, kl = self.sess.run([self.atrain, self.kl_mean], {self.tfa: a, self.tfadv: adv, self.tfs: s, self.tflam: METHOD['lam']})if kl > 4 * METHOD['kl_target']:  # this in in google's paperbreakif kl < METHOD['kl_target'] / 1.5:  # adaptive lambda, this is in OpenAI's paperMETHOD['lam'] /= 2elif kl > METHOD['kl_target'] * 1.5:METHOD['lam'] *= 2METHOD['lam'] = np.clip(METHOD['lam'], 1e-4, 10)    # sometimes explode, this clipping is my solutionelse:[self.sess.run(self.atrain, {self.tfs: s, self.tfa: a, self.tfadv: adv}) for _ in range(A_UPDATE_STEPS)][self.sess.run(self.ctrain, {self.tfs: s, self.tfdc_r: r}) for _ in range(C_UPDATE_STEPS)]def choose_action(self, s):s = s[np.newaxis, :]a = self.sess.run(self.sample_op, {self.tfs: s})[0]return np.clip(a, -2, 2)def get_v(self, s):if s.ndim < 2:s = s[np.newaxis, :]return self.sess.run(self.v, {self.tfs: s})env = gym.make('Pendulum-v0').unwrapped
S_DIM = env.observation_space.shape[0]
A_DIM = env.action_space.shape[0]
ppo = PPO()
all_ep_r = []for ep in range(EP_MAX):s = env.reset()buffer_s, buffer_a, buffer_r = [], [], []ep_r = 0for t in range(EP_LEN):env.render()a = ppo.choose_action(s)s_, r, done, _ = env.step(a)buffer_s.append(s)buffer_a.append(a)buffer_r.append((r+8)/8)s = s_ep_r += rif (t+1) % BATCH == 0 or t == EP_LEN - 1:v_s_ = ppo.get_v(s_)discounted_r = []for r in buffer_r[::-1]:v_s_ = r + GAMMA*v_s_discounted_r.append(v_s_)discounted_r.reverse()bs, ba, br = np.vstack(buffer_s), np.vstack(buffer_a), np.vstack(discounted_r)buffer_s, buffer_a, buffer_r = [], [], []ppo.update(bs, ba, br)if ep == 0:all_ep_r.append(ep_r)else:all_ep_r.append(all_ep_r[-1]*0.9 + ep_r*0.1)print('Ep:%d | Ep_r:%f' % (ep, ep_r))plt.plot(np.arange(len(all_ep_r)), all_ep_r)
plt.xlabel('Episode')
plt.ylabel('Moving averaged episode reward')
plt.show()

理解流程图

PPO算法本质上是一个AC算法,有Actor和Critic神经网络,其中,Critic网络的更新方式和AC算法差不多,Actor网络我感觉和Q-Learning一样有新旧神经网络,并周期性的更新旧神经网络。Critci网络就不多说了,不懂的可以参考一下莫烦的教程和我之前写得一篇理解,adv相当于AC中的TD_error。Actor网络主要作用就是决定策略π\piπ(pi),程序中实现的时候假设策略是一个正态分布,所以神经网络主要是预测合适的μ\muμ(mu)和σ\sigmaσ(sigma)。然后根据这个分布选择动作,作用于环境,环境反馈下个状态等等信息。程序更新神经网络实现的时候,会存储32个动作及其环境输出的信息来更新网络,就是上面流程图中提到的batch和buffer,其中缓存下来的reward还需要做一个discounted的转换(就是一个累计的reward)。

程序讲解

大致思路

程序主要分为两部分,一部分是PPO类,还有一部分就是主程序。实现的思路莫烦老师已经已经讲的很清楚了,这边就不赘述了,截图蹭页数

主程序讲的就是环境env和算法交互的内容,输入环境内容,输出算法决策并更新算法参数。

实现注意点

几个我自己看程序的疑惑点记录一下,主要是在PPO类中:

1、tf.squeeze

__init__sample_action中self.sample_op = tf.squeeze(pi.sample(1), axis=0),这边pi是一个正态分布,sample(1)就是采一个点(就是选一个动作),我们调整一下程序调试一下:

 # self.sample_op = tf.squeeze(pi.sample(1), axis=0)self.sample_op = pi.sample(1)…………a = self.sess.run(self.sample_op, {self.tfs: s})

看下输出结果:

再调整一下程序:

 self.sample_op = tf.squeeze(pi.sample(1), axis=0)# self.sample_op = pi.sample(1)…………a = self.sess.run(self.sample_op, {self.tfs: s})

输出结果:

明显压缩了一维。写一段程序帮助理解一下:

简单的来说,squeeze是改变shape的,里面的内容是不变的:把所有一维的抹去(参数axis是锁定要抹去的维数的,下面一段程序axis=0,就是指抹去shape中的第一个1)。

2、Normal().prob

ratio = pi.prob(self.tfa) / oldpi.prob(self.tfa)这个函数是用来求对应点的概率密度的。

简单的PPO算法笔记相关推荐

  1. 维吉尼亚加密算法 (C语言实现简单的加密算法) ------- 算法笔记007

    概念理解 什么是维吉尼亚加密算法 加密步骤:1.创建一个匹配循环链表:2.接受需要加密的明文:3.根据随机生成的密钥配合链表进行移位:4.输出/保存对应的密文 解密步骤: 1.接受加密后的密钥:2.根 ...

  2. 【原创】强化学习笔记|从零开始学习PPO算法编程(pytorch版本)

    从零开始学习PPO算法编程(pytorch版本)_melody_cjw的博客-CSDN博客_ppo算法 pytorch 从零开始学习PPO算法编程(pytorch版本)(二)_melody_cjw的博 ...

  3. 胡凡 《算法笔记》 上机实战训练指南 3.1 简单模拟

    胡凡 <算法笔记> 上机实战训练指南 3.1 持续更新中 , 菜鸡的刷题笔记- 大学到现在了还没咋好好刷过题,该push自己了- 文章目录 胡凡 <算法笔记> 上机实战训练指南 ...

  4. 强化学习经典算法笔记(十二):近端策略优化算法(PPO)实现,基于A2C(下)

    强化学习经典算法笔记(十二):近端策略优化算法(PPO)实现,基于A2C 本篇实现一个基于A2C框架的PPO算法,应用于连续动作空间任务. import torch import torch.nn a ...

  5. 算法笔记--简单实现栈的先入后出(FILO,First In Last Out)功能

    算法笔记–简单实现栈的先入后出(FILO,First In Last Out)功能 stack 栈,是一个 先入后出(FILO,First In Last Out)的 有序列表,可以形象地理解为手枪的 ...

  6. 算法笔记 简单贪心(月饼问题)

    ** 概念 ** 贪心法是求解一类最优问题的方法,它总是考虑当前状态下局部最优(或较优)的策略,来使全局的结果达到最优(或较优).显然,如果采取较优而非最优的策略(最优策略可能不存在或是不易想到),得 ...

  7. C语言、Java学习笔记(三)---几种简单的排序算法

    假期已经过了一半,整个人都变得颓废了许多.今天没有出去玩,就学了几个简单的排序算法,以求安慰自己,好歹也是在假期里学习过了.(瘫- C 这里一次性给出三种排序方法的代码,分别是冒泡排序,选择排序和归并 ...

  8. 【算法笔记题解】《算法笔记知识点记录》第三章——入门模拟1——简单模拟

    如果喜欢大家还希望给个收藏点赞呀0.0 相关知识点大家没基础的还是要看一下的,链接: <算法笔记知识点记录>第三章--入门模拟 由于放原题的话文章实在太长,所以题多的话我只放思路和题解,大 ...

  9. ChatGPT通俗导论:从RL之PPO算法、RLHF到GPT-N、instructGPT

    前言 自从我那篇BERT通俗笔记一经发布,然后就不断改.不断找人寻求反馈.不断改,其中一位朋友倪老师(之前我司NLP高级班学员现课程助教老师之一)在谬赞BERT笔记无懈可击的同时,给我建议到,&quo ...

最新文章

  1. H.264 基础及 RTP 封包详解
  2. 差异表达基因-火山图和聚类图解释
  3. 10a大电流稳压芯片_稳压二极管你见过,但是它的这些参数你知道吗
  4. C. Kefa and Park【树的遍历】
  5. 里bl2和bl3为什么分开_分手挽回:为什么不建议过早同居
  6. 02-初识CoreData
  7. NOD32客户端更新文件
  8. C/C++线程与多线程工作笔记0007---单线程实现文件查找系统
  9. docker配置 注册中心
  10. C# BackgroundWorker使用总结
  11. 18个基于Web的代码开发编辑器
  12. 软件架构入门及分类——微服务架构
  13. 计算机一级安装包怎么升级,详细教您win7如何升级为sp1
  14. 木马开发的基本理论基础(四)
  15. 异常解决 java.lang.UnsupportedOperationException: Required method destroyItem was not overridden
  16. 2013年互联网江湖格局观
  17. android pppd log,未记录的pppd退出代码
  18. dolphinscheduler v2.0.1 master和worker执行流程分析(一)
  19. 人工智能 | ShowMeAI资讯日报 #2022.06.22
  20. 悼念512汶川大地震遇难同胞——珍惜现在,感恩生活 dp

热门文章

  1. HTML中插入自动播放的背景音乐-亲测有效
  2. 2019 年移动安全总结汇报演讲稿
  3. 火星开发的价值_开发火星是幌子,月球才是必争之地
  4. Java SE核心API(2) —— 正则表达式、Object、包装类
  5. unicode 和 GB2312 编码对应表
  6. AngularJS控制器(Controller)
  7. 吊打java面试官之 Hashtable详细介绍(源码解析)和使用示例
  8. 宿州市空间数据库管理系统(2)
  9. 检查SSD固态硬盘的使用量和寿命
  10. windows 单机 - elasticsearch-7.11.1 、kibana-7.11.1 安装部署