基础知识

关于Q-learning 和 Sarsa 算法, 详情参见博客 强化学习(Q-Learning,Sarsa)
Sarsa 算法框架为
Q-learning 算法框架为

关于FrozenLake-v0环境介绍, 请参见https://copyfuture.com/blogs-details/20200320113725944awqrghbojzsr9ce

此图来自 强化学习FrozenLake求解

需要注意的细节

训练时

  • 采用 ϵ \epsilon ϵ 贪心算法;
# 贪婪动作选择,含嗓声干扰
a = np.argmax(Q_all[s, :] + np.random.randn(1, env.action_space.n) * (1. / (i + 1)))
  • 对 Q-learning 算法
# 更新Q表
# Q-learning
Q_all[s, a] = Q_all[s, a] + alpha * (r + gamma * np.max(Q_all[s1, :]) - Q_all[s, a])
  • 对 Sarsa 算法
# sarsa
# 更新Q表
a_ = np.argmax(Q_all[s1, :] + np.random.randn(1, env.action_space.n) * (1. / (i + 1)))
Q_all[s, a] = Q_all[s, a] + alpha * (r + gamma * Q_all[s1, a_] - Q_all[s, a])

测试时

  • 不采用 ϵ \epsilon ϵ 贪心算法;
a = np.argmax(Q_all[s, :])
  • 不更新Q表
# # 不更新Q表
# Q_all[s, a] = Q_all[s, a] + alpha * (r + gamma * np.max(Q_all[s1, :]) - Q_all[s, a])

寻找模型中最优的 α \alpha α, γ \gamma γ

我们计算一下不同参数下的学习率, 如下图所示


比较两种算法的准确率, 我们用Q-learning算法的准确率减掉Sarsa的准确率, 得到

从图中可以看到, 大于0的点均表明在此点对应的 α , γ \alpha,\gamma α,γ下, Q-learning 准确率高于Sarsa.

Python代码

import gym
import numpy as np
import random
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable# gym创建冰湖环境
env = gym.make('FrozenLake-v0')
env.render()  # 显示初始environment
# 初始化Q表格,矩阵维度为【S,A】,即状态数*动作数
Q_all = np.zeros([env.observation_space.n, env.action_space.n])
# 设置参数,
# 其中α\alpha 为学习速率(learning rate),γ\gamma为折扣因子(discount factor)
alpha = 0.8
gamma = 0.95
num_episodes = 2000
#
Alpha = np.arange(0.75, 1, 0.02)
Gamma = np.arange(0.1, 1, 0.05)
#Alpha = np.ones_like(Gamma)*0.97
# Training
correct_train = np.zeros([len(Alpha), len(Gamma)])
correct_test = np.zeros([len(Alpha), len(Gamma)])
for k in range(len(Alpha)):for p in range(len(Gamma)):alpha = Alpha[k]gamma = Gamma[p]# trainingrList = []for i in range(num_episodes):# 初始化环境,并开始观察s = env.reset()rAll = 0d = Falsej = 0# 最大步数while j < 99:j += 1# 贪婪动作选择,含嗓声干扰a = np.argmax(Q_all[s, :] + np.random.randn(1, env.action_space.n) * (1. / (i + 1)))# 从环境中得到新的状态和回报s1, r, d, _ = env.step(a)# 更新Q表# Q-learningQ_all[s, a] = Q_all[s, a] + alpha * (r + gamma * np.max(Q_all[s1, :]) - Q_all[s, a])# sarsaa_ = np.argmax(Q_all[s1, :] + np.random.randn(1, env.action_space.n) * (1. / (i + 1)))Q_all[s, a] = Q_all[s, a] + alpha * (r + gamma * Q_all[s1, a_] - Q_all[s, a])# 累加回报rAll += r# 更新状态s = s1# Game Overif d:breakrList.append(rAll)correct_train[k, p] = (sum(rList) / num_episodes)# testrList = []for i in range(num_episodes):# 初始化环境,并开始观察s = env.reset()rAll = 0d = Falsej = 0# 最大步数while j < 99:j += 1# 贪婪动作选择,含嗓声干扰a = np.argmax(Q_all[s, :])# 从环境中得到新的状态和回报s1, r, d, _ = env.step(a)# # 更新Q表# Q_all[s, a] = Q_all[s, a] + alpha * (r + gamma * np.max(Q_all[s1, :]) - Q_all[s, a])# 累加回报rAll += r# 更新状态s = s1# Game Overif d:breakrList.append(rAll)correct_test[k, p] = sum(rList) / num_episodes# print("Score over time:" + str(sum(rList) / num_episodes))
# print("打印Q表:", Q_all)# Test
plt.figure()
ax = plt.subplot(1, 1, 1)
h = plt.imshow(correct_train, interpolation='nearest', cmap='rainbow',extent=[0.75, 1, 0, 1],origin='lower', aspect='auto')
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(h, cax=cax)
plt.show()

参考文献

