这里写自定义目录标题

  • 写在前面
  • QMIX 论文解读
    • 多智能体的相关工作
      • Dec-POMDP
      • DRQN
      • IQL(Independent Q-Learning)
      • VDN(Value Decomposition Networks)
    • QMIX
    • QMIX 结构
    • 核心——单调性约束
  • 在 aistudio 上运行 QMIX
    • 环境准备
    • 测试游戏环境
    • QMIX 代码部分
      • 执行部分
      • 定义 agent 网络结构
      • Mixing 网络定义
      • 策略与训练部分
      • 主函数部分(负责定义与调度)
      • 开始训练
        • 建议使用 visual 打开 run 文件夹进行训练曲线的监测,当训练收敛并稳定时(指 reward 达到最高的 20 左右),手动终止训练
  • 训练效果
    • 8 枪兵 vs 8 枪兵
      • 结果可视化
  • 个人简介

写在前面

多智能体+分布式训练非常非常吃内存,本文所附代码强烈建议使用高于32GB内存的机器进行训练

本项目分为两个部分:论文解读与代码实现。解读部分会尽量脱离公式谈谈自己的理解,代码实现部分会讲一讲实现的思路

QMIX 论文解读

多智能体的相关工作

Dec-POMDP

去中心化的部分可观测马尔科夫模型(Decentralized partially observable Markov decision progress,DEC-POMDP),是研究不确定性情况下多主体协同决策的重要模型。

由于其求解难度是 NEXP-complete,迄今为止尚没有有效的算法能求出其最优解,但是可以用强化学习来近似求解。

在多智能体强化学习中一种比较典型的学习模式为中心式训练,分布式执行,即在训练时利用所共享的信息来帮助更有效的分布式执行。然而,围绕如何最好地利用集中培训仍然存在着许多挑战。

其中一个挑战是如何表示和使用大多数强化学习方法学习的动作值函数。一方面,正确地捕捉主体行为的影响,需要一个集中的行动价值函数,它决定了全球状态和联合行动的条件。

另一方面,当存在多个 agent 时,这样的函数很难学习,即使可以学习,也无法提供明显的方法来提取分散的策略,允许每个智能体根据单个观察结果选择单个操作。

DRQN

将 DQN 与 LSTM 结合,主要是讲全连接层换成 LSTM 等循环神经网络,使得 DQN 具备了学习长时间序列的能力

IQL(Independent Q-Learning)

IQL(independent Q-learning)就是非常暴力的给每个智能体执行一个Q-learning算法,因为共享环境,并且环境随着每个智能体策略、状态发生改变,对每个智能体来说,环境是动态不稳定的,因此这个算法也无法收敛,但是在部分应用中也具有较好的效果。

放弃学习中心式QtotQ_{tot}Qtot而转去对每个智能体学习一个独立的$Q_a}。但是这个方法不能表现出智能体之间的互动,而且每个智能体的训练会被其他智能体的探索和学习干扰。

VDN(Value Decomposition Networks)

VDN(value decomposition networks)也是采用对每个智能体的值函数进行整合,得到一个联合动作值函数。VDN假设中心式的Qtot可以分解为各个Qa的线性相加。

τ=(τ1,...,τn)\tau=(\tau_1,...,\tau_n)τ=(τ1,...,τn)表示联合动作-观测历史,其中τi=(ai,0,oi,0,...,ai,t−1,oi,t−1)\tau_i=(a_{i,0},o_{i,0},...,a_{i,t-1},o_{i,t-1})τi=(ai,0,oi,0,...,ai,t1,oi,t1)为动作-观测历史,a=(a1,...,an)a=(a_1,...,a_n)a=(a1,...,an)表示联合动作。

QtotQ_{tot}Qtot为联合动作值函数, Q(τi,ai;θi)Q(\tau_i,a_i;\theta_i)Q(τi,ai;θi)为智能体i的局部动作值函数,局部值函数只依赖于每个智能体的局部观测。

VDN采用的方法就是直接相加求和的方式

Qtot=∑i=1nQ(τi,ai;θi)Q_{tot}=\sum_{i=1}^{n}Q(\tau_i,a_i;\theta_i)Qtot=i=1nQ(τi,ai;θi)

虽然Q(τi,ai;θi)Q(\tau_i,a_i;\theta_i)Q(τi,ai;θi)不是用来估计累积期望回报的,但是这里依然叫它为值函数。

分布式的策略可以通过对每个Q(τi,ai;θi)Q(\tau_i,a_i;\theta_i)Q(τi,ai;θi)取max得到。

