论文:QMIX: Monotonic Value Function Factorisation for Deep Multi-Agent Reinforcement Learning
参考博客:多智能体强化学习入门(五)——QMIX算法分析、多智能体强化学习入门QMIX

参考书籍:《深度强化学习学术前沿与实战应用》

MARL中如何表示和使用动作价值函数使得系统达到一个均衡稳态是多智能体系统的目标。

IQL让每个智能体单独定义一个函数QaQ_aQa​。这种方法不能明确表示智能体之间的相互作用,并且可能不会收敛,因为每个智能体的学习都被其他智能体的探索和学习混淆。

另一种是学习一个完全集中式的动作价值函数,即反事实多智能体(counterfactual Multi-Agent COMA),用它来指导actor-critic框架中的分布策略的优化,但需要on-policy学习,导致样本效率低下,并且存在多个智能体是,训练完全集中的critic是不切实际的。


因此这里采用了QMIX,和VDN一样采用集中式分解Q的方法,处于IQL和COMA之间,但可以表示更丰富的动作价值函数。由于VDN的完全因子分解对于获得分散策略并不是必须的。相反,QMIX只需要确保在Q上执行的全局argmax与在每个Q上执行的一组单独的argmax参数产生相同的结果。因此,只需要求对QπQ^{\pi}Qπ于每个QaQ_aQa​之间存在单调约束,即:
δQπδQa≥0,∀a\frac{\delta Q^{\pi}}{\delta Q_a} \geq 0, \ \forall a δQa​δQπ​≥0, ∀a
不同于VDN中简单的总和,QMIX由代表每个QaQ_aQa​的智能体网络和将它们组合到QπQ^{\pi}Qπ中的混合网络组成,以复杂的非线性方式确保集中式和分散式策略之间的一致性。同时,它通过限制混合网络具有正权重来强制执行上式的约束。因此,QMIX可以表示复杂的集中式动作价值函数,其中包含一个因式表示,可以很好扩展智能体的数量,并允许通过线性时间的argmax操作轻松得到分散策略。

QMIX提出为了确保一致性,只需要确保在QπQ^{\pi}Qπ上执行的全局argmax产生与在每个QaQ_aQa​上执行的一组单独argmax操作有相同的结果:
argmaxuQπ(τ,u)=(argmaxu1Q1(τ1,u1)...argmaxunQn(τn,un))argmax_u Q^{\pi} (\tau,u) = (argmax_{u_1}Q_1(\tau_1,u_1) ... argmax_{u_n}Q_n(\tau_n,u_n)) argmaxu​Qπ(τ,u)=(argmaxu1​​Q1​(τ1​,u1​)...argmaxun​​Qn​(τn​,un​))

QMIX使用由智能体网络、混合网络和一组超网络组成的体系结构来代表QπQ^{\pi}Qπ。它采用一个混合网络对单智能体局部值函数进行合并,并在训练学习过程中加入全局状态信息辅助,来提高算法性能。

图©表示每个智能体采用一个DRQN来拟合自身的Q值函数的到 Qi(τi,ai,θi)Q_i(\tau_i,a_i,\theta_i)Qi​(τi​,ai​,θi​) ,DRQN循环输入当前的观测 oi,to_{i,t}oi,t​ 以及上一时刻的动作 ai,t−1a_{i,t-1}ai,t−1​来得到Q值。

图(b)表示混合网络的结构。其输入为每个DRQN网络的输出。为了满足上述的单调性约束,混合网络的所有权值都是非负数,对偏移量不做限制,这样就可以确保满足单调性约束。

为了能够更多的利用到系统的状态信息 sts_tst​ ,采用一种超网络(hypernetwork),混合网络的权重由单独的超网络产生。将状态 sts_tst​ 作为输入,输出为混合网络的权值及偏移量。为了保证权值的非负性,采用一个线性网络以及绝对值激活函数保证输出不为负数。

