表格型方法——Sarsa

  • 简介
  • 实战

简介

  • Sarsa全称是state-action-reward-state’-action’,目的是学习特定的state下,特定action的价值Q,最终建立和优化一个Q表格,以state为行,action为列,根据与环境交互得到的reward来更新Q表格,更新公式为:

  • Sarsa在训练中为了更好的探索环境,采用ε-greedy方式来训练,有一定概率随机选择动作输出。

实战

  • 使用 Sarsa 解决机器人找金币问题。
    机器人找金币环境下载

Agent

  • Agent是和环境environment交互的主体。
  • predict()方法:输入观察值observation(或者说状态state),输出动作值
  • sample()方法:再predict()方法基础上使用ε-greedy增加探索
  • learn()方法:输入训练数据,完成一轮Q表格的更新
# agent.py
class SarsaAgent(object):def __init__(self, obs_n, act_n, learning_rate=0.01, gamma=0.9, e_greed=0.1):self.act_n = act_n      # 动作维度,有几个动作可选self.lr = learning_rate # 学习率self.gamma = gamma      # reward的衰减率self.epsilon = e_greed  # 按一定概率随机选动作self.Q = np.zeros((obs_n, act_n))# 根据输入观察值,采样输出的动作值,带探索def sample(self, obs):if np.random.uniform(0, 1) < (1.0 - self.epsilon): #根据table的Q值选动作action = self.predict(obs)else:action = np.random.choice(self.act_n) #有一定概率随机探索选取一个动作return action# 根据输入观察值,预测输出的动作值def predict(self, obs):Q_list = self.Q[obs, :]maxQ = np.max(Q_list)action_list = np.where(Q_list == maxQ)[0]  # maxQ可能对应多个actionaction = np.random.choice(action_list)return action# 学习方法,也就是更新Q-table的方法def learn(self, obs, action, reward, next_obs, next_action, done):""" on-policyobs: 交互前的obs, s_taction: 本次交互选择的action, a_treward: 本次动作获得的奖励rnext_obs: 本次交互后的obs, s_t+1next_action: 根据当前Q表格, 针对next_obs会选择的动作, a_t+1done: episode是否结束"""predict_Q = self.Q[obs, action]if done:target_Q = reward # 没有下一个状态了else:target_Q = reward + self.gamma * self.Q[next_obs, next_action] # Sarsaself.Q[obs, action] += self.lr * (target_Q - predict_Q) # 修正q# 保存Q表格数据到文件def save(self):npy_file = './q_table.npy'np.save(npy_file, self.Q)print(npy_file + ' saved.')# 从文件中读取Q值到Q表格中def restore(self, npy_file='./q_table.npy'):self.Q = np.load(npy_file)print(npy_file + ' loaded.')

Training && Test(训练&&测试)

  • run_episode():agent在一个episode中训练的过程,使用agent.sample()与环境交互,使用agent.learn()训练Q表格。
  • test_episode():agent在一个episode中测试效果,评估目前的agent能在一个episode中拿到多少总reward。
def run_episode(env, agent, render=False):total_steps = 0 # 记录每个episode走了多少steptotal_reward = 0obs = env.reset() # 重置环境, 重新开一局(即开始新的一个episode)action = agent.sample(obs) # 根据算法选择一个动作while True:next_obs, reward, done, _ = env.step(env.actions[action]) # 与环境进行一个交互next_action = agent.sample(next_obs-1) # 根据算法选择一个动作# 训练 Sarsa 算法agent.learn(obs-1, action, reward, next_obs-1, next_action, done)action = next_actionobs = next_obs  # 存储上一个观察值total_reward += rewardtotal_steps += 1 # 计算step数if render:env.render() #渲染新的一帧图形if done:breakreturn total_reward, total_stepsdef test_episode(env, agent):total_reward = 0obs = env.reset()while True:action = agent.predict(obs-1) # greedynext_obs, reward, done, _ = env.step(env.actions[action])total_reward += rewardobs = next_obstime.sleep(0.5)env.render()if done:breakreturn total_reward

创建环境和Agent,启动训练

env = GridEnv()agent = SarsaAgent(obs_n = len(env.states),    # 状态维度act_n = len(env.actions),   # 动作维度,有几个动作可选learning_rate = 0.1,        # 学习率gamma = 0.9,                # reward的衰减率e_greed = 0.1               # 按一定概率随机选动作
)# 训练500个episode,打印每个episode的分数
for episode in range(500):ep_reward, ep_steps = run_episode(env, agent, True)print('Episode %s: steps = %s , reward = %.1f' % (episode, ep_steps, ep_reward))
print("保存权重")
agent.save()
# 全部训练结束,查看算法效果
# agent.restore()test_reward = test_episode(env, agent)
print('test reward = %.1f' % (test_reward))
env.close()