VDN假设中心式的QtotQ_{tot}Qtot可以分解为各个QaQ_aQa的线性相加,而QMIX可以视为VDN的拓展

QMIX

QMIX 以视为 VDN 的拓展不同于上面两种方式,论文中提出学习一个中心式但可分解的 QtotQ_{tot}Qtot 即 QMIX,可以以集中、端到端的方式训练分散的策略。

QMIX 结构

这是 QMIX 的模型结构,宏观的示意在图中,QMIX 由 agent 网络、Mixing 网络和一组 hypernetwork 超网络(由一个网络为另一个网络生成参数)组成的体系结构来表示 QtotQ_totQtot

对于每一个智能体实现一个自己的 DRQN 网络,输入是自己的观察序列o和上一个动作u,计算得到自己的Q值,即图右。

Mixing 网络是一个前馈神经网络,接收所有的QaQ_aQa,输出为QtotQ_{tot}Qtot,为了满足单调性约束,混合网络的参数由单独的超参数网络生成,即图中左边红色方块,输入为全局状态信息,输出为混合网络的权重和偏置,权重需要满足大于等于0的要求,所以需要经过绝对值函数的激活层,当权重大于等于0时,才能满足最后的Qtot对于Qa的导数大于等于0。

核心——单调性约束

如果让能够评价整体策略优劣的中心式QtotQ_{tot}Qtot对每个智能体取得的QaQ_aQa的导数大于等于0

∂Qtot∂Qa≥0\frac{\partial Q_{tot}}{\partial Q_{a}}\geq 0QaQtot0

如果满足上面的要求,求解最大化QtotQ_{tot}Qtot等价于对每个QaQ_aQa做最大化

这样做的好处:

  1. 求解QtotQ_{tot}Qtot时更为方便,只需要对每个QaQ_aQa做最大化即可
  2. QtotQ_{tot}Qtot中可以显式地提取分布式执行的各个智能体的策略

在 aistudio 上运行 QMIX

在 PARL 的 example 中,同样有 QMIX 算法的实现与案例,感兴趣的童鞋可以参考一下

gitee:https://gitee.com/paddlepaddle/PARL/tree/develop/examples/qmix

github:https://github.com/PaddlePaddle/PARL/tree/develop/examples/qmix

环境准备

本文的项目地址(欢迎fork和star):
https://aistudio.baidu.com/aistudio/projectdetail/1698077

三个步骤:解压游戏的二进制文件、安装游戏接口与解压并放置地图

# 只解压一次
!unzip -P iagreetotheeula -q /home/aistudio/data/data72721/SC2.4.10.zip
# 重新启动环境时安装
!pip install git+https://gitee.com/wzduang/smac_copy.git
!mkdir StarCraftII/Maps
!unzip /home/aistudio/data/data76150/SMAC_Maps.zip
!mv SMAC_Maps StarCraftII/Maps

测试游戏环境

!python -m smac.examples.random_agents

QMIX 代码部分

执行部分

该部分采用了多线程,主要负责分布式地与环境进行交互,并通过管道将数据回传给训练部分

