深度强化学习笔记(二)——Q-learning学习与二维寻路demo实现

文章目录

  • 深度强化学习笔记(二)——Q-learning学习与二维寻路demo实现
    • 前言
    • 理论
      • 什么是Q-Learning
      • 算法
        • 学习率
        • 折扣因子
        • 初始条件
      • 例子
    • 代码
      • 基础版走迷宫示意图
      • 升级版走迷宫示意图
      • 完整代码

前言

这几天稍微闲下来,把原来漏的坑给补上,并做了一个Q-Learning的demo,因为Q-leraning的demo,目前我看到比较多的都是莫烦大佬讲的那个一纬寻路的demo,我觉得看起来没有那么有代表性,于是在此基础上,自己修改做了一个二维寻路的demo,奥利给

理论

什么是Q-Learning

Q-learning是一种无模式RL的形式,它也可以被视为异步DP的方法。它通过体验行动的后果,使智能体能够在马尔可夫域中学习以最优方式行动,而无须构建域的映射。智能体在特定状态下尝试行动,并根据其收到的即时奖励或触发以及对其所处状态的值得估计来评估其后果。 通过反复尝试所有状态的所有行动,它可以通过长期折扣奖励来判断总体上最好的行为。

算法

从状态Δt\Delta tΔt步进入未来步长的权重计算为γΔt\gamma^{\Delta t}γΔt,γ\gammaγ(折扣因子)是介于0和1,并且具有对较迟收到的奖励(反映出良好开端的价值)进行估值的效果。γ\gammaγ也可以被解释为在每一步Δt\Delta tΔt都成功的概率
Q:S×A−>RQ:S\times A->R Q:S×A−>R
首先把Q-learning状态表的动作表初始化为0,然后通过训练更新每个单元。在每个时间t智能体选择动作ata_tat​,观察奖励rtr_trt​,进入新状态st+1s_{t+1}st+1​(可能取决于先前状态sts_tst​和所选的动作),并对Q进行更新 ,该算法的核心是一个简单的值迭代更新过程,即使用旧值和新信息的加权平均值
Qnew(st,at)<−(1−α)Q(st,at)+α(rt+γmaxaQ(st+1,a))Q^{new}(s_t,a_t)<-(1-\alpha)Q(s_t,a_t)+\alpha(r_t+\gamma max_aQ(s_{t+1},a))Qnew(st​,at​)<−(1−α)Q(st​,at​)+α(rt​+γmaxa​Q(st+1​,a))
其中,rtr_trt​是从状态sts_tst​移动到状态st+1s_{t+1}st+1​时收到的奖励,α\alphaα是学习率,Q(st,qt)Q(s_t,q_t)Q(st​,qt​)为旧值,maxaQ(st+1m,a)max_aQ(s_{t+1}m,a)maxa​Q(st+1​m,a)为信息

学习率

α\alphaα确定了新获取的信息在多大程度上覆盖旧信息。因子0使得智能体什么都不学习(专门利用先验知识),而因子1使得智能体只考虑最新信息(忽略先验知识,以探索可能性),一般情况下,通常使用恒定的学习率,αt=0.1\alpha _t=0.1αt​=0.1

折扣因子

折扣因子\gamma$决定了未来奖励的重要性。因子0将通过仅考虑当前奖励使得智能体近视,而接近1的因子将智能体努力获得长期高奖励。这种情况下,从较低的折扣因子开始并将其增加到最终值会加速学习

初始条件

由于Q-learning是迭代算法嘛,因此它隐含地假定在第一次更新发生之前的初始条件。高初始值,也称为"乐观初始条件",也可以鼓励探索:无论选择何种动作,更新规则将其使具有比其他替代方案更低的值,从而增加其选择概率

例子

原文链接:点我
翻译复制的一个大佬的作品:A Painless Q-learning Tutorial (一个 Q-learning 算法的简明教程)





代码

因为有比较详细的注释,这里就不过多解释了

这里的代码最好在CMD里运行,由于jupyter notebook或者pycharm等IDE无法使用os.system('cls')清楚输出,会导致每次迷宫行动都是单独的一个,而不会在原基础更新,看起来比较难受~
升级版作图,由于IDE也不会更新,只会生成新的图,pycharm会存在生成25张图之后,就不会再生成了,无法看到后续变化,jupyter notebook会生成多张图,但是不容易看
如果有大佬知道解决方案,请私信或者评论里提出,谢谢大佬们

基础版走迷宫示意图


升级版走迷宫示意图

因为我暂时没有想到好的解决方法,大佬们可以自己想下,由于画图会增加内存,而当地图尺寸在5以上的时候,有部分由于随机种子的原因,在还前期训练的时候,有一两次的训练的轮数达到400次左右,导致内存占用过高,而训练图崩溃,大佬们可以想下怎么解决

完整代码