对偏移量采用同样方式但没有非负性的约束,混合网络最后一层的偏移量通过两层网络以及ReLU激活函数得到非线性映射网络。由于状态信息 sts_tst​ 是通过超网络混合到 QtotQ_{tot}Qtot​ 中的,而不是仅仅作为混合网络的输入项,这样带来的一个好处是,如果作为输入项则 sts_tst​ 的系数均为正,这样则无法充分利用状态信息来提高系统性能,相当于舍弃了一半的信息量。

于DQN类似,最终的Loss函数计算为:
L(θ)=∑i=1b[(yiπ−Qπ(τ,a,s;θ))2]L(\theta) = \sum_{i=1}^b [(y_i^{\pi} - Q^{\pi}(\tau,a,s;\theta))^2] L(θ)=i=1∑b​[(yiπ​−Qπ(τ,a,s;θ))2]

yiπ=r+γmax⁡a′Qπ(τ′,a′,s′;θ−)y_i^{\pi} = r + \gamma \max_{a'}Q^{\pi} (\tau',a',s';\theta^-) yiπ​=r+γa′max​Qπ(τ′,a′,s′;θ−)

代码分析

QMIX主要是在模型上的改进,因此这里只列出模型的代码:

写法1:

# 这里把两个网络模型合起来计算了
class QMIXNet(nn.Module):def __init__(self,num_agents,action_space,state_shape,agent_shape,agent_hidden_size,mixing_hidden_size):super(QMIX,self).__init__()self.num_agents = num_agentsself.action_space = action_spaceself.state_shape = state_shapeself.agent_shape = agent_shapeself.agent_hidden_size = agent_hidden_sizeself.mixing_hidden_size = mixing_hidden_sizeself.agent_ff_in = nn.Linear(self.agent_shape,self.agent_shape)self.agent_net = nn.GRU(self.agent_shape,self.agent_hidden_size)self.agent_ff_out = nn.Linear(self.agent_hidden_shape,self.aciton_space)self.hyper_net1 = nn.Linear(self.state_shape,self.num_agents * self.mixing_hidden_size)self.hyper_net2 = nn.Linear(self.state_shape, self.mixing_hidden_size)def forward(self,global_state,agent_obs):# 计算单智能体的Q值q_n = self.agent_ff_in(agent_obs)q_n = self.agent_net(q_n)q_n = self.agent_ff_out(q_n).max(dim = 1)[0]# 输入状态s到网络中,输出计算单智能体Q值在总Q值的权重参数,这里没计算偏置bw1 = self.hyper_net1(global_state).abs()w2 = self.hyper_net2(global_state).abs()w1 = w1.view(self.num_agents,self.mixing_hidden_size)w2 = w2.view(self.mixing_hidden_size,1)# ELU激活函数:ELU(x) = max(0,x) + min(0,α*(exp(x) - 1))# torch.mm为两个矩阵相乘q_tot = F.elu(torch.mm(q_n,w1))q_tot = F.elu(torch.mm(q_tot,w2))return q_tot

写法2:

