问题描述:

确定环境中的最佳操作的规则叫做策略,学习这些策略的网络称为策略网络。

代码展示:

import numpy as np
import gym
import tensorflow as tf
import matplotlib.pyplot as plt#Pong env
env = gym.make("Pong-v0")
observation = env.reset()
for i in range(22):#20 帧之后发球if i>20:plt.imshow(observation)plt.show()#得到下一个观察observation,_,_,_ = env.step(1)#函数预处理输入数据
def preprocess_frame(frame):# 移去图像顶部和某些背景frame = frame[35:195,10:150] # 图像帧度灰度化并缩小1/2frame = frame [::2,::2,0]# 设置背景值为0frame[frame==144] =0frame[frame ==109] = 0# 设置球拍及拍数为1frame[frame != 0] =1return frame.astype(np.float).ravel()obs_preprocessed = preprocess_frame(observation).reshape(80,70)
plt.imshow(obs_preprocessed,cmap ='gray')
plt.show()observation_next,_,_,_ = env.step(1)
diff = preprocess_frame(observation_next) - preprocess_frame(observation)
plt.imshow(diff.reshape(80,70),cmap='gray')
plt.show()input_dim = 80*70
hidden_L1 = 400
hidden_L2 = 200
actions = [1,2,3]
n_actions = len(actions)
model = {}
with tf.compat.v1.variable_scope('L1',reuse=False):inint_W1 = tf.compat.v1.truncated_normal_initializer(mean = 0,stddev=1./np.sqrt(input_dim),dtype=tf.float32)model['W1'] = tf.compat.v1.get_variable('W1',[input_dim,hidden_L1],initializer=inint_W1)with tf.compat.v1.variable_scope('L2',reuse=False):init_W2 = tf.compat.v1.truncated_normal_initializer(mean = 0,stddev=1./np.sqrt(hidden_L1),dtype=tf.float32)model['W2']= tf.compat.v1.get_variable('W2',[hidden_L1,n_actions],initializer=init_W2)#策略函数
def policy_forward(x):tf.compat.v1.disable_eager_execution()x = tf.matmul(x,model['W1'])x = tf.nn.relu(x)x = tf.matmul(x,model['W2'])p = tf.nn.softmax(x)return p#折扣奖励函数
def discounted_rewards(reward,gamma):discounted_function = lambda a,v:a*gamma +v;reward_reverse = tf.scan(discounted_function,tf.reverse(reward,[True,False]))discounted_reward = tf.reverse(reward_reverse,[True,False])return discounted_rewardlearning_rate = 0.001
gamma = 0.99
batch_size = 10tf.compat.v1.disable_eager_execution()
#定义占位符并单独设置反向更新
episode_x = tf.compat.v1.placeholder(dtype=tf.float32,shape=[None,input_dim])
episode_y  = tf.compat.v1.placeholder(dtype=tf.float32,shape = [None,n_actions])
episode_reward = tf.compat.v1.placeholder(dtype = tf.float32,shape=[None,1])episode_discounted_reward = discounted_rewards(episode_reward,gamma)
episode_mean,episode_variance = tf.nn.moments(episode_discounted_reward,[0],shift = None)#标准化折扣后的收益
episode_discounted_reward -= episode_mean
episode_discounted_reward /= tf.sqrt(episode_variance + 1e-6)#优化器设定
tf.compat.v1.disable_v2_behavior()
tf.compat.v1.disable_eager_execution()tf_aprob = policy_forward(episode_x)
loss = tf.nn.l2_loss(episode_y - tf_aprob)
optimizer = tf.compat.v1.train.AdadeltaOptimizer(learning_rate)
gradients = optimizer.compute_gradients(loss,var_list= tf.compat.v1.trainable_variables(),grad_loss = episode_discounted_reward)
train_op = optimizer.apply_gradients(gradients)#图像初始化
sess = tf.compat.v1.InteractiveSession()
tf.compat.v1.global_variables_initializer().run()#训练模型存储设定saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables())
save_path = 'checkpoints/pong_rl.ckt'obs_prev = None
xs,ys,rs = [],[],[]
reward_sum = 0
episode_number = 0
reward_window = None
reward_best = -22
history = []observation = env.reset()
while True:if True:env.render()#预处理观测值,加载不同的图像给网络obs_cur = preprocess_frame(observation)obs_diff = obs_cur - obs_prev if obs_prev is not None else np.zeros(input_dim)obs_prev = obs_cur#策略采样一次动作feed = {episode_x:np.reshape(obs_diff,(1,-1))}aprob = sess.run(tf_aprob,feed)aprob = aprob[0,:]action = np.random.choice(n_actions,p=aprob)label = np.zeros_like(aprob)label[action] =1#返回环境动作并提取下一个观测,回报和状态observation,reward,done,info = env.step(action +1)if done:observation = env.reset()reward_sum += reward#记录游戏历史xs.append(obs_diff)ys.append(label)rs.append(reward)if done:history.append(reward_sum)reward_window = -21 if reward_window is None else np.mean(history[-100:])#用存储值更新权重 - 更新策略feed = {episode_x : np.vstack(xs),episode_y:np.vstack(ys),episode_reward:np.vstack(rs),}_ = sess.run(train_op,feed)print('epochs {:2d}: reward :{:2.0f}'.format(episode_number,reward_sum))xs,ys,rs = [],[],[]episode_number += 1observation = env.reset()reward_sum = 0#10个场景后存储最佳模型if (episode_number % 10 == 0) & (reward_window > reward_best):saver.save(sess,save_path,global_step=episode_number)reward_best = reward_windowprint('save best model {:2d}:{:2.5f} (reward window)'.format(episode_number,reward_window))

实现截图:

参考:

