强化学习——Sarsa算法
表格型方法——Sarsa
- 简介
- 实战
简介
Sarsa全称是state-action-reward-state’-action’,目的是学习特定的state下,特定action的价值Q,最终建立和优化一个Q表格,以state为行,action为列,根据与环境交互得到的reward来更新Q表格,更新公式为:
Sarsa在训练中为了更好的探索环境,采用ε-greedy方式来训练,有一定概率随机选择动作输出。
实战
- 使用 Sarsa 解决机器人找金币问题。
机器人找金币环境下载
Agent
- Agent是和环境environment交互的主体。
- predict()方法:输入观察值observation(或者说状态state),输出动作值
- sample()方法:再predict()方法基础上使用ε-greedy增加探索
- learn()方法:输入训练数据,完成一轮Q表格的更新
# agent.py
class SarsaAgent(object):def __init__(self, obs_n, act_n, learning_rate=0.01, gamma=0.9, e_greed=0.1):self.act_n = act_n # 动作维度,有几个动作可选self.lr = learning_rate # 学习率self.gamma = gamma # reward的衰减率self.epsilon = e_greed # 按一定概率随机选动作self.Q = np.zeros((obs_n, act_n))# 根据输入观察值,采样输出的动作值,带探索def sample(self, obs):if np.random.uniform(0, 1) < (1.0 - self.epsilon): #根据table的Q值选动作action = self.predict(obs)else:action = np.random.choice(self.act_n) #有一定概率随机探索选取一个动作return action# 根据输入观察值,预测输出的动作值def predict(self, obs):Q_list = self.Q[obs, :]maxQ = np.max(Q_list)action_list = np.where(Q_list == maxQ)[0] # maxQ可能对应多个actionaction = np.random.choice(action_list)return action# 学习方法,也就是更新Q-table的方法def learn(self, obs, action, reward, next_obs, next_action, done):""" on-policyobs: 交互前的obs, s_taction: 本次交互选择的action, a_treward: 本次动作获得的奖励rnext_obs: 本次交互后的obs, s_t+1next_action: 根据当前Q表格, 针对next_obs会选择的动作, a_t+1done: episode是否结束"""predict_Q = self.Q[obs, action]if done:target_Q = reward # 没有下一个状态了else:target_Q = reward + self.gamma * self.Q[next_obs, next_action] # Sarsaself.Q[obs, action] += self.lr * (target_Q - predict_Q) # 修正q# 保存Q表格数据到文件def save(self):npy_file = './q_table.npy'np.save(npy_file, self.Q)print(npy_file + ' saved.')# 从文件中读取Q值到Q表格中def restore(self, npy_file='./q_table.npy'):self.Q = np.load(npy_file)print(npy_file + ' loaded.')
Training && Test(训练&&测试)
- run_episode():agent在一个episode中训练的过程,使用agent.sample()与环境交互,使用agent.learn()训练Q表格。
- test_episode():agent在一个episode中测试效果,评估目前的agent能在一个episode中拿到多少总reward。
def run_episode(env, agent, render=False):total_steps = 0 # 记录每个episode走了多少steptotal_reward = 0obs = env.reset() # 重置环境, 重新开一局(即开始新的一个episode)action = agent.sample(obs) # 根据算法选择一个动作while True:next_obs, reward, done, _ = env.step(env.actions[action]) # 与环境进行一个交互next_action = agent.sample(next_obs-1) # 根据算法选择一个动作# 训练 Sarsa 算法agent.learn(obs-1, action, reward, next_obs-1, next_action, done)action = next_actionobs = next_obs # 存储上一个观察值total_reward += rewardtotal_steps += 1 # 计算step数if render:env.render() #渲染新的一帧图形if done:breakreturn total_reward, total_stepsdef test_episode(env, agent):total_reward = 0obs = env.reset()while True:action = agent.predict(obs-1) # greedynext_obs, reward, done, _ = env.step(env.actions[action])total_reward += rewardobs = next_obstime.sleep(0.5)env.render()if done:breakreturn total_reward
创建环境和Agent,启动训练
env = GridEnv()agent = SarsaAgent(obs_n = len(env.states), # 状态维度act_n = len(env.actions), # 动作维度,有几个动作可选learning_rate = 0.1, # 学习率gamma = 0.9, # reward的衰减率e_greed = 0.1 # 按一定概率随机选动作
)# 训练500个episode,打印每个episode的分数
for episode in range(500):ep_reward, ep_steps = run_episode(env, agent, True)print('Episode %s: steps = %s , reward = %.1f' % (episode, ep_steps, ep_reward))
print("保存权重")
agent.save()
# 全部训练结束,查看算法效果
# agent.restore()test_reward = test_episode(env, agent)
print('test reward = %.1f' % (test_reward))
env.close()
强化学习——Sarsa算法相关推荐
- 强化学习—— TD算法(Sarsa算法+Q-learning算法)
强化学习-- TD算法(Sarsa算法+Q-learning算法) 1. Sarsa算法 1.1 TD Target 1.2 表格形式的Sarsa算法 1.3 神经网络形式的Sarsa算法 2. Q- ...
- 强化学习常用算法总结
强化学习常用算法总结 本文为2020年6月参加的百度PaddlePaddle强化学习训练营总结 1. 表格型方法:Sarsa和Q-Learning算法 State-action-reward-stat ...
- 【人工智能II】实验2 强化学习Q-Learning算法
强化学习Q-Learning算法 核心思想 实验原理 实验流程图 实验分析 理解Q-Learning算法 GYM库 更换实验环境 实验代码 Q-Learning: Sarsa代码 搞不懂我一个本科生为 ...
- [PARL强化学习]Sarsa和Q—learning的实现
[PARL强化学习]Sarsa和Q-learning的实现 Sarsa和Q-learning都是利用表格法再根据MDP四元组<S,A,P,R>:S: state状态,a: action动作 ...
- 深度强化学习-DDPG算法原理和实现
全文共3077个字,8张图,预计阅读时间15分钟. 基于值的强化学习算法的基本思想是根据当前的状态,计算采取每个动作的价值,然后根据价值贪心的选择动作.如果我们省略中间的步骤,即直接根据当前的状态来选 ...
- 基于强化学习SAC_LSTM算法的机器人导航
[前言]在人群之间导航的机器人通常使用避碰算法来实现安全高效的导航.针对人群中机器人的导航问题,本文采用强化学习SAC算法,并结合LSTM长短期记忆网络,提高移动机器人的导航性能.在我们的方法中,机器 ...
- 【强化学习PPO算法】
强化学习PPO算法 一.PPO算法 二.伪代码 三.相关的简单理论 1.ratio 2.裁断 3.Advantage的计算 4.loss的计算 四.算法实现 五.效果 六.感悟 最近再改一个代码, ...
- DRL:强化学习-Q-Learning算法
文章目录 强化学习 Q-Learning算法 1. 问题及原因 2. Estimator原理与思想 (1)单估计器方法(Single Estimator) (2)双估计器方法(Double Estim ...
- 深度强化学习主流算法介绍(二):DPG系列
之前的文章可以看这里 深度强化学习主流算法介绍(一):DQN系列 相关论文在这里 开始介绍DPG之前,先回顾下DQN系列 DQN直接训练一个Q Network 去估计每个离散动作的Q值,使用时选择Q值 ...
- 强化学习 五子棋算法
强化学习 五子棋算法 蒙特卡洛树搜索 MCTS 蒙特卡洛树搜索算法 上限置信区间算法 UCT Minimax算法与纳什均衡 alpha beta剪枝 估值函数 优化与总结 本文会以AI五子棋展开,讲解 ...
最新文章
- 每日一皮:当我突然有一个很棒的调试想法...
- 9.Windows线程切换_TSS
- linux下使用syslog日志调试程序快速的调试代码信息的过程
- 华为NIP网络***检测系统
- hdu 3068 最长回文(manacher算法)
- 数据链路层: HDLC
- Linux第二次作业
- NeurlPS 2019丨微软亚洲研究院 5 篇精选论文解读
- 如何通过 Tampermonkey 快速查找 JavaScript 加密入口
- orabbix监控oracle11g,orabbix 监控oracle
- 友元函数类图_要达到形式的公平,需要具备的前提条件是()。
- IEEE-SA董事刘东:开放+开源将带来新一波SDNFV创新
- 前端HTML银行管理系统界面部分实现
- 我从零开始学黑莓开发的过程
- 计算机设备维修与及日常保养,电脑主机日常的维护保养计划
- java for语句 实现一个功能:
- 【推荐】《Java 并发编程的艺术》迷你书
- vue单页应用首屏加载速度慢如何解决
- HashMap数据结构
- 数据结构 图(一)丛林中的路