%%writefile runner.py
import qmix
from smac.env import StarCraft2Env
import numpy as np
import paddle
import paddle.nn.functional as F
from multiprocessing import Process, Lock, Pipe, Value
from threading import Thread
import timeclass Transform:def transform(self, tensor):raise NotImplementedErrordef infer_output_info(self, vshape_in, dtype_in):raise NotImplementedErrorclass OneHot(Transform):def __init__(self, out_dim):self.out_dim = out_dimdef transform(self, tensor):y_onehot = dense_to_onehot(tensor, num_classes=self.out_dim).squeeze()return y_onehot.astype("float32")def infer_output_info(self, vshape_in, dtype_in):return (self.out_dim,), "float32"def env_run(scenario, id, child_conn, locker, replay_buffer_size):# 定义地图以及回放路径env = StarCraft2Env(map_name=scenario, replay_dir="./replay/")env_info = env.get_env_info()# 进程 idprocess_id = id# 获取信息action_n = env_info["n_actions"]agent_nb = env_info["n_agents"]state_shape = env_info["state_shape"]obs_shape = env_info["obs_shape"] + agent_nb + action_n#self.episode_limit = env_info['episode_limit']agent_id_one_hot = OneHot(agent_nb)actions_one_hot = OneHot(action_n)agent_id_one_hot_array = []for agent_id in range(agent_nb):agent_id_one_hot_array.append(agent_id_one_hot.transform(np.array([agent_id])))agent_id_one_hot_array = np.array(agent_id_one_hot_array)actions_one_hot_reset = np.zeros((agent_nb, action_n), dtype="float32")state_zeros = np.zeros(state_shape)obs_zeros = np.zeros((agent_nb, obs_shape))actions_zeros = np.zeros([agent_nb, 1])reward_zeros = 0agents_available_actions_zeros = np.zeros((agent_nb, action_n))agents_available_actions_zeros[:,0] = 1child_conn.send(id)while True:while True:data = child_conn.recv()if data == 'save':env.save_replay()child_conn.send('save ok.')elif data == 'close':env.close()exit()else:breaklocker.acquire()env.reset()locker.release()episode_reward = 0episode_step = 0obs = np.array(env.get_obs())obs = np.concatenate([obs, actions_one_hot_reset, agent_id_one_hot_array], axis=-1)state = np.array(env.get_state())terminated = Falsewhile not terminated:agents_available_actions = []for agent_id in range(agent_nb):agents_available_actions.append(env.get_avail_agent_actions(agent_id))child_conn.send(["actions", obs, agents_available_actions])actions = child_conn.recv()reward, terminated, _ = env.step(actions)agents_available_actions2 = []for agent_id in range(agent_nb):agents_available_actions2.append(env.get_avail_agent_actions(agent_id))obs2 = np.array(env.get_obs())actions_one_hot_agents = []for action in actions:actions_one_hot_agents.append(actions_one_hot.transform(np.array(action)))actions_one_hot_agents = np.array(actions_one_hot_agents)obs2 = np.concatenate([obs2, actions_one_hot_agents, agent_id_one_hot_array], axis=-1)state2 = np.array(env.get_state())child_conn.send(["replay_buffer", state, actions, [reward], [terminated], obs, agents_available_actions, 0])episode_reward += rewardepisode_step += 1obs = obs2state = state2for _ in range(episode_step, replay_buffer_size):child_conn.send(["actions", obs_zeros, agents_available_actions_zeros])child_conn.send(["replay_buffer", state_zeros, actions_zeros, [reward_zeros], [True], obs_zeros, agents_available_actions_zeros, 1])child_conn.recv()child_conn.send(["episode_end", episode_reward, episode_step, env.win_counted])class Runner:def __init__(self, arglist, scenario, actors):env = StarCraft2Env(map_name=scenario, replay_dir="./replay/")env_info = env.get_env_info()self.actors = actorsself.scenario = scenarioself.n_actions = env_info["n_actions"]self.n_agents = env_info["n_agents"]self.state_shape = env_info["state_shape"]self.obs_shape = env_info["obs_shape"] + self.n_agents + self.n_actionsself.episode_limit = env_info['episode_limit']self.qmix_algo = qmix.QMix(arglist.train, self.n_agents, self.obs_shape, self.state_shape, self.n_actions, 0.0005, replay_buffer_size=1000)# 验证模式if arglist.train == False:self.qmix_algo.load_model('./saved/agents_' + str(arglist.load_episode_saved))print('Load model agent ', str(arglist.load_episode_saved))self.episode_global_step = 0self.episode = 0self.process_com = []self.locker = Lock()for idx in range(self.actors):parent_conn, child_conn = Pipe()Process(target=env_run, args=[self.scenario, idx, child_conn, self.locker, self.episode_limit]).start()self.process_com.append(parent_conn)for process_conn in self.process_com:process_id = process_conn.recv()print(process_id, " is ready !")passdef reset(self):self.qmix_algo.on_reset(self.actors)self.episodes = []self.episode_reward = []self.episode_step = []self.replay_buffers = []self.win_counted_array = []episode_managed = self.episodefor _ in range(self.actors):self.episodes.append(episode_managed)self.episode_reward.append(0)self.episode_step.append(0)self.win_counted_array.append(False)self.replay_buffers.append(qmix.ReplayBuffer(self.episode_limit))episode_managed += 1for process_conn in self.process_com:process_conn.send("Go !")def run(self):episode_done = 0process_size = len(self.process_com)available_to_send = np.array([True for _ in range(self.actors)])while True:obs_batch = []available_batch = []actions = Nonefor idx, process_conn in enumerate(self.process_com):data = process_conn.recv()if data[0] == "actions":obs_batch.append(data[1])available_batch.append(data[2])if idx == process_size - 1:obs_batch = np.concatenate(obs_batch, axis=0)available_batch = np.concatenate(available_batch, axis=0)actions = self.qmix_algo.act(self.actors, paddle.to_tensor(obs_batch), paddle.to_tensor(available_batch))elif data[0] == "replay_buffer":self.replay_buffers[idx].add(data[1], data[2], data[3], data[4], data[5], data[6], data[7])elif data[0] == "episode_end":self.episode_reward[idx] = data[1]self.episode_step[idx] = data[2]self.win_counted_array[idx] = data[3]available_to_send[idx] = Falseepisode_done += 1if actions is not None:for idx_proc, process in enumerate(self.process_com):if available_to_send[idx_proc]:process.send(actions[idx_proc])if episode_done >= self.actors:breakself.episode += self.actorsself.episode_global_step += max(self.episode_step)self.qmix_algo.decay_epsilon_greddy(self.episode_global_step)return self.replay_buffersdef save(self):for process in self.process_com:process.send('save')data = process.recv()print(data)def close(self):for process in self.process_com:process.send('close')def dense_to_onehot(labels_dense, num_classes=10):num_labels = labels_dense.shape[0]index_offset = np.arange(num_labels) * num_classeslabels_onehot = np.zeros((num_labels,num_classes))# 展平的索引值对应相加,然后得到精确索引并修改 labels_onehot 中的每一个值labels_onehot.flat[index_offset + labels_dense.ravel()] = 1return labels_onehot
Overwriting runner.py