【1】https://blog.csdn.net/kyolxs/article/details/86693085
【2】 强化学习(Q-Learning,Sarsa)
【3】 强化学习FrozenLake求解
【4】https://copyfuture.com/blogs-details/20200320113725944awqrghbojzsr9ce

强化学习用 Sarsa 算法与 Q-learning 算法实现FrozenLake-v0相关推荐

  1. 强化学习(二):Q learning 算法

    强化学习(一):基础知识 强化学习(二):Q learning算法 Q learning 算法是一种value-based的强化学习算法,Q是quality的缩写,Q函数 Q(state,action ...

  2. 强化学习 补充笔记(TD算法、Q学习算法、SARSA算法、多步TD目标、经验回放、高估问题、对决网络、噪声网络)

    学习目标: 深入了解马尔科夫决策过程(MDP),包含TD算法.Q学习算法.SARSA算法.多步TD目标.经验回放.高估问题.对决网络.噪声网络.基础部分见:强化学习 马尔科夫决策过程(价值迭代.策略迭 ...

  3. 【强化学习笔记】从 “酒鬼回家” 认识Q Learning算法

    1.背景 现在笔者来讲一个利用Q-learning 方法帮助酒鬼回家的一个小例子, 例子的环境是一个一维世界, 在世界的右边是酒鬼的家.这个酒鬼因为喝多了,根本不记得回家的路,只是根据自己的直觉一会向 ...

  4. 通过代码学 Sutton 强化学习:SARSA、Q-Learning 时序差分算法训练 CartPole

    来源 | MyEncyclopedia TD Learning本质上是加了bootstrapping的蒙特卡洛(MC),也是model-free的方法,但实践中往往比蒙特卡洛收敛更快.我们选取Open ...

  5. (学习用1)调用用RRT算法进行笛卡尔空间轨迹规划和关节空间轨迹规划

    在MoveIt中,可以通过调用computeCartesianPath()函数来使用RRT算法进行笛卡尔空间轨迹规划,可以通过调用computeJointSpacePath()函数来使用RRT算法进行 ...

  6. 强化学习入门 : 一文入门强化学习 (Sarsa、Q learning、Monte-carlo learning、Deep-Q-Network等)

    最近博主在看强化学习的资料,找到这两个觉得特别适合入门,一个是"一文入门深度学习",一个是"莫烦PYTHON". 建议:看资料的时候可以多种资料一起参考,一边调 ...

  7. Deep Q Network 算法

     Deep Q Network 算法前置基础知识: Reinforcement Learning 基本概念 Q Leaning算法原理 深度学习神经网络知识 Tensorflow.Pytorch.Py ...

  8. [PARL强化学习]Sarsa和Q—learning的实现

    [PARL强化学习]Sarsa和Q-learning的实现 Sarsa和Q-learning都是利用表格法再根据MDP四元组<S,A,P,R>:S: state状态,a: action动作 ...

  9. [强化学习实战]出租车调度-Q learning SARSA

    出租车调度-Q learning & SARSA 案例分析 实验环境使用 同策时序差分学习调度 异策时序差分调度 资格迹学习调度 结论 代码链接 案例分析 本节考虑Gym库里出租车调度问题(T ...

最新文章

  1. K-最近邻法(KNN)简介
  2. 7.类的访问控制和继承
  3. mysql基础(全,必看)
  4. Restful Service 中 DateTime 在 url 中传递
  5. 使用bat向文件的第一行中写入内容
  6. Eclipse 乱码 解决方案总结(UTF8 -- GBK)
  7. 10 个学习iOS开发的最佳网站(转)
  8. 网易云音乐(电脑版)网络连接不上,救命啊!!!
  9. Java全系列教程:『Java学习指南』
  10. js中this指向的四种规则+ 箭头函数this指向
  11. 会声会影2020迅雷磁力链接bt搜索种子百度云网盘下载及有效序列号
  12. 小米电视联网后显示无法解析小米电视服务器,小米电视机功能详解 教你轻松使用...
  13. SG Input 软件安全分析之逆向分析
  14. GSM Hacking:如何对GSM/GPRS网络测试进行测试
  15. 迅歌KTV服务器各型号,2017年ktv必点歌曲排行榜(4页)-原创力文档
  16. 如何剪辑QQ酷狗下载的音乐?
  17. 迪赛智慧数——柱状图(象形标识图):全国历年结婚登记数
  18. 【LuoguP4233】射命丸文的笔记-多项式求逆
  19. EXCEL技巧——EXCEL如何实现隔行隔列求和
  20. 西门子博途v16系统要求_博途V16安装TIA Portal v16

热门文章

  1. 创建订单【项目 商城】
  2. 一个Android开发者眼中的微信小程序
  3. css如何让文字不换行,css如何让文字不换行显示?
  4. 腾讯云服务器如何重装系统
  5. oracle 中的 NVL2() 函数
  6. 连续空间和离散空间的距离基础
  7. 前端播放rtmp协议的视频流文件
  8. Win10 串口编程
  9. AssetsManager下载类
  10. 里氏代换原则——与多态的辩证关系