目录

  • 0. 前言
  • 1. 超参数
  • 2. 训练
    • 2.1 初始化环境和智能体
    • 2.2 智能体选择动作
    • 2.3 环境接收动作并反馈下一个状态和奖励
    • 2.4 智能体进行策略更新(学习)
  • 3. 结果处理
    • 3.1 模型保存
    • 3.2 模型读取
    • 3.3 模型测试

0. 前言

本篇博客的代码来源于蘑菇书《Easy RL》Q学习部分的悬崖行走实战部分,本人在学习的同时对代码进行完整的解读,如有错误之处,烦请指正。
Easy-RL github :https://github.com/datawhalechina/easy-rl
注意,此为v.1.0.3分支
这部分代码有两个核心文件:

  • qlearning.py
  • task0.py

首先学习 task0 部分

1. 超参数

机器学习模型中一般有两类参数:一类需要从数据中学习和估计得到,称为模型参数(Parameter),即模型本身的参数。还有一类则是机器学习算法中的调优参数(tuning parameters),需要人为设定,称为超参数(Hyperparameter)

class Config:"""超参数"""def __init__(self):################################## 环境超参数 ###################################self.algo_name = 'Q-learning'  # 算法名称,我们使用Q学习算法self.env_name = 'CliffWalking-v0'  # 环境名称,悬崖行走self.device = torch.device( "cuda" if torch.cuda.is_available() else "cpu")  # 检测GPU,如果没装CUDA的话默认为CPUself.seed = 10  # 随机种子,置0则不设置随机种子。我们学习过程中的随机值都对应着一个随机种子,方便我们复现学习结果self.train_eps = 400  # 训练的回合数self.test_eps = 30  # 测试的回合数################################################################################################################## 算法超参数 ###################################self.gamma = 0.90  # 强化学习中的折扣因子self.epsilon_start = 0.95  # ε-贪心策略中的初始epsilon,减小此值可减少学习开始时的随机探索几率self.epsilon_end = 0.01  # ε-贪心策略中的终止epsilon,越小学习结果越逼近self.epsilon_decay = 300  # e-greedy策略中epsilon的衰减率,此值越大衰减的速度越快self.lr = 0.1  # 学习率################################################################################################################# 保存结果相关参数 ################################self.result_path = curr_path + "/outputs/" + self.env_name + \'/' + curr_time + '/results/'  # 保存结果的路径self.model_path = curr_path + "/outputs/" + self.env_name + \'/' + curr_time + '/models/'  # 保存模型的路径self.save_fig = True  # 是否保存图片,注意这里改为 save_fig################################################################################

2. 训练

def train(cfg, env, agent):print('开始训练!')print(f'环境:{cfg.env_name}, 算法:{cfg.algo_name}, 设备:{cfg.device}')rewards = []  # 记录每回合的奖励,用来记录并分析奖励的变化ma_rewards = []  # 由于得到的奖励可能会产生振荡,使用一个滑动平均的量来反映奖励变化的趋势# 开始回合训练for i_ep in range(cfg.train_eps):ep_reward = 0  # 记录每个回合的奖励state = env.reset()  # 重置环境,开始新的回合# 开始当前回合的行走,直至走到终点while True:  action = agent.choose_action(state)  # 根据算法选择一个动作next_state, reward, done, _ = env.step(action)  # 与环境进行一次动作交互agent.update(state, action, reward, next_state, done)  # Q学习算法更新state = next_state  # 更新状态ep_reward += rewardif done:breakrewards.append(ep_reward)if ma_rewards:ma_rewards.append(ma_rewards[-1] * 0.9 + ep_reward * 0.1)else:ma_rewards.append(ep_reward)print("回合数:{}/{},奖励{:.1f}".format(i_ep + 1, cfg.train_eps, ep_reward))print('完成训练!')return rewards, ma_rewards

2.1 初始化环境和智能体

def env_agent_config(cfg, seed=1):"""创建环境和智能体Args:cfg ([type]): [description]seed (int, optional): 随机种子. Defaults to 1.Returns:env [type]: 环境agent : 智能体"""env = gym.make(cfg.env_name)env = CliffWalkingWapper(env)  # 使用自定义装饰器装饰环境env.seed(seed)  # 设置随机种子,每个种子对应一个随机结果,只是为了让结果可以精确复现,一般情况下可删去n_states = env.observation_space.n  # 状态维度,即 48 个状态n_actions = env.action_space.n  # 动作维度, 即 4 个动作agent = QLearning(n_states, n_actions, cfg)  # 为智能体设置参数return env, agent

2.2 智能体选择动作

对于上述代码中的action = agent.choose_action(state)
其方法实现如下:

    def choose_action(self, state):self.sample_count += 1self.epsilon = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * \math.exp(-1. * self.sample_count / self.epsilon_decay)  # epsilon是会递减的,这里选择指数递减# e-greedy 策略if np.random.uniform(0, 1) > self.epsilon:action = np.argmax(self.Q_table[str(state)])  # 选择Q(s,a)最大对应的动作else:action = np.random.choice(self.n_actions)  # 随机选择动作return action

