Sarsa 简介

Sarsa全称是state-action-reward-state’-action’,目的是学习特定的state下,特定action的价值Q,最终建立和优化一个Q表格,以state为行,action为列,根据与环境交互得到的reward来更新Q表格,更新公式为:

Sarsa在训练中为了更好的探索环境,采用ε-greedy方式来训练,有一定概率随机选择动作输出。

悬崖问题

找到绕过悬崖通往终端的最短路径(快速到达目的地),每走一步都有-1的惩罚,掉进悬崖会有-100的惩罚(并被拖回出发点),直到到达目的地结束游戏,如下图所示。

源程序

# Step1 导入依赖
import gym
import numpy as np
import time
import matplotlib.pyplot as plt# Step2 定义Agent
class SarsaAgent(object):def __init__(self, obs_n, act_n, lr, gamma, epsilon):self.obs_n = obs_nself.act_n = act_nself.lr = lrself.gamma = gammaself.epsilon = epsilonself.Q_table = np.zeros((obs_n, act_n))def sample(self, obs):"""根据输入观察值,采样输出的动作值,带探索:param obs:当前state:return: 下一个动作"""action = 0if np.random.uniform(0, 1) < (1.0 - self.epsilon):  # 根据table的Q值选动作action = self.predict(obs)else:action = np.random.choice(self.act_n)  # 有一定概率随机探索选取一个动作return actiondef predict(self, obs):'''根据输入观察值,预测输出的动作值:param obs:当前state:return:预测的动作'''Q_list = self.Q_table[obs, :]maxQ = np.max(Q_list)action_list = np.where(Q_list == maxQ)[0]  # maxQ可能对应多个actionaction = np.random.choice(action_list)return actiondef learn(self, obs, act, reward, next_obs, next_act, done):'''on-policy:param obs:交互前的obs, s_t:param act:本次交互选择的action, a_t:param reward:本次动作获得的奖励r:param next_obs:本次交互后的obs, s_t+1:param next_act:根据当前Q表格, 针对next_obs会选择的动作, a_t+1:param done:episode是否结束:return:null'''predict_Q = self.Q_table[obs, act]if done:target_Q = reward  # 没有下一个状态了else:target_Q = reward + self.gamma * self.Q_table[next_obs, next_act]  # Sarsaself.Q_table[obs, act] += self.lr * (target_Q - predict_Q)  # 修正q# 保存Q表格数据到文件def save(self):npy_file = './q_table.npy'np.save(npy_file, self.Q_table)print(npy_file + ' saved.')# 从文件中读取Q值到Q表格中def restore(self, npy_file='./q_table.npy'):self.Q_table = np.load(npy_file)print(npy_file + ' loaded.')# Step3 Training && Test(训练&&测试)
def train_episode(env, agent, render=False):total_reward = 0total_steps = 0  # 记录每个episode走了多少stepobs = env.reset()act = agent.sample(obs)while True:next_obs, reward, done, _ = env.step(act)  # 与环境进行一个交互next_act = agent.sample(next_obs)  # 根据算法选择一个动作# 训练Sarsa算法agent.learn(obs, act, reward, next_obs, next_act, done)act = next_actobs = next_obs  # 存储上一个观察值total_reward += rewardtotal_steps += 1if render:env.render()  # 渲染新的一帧图形if done:breakreturn total_reward, total_stepsdef test_episode(env, agent):total_reward = 0total_steps = 0  # 记录每个episode走了多少stepobs = env.reset()while True:action = agent.predict(obs)  # greedynext_obs, reward, done, _ = env.step(action)total_reward += rewardtotal_steps += 1obs = next_obs# time.sleep(0.5)# env.render()if done:breakreturn total_reward, total_steps# Step4 创建环境和Agent,启动训练# 使用gym创建悬崖环境
env = gym.make("CliffWalking-v0")  # 0 up, 1 right, 2 down, 3 left# 创建一个agent实例,输入超参数
agent = SarsaAgent(obs_n=env.observation_space.n,act_n=env.action_space.n,lr=0.001,gamma=0.99,epsilon=0.1
)print("Start training ...")
total_reward_list = []
# 训练1000个episode,打印每个episode的分数
for episode in range(1000):ep_reward, ep_steps = train_episode(env, agent, False)total_reward_list.append(ep_reward)if episode % 50 == 0:print('Episode %s: steps = %s , reward = %.1f' % (episode, ep_steps, ep_reward))print("Train end")
def show_reward(total_reward):N = len(total_reward)x = np.linspace(0, N, 1000)plt.plot(x, total_reward, 'b-', lw=1, ms=5)plt.show()show_reward(total_reward_list)# 全部训练结束,查看算法效果
test_reward, test_steps = test_episode(env, agent)
print('test steps = %.1f , reward = %.1f' % (test_steps, test_reward))

实验结果

训练1000个episode的奖励情况如下图:

Start training ...
Episode 0: steps = 1160 , reward = -3239.0
Episode 50: steps = 138 , reward = -237.0
Episode 100: steps = 94 , reward = -94.0
Episode 150: steps = 1103 , reward = -1301.0
Episode 200: steps = 1257 , reward = -1257.0
Episode 250: steps = 213 , reward = -213.0
Episode 300: steps = 106 , reward = -106.0
Episode 350: steps = 73 , reward = -73.0
Episode 400: steps = 98 , reward = -98.0
Episode 450: steps = 194 , reward = -194.0
Episode 500: steps = 229 , reward = -229.0
Episode 550: steps = 198 , reward = -495.0
Episode 600: steps = 125 , reward = -125.0
Episode 650: steps = 76 , reward = -76.0
Episode 700: steps = 186 , reward = -285.0
Episode 750: steps = 126 , reward = -225.0
Episode 800: steps = 1429 , reward = -2023.0
Episode 850: steps = 132 , reward = -132.0
Episode 900: steps = 743 , reward = -1040.0
Episode 950: steps = 136 , reward = -136.0
Train end
test steps = 125.0 , reward = -125.0Process finished with exit code 0

