CliffWalking

  • 如下图所示,S是起点,C是障碍,G是目标
  • agent从S开始走,目标是找到到G的最短路径
  • 这里reward可以建模成-1,最终目标是让return最大,也就是路径最短

代码和解释

import gym
import numpy as np
import time# agent.py
class SarsaAgent(object):def __init__(self, obs_n, act_n, learning_rate=0.01, gamma=0.9, e_greed=0.1):self.act_n = act_n      # 动作维度,有几个动作可选self.lr = learning_rate # 学习率self.gamma = gamma      # reward的衰减率self.epsilon = e_greed  # 按一定概率随机选动作self.Q = np.zeros((obs_n, act_n))# 根据输入观察值,采样输出的动作值,带探索def sample(self, obs):if np.random.uniform(0, 1) < (1.0 - self.epsilon): #根据table的Q值选动作action = self.predict(obs)else:action = np.random.choice(self.act_n) #有一定概率随机探索选取一个动作return action# 根据输入观察值,预测输出的动作值def predict(self, obs):Q_list = self.Q[obs, :]maxQ = np.max(Q_list)action_list = np.where(Q_list == maxQ)[0]  # maxQ可能对应多个actionaction = np.random.choice(action_list)return action# 学习方法,也就是更新Q-table的方法def learn(self, obs, action, reward, next_obs, next_action, done):""" on-policyobs: 交互前的obs, s_taction: 本次交互选择的action, a_treward: 本次动作获得的奖励rnext_obs: 本次交互后的obs, s_t+1next_action: 根据当前Q表格, 针对next_obs会选择的动作, a_t+1done: episode是否结束"""predict_Q = self.Q[obs, action]if done:target_Q = reward # 没有下一个状态了else:target_Q = reward + self.gamma * self.Q[next_obs, next_action] # Sarsaself.Q[obs, action] += self.lr * (target_Q - predict_Q) # 修正q# 保存Q表格数据到文件def save(self):npy_file = './q_table.npy'np.save(npy_file, self.Q)print(npy_file + ' saved.')# 从文件中读取Q值到Q表格中def restore(self, npy_file='./q_table.npy'):self.Q = np.load(npy_file)print(npy_file + ' loaded.')def run_episode(env, agent, render=False):total_steps = 0 # 记录每个episode走了多少steptotal_reward = 0obs = env.reset() # 重置环境, 重新开一局(即开始新的一个episode)action = agent.sample(obs) # 根据算法选择一个动作while True:next_obs, reward, done, _ = env.step(action) # 与环境进行一个交互next_action = agent.sample(next_obs) # 根据算法选择一个动作# 训练 Sarsa 算法agent.learn(obs, action, reward, next_obs, next_action, done)action = next_actionobs = next_obs  # 存储上一个观察值total_reward += rewardtotal_steps += 1 # 计算step数if render:env.render() #渲染新的一帧图形if done:breakreturn total_reward, total_stepsdef test_episode(env, agent):total_reward = 0obs = env.reset()while True:action = agent.predict(obs) # greedynext_obs, reward, done, _ = env.step(action)total_reward += rewardobs = next_obs# time.sleep(0.5)# env.render()if done:breakreturn total_reward# 使用gym创建悬崖环境
env = gym.make("CliffWalking-v0")  # 0 up, 1 right, 2 down, 3 left# 创建一个agent实例,输入超参数
agent = SarsaAgent(obs_n=env.observation_space.n,act_n=env.action_space.n,learning_rate=0.1,gamma=0.9,e_greed=0.1)# 训练500个episode,打印每个episode的分数
for episode in range(500):ep_reward, ep_steps = run_episode(env, agent, False)print('Episode %s: steps = %s , reward = %.1f' % (episode, ep_steps, ep_reward))# 全部训练结束,查看算法效果
test_reward = test_episode(env, agent)
print('test reward = %.1f' % (test_reward))
agent.save()

运行结果

  • 只保留最后一部分
  • Q值表的部分
  • 读取方法见此link