定义 agent 网络结构

%%writefile rnn_agent.py
import paddle
import paddle.nn as nn
import paddle.nn.functional as Fclass RNNAgent(nn.Layer):def __init__(self, input_shape, rnn_hidden_dim=64, n_actions=1):super(RNNAgent, self).__init__()self.rnn_hidden_dim = rnn_hidden_dimprint('input_shape: ', input_shape)self.fc1 = nn.Linear(input_shape, rnn_hidden_dim)self.rnn = nn.GRUCell(rnn_hidden_dim, rnn_hidden_dim)self.fc2 = nn.Linear(rnn_hidden_dim, n_actions)def init_hidden(self):return paddle.zeros([1, self.rnn_hidden_dim])def forward(self, inputs, hidden_state):x = F.relu(self.fc1(inputs))h_in = hidden_state.reshape([-1, self.rnn_hidden_dim])h, _ = self.rnn(x, h_in)q = self.fc2(h)return q, hdef update(self, agent):self.load_dict(agent.state_dict())
Overwriting rnn_agent.py

Mixing 网络定义

%%writefile qmixer.py
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import numpy as npclass QMixer(nn.Layer):def __init__(self, n_agents, state_shape, mixing_embed_dim=64):super(QMixer, self).__init__()self.n_agents = n_agentsself.state_dim = int(np.prod(state_shape))self.embed_dim = mixing_embed_dimself.hyper_w_1 = nn.Linear(self.state_dim, self.embed_dim * self.n_agents)self.hyper_w_final = nn.Linear(self.state_dim, self.embed_dim)# 隐藏层的偏置self.hyper_b_1 = nn.Linear(self.state_dim, self.embed_dim)# 最后一层的 V(s)self.V = nn.Sequential(nn.Linear(self.state_dim, self.embed_dim),nn.ReLU(),nn.Linear(self.embed_dim, 1))def forward(self, agent_qs, states):bs = agent_qs.shape[0]states = states.reshape([-1, self.state_dim])agent_qs = agent_qs.reshape([-1, 1, self.n_agents])# First layerw1 = paddle.abs(self.hyper_w_1(states))b1 = self.hyper_b_1(states)w1 = w1.reshape([-1, self.n_agents, self.embed_dim])b1 = b1.reshape([-1, 1, self.embed_dim])hidden = F.elu(paddle.bmm(agent_qs, w1) + b1)# Second layerw_final = paddle.abs(self.hyper_w_final(states))w_final = w_final.reshape([-1, self.embed_dim, 1])# State-dependent biasv = self.V(states).reshape([-1, 1, 1])# Compute final outputy = paddle.bmm(hidden, w_final) + v# Reshape and returnq_tot = y.reshape([bs, -1, 1])return q_totdef update(self, agent):self.load_dict(agent.state_dict())
Overwriting qmixer.py

策略与训练部分

包括 ϵ−greedy\epsilon-greedyϵgreedy 探索,经验回放与 Qmix agent 主体