此处使用的ε-贪心算法公式:

随着学习过程的增加,epsilon 会进行指数级衰减,直到逼近 epsilon_end。
在随机选择的数大于 epsilon ,即值在 1-epsilon 范围内时,选择Q(s,a)最大对应的动作。
现在,我们来尝试打印一下当前的状态:print(self.Q_table[str(state)])
输出结果为:[ -7.45800334 -78.37958986 -7.46127197 -7.48193639]
以上数组中的四个数值即为各个动作会产生的价值。

2.3 环境接收动作并反馈下一个状态和奖励

动作选择完后,我们使用此动作与环境进行一次交互:

next_state, reward, done, _ = env.step(action)

通过给定动作,我们可以从地图中得到下一个状态和奖励。

  • 例如在起点格36执行动作UP=0,下一个状态为24,奖励为-1;
  • 我们还需要设置地图的边界,例如在起点执行动作 LEFT=1,下一个状态还是36,奖励为−1W;
  • 如果执行动作RIGHT=3,那么会掉入悬崖,下一个状态为36,奖励为 −100 。

具体的逻辑计算过程在C:\Python310\Lib\site-packages\gym\envs\toy_text\cliffwalking.py查看。
参数 done 用于判断是否抵达终点。

2.4 智能体进行策略更新(学习)

现在,我们得到了当前状态、选择的动作、奖励和下一个状态,就可以在智能体内使用Q学习算法更新Q表格:

agent.update(state, action, reward, next_state, done)  # Q学习算法更新

方法实现如下:

    def update(self, state, action, reward, next_state, done):Q_predict = self.Q_table[str(state)][action]  # 读取预测价值if done:  # 终止状态判断Q_target = reward  # 终止状态下获取不到下一个动作,直接将 Q_target 更新为对应的奖励else:Q_target = reward + self.gamma * np.max(self.Q_table[str(next_state)])self.Q_table[str(state)][action] += self.lr * (Q_target - Q_predict)

其中涉及到的公式就是书中讲过的 Q学习的增量学习伪代码:

这样,就更新好了当前状态对应动作的价值,即策略更新。

3. 结果处理

在上文中,我们完成了一回合的学习,在每回合的学习结束后,我们需要将此回合的奖励记录下来,以便后续可视化使用:

     rewards.append(ep_reward)if ma_rewards:ma_rewards.append(ma_rewards[-1] * 0.9 + ep_reward * 0.1)else:ma_rewards.append(ep_reward)

由于得到的奖励可能会产生振荡,我们使用一个滑动平均的量来反映奖励变化的趋势,即使用新的奖励与上一个奖励计算出一个平均的奖励加入到列表中。

3.1 模型保存

等到所有回合都执行结束后,保存这个训练好的模型:

 make_dir(cfg.result_path, cfg.model_path)  # 创建保存结果和模型路径的文件夹agent.save(path=cfg.model_path)  # 保存模型

save的实现:

    def save(self, path):import dilltorch.save(obj=self.Q_table,f=path + "Qlearning_model.pkl",pickle_module=dill)print("保存模型成功!")

dill模块:https://pypi.org/project/dill/
dill extends python’s pickle module for serializing(序列化) and de-serializing(反序列化) python objects to the majority of the built-in python types. Serialization is the process of converting an object to a byte stream, and the inverse of which is converting a byte stream back to a python object hierarchy.
dill provides the user the same interface as the pickle module, and also includes some additional features. In addition to pickling python objects, dill provides the ability to save the state of an interpreter session in a single command. Hence, it would be feasable to save an interpreter session, close the interpreter, ship the pickled file to another computer, open a new interpreter, unpickle the session and thus continue from the ‘saved’ state of the original interpreter session.

我们用 pkl 文件(该存储方式,可以将python项目过程中用到的一些暂时变量、或者需要提取、暂存的字符串、列表、字典等数据保存起来)来保存这个训练好的模型,即 Q表格。打包的模块使用 dill模块。

torch.save()
保存一个序列化(serialized)的目标到磁盘。函数使用了Python的pickle程序用于序列化。模型(models),张量(tensors)和文件夹(dictionaries)都是可以用这个函数保存的目标类型。

3.2 模型读取

    def load(self, path):import dillself.Q_table = torch.load(f=path + 'Qlearning_model.pkl', pickle_module=dill)print("加载模型成功!")

与模型保存类似,使用torch.load()进行模型的读取操作,从而加载训练好的 Q表格。

3.3 模型测试

模型测试与训练的方法基本一致,唯一的区别只是不用再进行 Q表格的更新,即没有下面这行代码:

agent.update(state, action, reward, next_state, done)  # Q学习算法更新

