NeurIPS论文解读|Decision Transformer: 通过序列建模解决离线强化学习问题
今天为大家推荐一篇2021年被NeurIPS收录的一篇论文。
《Decision Transformer: reinforcement learning via sequence modeling》
推荐读者将本博客结合原论文食用。如有谬误偏颇烦请指出!
论文链接:
https://openreview.net/forum?id=a7APmM4B9d
1. 论文概览
先谈谈我的看法:在我看来,Decision Transformer与传统的RL算法最大的区别在于它训练的目标不再是为了最大化累计折扣奖励,而是学习从 ,到的映射。为什么在训练的时候给medium级别的示例序列,而推断的时候我们调大,就能输出更好的动作?其中还是存在着“拼接”,由于DT是以自回归的方式逐步地成动作以及给定奖励,那么在某一个状态下,采取什么样的动作能得到什么样的奖励,是网络完全有可能学习到的。最后体现在整条序列上的就是,网络完全有可能学习到如何“拼接”出能产生给定的动作序列。
Decision Transformer(DT)[1]是纯监督学习,用来解决Offline Reinforcement Learning的问题。它不再将强化学习建模为马尔科夫决策过程(MDP),具体表现在网络在训练时拿到了非常long-term的信息,完全不符合马尔科夫性了。
具体而言,DT将RL当作一个自回归的序列建模问题,建模回报序列(return-to-go)、状态序列(state)与动作序列(action)之间的关系。与一般认为的行为克隆(behavior cloning)只建模状态和动作关系相比,额外考虑了回报以及过去的三元组(,,)序列。最后的效果非常好,击败了一众当时顶尖的离线强化学习方法。
2. 具体做法
如Figure 1所示,
网络输入是
输出是
,是以自回归(autoregression)的方式生成动作。网络结构可以认为是Transformer[2] 的 Decoder 部分的修改(GPT),主要是masked multi-head self-attention。
2.1 网络输入
先从训练的时候讲起,如大家所知,基于时序差分算法的强化学习方法输入通常是四元组:来完成一次更新。
而DT是以一条序列(trajectory)作为输入的:
代表游戏从开始到结束的一整条序列,但是在实际训练过程中,我们往往只会截取K个时间步作为输入,这一点之后再说。
其中需要额外注意的是,和以往r代表奖励(reward)不同,这里作者采用的是 returns-to-go:
即从当前时刻开始,到这条序列结束的所有奖励 reward 的和,且没有折扣(折扣系数)。
为啥这样做呢?
这是因为,DT的目标是基于未来希望得到的回报来生成当前的动作,所以用 reward当然过于短视了,因为reward不具备未来的信息。
这里需要解释一下,作者希望DT达到一个什么地步呢?比如训练的时候,用return-to-go是90分的数据来训练,在测试的时候,如果我们输入的 return-to-go是100分,那么DT能够生成比训练时候更好的动作。事实上确实也做到了这一点,实验结果显示,在medium-expert, medium以及 medium-replay上是比行为克隆要强的。为啥能做到这一点是令人十分好奇的。
在测试的时候,我们给定希望的性能,比如设定归一化后
,然后输入环境的初始状态,网络就给出动作。(事实上还可以输出r和s,但本篇文章没有利用这些,不予考虑),让智能体执行动作,环境给出奖励与下一个状态,计算得到
,再将,,加入到输入序列中,网络前传,得到动作......以这种自回归的方式来进行推断直到游戏结束。
还有一个另外的疑问是,我们能否交换和的顺序呢?
很明显a应该在这两者之后,我猜测和的顺序也是可以交换的。为此我在d4rl的hopper-medium-expert数据集上训练了DT,并每经过一段时间进行evaluation,比较两个顺序的性能,结果如下:
经过我个人的实验,交换两者顺序之后得到的性能几乎相同,当然如果需要更确切地结论可能需要更多的实验,这里不予验证。
2.2 网络结构
和Transformer中有positional encoding来提供位置信息一样,在DT 中也需要一个timestep的信息来告知网络当前的步数。不同于位置编码的是,这里一个timestep对应了三个token。
再来说说前文提到的每次只建模个timesteps的序列:文中作者表示,在推断时,将序列中将最新的个timesteps送入网络。因为序列长度(一局游戏)是不确定的,可能高达 1000 个时间步,而模型是需要固定一个建模的时间步长来进行训练的,所以在这里设置一个代表建模的序列长度的超参数。在训练的过程中,将offline dataset trajectory中间的随机个连续的时间步数据送入进行训练。在我的实验中,选择的超参数=20。
数据处理相关代码[3]如下:
网络前传过程中,transformer模块输入的序列长度为3,即。timesteps的信息输入表现为其embedding被加到这三个模态数据的embedding上。
网络前传的代码[4]如下
B, T, _ = states.shapetime_embeddings = self.embed_timestep(timesteps) # shape: (B,context_len/T,h_dim)# time embeddings are treated similar to positional embeddings
# shape: (B,context_len,h_dim)
state_embeddings = self.embed_state(states) + time_embeddings
action_embeddings = self.embed_action(actions) + time_embeddings
returns_embeddings = self.embed_rtg(returns_to_go) + time_embeddings# stack rtg, states and actions and reshape sequence as
# (r1, s1, a1, r2, s2, a2 ...)
# after stack shape: (B, 3, context_len/T, h_dim)
h = torch.stack((returns_embeddings, state_embeddings, action_embeddings),dim=1).permute(0, 2, 1, 3).reshape(B, 3 * T, self.h_dim)h = self.embed_ln(h)# transformer and prediction
# 这个 transformer 具体而言,是多组 attention + mlp + layernorm + resnet 结构块,
# 具体可以参照文末给出的代码仓库,或者任意一个 transformer 的实现。
h = self.transformer(h)# get h reshaped such that its size = (B x 3 x T x h_dim) and
# h[:, 0, t] is conditioned on r_0, s_0, a_0 ... r_t
# h[:, 1, t] is conditioned on r_0, s_0, a_0 ... r_t, s_t
# h[:, 2, t] is conditioned on r_0, s_0, a_0 ... r_t, s_t, a_t
h = h.reshape(B, T, 3, self.h_dim)# get predictions
return_preds = self.predict_rtg(h[..., 2, :]) # predict next rtg given r, s, a
state_preds = self.predict_state(h[..., 2, :]) # predict next state given r, s, a
action_preds = self.predict_action(h[..., 1, :]) # predict action given r, sreturn state_preds, action_preds, return_preds
对照代码和结构图,可以发现,网络的输入和输出是一一对应的,不同模态的内容只在self-attention层进行交互。图中在上方,与对应,由于mask,它的预测只利用了及
2.3 损失函数
如果是连续动作,那么是 MSE,如果是离散动作,就是交叉熵:
之前的 token 的信息。
if self._cfg.model.continuous:action_loss = F.mse_loss(action_preds, action_target)
else:action_loss = F.cross_entropy(action_preds, action_target)
我们当然也可以预测序列下一步的和,但是作者在文中表示,这样的做法没有显著增加性能,所以就没管了。后来的一些文章中有考虑状态和奖励的信息,但这超出了本文的范围。
下图是论文中的DT的伪代码,以供参考:
需要额外注意的是,在实际推断过程中,虽然我们使用的序列模型对于任何序列长度的输入都可以处理,但是由于我们在训练时让网络学习的是建模序列长度为的序列,因此我们推断时的输入长度也应该是。具体实现上是先用占位符placeholder(0.)填充满个序列长度,这个是超参数,代表游戏最长timesteps,在我的实验中设定为1000。每次选择长度为的序列作为输入。随着自回归的进行,逐渐用真实的值来替换掉 placeholder。这部分代码[5]如下:
for t in range(self.max_eval_ep_len):total_timestpes += 1# 自回归,用真实的值替换掉作为占位符的0# add statestates[0, t] = torch.from_numpy(running_state).to(device)states[0, t] = (states[0, t] - state_mean) / state_std# add rtgrunning_rtg = running_rtg - (running_reward / rtg_scale)rewards_to_go[0, t] = running_rtg# 如果长度不到K,则输入前K个,最后预测的动作索引为t,保证只用了前t个timesteps的信息if t < args.context_len:_, act_preds = model.forward(timesteps[:, :args.context_len], states[:, :args.context_len],actions[:, :args.context_len], rewards_to_go[:, :context_len])act = act_preds[0, t].detach()# 如果序列长度超过了K,则输入最近K个timesteps的内容else:_, act_preds, _ = model.forward(timesteps[:, t-args.context_len+1: t+1], states[:, t-args.context_len+1: t+1],actions[:, t-args.context_len+1: t+1], rewards_to_go[:, t-args.context_len+1:t+1])act = act_preds[0, -1].detach()
在时,如何保证输出的动作只用到了前个timesteps的信息?
1.attention mask:通过attention mask来保证时刻的attention score对于时刻之后的 value 赋分都是0;
2.全链接层怎么办?对于shape为 的输入,mlp会对于个序列的向量分别做映射,每一个时间步输出的结果 不会包含其它时间步的信息。
那么总结一下,Decision Transformer的策略可以表示为
,其中
3. 实验
一篇论文的好坏一定要去关注它的实验,否则只看前面的方法啥的,容易被作者的故事套进去,而实验数据是不会造假的。
下文的BC方法的网络结构和DT一样,唯一不同是去除了Return-to-go这一个模态的信息(相当于输入序列变成2了)。可以理解为序列建模版本的BC。
3.1 Atari(离散动作空间)
除了CQL在Qbert中取得的超模成绩之外,DT都取得了非常具有竞争力的成绩。
3.2 D4RL-Mujoco(连续动作空间)
DT在绝大多数的项目中都取得了最好的成绩。
4. 消融实验
作者提出了一些问题并用设计实验进行验证:
1.DT是否比普通的行为克隆(BC)强?✅
2.DT建模return-to-go的效果好吗?✅
3.上下文长度的选取对性能的影响如何?✅
如果K=1相当于还是服从马尔可夫性,可惜实验结果表明这样效果极差。作者认为,序列建模对于上下文依赖程度很高,过去的信息对于Atari游戏是有用的(在DRQN 中,作者认为当有移动物体出现时,每一帧确实都是POMDP,因为只观测单帧的信息无法判定移动物体的速度和方向)。上下文信息使得transformer去搞清楚,生成哪些动作,能够导致更高的回报。
4. 是否DT在long-term credit assignment上表现良好?✅
作者在Key-to-Door这个变种环境上进行了验证,在这个环境中,必须在第一个房间中捡起钥匙,才能在第三个房间到达门拿到奖励,因此非常考验算法的长期置信分配能力——明白第一个房间的拿钥匙这个动作,对于很久之后的开门拿奖励的贡献巨大。
5. DT是否在稀疏奖励情况下准确预测奖励?✅
如果将输出action改为输出return-to-go,那么DT就可以由一个actor 转变为一个critic。
6. DT是否在稀疏奖励情况下也表现良好?✅
7. 为什么DT不需要像其他offline算法一样进行悲观值估计或者动作约束?原因正是因为DT不需要借助TD算法的Bootstrapping 来进行值估计或策略提升,不需要优化学习到的Q函数,避免了不准确的值函数估计,所以不需要进行悲观值估计或者动作约束。
8. 如何将DT用于online RL ?
未来的工作,实际上从现在的视角看已经有文章了,可见Online Decision Transformer[6],解决的是offline2online问题。
5. 总结
之所以DT效果这么好,作者认为是attention机制赋予了DT极佳的 long-term credit assignment的能力,能够在很长的序列中捕获动作对于奖励的作用。这一点在最后的实验部分(Key-to-Door 环境)得到了验证。但究竟为啥DT效果能这么卓越,我想至少文中是没有讲得很明白的。
In my opinion,decision transformer与传统的RL算法最大的区别在于它训练的目标不再是为了最大化累计奖励,而是学习从,到的映射。为什么在训练的时候给medium的序列,而推断的时候我们调大,就能输出更好的动作?其中还是存在着“拼接”,由于是自回归的方式逐步的生成动作以及给定奖励,那么在同一个状态下,采取什么样的动作能得到什么样的奖励,是网络完全有可能学习到的。最后体现在整条序列上的就是,网络完全有可能学习到如何“拼接”出能产生给定的动作序列。
值得一提的是,同期也有一篇用transformer来做Offline RL的文章Trajectory Transformer[7]同样值得一看。
参考文献
[1]Decision Transformer https://arxiv.org/abs/2106.01345
[2]Attention is All You Need https://proceedings.neurips.cc/paper/2017/hash/3f5ee243547dee91fbd053c1c4a845aa-Abstract.html
[3]https://github.com/opendilab/DI-engine/blob/main/ding/utils/data/dataset.py#L332
[4]https://github.com/opendilab/DI-engine/blob/main/ding/model/template/decision_transformer.py#L88
[5]https://github.com/opendilab/DI-engine/blob/main/ding/policy/decision_transformer.py#L254
[6]Online DT https://arxiv.org/abs/2202.05607
[7]Trajectory Transformer https://trajectory-transformer.github.io/
NeurIPS论文解读|Decision Transformer: 通过序列建模解决离线强化学习问题相关推荐
- 《强化学习周刊》第69期:ICLR2023强化学习论文推荐、MIT实现自动调整内在奖励的强化学习...
No.69 智源社区 强化学习组 强 化 学 习 研究 观点 资源 活动 周刊订阅 告诉大家一个好消息,<强化学习周刊>已经开启"订阅功能",以后我们会向您自动推送最 ...
- 【强化学习论文】Decision Transformer:通过序列建模进行强化学习
Article 文献题目:Decision Transformer: Reinforcement Learning via Sequence Modeling 文献时间:2021 摘要 我们引入了一个 ...
- 《强化学习周刊》第37期:视觉深层框架、Transformer World模型、注意力增强强化学习...
No.37 智源社区 强化学习组 强 化 学 习 研究 观点 资源 活动 关于周刊 强化学习作为人工智能领域研究热点之一,其研究进展与成果也引发了众多关注.为帮助研究与工程人员了解该领域的相关进展和 ...
- 快手+何向南团队最新论文Counterfactual Interactive Recommender System ,反事实推理融入离线强化学习,解决filter bubble/信息茧房问题
论文名:CIRS: Bursting Filter Bubbles by Counterfactual Interactive Recommender System 目录 链接 1. 相关背景 1.1 ...
- 论文推荐:StarCraft II Unplugged 离线强化学习
星际争霸 II 是暴雪开发的一款真正的战略游戏,它是一个挑战,因为它从机器学习的角度展示了一些有趣的属性:实时.部分可观察性以及广阔的行动和观察空间.掌握游戏需要时间策略规划,实时控制宏观和微观层面, ...
- 离线强化学习论文学习 Critic Regularized Regression
Critic Regularized Regression 1.摘要 离线强化学习(RL),也被称为批量RL,提供了在没有在线环境交互的情况下从大型预记录数据集进行策略优化的前景.它解决了数据收集成本 ...
- 从顶会论文看2022年推荐系统序列建模的趋势
©作者 | 猫的薛定谔 最近看了 22 年几篇顶会的序列建模的文章,模型无不复杂高深,但细细看后,发现这些文章本质上均是输入的变化,模型只为了配合输入.看看最近的顶会是怎么玩的吧. 背景 序列建模的目 ...
- 结构体实验报告总结_解读!清华、谷歌等10篇强化学习论文总结
强化学习(Reinforcement Learning,RL)正成为当下机器学习中最热门的研究领域之一.与常见的监督学习和非监督学习不同,强化学习强调智能体(agent)与环境(environment ...
- 通过图注意力神经网络进行多智能体游戏抽象_[读论文] AttnPath: 将图注意力机制融入基于深度强化学习的知识图谱推理中...
论文原文:Incorporating Graph Attention Mechanism into Knowledge Graph Reasoning Based on Deep Reinforcem ...
最新文章
- keras.utils.to_categorical用法
- 教你用Python解决非平衡数据问题(附代码)
- HDU 1233 还是畅通工程。
- Linux的find -print 和 -print0区别:换行不换行
- I/O多路转接之poll——基于TCP协议
- maven 打包指定依赖包_[Maven]-Maven基础-01-基础概念
- MySQL笔记-简单配置主从库
- MySql error 2003 Can't connect to MySQL server on 'localhost' (0)
- Golang实践录:开篇
- python怎么修改界面颜色_python中颜色设置
- 约瑟夫环c语言代码加注释,约瑟夫环c语言代码原创.doc
- 基于SSM的网上购物商城管理系统
- Python函数嵌套
- 回答朋友关于美团点评业务线工作体验
- Tajima‘s D群体遗传
- unity 摄像机导入 3dmax 带fov
- Camunda流程引擎事件(Events)介绍
- 如何快速去除图片上的水印
- Cascading介绍
- 一分二网线连接器正确使用方法
热门文章
- win10解压安装mysql方法及遇见的问题(缺少MSVCR120.dll文件、服务无法启动)
- 英语的简洁著名商品经典广告主题句
- 知乎高赞:什么能力很重要,但大多数人都没有?
- 串口编程 - 串口简介
- Java--进制转换(一)
- 2分钟学会一个简单USB插入检测电路
- 微信整人假红包图片_微信假红包生成器安卓-微信红包图片截图整人交友方法 红包显示一会出现你的交友宣言...
- 干货 | 科研决策怎么做?四个步骤一招解决!
- 初识Java调用百度API实现图像识别
- 无法启动Microsoft Office Outlook 无法打开 Outlook窗口