强化学习——Sarsa算法相关推荐

  1. 强化学习—— TD算法(Sarsa算法+Q-learning算法)

    强化学习-- TD算法(Sarsa算法+Q-learning算法) 1. Sarsa算法 1.1 TD Target 1.2 表格形式的Sarsa算法 1.3 神经网络形式的Sarsa算法 2. Q- ...

  2. 强化学习常用算法总结

    强化学习常用算法总结 本文为2020年6月参加的百度PaddlePaddle强化学习训练营总结 1. 表格型方法:Sarsa和Q-Learning算法 State-action-reward-stat ...

  3. 【人工智能II】实验2 强化学习Q-Learning算法

    强化学习Q-Learning算法 核心思想 实验原理 实验流程图 实验分析 理解Q-Learning算法 GYM库 更换实验环境 实验代码 Q-Learning: Sarsa代码 搞不懂我一个本科生为 ...

  4. [PARL强化学习]Sarsa和Q—learning的实现

    [PARL强化学习]Sarsa和Q-learning的实现 Sarsa和Q-learning都是利用表格法再根据MDP四元组<S,A,P,R>:S: state状态,a: action动作 ...

  5. 深度强化学习-DDPG算法原理和实现

    全文共3077个字,8张图,预计阅读时间15分钟. 基于值的强化学习算法的基本思想是根据当前的状态,计算采取每个动作的价值,然后根据价值贪心的选择动作.如果我们省略中间的步骤,即直接根据当前的状态来选 ...

  6. 基于强化学习SAC_LSTM算法的机器人导航

    [前言]在人群之间导航的机器人通常使用避碰算法来实现安全高效的导航.针对人群中机器人的导航问题,本文采用强化学习SAC算法,并结合LSTM长短期记忆网络,提高移动机器人的导航性能.在我们的方法中,机器 ...

  7. 【强化学习PPO算法】

    强化学习PPO算法 一.PPO算法 二.伪代码 三.相关的简单理论 1.ratio 2.裁断 3.Advantage的计算 4.loss的计算 四.算法实现 五.效果 六.感悟   最近再改一个代码, ...

  8. DRL:强化学习-Q-Learning算法

    文章目录 强化学习 Q-Learning算法 1. 问题及原因 2. Estimator原理与思想 (1)单估计器方法(Single Estimator) (2)双估计器方法(Double Estim ...

  9. 深度强化学习主流算法介绍(二):DPG系列

    之前的文章可以看这里 深度强化学习主流算法介绍(一):DQN系列 相关论文在这里 开始介绍DPG之前,先回顾下DQN系列 DQN直接训练一个Q Network 去估计每个离散动作的Q值,使用时选择Q值 ...

  10. 强化学习 五子棋算法

    强化学习 五子棋算法 蒙特卡洛树搜索 MCTS 蒙特卡洛树搜索算法 上限置信区间算法 UCT Minimax算法与纳什均衡 alpha beta剪枝 估值函数 优化与总结 本文会以AI五子棋展开,讲解 ...

最新文章

  1. 每日一皮:当我突然有一个很棒的调试想法...
  2. 9.Windows线程切换_TSS
  3. linux下使用syslog日志调试程序快速的调试代码信息的过程
  4. 华为NIP网络***检测系统
  5. hdu 3068 最长回文(manacher算法)
  6. 数据链路层: HDLC
  7. Linux第二次作业
  8. NeurlPS 2019丨微软亚洲研究院 5 篇精选论文解读
  9. 如何通过 Tampermonkey 快速查找 JavaScript 加密入口
  10. orabbix监控oracle11g,orabbix 监控oracle
  11. 友元函数类图_要达到形式的公平,需要具备的前提条件是()。
  12. IEEE-SA董事刘东:开放+开源将带来新一波SDNFV创新
  13. 前端HTML银行管理系统界面部分实现
  14. 我从零开始学黑莓开发的过程
  15. 计算机设备维修与及日常保养,电脑主机日常的维护保养计划
  16. java for语句 实现一个功能:
  17. 【推荐】《Java 并发编程的艺术》迷你书
  18. vue单页应用首屏加载速度慢如何解决
  19. HashMap数据结构
  20. 数据结构 图(一)丛林中的路

热门文章

  1. 基于Arduino、STM32进行红外遥控信号接收
  2. DA-TLC5615
  3. 【雅思阅读】王希伟阅读P3(Heading)
  4. 千锋教育java开发_千锋Java学院-中国Java培训|Java开发培训开拓者
  5. TestNG - 运行失败的test
  6. 洛谷 P2689 东南西北
  7. java fifo lifo_栈方法 LIFO - 队方法 FIFO
  8. 多款比较好用又免费的设计工具
  9. 泊松分布分布与Python图解
  10. python爬取bilibili数据_用 Python 抓取 bilibili 弹幕并分析!