import numpy as np
import pandas as pd
import time
from matplotlib import pyplot as plt
import osnp.random.seed(3) #固定随机种子,方便调试
N_STATES = 3#迷宫的边长
ACTIONS = ['up','down','left', 'right']#动作,上下左右移动
EPSILON = 0.9#随机率,每10次,选择一次新的随机动作
ALPHA = 0.1#学习率
GAMMA = 0.9#折扣率,当折扣率为0时,则只关注当前奖励,而接近1则关注长期奖励
MAX_EPISODES = 20#最大EPISODERS次数
FRESH_TIME = 0.1#移动时间
X=1#初始坐标x
Y=1#初始坐标Y
target_x=3 #宝藏坐标x
target_y=3 #宝藏坐标y
NOW_STATE=0 #当前位置对应的Q表值#生成初始Q表
#因为边长为N,每个点有上下左右四个选择,则数量为N*N*4
#N:N
#ACTIONS:动作
def build_q_table(N, actions):table = pd.DataFrame(np.zeros((N*N, len(actions))),columns=actions,)#(table)return table#选择动作
#EPSILON:选择新动作的概率为1-EPSILON
#state:当前的状态位置(已经转换为一维表上的值)
#q_table:Q表
def choose_action(state, q_table):state_actions = q_table.iloc[state, :]if (np.random.uniform() > EPSILON) or ((state_actions == 0).all()):#当前可以选择的动作表得奖励值都为0(即初始表)时,或者当随机数大于EPSILON时,随机选择一个新动作action_name = np.random.choice(ACTIONS)else:#否则选则当前可选动作表里奖励最大的值action_name = state_actions.idxmax()#返回要选择的动作的名字return action_name#更新当前位置
def get_env_feedback(x,y,A):if A == 'up':  #向上移动#达到上边界,不做变化if y==1:y=yelse:y=y-1elif A=='down':#到达下边界,不做变化if y==N_STATES:y=yelse:y=y+1elif A=='left':#到达左边界,不做变化if x==1:x=xelse:x=x-1elif A=='right':#到达右边界,不做变化if x==N_STATES:x=xelse:x=x+1R=0# 判断是否达到终点,到达则将奖励置于1if (x == target_x) and (y==target_y):R=1return x,y,R# 若到达宝藏位置,则打印本回合的序号和经历的步数。
# 否则打印本次移动后小人的位置(二维世界的当前状态)
def update_env(x,y,target_x, target_y, episode, step_counter):#做一个以+为点的坐标,*为宝藏env_list = np.array(['+']*(N_STATES*N_STATES))env_list=env_list.reshape(N_STATES,N_STATES)#确定是否达到终点if (x == target_x) and (y==target_y):interaction = 'Episode %s: total_steps = %s' % (episode+1, step_counter)print(interaction+'\n', end='')time.sleep(2)step_counter=0return step_counterelse:#若未达到终点#老版本0.01显示图像env_list[target_x-1][target_y-1]='*'env_list[x-1][y-1] = 'o'interaction=''for a in range(N_STATES):interaction1 = ''.join(env_list[a,:])interaction1=interaction1+'\n'interaction=interaction+interaction1print(interaction)time.sleep(0.3)os.system('cls')#升级版迷宫图# label_x=np.array(range(N_STATES))# label_y=np.array(range(N_STATES))# label_x = label_x.reshape(N_STATES, 1)# label_x = np.tile(label_x, (1, N_STATES))# label_x = label_x.reshape(N_STATES*N_STATES)# label_y = np.tile(label_y, (N_STATES, 1))# label_y = label_y.reshape(N_STATES*N_STATES)# plt.ion()# plt.cla()# plt.plot(label_x, label_y, 'x', markersize=10)# plt.plot(x-1, y-1, 'o', markersize=20)# plt.plot(target_x-1, target_y-1, 's', markersize=20)# plt.show()# plt.pause(0.01)return step_counter-1# 强化学习主要的控制器
#step_counter:记录该次episode运行了多少次
#is_terminated:是否是终点
def rl(X_1,Y_1,target_x,target_y):#c创建初始化零值表q_table = build_q_table(N_STATES, ACTIONS)#循环最大次数step_counter = 0for episode in range(MAX_EPISODES):X=X_1Y=Y_1is_terminated = Falsestep_counter=update_env(X,Y,target_x,target_y,episode, step_counter)#如果不是终点,则进行循环while not is_terminated:#当前状态为第x列第y行对应的值NOW_STATE=int((Y*EPSILON-1)+X-1)A = choose_action(NOW_STATE, q_table)#更新当前位置,进行行为并获取奖励和下一次的状态X_,Y_,R = get_env_feedback(X,Y, A)q_predict = q_table.loc[NOW_STATE, A]#确定是否达到终点if (X_ != target_x) or (Y_!=target_y):NOW_STATE_ = int((Y_ * EPSILON-1) + X_-1)q_target = R + GAMMA * q_table.iloc[NOW_STATE_, :].max()else:q_target = Ris_terminated = True#更新Q表q_table.loc[NOW_STATE, A] += ALPHA * (q_target - q_predict)#移动到下一个状态X=X_Y=Y_#打印状态step_counter=update_env(X,Y,target_x,target_y, episode, step_counter+1)step_counter += 1return q_table
if __name__ == "__main__":q_table = rl(X,Y,target_x,target_y)print('\r\nQ-table:\n')print(q_table)