强化学习 Sarsa 实战GYM下的CliffWalking爬悬崖游戏相关推荐

  1. 【零基础强化学习】100行代码教你训练——基于SARSA的CliffWalking爬悬崖游戏

    基于SARSA的CliffWalking爬悬崖游戏

  2. 强化学习环境库 Gym 发布首个社区发布版,全面兼容 Python 3.9

    作者:肖智清 来源:AI科技大本营 强化学习环境库Gym于2021年8月中旬迎来了首个社区志愿者维护的发布版Gym 0.19.该版本全面兼容Python 3.9,增加了多个新特性. 强化学习环境库的事 ...

  3. [PARL强化学习]Sarsa和Q—learning的实现

    [PARL强化学习]Sarsa和Q-learning的实现 Sarsa和Q-learning都是利用表格法再根据MDP四元组<S,A,P,R>:S: state状态,a: action动作 ...

  4. 【深入浅出强化学习-编程实战】 7 基于策略梯度的强化学习-Cartpole(小车倒立摆系统)

    [深入浅出强化学习-编程实战] 7 基于策略梯度的强化学习-Cartpole 小车倒立摆MDP模型 代码 代码解析 小车倒立摆MDP模型 状态输入:s=[x,x˙,θ,θ˙]s = [x,\dot{x ...

  5. 爬虫实战2(下):爬取豆瓣影评

       上篇笔记我详细讲诉了如何模拟登陆豆瓣,这次我们将记录模拟登陆+爬取影评(复仇者联盟4)实战.本文行文结构如下: 模拟登陆豆瓣展示 分析网址和源码爬取数据 进行面对对象重构 总结   一.模拟登陆 ...

  6. 【深度学习入门到精通系列】 深入浅出强化学习 Sarsa

    文章目录 1 什么是 Sarsa 2 Sarsa 算法更新 3 Sarsa 思维决策 4 什么是 Sarsa(lambda) 5 Sarsa-lambda 1 什么是 Sarsa 同样, 我们会经历正 ...

  7. 强化学习基础05——gym

    OpenAI gym OpenAI gym是强化学习最常用的标准库,如果研究强化学习,肯定会用到gym. gym有几大类控制问题,第一种是经典控制问题,比如cart pole和pendulum. Ca ...

  8. 写一个强化学习训练的gym环境

    需求 要用强化学习(Reinforcement Learning)算法解决问题,需要百千万次的训练,真实环境一般不允许这么多次训练(时间太长.试错代价太大),需要开发仿真环境.OpenAI的gym环境 ...

  9. 强化学习--Sarsa

    系列文章目录 强化学习 提示:写完文章后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 系列文章目录 前言 一.强化学习是什么? 二.核心算法(免模型学习) Sarsa 1.学习心得 前言 ...

最新文章

  1. 引入css的四种方式
  2. Kubernetes CRD开发汇总
  3. 【Servlet】Servlet与MVC分层开发
  4. asp实现注册登录界面_python app (kivy)-与小型数据库连接,实现注册登录操作
  5. [Redux/Mobx] 什么是单一数据源?
  6. 数组用法以及引用类型和值类型
  7. vice versa VS the other way around
  8. python入门指南txt-【杂谈】爬虫基础与快速入门指南
  9. 100: cf 878C set+并查集+链表
  10. pandas某一列去重获取唯一值
  11. 秦纪二 秦始皇帝二十年(甲戌,公元前227年)——摘要
  12. python头像变二维码_学了Python之后,美化二维码如此简单
  13. html before table,Use greasemonkey to add HTML before table
  14. 【无标题】警惕利用个人收款码升级套路诈骗
  15. java 华容道游戏下载_Java初学之华容道游戏
  16. JavScript简介
  17. 专业 DAW 音频插件:Voxengo Plug-ins Bundle for Mac(Voxengo系列音频插件合集)
  18. Ubuntu中的Load/Unload Cycle Count问题及解决方案
  19. android 微信支付问题总结
  20. MySQL_基础+高级篇- 数据库 -sql -mysql教程_mysql视频_mysql入门_尚硅谷2

热门文章

  1. Hello Shader之Hello Trangle
  2. 网易笔试题(java 、 c++ 、软件测试等)
  3. keras之权重初始化
  4. Python List 包含关系判定
  5. SQL语句中,为什么where子句不能使用列别名,而order by却可以?
  6. ffmpeg视频中提取语音
  7. Cleartext vs. Plaintext vs. Ciphertext vs. Plaintext vs. Clear Text
  8. 研发人员一些比较重要的能力指标参考
  9. Docker修改镜像源为阿里云
  10. Elasticsearch(六)了解全文搜索