%%writefile qmix.py
import rnn_agent
import qmixer
import paddle
import paddle.nn.functional as F
from paddle.distribution import Categorical
import numpy as np
import random
from collections import deque# 探索
class EpsilonGreedy:def __init__(self, action_nb, agent_nb, final_step, epsilon_start=float(1), epsilon_end=0.05):self.epsilon = epsilon_startself.initial_epsilon = epsilon_startself.epsilon_end = epsilon_endself.action_nb = action_nbself.final_step = final_stepself.agent_nb = agent_nbdef act(self, value_action, avail_actions):if np.random.random() > self.epsilon:action = value_action.argmax(axis=-1).detach().numpy()else:action = Categorical(avail_actions.cast("float32")).sample([1]).squeeze().cast("int64").detach().numpy()return actiondef epislon_decay(self, step):progress = step / self.final_stepdecay = self.initial_epsilon - progressif decay <= self.epsilon_end:decay = self.epsilon_endself.epsilon = decay# 经验回放
class ReplayBuffer(object):def __init__(self, buffer_size, random_seed=314):"""The right side of the deque contains the most recent experiences """self.buffer_size = buffer_sizeself.count = 0self.buffer = deque()def add(self, s, a, r, t, obs, available_actions, filled):experience = [s, a, r, t, obs, available_actions, np.array([filled])]if self.count < self.buffer_size: self.buffer.append(experience)self.count += 1else:self.buffer.popleft()self.buffer.append(experience)def size(self):return self.countdef sample_batch(self, batch_size):batch = []for idx in range(batch_size):batch.append(self.buffer[idx])batch = np.array(batch)s_batch = np.array([_[0] for _ in batch], dtype='float32')a_batch = np.array([_[1] for _ in batch], dtype='float32')r_batch = np.array([_[2] for _ in batch])t_batch = np.array([_[3] for _ in batch])obs_batch = np.array([_[4] for _ in batch], dtype='float32')available_actions_batch = np.array([_[5] for _ in batch], dtype='float32')filled_batch = np.array([_[6] for _ in batch], dtype='float32')return s_batch, a_batch, r_batch, t_batch, obs_batch, available_actions_batch, filled_batchdef clear(self):self.buffer.clear()self.count = 0# 批量采样
class EpisodeBatch:def __init__(self, buffer_size, random_seed=314):self.buffer_size = buffer_sizeself.count = 0self.buffer = deque()def reset(self):passdef add(self, replay_buffer):if self.count < self.buffer_size: self.buffer.append(replay_buffer)self.count += 1else:self.buffer.popleft()self.buffer.append(replay_buffer)def _get_max_episode_len(self, batch):max_episode_len = 0for replay_buffer in batch:_, _, _, t, _, _, _ = replay_buffer.sample_batch(replay_buffer.size())for idx, t_idx in enumerate(t):if t_idx == True:if idx > max_episode_len:max_episode_len = idx + 1breakreturn max_episode_lendef sample_batch(self, batch_size):batch = []if self.count < batch_size:batch = random.sample(self.buffer, self.count)else:batch = random.sample(self.buffer, batch_size)episode_len = self._get_max_episode_len(batch)s_batch, a_batch, r_batch, t_batch, obs_batch, available_actions_batch, filled_batch = [], [], [], [], [], [], []for replay_buffer in batch:s, a, r, t, obs, available_actions, filled = replay_buffer.sample_batch(episode_len)s_batch.append(s)a_batch.append(a)r_batch.append(r)t_batch.append(t)obs_batch.append(obs)available_actions_batch.append(available_actions)filled_batch.append(filled)filled_batch = np.array(filled_batch)r_batch = np.array(r_batch)t_batch = np.array(t_batch)a_batch = np.array(a_batch)obs_batch = np.array(obs_batch)available_actions_batch = np.array(available_actions_batch)return s_batch, a_batch, r_batch, t_batch, obs_batch, available_actions_batch, filled_batch, episode_lendef size(self):return self.count# Qmix agent
class QMix:def __init__(self, training, agent_nb, obs_shape, states_shape, action_n, lr, gamma=0.99, batch_size=32, replay_buffer_size=10000, update_target_network=200, final_step=50000): #32self.training = trainingself.gamma = gammaself.batch_size = batch_sizeself.update_target_network = update_target_networkself.hidden_states = Noneself.target_hidden_states = Noneself.agent_nb = agent_nbself.action_n = action_nself.state_shape = states_shapeself.obs_shape = obs_shapeself.epsilon_greedy = EpsilonGreedy(action_n, agent_nb, final_step)self.episode_batch = EpisodeBatch(replay_buffer_size)self.agents = rnn_agent.RNNAgent(obs_shape, n_actions=action_n)self.target_agents = rnn_agent.RNNAgent(obs_shape, n_actions=action_n)self.qmixer = qmixer.QMixer(agent_nb, states_shape, mixing_embed_dim=32)self.target_qmixer = qmixer.QMixer(agent_nb, states_shape, mixing_embed_dim=32)self.target_agents.update(self.agents)self.target_qmixer.update(self.qmixer)self.params = list(self.agents.parameters())self.params += self.qmixer.parameters()clip_grad = paddle.nn.ClipGradByNorm(clip_norm=10)self.optimizer = paddle.optimizer.RMSProp(parameters=self.params, learning_rate=lr, rho=0.99, epsilon=0.00001, grad_clip=clip_grad)def save_model(self, filename):paddle.save(self.agents.state_dict(), filename)def load_model(self, filename):self.agents.load_dict(paddle.load(filename))self.agents.eval()def _init_hidden_states(self, batch_size):self.hidden_states = self.agents.init_hidden().unsqueeze(0).expand([batch_size, self.agent_nb, -1])self.target_hidden_states = self.target_agents.init_hidden().unsqueeze(0).expand([batch_size, self.agent_nb, -1])def decay_epsilon_greddy(self, global_steps):self.epsilon_greedy.epislon_decay(global_steps)def on_reset(self, batch_size):self._init_hidden_states(batch_size)def update_targets(self, episode):if episode % self.update_target_network == 0 and self.training:self.target_agents.update(self.agents)self.target_qmixer.update(self.qmixer)passdef train(self):if self.training and self.episode_batch.size() > self.batch_size:for _ in range(2):self._init_hidden_states(self.batch_size)s_batch, a_batch, r_batch, t_batch, obs_batch, available_actions_batch, filled_batch, episode_len = self.episode_batch.sample_batch(self.batch_size)r_batch = r_batch[:, :-1]a_batch = a_batch[:, :-1]t_batch = t_batch[:, :-1]filled_batch = filled_batch[:, :-1]mask = (1 - filled_batch) * (1 - t_batch)r_batch = paddle.to_tensor(r_batch, dtype="float32")t_batch = paddle.to_tensor(t_batch, dtype="float32")mask = paddle.to_tensor(mask, dtype="float32")a_batch = paddle.to_tensor(a_batch, dtype="int64")mac_out = []for t in range(episode_len):obs = obs_batch[:, t]obs = np.concatenate(obs, axis=0)obs = paddle.to_tensor(obs, dtype="float32")agent_actions, self.hidden_states = self.agents(obs, self.hidden_states)agent_actions = agent_actions.reshape([self.batch_size, self.agent_nb, -1])mac_out.append(agent_actions)mac_out = paddle.stack(mac_out, axis=1)_a_batch = F.one_hot(a_batch.detach(), mac_out[:, :-1].shape[-1]).squeeze(-2)chosen_action_qvals = mac_out[:, :-1]chosen_action_qvals = chosen_action_qvals.multiply(_a_batch).sum(-1)target_mac_out = []for t in range(episode_len):obs = obs_batch[:, t]obs = np.concatenate(obs, axis=0)obs = paddle.to_tensor(obs, dtype="float32")agent_actions, self.target_hidden_states = self.target_agents(obs, self.target_hidden_states)agent_actions = agent_actions.reshape([self.batch_size, self.agent_nb, -1])target_mac_out.append(agent_actions)target_mac_out = paddle.stack(target_mac_out[1:], axis=1)available_actions_batch = paddle.to_tensor(available_actions_batch)_condition_ = paddle.zeros(target_mac_out.shape)_condition_ = _condition_ - 9999999target_mac_out = paddle.where(available_actions_batch[:, 1:] == 0, _condition_, target_mac_out)target_max_qvals = target_mac_out.max(axis=3)states = paddle.to_tensor(s_batch, dtype="float32")chosen_action_qvals = self.qmixer(chosen_action_qvals, states[:, :-1])target_max_qvals = self.target_qmixer(target_max_qvals, states[:, 1:])yi = r_batch + self.gamma * (1 - t_batch) * target_max_qvalstd_error = (chosen_action_qvals - yi.detach())mask = mask.expand_as(td_error)masked_td_error = td_error * maskloss = (masked_td_error ** 2).sum() / mask.sum()print('loss:', loss.numpy().item())self.optimizer.clear_grad()loss.backward()self.optimizer.step()def act(self, batch, obs, agents_available_actions):value_action, self.hidden_states = self.agents(paddle.to_tensor(obs, dtype="float32"), self.hidden_states)condition = paddle.zeros(value_action.shape)condition = condition - int(1e10)value_action = paddle.where(agents_available_actions == 0, condition, value_action)if self.training:value_action = self.epsilon_greedy.act(value_action, agents_available_actions)else:value_action = np.argmax(value_action.numpy(), -1)value_action = value_action.reshape([batch, self.agent_nb, -1])return value_action
Overwriting qmix.py