《Python深度学习实战:75个有关神经网络建模、强化学习与迁移》

【强化学习】在Pong环境下实现策略梯度相关推荐

  1. 【深入浅出强化学习-编程实战】 7 基于策略梯度的强化学习-Cartpole(小车倒立摆系统)

    [深入浅出强化学习-编程实战] 7 基于策略梯度的强化学习-Cartpole 小车倒立摆MDP模型 代码 代码解析 小车倒立摆MDP模型 状态输入:s=[x,x˙,θ,θ˙]s = [x,\dot{x ...

  2. 【githubshare】深度学习蘑菇书,覆盖了强化学习、马尔可夫决策过程、策略梯度、模仿学习

    GitHub 上的深度学习技术书籍:<蘑菇书 EasyRL>,覆盖了强化学习.马尔可夫决策过程.策略梯度.模仿学习等多个知识点. GitHub:github.com/datawhalech ...

  3. 强化学习(Reinforcement Learning)之策略梯度(Policy Gradient)的一点点理解以及代码的对应解释

    一.策略梯度算法推导以及解释 1.1 背景 设πθ(s)\pi_{\theta }(s)πθ​(s)是一个有网络参数θ\thetaθ的actor,然后我们让这个actor和环境(environment ...

  4. 【强化学习】spinningup最简单的策略梯度(VPG)代码详细注释——基于pytorch实现

    参考链接:https://spinningup.qiwihui.com/zh_CN/latest/spinningup/rl_intro3.html 需要配合spinningup的公式推导 全部代码 ...

  5. 不等式视角下的策略梯度算法

    本文首发于:行者AI 强化学习(Reinforcement Learning,RL),也叫增强学习,是指一类从(与环境)交互中不断学习的问题以及解决这类问题的方法.强化学习问题可以描述为一个智能体从与 ...

  6. ML之RL:基于MovieLens电影评分数据集利用强化学习算法(多臂老虎机+EpsilonGreedy策略)实现对用户进行Top电影推荐案例

    ML之RL:基于MovieLens电影评分数据集利用强化学习算法(多臂老虎机+EpsilonGreedy策略)实现对用户进行Top电影推荐案例 目录 基于MovieLens电影评分数据集利用强化学习算 ...

  7. 【大咖说Ⅲ】谢娟英教授:基于深度学习的野外环境下蝴蝶物种自动识别

    欢迎来到2022 CCF BDCI 大咖说系列专题报告 听顶级专家学者围绕特定技术领域或选题,讲述自身成果的研究价值与实际应用价值 便于广大技术发烧友.大赛参赛者吸收学术知识,强化深度学习 每周一.三 ...

  8. java spring 实现策略,Spring 环境下实现策略模式的示例

    背景 最近在忙一个需求,大致就是给满足特定条件的用户发营销邮件,但是用户的来源有很多方式:从 ES 查询的.从 csv 导入的.从 MongoDB 查询-.. 需求很简单,但是怎么写的优雅,方便后续扩 ...

  9. Linux环境下基于策略的路由

    Linux环境下基于策略的路由 原文作者:Matthew G. Marsh 原文出处:[url]http://www.sysadminmag.com/linux/articles/v09/i01/a3 ...

  10. 【PyTorch深度强化学习】带基线的蒙特卡洛策略梯度法(REINFOECE)在短走廊和CartPole环境下的实战(超详细 附源码)

    需要源码请点赞关注收藏后评论区留言留下QQ~~~ 一.带基线的REINFORCE REINFORCE的优势在于只需要很小的更新步长就能收敛到局部最优,并保证了每次更新都是有利的,但是假设每个动作的奖赏 ...

最新文章

  1. 安卓 java内存碎片_理解Android Java垃圾回收机制
  2. jQuery mobile 之三
  3. JAVA软件图片浏览下载_java模拟浏览器下载图片
  4. Bootstrap3 带条纹的表格样式
  5. 【IDEA】向IntelliJ IDEA创建的项目导入Jar包的两种方式
  6. 【设计模式:单例模式】单例模式01:饿汉模式
  7. JAVA正则提取字符串中的日期
  8. opengl笔记——OpenGL好资料备忘
  9. poj 2378 树型dp
  10. RVCT31编译问题
  11. 清理谷歌浏览器注册表_chrome注册表怎么清理_如何清理没用的chrome注册表-win7之家...
  12. 子进程 已安装 pre-removal 脚本 返回了错误号 1或2 与 子进程 已安装 post-installation 脚本 返回了错误号 1或2
  13. 计算机维护系统Win8PE,U盘启动计算机维护系统
  14. 百度快照劫持之JS劫持诊断与恢复教程
  15. linux tmp php文件怎么打开,tmp文件用什么打开
  16. OutputFormat类——Hadoop
  17. Android脑图--Android动画
  18. 计算机系统分盘作用,电脑为什么要分区,分区的好处
  19. 【记录】Nginx开源版安装与部署
  20. 小程序app.js的配置

热门文章

  1. KK集团旗下公司又遭处罚:招股书已“失效”一个月,快客电商曾被罚30万元
  2. Kubuntu终端中文显示一半解决办法
  3. CSS 3.0实现八卦图
  4. 孪生网络 应用_数字孪生照进现实,Unity如何打造数字世界的基础设施?
  5. 【错误】E45: ‘readonly‘ option is set (add to override)
  6. 【Python】使用torrentParser1.03对单文件torrent的分析结果
  7. 用神经网络实现机器翻译实战
  8. 中国人民银行招聘计算机考什么,求中国人民银行招聘计算机专业人员的考试题。...
  9. 正则表达式(正则表达式的方法和属性、正则的修饰符、表达式、元字符、量词)
  10. python 网络设备管理_「python」使用Telnet进行网络设备巡检