【强化学习】在Pong环境下实现策略梯度
问题描述:
确定环境中的最佳操作的规则叫做策略,学习这些策略的网络称为策略网络。
代码展示:
import numpy as np
import gym
import tensorflow as tf
import matplotlib.pyplot as plt#Pong env
env = gym.make("Pong-v0")
observation = env.reset()
for i in range(22):#20 帧之后发球if i>20:plt.imshow(observation)plt.show()#得到下一个观察observation,_,_,_ = env.step(1)#函数预处理输入数据
def preprocess_frame(frame):# 移去图像顶部和某些背景frame = frame[35:195,10:150] # 图像帧度灰度化并缩小1/2frame = frame [::2,::2,0]# 设置背景值为0frame[frame==144] =0frame[frame ==109] = 0# 设置球拍及拍数为1frame[frame != 0] =1return frame.astype(np.float).ravel()obs_preprocessed = preprocess_frame(observation).reshape(80,70)
plt.imshow(obs_preprocessed,cmap ='gray')
plt.show()observation_next,_,_,_ = env.step(1)
diff = preprocess_frame(observation_next) - preprocess_frame(observation)
plt.imshow(diff.reshape(80,70),cmap='gray')
plt.show()input_dim = 80*70
hidden_L1 = 400
hidden_L2 = 200
actions = [1,2,3]
n_actions = len(actions)
model = {}
with tf.compat.v1.variable_scope('L1',reuse=False):inint_W1 = tf.compat.v1.truncated_normal_initializer(mean = 0,stddev=1./np.sqrt(input_dim),dtype=tf.float32)model['W1'] = tf.compat.v1.get_variable('W1',[input_dim,hidden_L1],initializer=inint_W1)with tf.compat.v1.variable_scope('L2',reuse=False):init_W2 = tf.compat.v1.truncated_normal_initializer(mean = 0,stddev=1./np.sqrt(hidden_L1),dtype=tf.float32)model['W2']= tf.compat.v1.get_variable('W2',[hidden_L1,n_actions],initializer=init_W2)#策略函数
def policy_forward(x):tf.compat.v1.disable_eager_execution()x = tf.matmul(x,model['W1'])x = tf.nn.relu(x)x = tf.matmul(x,model['W2'])p = tf.nn.softmax(x)return p#折扣奖励函数
def discounted_rewards(reward,gamma):discounted_function = lambda a,v:a*gamma +v;reward_reverse = tf.scan(discounted_function,tf.reverse(reward,[True,False]))discounted_reward = tf.reverse(reward_reverse,[True,False])return discounted_rewardlearning_rate = 0.001
gamma = 0.99
batch_size = 10tf.compat.v1.disable_eager_execution()
#定义占位符并单独设置反向更新
episode_x = tf.compat.v1.placeholder(dtype=tf.float32,shape=[None,input_dim])
episode_y = tf.compat.v1.placeholder(dtype=tf.float32,shape = [None,n_actions])
episode_reward = tf.compat.v1.placeholder(dtype = tf.float32,shape=[None,1])episode_discounted_reward = discounted_rewards(episode_reward,gamma)
episode_mean,episode_variance = tf.nn.moments(episode_discounted_reward,[0],shift = None)#标准化折扣后的收益
episode_discounted_reward -= episode_mean
episode_discounted_reward /= tf.sqrt(episode_variance + 1e-6)#优化器设定
tf.compat.v1.disable_v2_behavior()
tf.compat.v1.disable_eager_execution()tf_aprob = policy_forward(episode_x)
loss = tf.nn.l2_loss(episode_y - tf_aprob)
optimizer = tf.compat.v1.train.AdadeltaOptimizer(learning_rate)
gradients = optimizer.compute_gradients(loss,var_list= tf.compat.v1.trainable_variables(),grad_loss = episode_discounted_reward)
train_op = optimizer.apply_gradients(gradients)#图像初始化
sess = tf.compat.v1.InteractiveSession()
tf.compat.v1.global_variables_initializer().run()#训练模型存储设定saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables())
save_path = 'checkpoints/pong_rl.ckt'obs_prev = None
xs,ys,rs = [],[],[]
reward_sum = 0
episode_number = 0
reward_window = None
reward_best = -22
history = []observation = env.reset()
while True:if True:env.render()#预处理观测值,加载不同的图像给网络obs_cur = preprocess_frame(observation)obs_diff = obs_cur - obs_prev if obs_prev is not None else np.zeros(input_dim)obs_prev = obs_cur#策略采样一次动作feed = {episode_x:np.reshape(obs_diff,(1,-1))}aprob = sess.run(tf_aprob,feed)aprob = aprob[0,:]action = np.random.choice(n_actions,p=aprob)label = np.zeros_like(aprob)label[action] =1#返回环境动作并提取下一个观测,回报和状态observation,reward,done,info = env.step(action +1)if done:observation = env.reset()reward_sum += reward#记录游戏历史xs.append(obs_diff)ys.append(label)rs.append(reward)if done:history.append(reward_sum)reward_window = -21 if reward_window is None else np.mean(history[-100:])#用存储值更新权重 - 更新策略feed = {episode_x : np.vstack(xs),episode_y:np.vstack(ys),episode_reward:np.vstack(rs),}_ = sess.run(train_op,feed)print('epochs {:2d}: reward :{:2.0f}'.format(episode_number,reward_sum))xs,ys,rs = [],[],[]episode_number += 1observation = env.reset()reward_sum = 0#10个场景后存储最佳模型if (episode_number % 10 == 0) & (reward_window > reward_best):saver.save(sess,save_path,global_step=episode_number)reward_best = reward_windowprint('save best model {:2d}:{:2.5f} (reward window)'.format(episode_number,reward_window))
实现截图:
参考:
《Python深度学习实战:75个有关神经网络建模、强化学习与迁移》
【强化学习】在Pong环境下实现策略梯度相关推荐
- 【深入浅出强化学习-编程实战】 7 基于策略梯度的强化学习-Cartpole(小车倒立摆系统)
[深入浅出强化学习-编程实战] 7 基于策略梯度的强化学习-Cartpole 小车倒立摆MDP模型 代码 代码解析 小车倒立摆MDP模型 状态输入:s=[x,x˙,θ,θ˙]s = [x,\dot{x ...
- 【githubshare】深度学习蘑菇书,覆盖了强化学习、马尔可夫决策过程、策略梯度、模仿学习
GitHub 上的深度学习技术书籍:<蘑菇书 EasyRL>,覆盖了强化学习.马尔可夫决策过程.策略梯度.模仿学习等多个知识点. GitHub:github.com/datawhalech ...
- 强化学习(Reinforcement Learning)之策略梯度(Policy Gradient)的一点点理解以及代码的对应解释
一.策略梯度算法推导以及解释 1.1 背景 设πθ(s)\pi_{\theta }(s)πθ(s)是一个有网络参数θ\thetaθ的actor,然后我们让这个actor和环境(environment ...
- 【强化学习】spinningup最简单的策略梯度(VPG)代码详细注释——基于pytorch实现
参考链接:https://spinningup.qiwihui.com/zh_CN/latest/spinningup/rl_intro3.html 需要配合spinningup的公式推导 全部代码 ...
- 不等式视角下的策略梯度算法
本文首发于:行者AI 强化学习(Reinforcement Learning,RL),也叫增强学习,是指一类从(与环境)交互中不断学习的问题以及解决这类问题的方法.强化学习问题可以描述为一个智能体从与 ...
- ML之RL:基于MovieLens电影评分数据集利用强化学习算法(多臂老虎机+EpsilonGreedy策略)实现对用户进行Top电影推荐案例
ML之RL:基于MovieLens电影评分数据集利用强化学习算法(多臂老虎机+EpsilonGreedy策略)实现对用户进行Top电影推荐案例 目录 基于MovieLens电影评分数据集利用强化学习算 ...
- 【大咖说Ⅲ】谢娟英教授:基于深度学习的野外环境下蝴蝶物种自动识别
欢迎来到2022 CCF BDCI 大咖说系列专题报告 听顶级专家学者围绕特定技术领域或选题,讲述自身成果的研究价值与实际应用价值 便于广大技术发烧友.大赛参赛者吸收学术知识,强化深度学习 每周一.三 ...
- java spring 实现策略,Spring 环境下实现策略模式的示例
背景 最近在忙一个需求,大致就是给满足特定条件的用户发营销邮件,但是用户的来源有很多方式:从 ES 查询的.从 csv 导入的.从 MongoDB 查询-.. 需求很简单,但是怎么写的优雅,方便后续扩 ...
- Linux环境下基于策略的路由
Linux环境下基于策略的路由 原文作者:Matthew G. Marsh 原文出处:[url]http://www.sysadminmag.com/linux/articles/v09/i01/a3 ...
- 【PyTorch深度强化学习】带基线的蒙特卡洛策略梯度法(REINFOECE)在短走廊和CartPole环境下的实战(超详细 附源码)
需要源码请点赞关注收藏后评论区留言留下QQ~~~ 一.带基线的REINFORCE REINFORCE的优势在于只需要很小的更新步长就能收敛到局部最优,并保证了每次更新都是有利的,但是假设每个动作的奖赏 ...
最新文章
- 安卓 java内存碎片_理解Android Java垃圾回收机制
- jQuery mobile 之三
- JAVA软件图片浏览下载_java模拟浏览器下载图片
- Bootstrap3 带条纹的表格样式
- 【IDEA】向IntelliJ IDEA创建的项目导入Jar包的两种方式
- 【设计模式:单例模式】单例模式01:饿汉模式
- JAVA正则提取字符串中的日期
- opengl笔记——OpenGL好资料备忘
- poj 2378 树型dp
- RVCT31编译问题
- 清理谷歌浏览器注册表_chrome注册表怎么清理_如何清理没用的chrome注册表-win7之家...
- 子进程 已安装 pre-removal 脚本 返回了错误号 1或2 与 子进程 已安装 post-installation 脚本 返回了错误号 1或2
- 计算机维护系统Win8PE,U盘启动计算机维护系统
- 百度快照劫持之JS劫持诊断与恢复教程
- linux tmp php文件怎么打开,tmp文件用什么打开
- OutputFormat类——Hadoop
- Android脑图--Android动画
- 计算机系统分盘作用,电脑为什么要分区,分区的好处
- 【记录】Nginx开源版安装与部署
- 小程序app.js的配置
热门文章
- KK集团旗下公司又遭处罚:招股书已“失效”一个月,快客电商曾被罚30万元
- Kubuntu终端中文显示一半解决办法
- CSS 3.0实现八卦图
- 孪生网络 应用_数字孪生照进现实,Unity如何打造数字世界的基础设施?
- 【错误】E45: ‘readonly‘ option is set (add to override)
- 【Python】使用torrentParser1.03对单文件torrent的分析结果
- 用神经网络实现机器翻译实战
- 中国人民银行招聘计算机考什么,求中国人民银行招聘计算机专业人员的考试题。...
- 正则表达式(正则表达式的方法和属性、正则的修饰符、表达式、元字符、量词)
- python 网络设备管理_「python」使用Telnet进行网络设备巡检