# 这里分开成两个网络模型,一个分给各个智能体用,一个组成总Q值
# RNN类对应DRQN,可以计算各个智能体的Q值,forward的输出第一个是Q值,第二个是传给下一个RNN的隐藏值。
class RNN(nn.Module):# Because all the agents share the same network, input_shape=obs_shape+n_actions+n_agentsdef __init__(self, input_shape, args):super(RNN, self).__init__()self.args = argsself.fc1 = nn.Linear(input_shape, args.rnn_hidden_dim)self.rnn = nn.GRUCell(args.rnn_hidden_dim, args.rnn_hidden_dim)self.fc2 = nn.Linear(args.rnn_hidden_dim, args.n_actions)def forward(self, obs, hidden_state):x = f.relu(self.fc1(obs))# print(hidden_state.shape,"xxxxx")h_in = hidden_state.reshape(-1, self.args.rnn_hidden_dim)# print(h_in.shape,"uuu")h = self.rnn(x, h_in)q = self.fc2(h)print(q)print(h)return q, h
class QMixNet(nn.Module):def __init__(self, args):super(QMixNet, self).__init__()self.args = args# 因为生成的hyper_w1需要是一个矩阵,而pytorch神经网络只能输出一个向量,# 所以就先输出长度为需要的 矩阵行*矩阵列 的向量,然后再转化成矩阵# args.n_agents是使用hyper_w1作为参数的网络的输入维度,args.qmix_hidden_dim是网络隐藏层参数个数# 从而经过hyper_w1得到(经验条数,args.n_agents * args.qmix_hidden_dim)的矩阵if args.two_hyper_layers:self.hyper_w1 = nn.Sequential(nn.Linear(args.state_shape, args.hyper_hidden_dim),nn.ReLU(),nn.Linear(args.hyper_hidden_dim, args.n_agents * args.qmix_hidden_dim))# 经过hyper_w2得到(经验条数, 1)的矩阵self.hyper_w2 = nn.Sequential(nn.Linear(args.state_shape, args.hyper_hidden_dim),nn.ReLU(),nn.Linear(args.hyper_hidden_dim, args.qmix_hidden_dim))else:self.hyper_w1 = nn.Linear(args.state_shape, args.n_agents * args.qmix_hidden_dim)# 经过hyper_w2得到(经验条数, 1)的矩阵self.hyper_w2 = nn.Linear(args.state_shape, args.qmix_hidden_dim * 1)# hyper_w1得到的(经验条数,args.qmix_hidden_dim)矩阵需要同样维度的hyper_b1self.hyper_b1 = nn.Linear(args.state_shape, args.qmix_hidden_dim)# hyper_w2得到的(经验条数,1)的矩阵需要同样维度的hyper_b1self.hyper_b2 =nn.Sequential(nn.Linear(args.state_shape, args.qmix_hidden_dim),nn.ReLU(),nn.Linear(args.qmix_hidden_dim, 1))def forward(self, q_values, states):  # states的shape为(episode_num, max_episode_len, state_shape)# 传入的q_values是三维的,shape为(episode_num, max_episode_len, n_agents)episode_num = q_values.size(0)q_values = q_values.view(-1, 1, self.args.n_agents)  # (episode_num * max_episode_len, 1, n_agents) = (1920,1,5)states = states.reshape(-1, self.args.state_shape)  # (episode_num * max_episode_len, state_shape)w1 = torch.abs(self.hyper_w1(states))  # (1920, 160)b1 = self.hyper_b1(states)  # (1920, 32)w1 = w1.view(-1, self.args.n_agents, self.args.qmix_hidden_dim)  # (1920, 5, 32)b1 = b1.view(-1, 1, self.args.qmix_hidden_dim)  # (1920, 1, 32)hidden = F.elu(torch.bmm(q_values, w1) + b1)  # (1920, 1, 32)w2 = torch.abs(self.hyper_w2(states))  # (1920, 32)b2 = self.hyper_b2(states)  # (1920, 1)w2 = w2.view(-1, self.args.qmix_hidden_dim, 1)  # (1920, 32, 1)b2 = b2.view(-1, 1, 1)  # (1920, 1, 1)q_total = torch.bmm(hidden, w2) + b2  # (1920, 1, 1)q_total = q_total.view(episode_num, -1, 1)  # (32, 60, 1)return q_total