主函数部分(负责定义与调度)

训练时使用
python main.py --train

验证时
python main.py --load-episode-saved xxx --scenario xxx

分别代表加载的模型编号与地图名

%%writefile main.py
from smac.env import StarCraft2Env
import numpy as np
import qmix
import paddle
import os
import argparse
from time import gmtime, strftime
from visualdl import LogWriter
import runnerdef main(arglist):# 用于保存max_reward = 15current_time = strftime("%Y-%m-%d-%H-%M-%S", gmtime())writer = LogWriter(log_dir='./logs/' + current_time + '-snake', comment= "Star Craft II")# 线程数actors = 15if arglist.train == False:actors = 1env_runner = runner.Runner(arglist, arglist.scenario, actors)while arglist.train or env_runner.episode < 1:env_runner.reset()replay_buffers = env_runner.run()for replay_buffer in replay_buffers:env_runner.qmix_algo.episode_batch.add(replay_buffer)env_runner.qmix_algo.train()for episode in env_runner.episodes:env_runner.qmix_algo.update_targets(episode)for idx, episode in enumerate(env_runner.episodes):if episode % int(1e6) == 0 and arglist.train:env_runner.qmix_algo.save_model('./saved/agents_' + str(episode))if env_runner.episode_reward[idx] >= max_reward:max_reward = env_runner.episode_reward[idx]env_runner.qmix_algo.save_model('./saved/agents_reward_'+ str(env_runner.episode_reward[idx]) + "_" + str(episode))passprint(env_runner.win_counted_array)for idx, episode in enumerate(env_runner.episodes):print("Total reward in episode {} = {} and global step: {}".format(episode, env_runner.episode_reward[idx], env_runner.episode_global_step))if arglist.train:writer.add_scalar('Reward', paddle.to_tensor(env_runner.episode_reward[idx]), episode)writer.add_scalar('Victory', paddle.to_tensor(env_runner.win_counted_array[idx]), episode)if arglist.train == False:env_runner.save()env_runner.close()def parse_args():parser = argparse.ArgumentParser('SCII parser for QMIX')parser.add_argument('--train', action='store_true')parser.add_argument('--load-episode-saved', type=int, default=105500)parser.add_argument('--scenario', type=str, default="6h_vs_8z")return parser.parse_args()if __name__ == "__main__":try:os.mkdir('./saved')except OSError:print ("Creation of the directory failed")else:print ("Successfully created the directory")arglist = parse_args()main(arglist)
Overwriting main.py