【强化学习】《Easy RL》- Q-learning - CliffWalking(悬崖行走)代码解读相关推荐

  1. [PARL强化学习]Sarsa和Q—learning的实现

    [PARL强化学习]Sarsa和Q-learning的实现 Sarsa和Q-learning都是利用表格法再根据MDP四元组<S,A,P,R>:S: state状态,a: action动作 ...

  2. 强化学习(RL)——Reinforcement learning

    强化学习 一.强化学习简介 二.强化学习发展历程 三.深度强化学习DRL 四.马尔可夫决策过程 五.值函数 六.Q值 七.蒙特卡洛(MC)与时序差分(TD) 八.强化学习的代表算法 1.Q-learn ...

  3. 深度强化学习(Deep Reinforcement Learning)的资源

    深度强化学习(Deep Reinforcement Learning)的资源 2015-04-08 11:21:00|  分类: Torch |  标签:深度强化学习   |举报 |字号 订阅 Goo ...

  4. 深度强化学习—— 译 Deep Reinforcement Learning(part 0: 目录、简介、背景)

    深度强化学习--概述 翻译说明 综述 1 简介 2 背景 2.1 人工智能 2.2 机器学习 2.3 深度学习 2.4 强化学习 2.4.1 Problem Setup 2.4.2 值函数 2.4.3 ...

  5. 离线强化学习(Offline RL)系列3: (算法篇) IQL(Implicit Q-learning)算法详解与实现

    [更新记录] 论文信息:Ilya Kostrikov, Ashvin Nair, Sergey Levine: "Offline Reinforcement Learning with Im ...

  6. 离线强化学习(Offline RL)系列3: (算法篇)策略约束 - BRAC算法原理详解与实现(经验篇)

    论文原文:[Yifan Wu, George Tucker, Ofir Nachum: "Behavior Regularized Offline Reinforcement Learnin ...

  7. 离线强化学习(Offline RL)系列3: (算法篇) AWAC算法详解与实现

    [更新记录] 论文信息:AWAC: Accelerating Online Reinforcement Learning with Offline Datasets [Code] 本文由UC Berk ...

  8. 离线强化学习(Offline RL)系列3: (算法篇) Onestep 算法详解与实现

    [更新记录] 论文信息: David Brandfonbrener, William F. Whitney, Rajesh Ranganath, Joan Bruna: "Offline R ...

  9. 离线强化学习(Offline RL)系列3: (算法篇)策略约束 - BEAR算法原理详解与实现

    论文信息:Stabilizing Off-Policy Q-Learning via Bootstrapping Error Reduction 本文由UC Berkeley的Sergey Levin ...

  10. 离线强化学习(Offline RL)系列4:(数据集) 经验样本复杂度(Sample Complexity)对模型收敛的影响分析

    [更新记录] 文章信息:Samin Yeasar Arnob, Riashat Islam, Doina Precup: "Importance of Empirical Sample Co ...

最新文章

  1. 当我们拿到数据进行建模时,如何选择更合适的算法?
  2. hdu - 2512 一卡通大冒险 (斯特灵数 贝尔数)
  3. 教你使用百度深度学习框架PaddlePaddle完成波士顿房价预测(新手向)
  4. opencv进阶学习笔记14:分水岭算法 实现图像分割
  5. 前端笔记—第15篇js中的DOM操作
  6. 往文件中写数据--增量
  7. angular使用Md5加密
  8. html间隔怎么打_iPhone手机便签内容怎么设为重要事项?
  9. 怎么在eclipse中安装properties插件使其显示中文
  10. 中山计算机辅助设计报考,中山模具设计与CNC数控编程专业
  11. 一个JAVA WEB伪全栈的VUE入坑随笔,从零点零五学起
  12. 多重继承有时候确实有必要
  13. Linux的ip_conntrack半景
  14. 8位计算机的八位代表什么,八位二进制是什么意思
  15. 什么是域名?什么网站名?什么是URL?
  16. 密码框password调用数字键盘
  17. html铺满整个页面_html 怎么让背景图铺满整个页面?
  18. 华盛顿大学计算机语言学,华盛顿大学人工智能专业排名2020年
  19. 毕业论文小论文查重吗?
  20. cygwin的坑坑洼洼

热门文章

  1. php socket 模拟post,用PHP的Socket编程模拟Post来提交数据 | 学步园
  2. 三天学会MySQL - MySQL数据库章节练习
  3. 计算机桌面无法解锁,电脑win10系统锁屏后解锁却无法进入桌面的解决方法
  4. MySQL 规范数据库设计
  5. 我的世界服务器如何修改权限设置,我的世界设置成员权限 | 手游网游页游攻略大全...
  6. java中使用poi导出ppt(图片和表格)
  7. win10 1073linux密码,Linux Bash on Win10 忘记密码解决
  8. 比亚迪元EV汽车拆解报告
  9. Java基础重点总结
  10. 一加7使用adb强制90hz时遇到的问题