多智能体强化学习之QMIX相关推荐

  1. 多智能体强化学习入门

    参考文章:万字长文:详解多智能体强化学习的基础和应用 .多智能体强化学习入门(一)--基础知识与博弈 推荐文章:多智能体强化学习路线图 (MARL Roadmap) 推荐综述论文:An Overvie ...

  2. 上海交大开源训练框架,支持大规模基于种群多智能体强化学习训练

    机器之心专栏 作者:上海交大和UCL多智能体强化学习研究团队 基于种群的多智能体深度强化学习(PB-MARL)方法在星际争霸.王者荣耀等游戏AI上已经得到成功验证,MALib 则是首个专门面向 PB- ...

  3. 【四】多智能体强化学习(MARL)近年研究概览 {Learning cooperation(协作学习)、Agents modeling agents(智能体建模)}

    相关文章: [一]最新多智能体强化学习方法[总结] [二]最新多智能体强化学习文章如何查阅{顶会:AAAI. ICML } [三]多智能体强化学习(MARL)近年研究概览 {Analysis of e ...

  4. 多智能体强化学习:基本概念,通信方式,IPPO,MADDPG

    1,基本概念 1.1,简介 单个RL智能体通过与外界的交互来学习知识,具体过程是根据当前环境的状态,智能体通过策略给出的动作来对环境进行响应,相应地,智能体会得到一个奖励值以反馈动作的好坏程度.RL最 ...

  5. 多智能体强化学习(MARL)训练环境总结

    目前开源的多智能体强化学习项目都是需要在特定多智能体环境下交互运行,为了更好的学习MARL code,需要先大致了解一些常见的MARL环境以及库 文章目录 1.Farama Foundation 2. ...

  6. 多智能体强化学习:鼓励共享多智能体强化学习中的多样性

    题目:Celebrating Diversity in Shared Multi-Agent Reinforcement Learning 出处:Neural Information Processi ...

  7. 多智能体强化学习思路整理

    多智能体强化学习算法思路整理 目录 摘要 背景和意义 研究背景 强化学习 多智能体强化学习与博弈论基础 研究意义 问题与挑战 问题分类 问题分析 环境的不稳定性与可扩展性的平衡 部分可观测的马尔可夫决 ...

  8. 多智能体强化学习—QPLEX

    多智能体强化学习-QPLEX 论文地址:QPLEX: Duplex Dueling Multi-Agent Q-Learning 视频效果:Experiments on StarCraft II 建议 ...

  9. 多智能体强化学习环境【星际争霸II】SMAC环境配置

    多智能体强化学习这个领域中,很多Paper都使用的一个环境是--星际争多智能体挑战(StarCraft Multi-Agent Challenge, SMAC).最近也配置了这个环境,把中间一些步骤记 ...

最新文章

  1. input子系统分析(转)
  2. linux apache两种工作模式详解
  3. 两种交换排序算法:冒泡排序和快速排序
  4. 【杂谈】为什么你学了AI,企业却不要你
  5. The method setClass(Context, Class?) in the type Intent is not applicable for the arguments (GameV
  6. kali mysql停止服务器_从零开始:手把手教你黑客入门攻破服务器并获取ROOT权限...
  7. c++实验七-—项目2
  8. 广工十四届校赛 count 矩阵快速幂
  9. ArcGIS Pro 3.0最新消息
  10. windows下常见php集成环境安装包介绍
  11. Postman测试Soap协议接口
  12. QQ文件路径,QQ图片保存地址
  13. ASP.NET Core 2.1 开发跨平台应用教程
  14. 中国石油大学《物理化学》第一阶段在线作业
  15. 购物商城网站建设费用到底贵不贵?
  16. [渝粤教育] 郑州工程技术学院 食品微生物学 参考 资料
  17. 缓解过拟合(overfitting)的方法
  18. 和李兄之《定风波· 冬峦轻寒桂落香》一首
  19. USI环旭电子推出信用卡大小的SiPSet笔记本电脑主板
  20. 某bobo在线视频APP下载暴力流逆向

热门文章

  1. Node.js安装及常见问题解决办法
  2. 龙尚3G、4G模块嵌入式Linux系统使用说明【转】
  3. DRF 序列化器的使用
  4. 中国企业做大也要做强做优
  5. C++ (opaque) handle
  6. Ceph学习(1)---Ceph入门
  7. 淘宝手机端-selenium破解过程详解
  8. iOS app调用打电话功能
  9. Springboot的Restful
  10. 细节和真实:刘韧谈采访与写作