开始训练

建议使用 visual 打开 run 文件夹进行训练曲线的监测,当训练收敛并稳定时(指 reward 达到最高的 20 左右),手动终止训练

!python main.py --train

训练效果

8 枪兵 vs 8 枪兵

# 训练效果展示

结果可视化

安装星际争霸II,打开回放文件,就可以直接查看完整过程了

这里放出我训练的效果

用QMIX玩星际争霸(枪兵8v8)

个人简介

姓名:王子瑞

四川大学本科2018级电气工程学院自动化专业在读

感兴趣的方向:游戏AI(不完美信息博弈与深度增强学习),机器人应用,视觉SLAM,视觉深度学习

目前主要学习方向:自动控制算法,传统图像处理,强化学习,轻量化模型与边缘设备部署

AgentMaker 代码仓库:https://github.com/AgentMaker

玩转星际争霸局部战斗 —— QMIX相关推荐

  1. 跟随Deepmind使用AI玩《星际争霸2》

    近来人工智能的火爆,deepmind已经开始研究玩<星际争霸2>了,这个游戏比较复杂,是否可以像围棋一样超越人类呢? 要想学习这个,可以到DM的开源网站下载源码: https://gith ...

  2. AlphaStar再升级:多智能体强化学习玩《星际争霸2》,排名超99.8%人类玩家

    [进群了解最新免费公开课.技术沙龙信息] 作者 | DeepMind 译者 | 刘畅 编辑 | Jane 出品 | AI科技大本营(ID:rgznai100) AlphaStar是第一个在没有任何游戏 ...

  3. linux系统可以玩星际争霸吧,能在Linux系统下运行Windows软件的Wine

    Wine,是一款优秀的Linux系统平台下的模拟器软件,用来将Windows系统下的软件在Linux系统下稳定运行,如何配置请看下文 大家看看下面这幅图图1),你一定会惊奇地问:"哇!这不是 ...

  4. 【转贴】看星际争霸人工智能伯克利如何“主宰”了“2010星际争霸人工智能挑战赛”...

    我们集中在伯克利校区里的一个会议室里,局域网的设备散落在我们周围.桌面上堆满了电脑和披萨,屏幕上投影的是一场<星际争霸>的比赛.Oriol Vinyals,一位计算机科学的博士生,正指挥着 ...

  5. 【星际人工智能】虫群版天网:伯克利如何“主宰”了“2010星际争霸人工智能挑战赛”

    我们集中在伯克利校区里的一个会议室里,局域网的设备散落在我们周围.桌面上堆满了电脑和披萨,屏幕上投影的是一场<星际争霸>的比赛.Oriol Vinyals,一位计算机科学的博士生,正指挥着 ...

  6. DeepMind科学家:AI对战《星际争霸》胜算几何?

    记者 | 周翔 AI科技大本营1月29日消息,在<麻省理工科技评论>与深科技举办的新兴科技峰会上,谷歌DeepMind科学家Oriol Vinyals在会上发表了主题为<AI对战星际 ...

  7. 深度 | 人工智能的游戏征途:超级玛丽简直小菜一碟,星际争霸、LOL才是大boss...

    人类玩家依然能在类似<星际争霸>这样的游戏中夺冠 AP Photo/David Goldman "深蓝"打败人类距今已有20年,离 AlphaGo和李世石的世纪围棋之战 ...

  8. 莱比锡爆料:《星际争霸2》估计明年也没戏

    这回的德国莱比锡游戏展上,又爆出了点坏消息:即时战略史诗巨作<星际争霸2>估计明年也没戏! 根据游戏展上的可靠消息,<星际争霸2>的BETA还要几个月才能出,如果他们还准备出B ...

  9. 为什么我星际争霸画面是窗口_为什么我喜欢看全球星际联盟

    为什么我星际争霸画面是窗口 The Global StarCraft II League (GSL) is a StarCraft II: Legacy of the Void (SC 2) tour ...