深度强化学习笔记(二)——Q-learning学习与二维寻路demo实现相关推荐

  1. [PARL强化学习]Sarsa和Q—learning的实现

    [PARL强化学习]Sarsa和Q-learning的实现 Sarsa和Q-learning都是利用表格法再根据MDP四元组<S,A,P,R>:S: state状态,a: action动作 ...

  2. 深度学习笔记(17) 误差分析(二)

    深度学习笔记(17) 误差分析(二) 1. 使用来自不同分布的数据进行误差分析 2. 数据分布不匹配时的偏差与方差 3. 处理数据不匹配问题 1. 使用来自不同分布的数据进行误差分析 越来越多的团队都 ...

  3. MATLAB学习笔记(一):绘制二维箭头图

    MATLAB学习笔记(一):绘制二维箭头图 MATLAB矢量图绘制 1 quiver函数 2 应用:绘制某一曲线的切向量和法向量 MATLAB罗盘图绘制 1 compass函数 2 应用:绘制相量图 ...

  4. 【机器学习笔记】可解释机器学习-学习笔记 Interpretable Machine Learning (Deep Learning)

    [机器学习笔记]可解释机器学习-学习笔记 Interpretable Machine Learning (Deep Learning) 目录 [机器学习笔记]可解释机器学习-学习笔记 Interpre ...

  5. 台大李宏毅Machine Learning 2017Fall学习笔记 (16)Unsupervised Learning:Neighbor Embedding

    台大李宏毅Machine Learning 2017Fall学习笔记 (16)Unsupervised Learning:Neighbor Embedding

  6. 台大李宏毅Machine Learning 2017Fall学习笔记 (14)Unsupervised Learning:Linear Dimension Reduction

    台大李宏毅Machine Learning 2017Fall学习笔记 (14)Unsupervised Learning:Linear Dimension Reduction 本博客整理自: http ...

  7. 台大李宏毅Machine Learning 2017Fall学习笔记 (13)Semi-supervised Learning

    台大李宏毅Machine Learning 2017Fall学习笔记 (13)Semi-supervised Learning 本博客参考整理自: http://blog.csdn.net/xzy_t ...

  8. 【学习笔记】C++ 核心编程(二)类和对象——封装

    内容来自小破站<黑马程序员C++>复习自用 [学习笔记]C++ 核心编程(二)类和对象--封装 4 类和对象 4.1 封装 4.1.1 封装的意义(一) 4.1.1 封装的意义(二) 4. ...

  9. AR学习笔记(七):阈值二值化优化与颜色分割的优化

    AR学习笔记(七):阈值二值化优化与颜色分割的优化 阈值二值化的优化 当前方案 图像预处理 阈值二值化 优化方案 otsu法 顶帽变换 分块阈值法 颜色分割的优化 当前方案 优化方案 HSV模型分割 ...

最新文章

  1. python字典内存分析_Python减少字典对象占用的七成内存
  2. 21行代码AC——习题3-7 DNA序列(UVa-1368)_解题报告
  3. python输入end退出循环_4.学习python获取用户输入和while循环及if判断语句
  4. python利用myqr库生成二维码
  5. 可变悬挂调节软硬_国六最亲民的豪车,丐版2.0T纯进口,全系可变悬架+8气囊,才23万...
  6. C++ getline在VC6.0的一个bug(处理方法)(转)
  7. php页面怎么改造mip,WordPress MIP 改造之 a 标签替换为 mip-link 跳转链接
  8. NLP自然语言处理-文本摘要简述
  9. url在html中的作用,所谓的URL到底是什么意思,URL有什么作用
  10. 服务器tcp协议安装不了,win2008 R2提示错误“请安装TCP/IP协议 error=10106
  11. 安卓面试中高级安卓开发工程师总结之——大公司面试的方向和套路以及应对方法
  12. 快速画圆切线lisp_autolisp中画两圆公切线的程序怎么写?
  13. Word 安全模式可以启动,正常模式不能启动
  14. 大型养猪场智能监控系统开发
  15. java之Mybatis(实训笔记)
  16. 不要随便设置随机种子
  17. leaflet绘制具有虚线框的多边形(125)
  18. VM虚拟机局域网搭建
  19. C语言第十九讲——函数(2)
  20. 轻松解决网络广播风暴

热门文章

  1. zookeeper四字监控命令
  2. 012 相关性与线性表示总结;向量组的秩、向量组等价
  3. Unity技术手册 - 粒子发射和生命周期内速度子模块
  4. jQuery基础教程
  5. [SWPUCTF 2021 新生赛]easyrce
  6. php社区twig,twig模板简单实用介绍
  7. Oracle 循环插入数据
  8. win8电脑打不开html文件,Win8网页打不开qq能上_Win8能上qq打不开网页怎么办?-192路由网...
  9. 数据分析入门——美国各州人口分析
  10. Python爬取国家数据中心环境数据(全国城市空气质量小时报)并导入csv文件