REINFORCE 算法实现

REINFORCE算法是策略梯度算法最原始的实现算法,这里采用tensorflow2.0进行实现

import tensorflow as tf
import gym
from matplotlib import pyplot as plt
import numpy as npdef PGReinforce_run(PGReinforce_agent=None, episode=1000):PGReinforce_agent = PGReinforce_agent.PGReinforce(n_actions=2, n_features=4)PGReinforce_agent.net_init()score = []env = gym.make('CartPole-v1')bias = 5for i_episode in range(episode):# 初始化,observation = env.reset()done = Falset = 0while not done:env.render()action = PGReinforce_agent.choose_action(observation)PGReinforce_agent.traj_store(observation, action)observation_, reward, done, info = env.step(action)x, x_dot, theta, theta_dot = observationr2 = - abs(theta)*5# r1 = - abs(x)PGReinforce_agent.r_calculate(reward + r2)observation = observation_t += 1# PGReinforce_agent.loss_calculate()print("Episode finished after {} time steps".format(t + 1))score.append(t + 1)PGReinforce_agent.learn(5)if (i_episode + 1) % 100 == 0:plt.plot(score)  # 绘制波形# plt.draw()plt.savefig(f"RL_algorithm_package/img/pic_{0}_bias-{i_episode + 1}.png")class PGReinforce:def __init__(self, n_actions, n_features, gamma=0.9, learning_rate=0.01):self.gamma = gammaself.n_actions = n_actionsself.n_features = n_features# 轨迹self.traj = []# 网络self.pg_model = Noneself.net_init()# 一个轨迹序列的总回报self.r = []# 优化器self.opt = tf.keras.optimizers.Adam(learning_rate=learning_rate)def choose_action(self, s):"""注意,这里动作的选择与之前DQN不同,DQN是选择最大的可能性的动作,而这里要对输出进行采样,来选择动作:param s::return:"""s = s.reshape(1, 4)action_value = self.pg_model.predict(np.array(s))action = np.random.choice(np.arange(action_value.shape[1]), p=action_value[0])return actiondef r_calculate(self, r):"""每次输入r,累加算计每一个回合的总回报值r:param r: 回报值:return:"""self.r.append(r)def loss_calculate(self, bias):a_list = []s_list = []G_list = []for index in range(len(self.r)):G_list.append([self.gamma ** index * sum(self.r[index:])])for index, item in enumerate(self.traj):a_list.append(self.traj[index][1])s_list.append(self.traj[index][0])a_one_hot = tf.one_hot(a_list, self.n_actions)s = np.array(s_list)G = np.matmul(np.array(G_list), np.ones([1, 2]))out_put_a = self.pg_model(s)# 注意loss函数的计算。log_pro = tf.reduce_sum(a_one_hot * tf.math.log(out_put_a), axis=1)# 因为是计算真实的动作和网络输出动作的相似程度,需要采用reduce_mean函数来进行计算loss# 这里loss就一个值,因为是一个回合进行更新一次。意思是这个回合的相似程度,这里相当于是一个分类问题loss = - tf.reduce_mean((G - bias)*log_pro)self.r.clear()self.traj.clear()return lossdef traj_store(self, s, a):"""轨迹存储函数,将s,a存储到列表中,done以及r不必存储:param s: 状态:param a: 动作:return:"""s_list = []for index in range(self.n_features):s_list.append(s[index])self.traj.append([s_list, a])def net_init(self):inputs = tf.keras.Input(shape=(self.n_features,))d1 = tf.keras.layers.Dense(32, activation='relu')(inputs)output = tf.keras.layers.Dense(self.n_actions, activation='softmax')(d1)self.pg_model = tf.keras.Model(inputs=inputs, outputs=output)def learn(self, bias):# 更新梯度with tf.GradientTape() as Tape:loss = self.loss_calculate(bias)grads = Tape.gradient(loss, self.pg_model.trainable_variables)w = self.pg_model.get_weights()# print(f"w_before = {w}")self.opt.apply_gradients(zip(grads, self.pg_model.trainable_variables))w = self.pg_model.get_weights()# print(f"w_after = {w}")

①这里需要注意的点是REINFORCE程序的动作选择是通过对输出采样得到的,而不是采用最大化的算法。

    def choose_action(self, s):"""注意,这里动作的选择与之前DQN不同,DQN是选择最大的可能性的动作,而这里要对输出进行采样,来选择动作:param s::return:"""s = s.reshape(1, 4)action_value = self.pg_model.predict(np.array(s))action = np.random.choice(np.arange(action_value.shape[1]), p=action_value[0])return action

②这里是蒙特卡洛方法,在一轮游戏结束后,才会对神经网络进行更新。所以这里的回报值也是这一轮游戏的回报值。

    def r_calculate(self, r):"""每次输入r,累加算计每一个回合的总回报值r:param r: 回报值:return:"""self.r.append(r)

③loss计算

def loss_calculate(self, bias):a_list = []s_list = []G_list = []for index in range(len(self.r)):G_list.append([self.gamma ** index * sum(self.r[index:])])for index, item in enumerate(self.traj):a_list.append(self.traj[index][1])s_list.append(self.traj[index][0])a_one_hot = tf.one_hot(a_list, self.n_actions)s = np.array(s_list)G = np.matmul(np.array(G_list), np.ones([1, 2]))out_put_a = self.pg_model(s)# 注意loss函数的计算。log_pro = tf.reduce_sum(a_one_hot * tf.math.log(out_put_a), axis=1)# 因为是计算真实的动作和网络输出动作的相似程度,需要采用reduce_mean函数来进行计算loss# 这里loss就一个值,因为是一个回合进行更新一次。意思是这个回合的相似程度,这里相当于是一个分类问题loss = - tf.reduce_mean((G - bias)*log_pro)self.r.clear()self.traj.clear()return loss