最新文章

  1. R语言message函数、warning()函数和stop()函数输出程序运行健康状态信息实战
  2. 独家 | 神经网络的对抗性攻击:快速梯度符号方法的探索(附链接)
  3. 辨别 Windows CE, Windows XP Embedded 和 Windows Mobile
  4. 线程间通信 GET POST
  5. 信息学奥赛一本通 1939:【07NOIP普及组】纪念品分组 | P1094 [NOIP2007 普及组] 纪念品分组
  6. How to Install Snapd and Snap applications on CentOS 7
  7. SAP License:SAP Concur是什么?
  8. 设置sqlplus 显示列数和行数
  9. C#生成JSON数据格式的函数
  10. javascript 字符串中单引号和双引号区别
  11. JavaWeb快速入门(一)——新建JavaWeb项目
  12. matlab液体湿润模拟,【水文模型】01 三水源新安江模型
  13. 计算机word设置渐变填充,word文本效果在哪里?怎么设置填充渐变颜色?
  14. android设备(手机) 投屏 到mac电脑
  15. PCIe总线的参考时钟与同步时钟的差异
  16. Lrc文件与音乐的同步显示
  17. exoplay切换全屏_Flutter 实现视频全屏播放逻辑及解析
  18. 响应式网页上banner使用,unslider
  19. R-CNN+SPP-NET+Fast-R-CNN+Faster-R-CNN+YOLO+SSD阅读笔记
  20. 无法从服务器中获取信息,iOS7完美越狱工具evasi0n7越狱出错提示无法从网络服务器获得信息...

热门文章

  1. 大师如何炼成 罗杰斯“不从众”的成功之道
  2. 2019.3 计算机考试资料大放送
  3. 2022稳定学习年度研究进展系列报告丨精华观点总结
  4. Swift--调用系统导航Apple map
  5. Kotlin 插件的落幕,ViewBinding 的崛起,android适配问题
  6. 利用python每日自动邮件自动推送股票新闻
  7. 用 Python selenium爬取股票新闻并存入mysql数据库中带翻页功能demo可下载
  8. 第一届“多模态自然语言处理研讨会”精彩回顾(免费获取PPT)
  9. ORACLE统计报表
  10. HKEY_CURRENT_USER下在服务中注册自启动exe