强化学习实战-使用Sarsa算法解决悬崖问题相关推荐

  1. 强化学习实战-使用Q-learning算法解决悬崖问题

    Q-learning简介 Q-learning也是采用Q表格的方式存储Q值(状态动作价值),决策部分与Sarsa是一样的,采用ε-greedy方式增加探索. Q-learning跟Sarsa不一样的地 ...

  2. 强化学习之Q-learning与Sarsa算法解决悬崖寻路问题

    之前有写过利用Q-learning算法去解决-> 一维二维探宝游戏:https://blog.csdn.net/MR_kdcon/article/details/109612413 有风格子寻路 ...

  3. 强化学习笔记:Sarsa算法

    1 Sarsa(0) Sarsa算法和TD类似,只不过TD是更新状态的奖励函数V,这里是更新Q函数强化学习笔记:Q-learning :temporal difference 方法_UQI-LIUWJ ...

  4. 【强化学习实战】基于gym和tensorflow的强化学习算法实现

    [新智元导读]知乎专栏强化学习大讲堂作者郭宪博士开讲<强化学习从入门到进阶>,我们为您节选了其中的第二节<基于gym和tensorflow的强化学习算法实现>,希望对您有所帮助 ...

  5. 【经典书籍】深度强化学习实战(附最新PDF和源代码下载)

    关注上方"深度学习技术前沿",选择"星标公众号", 资源干货,第一时间送达! 深度强化学习可以说是人工智能领域现在最热门的方向,吸引了众多该领域优秀的科学家去发 ...

  6. 华为诺亚ICLR 2020满分论文:基于强化学习的因果发现算法

    2019-12-30 13:04:12 人工智能顶会 ICLR 2020 将于明年 4 月 26 日于埃塞俄比亚首都亚的斯亚贝巴举行,不久之前,大会官方公布论文接收结果:在最终提交的 2594 篇论文 ...

  7. 【强化学习实战-04】DQN和Double DQN保姆级教程(2):以MountainCar-v0

    [强化学习实战-04]DQN和Double DQN保姆级教程(2):以MountainCar-v0 实战:用Double DQN求解MountainCar问题 MountainCar问题详解 Moun ...

  8. PyTorch强化学习实战(1)——强化学习环境配置与PyTorch基础

    PyTorch强化学习实战(1)--强化学习环境配置与PyTorch基础 0. 前言 1. 搭建 PyTorch 环境 2. OpenAI Gym简介与安装 3. 模拟 Atari 环境 4. 模拟 ...

  9. 基于强化学习的服务链映射算法

    2018年1月   <通信学报>    魏亮,黄韬,张娇,王泽南,刘江,刘韵洁 摘要 提出基于人工智能技术的多智能体服务链资源调度架构,设计一种基于强化学习的服务链映射算法.通过Q-lea ...

最新文章

  1. 2014年数字:我的人生在命令行中
  2. 架构设计之分布式文件系统
  3. 【PM模块】维护处理的控制和报告
  4. 信息系统项目管理师-信息安全管理考点笔记
  5. 使用类计算矩形的面积
  6. html 加入li的作用,HTML的li有什么作用?
  7. 不积跬步无以至千里[转]
  8. vs 没有足够的内存继续执行程序_科赋内存条:韩国和台湾产的有不同?
  9. mysql 交叉统计_统计知识——交叉分组表
  10. INV标准报表+INVARAAS.rdf -- ABC分配报表
  11. java集合详解_Map、Set、List及其子类和接口你都明白吗?看这篇Java集合超详解
  12. 堆化 二叉堆一般用数组来表示。typedef struct _minHeapNodetypedef struct _otherInfo-icoding-C-数据结构
  13. Codeforces Global Round 12 C1 C2. Errich-Tac-Toe 思维构造 好题
  14. linux mysql启动_MySQL 安装(二)
  15. linux c c 常用的日志库,mslog: 一款超轻量级的C日志库,无需依赖额外的库,测试或移植过的系统有Linux(ubuntu,centos),Windows以及部分嵌入式设备;...
  16. 批处理作业调度问题 ——回溯法详解
  17. 计算机网络 HTTP工作机制 TCP三次握手四次挥手 TCP滑动窗口
  18. UEditor 之初体验后记
  19. JQuery Easy Ui dataGrid 数据表格 --转
  20. 编译器--简单数学表达式计算器(一)

热门文章

  1. QT MetaImage 一款图片工具软件
  2. java button属性设置_java的JButton怎样设置内边距
  3. 元数据管理平台技术白皮书
  4. 【诗歌】值得背诵古诗(一)
  5. DiabloFX展示模板 joomla多用途二手房公司企业博客商业主题
  6. win10配置docker环境
  7. 地税局工资管理系统(论文+设计)新
  8. 计算机主机安装图,最新版本:计算机主机插件安装图_布法罗计算机主板安装图...
  9. 缓存的穿透、击穿、雪崩分别是什么,有什么解决方法
  10. 如何为谷歌浏览器启用暗模式