④在训练过程中,感觉会出现过拟合的现象。网络到达一个比较好的参数之后,如果继续进行训练,系统的性能反而会下降。

PG-REINFORCE tensorflow 2.0相关推荐

  1. 请注意更新TensorFlow 2.0的旧代码

    TensorFlow 2.0 将包含许多 API 变更,例如,对参数进行重新排序.重新命名符号和更改参数的默认值.手动执行所有这些变更不仅枯燥乏味,而且容易出错.为简化变更过程并让您尽可能顺畅地过渡到 ...

  2. 独家 | TensorFlow 2.0将把Eager Execution变为默认执行模式,你该转向动态计算图了...

    机器之心报道 作者:邱陆陆 8 月中旬,谷歌大脑成员 Martin Wicke 在一封公开邮件中宣布,新版本开源框架--TensorFlow 2.0 预览版将在年底之前正式发布.今日,在上海谷歌开发者 ...

  3. 资料分享:推荐一本《简单粗暴TensorFlow 2.0》开源电子书!

    背景 本开源电子书是一篇精简的 TensorFlow 2.0 入门指导,基于 TensorFlow 的 Eager Execution(动态图)模式,力图让具备一定机器学习及 Python 基础的开发 ...

  4. 简单粗暴上手TensorFlow 2.0,北大学霸力作,必须人手一册!

    (图片付费下载自视觉中国) 整理 | 夕颜 出品 | AI科技大本营(ID:rgznai100) [导读] TensorFlow 2.0 于近期正式发布后,立即受到学术界与科研界的广泛关注与好评.此前 ...

  5. TensorFlow 2.0来了,为什么他却说“深度学习框架之争,现在谈结果为时尚早”?...

    记者 | 琥珀 出品 | AI科技大本营(ID:rgznai100) 半个多世纪前,浙江大学老校长竺可桢曾有两个非常经典的教育问题:"诸位在校,有两个问题应该自己问问,第一,到浙大来做什么? ...

  6. 掌声送给TensorFlow 2.0!用Keras搭建一个CNN | 入门教程

    作者 | Himanshu Rawlani 译者 | Monanfei,责编 | 琥珀 出品 | AI科技大本营(id:rgznai100) 2019 年 3 月 6 日,谷歌在 TensorFlow ...

  7. 我们期待的TensorFlow 2.0还有哪些变化?

    来源 | Google TensorFlow 团队 为提高 TensorFlow 的工作效率,TensorFlow 2.0 进行了多项更改,包括删除了多余的 API,使API 更加一致统一,例如统一的 ...

  8. TensorFlow 2.0新特性解读,Keras API成核心

    来源 | Google TensorFlow 团队 2018 年 11 月,TensorFlow 迎来了它的 3 岁生日,我们回顾了几年来它增加的功能,进而对另一个重要里程碑 TensorFlow 2 ...

  9. TensorFlow 2.0开发者预览版发布

    整理 | Jane 出品 | AI科技大本营 从去年 8 月 Google 公开发布消息正在研发 TensorFlow 2.0 ,让我们在 12 月 提前看到了一些 高级 API 的变化,今天我们终于 ...

  10. TensorFlow 2.0发布在即,高级API变化抢先看

    作者 | Sandeep Gupta, Josh Gordon, and Karmel Allison 整理 | 非主流.Jane 出品 | AI科技大本营 [导语]早在今年 8 月的时候,谷歌开源战 ...

最新文章

  1. 方舟手游服务器设置文件翻译,方舟生存进化手游界面翻译 方舟生存进化手机版中文对照翻译一览...
  2. mvn 打包可执行包_用Maven打包发布可执行的jar包
  3. mvc html.antiforgerytoken,MVC Html.AntiForgeryToken() 防止CSRF***
  4. 使用feed_dict不一定要用占位符
  5. 有一次面一非常想去的 飞鸽传书绿色版 公司
  6. 数据结构中三表合一的实现
  7. DEIGRP 的配置
  8. [k8s]dashboard1.8.1搭建( heapster1.5+influxdb+grafana)
  9. php取excel中的值,在Php Excel中使用列名获取单元格值
  10. sonarqube如何导入规则_webpack如何使用Vue
  11. -webkit-padding-start: 40px;ul的padding-left:40px;问题
  12. 机器学习- 吴恩达Andrew Ng Week6 知识总结 Machine Learning System Design
  13. DSSM模型的原理简介,预测两个句子的语义相似度
  14. pom文件显示删除线
  15. 索尼a5100_索尼a5100像素是多少?索尼a5100分辨率是多少?
  16. 语音特征:spectrogram、Fbank(fiterbank)、MFCC
  17. vue插槽,分分钟理解
  18. nuScenes 数据集(CVPR 2020)
  19. 分享一些实用的生活软件
  20. 最美的时候你遇见了谁

热门文章

  1. SDL(Simple DirectMedia Layer) 简介
  2. 八进制在计算机系统中的应用场景,二进制、八进制、十进制、十六进制都能干什么? 十六进制计算器使用场景...
  3. uva10635Prince and Princess(LIS)
  4. 2.1、用JsonParser解析json树模型
  5. 华为天才少年稚晖君自制硬萌机器人,开源 5 天,GitHub 收获 2900 星!
  6. 【摸鱼神器】基于python的BOSS识别系统
  7. 模型评价方法及代码实现
  8. 系统背景描述_舞台灯光网络系统及光源角度资料免费分享
  9. 【下一步计划】毕业后
  10. jsp新代码第45课