Dueling Double Deep Q Network(D3QN)算法结合了Double DQN和Dueling DQN算法的思想,进一步提升了算法的性能。如果对Doubel DQN和Dueling DQN算法还不太了解的话,可以参考我的这两篇博文:深度强化学习-Double DQN算法原理与代码和深度强化学习-Dueling DQN算法原理与代码,分别详细讲述了这两个算法的原理以及代码实现。本文就带领大家了解一下D3QN算法,代码链接见下方。

代码:https://github.com/indigoLovee/D3QN

喜欢的话可以点个star呢。

1 D3QN算法简介

Dueling Double Deep Q Network(D3QN)算法是在Dueling DQN算法的基础上融入了Doubel DQN算法的思想,它与Dueling DQN算法唯一的区别在于计算目标值的方式。在Dueling DQN算法中,目标值的计算方式为

即利用目标网络获取状态下所有动作的动作价值,然后基于最优动作价值计算目标值。由于这里的最大化操作,导致算法存在“过估计”问题,影响决策的准确性。其中表示目标网络参数。

在D3QN算法中,目标值的计算方式为

即利用评估网络获取状态下最优动作价值对应的动作,然后利用目标网络计算该动作的动作价值,从而得到目标值。通过两个网络的交互,有效避免了算法的“过估计”问题。其中分别表示评估网络和目标网络的参数。

这其实就是D3QN算法的核心所在啦,如果已经熟悉Dueling DQN和Doubel DQN算法的话,这个算法其实是非常容易理解的。

2 D3QN算法代码

经验回放采用集中式均匀回放,代码如下(脚本buffer.py):

import numpy as npclass ReplayBuffer:def __init__(self, state_dim, action_dim, max_size, batch_size):self.mem_size = max_sizeself.batch_size = batch_sizeself.mem_cnt = 0self.state_memory = np.zeros((self.mem_size, state_dim))self.action_memory = np.zeros((self.mem_size, ))self.reward_memory = np.zeros((self.mem_size, ))self.next_state_memory = np.zeros((self.mem_size, state_dim))self.terminal_memory = np.zeros((self.mem_size, ), dtype=np.bool)def store_transition(self, state, action, reward, state_, done):mem_idx = self.mem_cnt % self.mem_sizeself.state_memory[mem_idx] = stateself.action_memory[mem_idx] = actionself.reward_memory[mem_idx] = rewardself.next_state_memory[mem_idx] = state_self.terminal_memory[mem_idx] = doneself.mem_cnt += 1def sample_buffer(self):mem_len = min(self.mem_size, self.mem_cnt)batch = np.random.choice(mem_len, self.batch_size, replace=False)states = self.state_memory[batch]actions = self.action_memory[batch]rewards = self.reward_memory[batch]states_ = self.next_state_memory[batch]terminals = self.terminal_memory[batch]return states, actions, rewards, states_, terminalsdef ready(self):return self.mem_cnt > self.batch_size

目标网络的更新方式为软更新,D3QN算法的实现代码如下(脚本D3QN.py):

import torch as T
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from buffer import ReplayBufferdevice = T.device("cuda:0" if T.cuda.is_available() else "cpu")class DuelingDeepQNetwork(nn.Module):def __init__(self, alpha, state_dim, action_dim, fc1_dim, fc2_dim):super(DuelingDeepQNetwork, self).__init__()self.fc1 = nn.Linear(state_dim, fc1_dim)self.fc2 = nn.Linear(fc1_dim, fc2_dim)self.V = nn.Linear(fc2_dim, 1)self.A = nn.Linear(fc2_dim, action_dim)self.optimizer = optim.Adam(self.parameters(), lr=alpha)self.to(device)def forward(self, state):x = T.relu(self.fc1(state))x = T.relu(self.fc2(x))V = self.V(x)A = self.A(x)Q = V + A - T.mean(A, dim=-1, keepdim=True)return Qdef save_checkpoint(self, checkpoint_file):T.save(self.state_dict(), checkpoint_file)def load_checkpoint(self, checkpoint_file):self.load_state_dict(T.load(checkpoint_file))class D3QN:def __init__(self, alpha, state_dim, action_dim, fc1_dim, fc2_dim, ckpt_dir,gamma=0.99, tau=0.005, epsilon=1.0, eps_end=0.01, eps_dec=5e-7,max_size=1000000, batch_size=256):self.gamma = gammaself.tau = tauself.epsilon = epsilonself.eps_min = eps_endself.eps_dec = eps_decself.batch_size = batch_sizeself.checkpoint_dir = ckpt_dirself.action_space = [i for i in range(action_dim)]self.q_eval = DuelingDeepQNetwork(alpha=alpha, state_dim=state_dim, action_dim=action_dim,fc1_dim=fc1_dim, fc2_dim=fc2_dim)self.q_target = DuelingDeepQNetwork(alpha=alpha, state_dim=state_dim, action_dim=action_dim,fc1_dim=fc1_dim, fc2_dim=fc2_dim)self.memory = ReplayBuffer(state_dim=state_dim, action_dim=action_dim,max_size=max_size, batch_size=batch_size)self.update_network_parameters(tau=1.0)def update_network_parameters(self, tau=None):if tau is None:tau = self.taufor q_target_params, q_eval_params in zip(self.q_target.parameters(), self.q_eval.parameters()):q_target_params.data.copy_(tau * q_eval_params + (1 - tau) * q_target_params)def remember(self, state, action, reward, state_, done):self.memory.store_transition(state, action, reward, state_, done)def decrement_epsilon(self):self.epsilon = self.epsilon - self.eps_dec \if self.epsilon > self.eps_min else self.eps_mindef choose_action(self, observation, isTrain=True):state = T.tensor([observation], dtype=T.float).to(device)q_vals = self.q_eval.forward(state)action = T.argmax(q_vals).item()if (np.random.random() < self.epsilon) and isTrain:action = np.random.choice(self.action_space)return actiondef learn(self):if not self.memory.ready():returnstates, actions, rewards, next_states, terminals = self.memory.sample_buffer()batch_idx = T.arange(self.batch_size, dtype=T.long).to(device)states_tensor = T.tensor(states, dtype=T.float).to(device)actions_tensor = T.tensor(actions, dtype=T.long).to(device)rewards_tensor = T.tensor(rewards, dtype=T.float).to(device)next_states_tensor = T.tensor(next_states, dtype=T.float).to(device)terminals_tensor = T.tensor(terminals).to(device)with T.no_grad():q_ = self.q_target.forward(next_states_tensor)max_actions = T.argmax(self.q_eval.forward(next_states_tensor), dim=-1)q_[terminals_tensor] = 0.0target = rewards_tensor + self.gamma * q_[batch_idx, max_actions]q = self.q_eval.forward(states_tensor)[batch_idx, actions_tensor]loss = F.mse_loss(q, target.detach())self.q_eval.optimizer.zero_grad()loss.backward()self.q_eval.optimizer.step()self.update_network_parameters()self.decrement_epsilon()def save_models(self, episode):self.q_eval.save_checkpoint(self.checkpoint_dir + 'Q_eval/D3QN_q_eval_{}.pth'.format(episode))print('Saving Q_eval network successfully!')self.q_target.save_checkpoint(self.checkpoint_dir + 'Q_target/D3QN_Q_target_{}.pth'.format(episode))print('Saving Q_target network successfully!')def load_models(self, episode):self.q_eval.load_checkpoint(self.checkpoint_dir + 'Q_eval/D3QN_q_eval_{}.pth'.format(episode))print('Loading Q_eval network successfully!')self.q_target.load_checkpoint(self.checkpoint_dir + 'Q_target/D3QN_Q_target_{}.pth'.format(episode))print('Loading Q_target network successfully!')

算法仿真环境为gym库中的LunarLander-v2,因此需要先配置好gym库。进入Anaconda3中对应的Python环境中,执行下面的指令

pip install gym

但是,这样安装的gym库只包括少量的内置环境,如算法环境、简单文字游戏和经典控制环境,无法使用LunarLander-v2。因此还需要安装一些其他依赖项,具体可以参考我的这篇博文:AttributeError: module ‘gym.envs.box2d‘ has no attribute ‘LunarLander‘ 解决办法

让智能体在环境中训练500轮,训练代码如下(脚本train.py):

import gym
import numpy as np
import argparse
from utils import create_directory, plot_learning_curve
from D3QN import D3QNparser = argparse.ArgumentParser()
parser.add_argument('--max_episodes', type=int, default=500)
parser.add_argument('--ckpt_dir', type=str, default='./checkpoints/D3QN/')
parser.add_argument('--reward_path', type=str, default='./output_images/reward.png')
parser.add_argument('--epsilon_path', type=str, default='./output_images/epsilon.png')args = parser.parse_args()def main():env = gym.make('LunarLander-v2')agent = D3QN(alpha=0.0003, state_dim=env.observation_space.shape[0], action_dim=env.action_space.n,fc1_dim=256, fc2_dim=256, ckpt_dir=args.ckpt_dir, gamma=0.99, tau=0.005, epsilon=1.0,eps_end=0.05, eps_dec=5e-4, max_size=1000000, batch_size=256)create_directory(args.ckpt_dir, sub_dirs=['Q_eval', 'Q_target'])total_rewards, avg_rewards, epsilon_history = [], [], []for episode in range(args.max_episodes):total_reward = 0done = Falseobservation = env.reset()while not done:action = agent.choose_action(observation, isTrain=True)observation_, reward, done, info = env.step(action)agent.remember(observation, action, reward, observation_, done)agent.learn()total_reward += rewardobservation = observation_total_rewards.append(total_reward)avg_reward = np.mean(total_rewards[-100:])avg_rewards.append(avg_reward)epsilon_history.append(agent.epsilon)print('EP:{} Reward:{} Avg_reward:{} Epsilon:{}'.format(episode+1, total_reward, avg_reward, agent.epsilon))if (episode + 1) % 50 == 0:agent.save_models(episode+1)episodes = [i+1 for i in range(args.max_episodes)]plot_learning_curve(episodes, avg_rewards, title='Reward', ylabel='reward',figure_file=args.reward_path)plot_learning_curve(episodes, epsilon_history, title='Epsilon', ylabel='epsilon',figure_file=args.epsilon_path)if __name__ == '__main__':main()

训练时还会用到画图函数和创建文件夹函数,它们均放置在utils.py脚本中,具体代码如下:

import os
import matplotlib.pyplot as pltdef create_directory(path: str, sub_dirs: list):for sub_dir in sub_dirs:if os.path.exists(path + sub_dir):print(path + sub_dir + 'is already exist!')else:os.makedirs(path + sub_dir, exist_ok=True)print(path + sub_dir + 'create successfully!')def plot_learning_curve(episodes, records, title, ylabel, figure_file):plt.figure()plt.plot(episodes, records, linestyle='-', color='r')plt.title(title)plt.xlabel('episode')plt.ylabel(ylabel)plt.show()plt.savefig(figure_file)

仿真结果如下图所示:

平均累积奖励曲线

epsilon变化曲线

通过平均累积奖励可以看出,D3QN算法大约在300步左右时趋于收敛。

深度强化学习-D3QN算法原理与代码相关推荐

  1. 深度强化学习-DQN算法原理与代码

    DQN算法是DeepMind团队提出的一种深度强化学习算法,在许多电动游戏中达到人类玩家甚至超越人类玩家的水准,本文就带领大家了解一下这个算法,论文和代码的链接见下方. 论文:Human-level ...

  2. 深度强化学习-DDPG算法原理和实现

    全文共3077个字,8张图,预计阅读时间15分钟. 基于值的强化学习算法的基本思想是根据当前的状态,计算采取每个动作的价值,然后根据价值贪心的选择动作.如果我们省略中间的步骤,即直接根据当前的状态来选 ...

  3. 深度强化学习——DQN算法原理

    DQN算法原理 一.DQN算法是什么 二.DQN训练过程 三.经验回放 (Experience Replay) 四.目标网络(Target Network) 1.自举(Bootstrapping) 2 ...

  4. 深度强化学习-Actor-Critic算法原理和实现

    全文共2543个字,2张图,预计阅读时间15分钟. 基于值的强化学习算法的基本思想是根据当前的状态,计算采取每个动作的价值,然后根据价值贪心的选择动作.如果我们省略中间的步骤,即直接根据当前的状态来选 ...

  5. 深度强化学习DDPG算法高性能Pytorch代码(改写自spinningup,低环境依赖,低阅读障碍)

    写在前面 DRL各种算法在github上各处都是,例如莫凡的DRL代码.ElegantDRL(推荐,易读性NO.1) 很多代码不是原算法的最佳实现,在具体实现细节上也存在差异,不建议直接用在科研上. ...

  6. 深度强化学习主流算法介绍(二):DPG系列

    之前的文章可以看这里 深度强化学习主流算法介绍(一):DQN系列 相关论文在这里 开始介绍DPG之前,先回顾下DQN系列 DQN直接训练一个Q Network 去估计每个离散动作的Q值,使用时选择Q值 ...

  7. 深度强化学习探索算法最新综述,近200篇文献揭示挑战和未来方向

    ©作者 | 杨天培.汤宏垚等 来源 | 机器之心 强化学习是在与环境交互过程中不断学习的,⽽交互中获得的数据质量很⼤程度上决定了智能体能够学习到的策略的⽔平.因此,如何引导智能体探索成为强化学习领域研 ...

  8. 深度强化学习——actor-critic算法(4)

    一.本文概要: actor是策略网络,用来控制agent运动,你可以把他看作是运动员,critic是价值网络,用来给动作打分,你可以把critic看作是裁判,这节课的内容就是构造这两个神经网络,然后通 ...

  9. 【强化学习】Q-Learning原理及代码实现

    最近工作是在太忙了,无奈,也没空更新博客,职业上也从研发变成了产品,有小半年没写代码了,怕自己手生的不行,给自己两天时间,写了点东西,之前做搞机器学习,搞深度学习,但一直对依赖全场景数据喂模型的方向有 ...

最新文章

  1. linux内核 DebugFS
  2. java对象流定义_Java 对象流的用法,将自定义类数组写入文件中
  3. Nginx基本功能及其原理
  4. MySQL8.0.11的安装和Navicat连接mysql
  5. 列举计算机组装所需的各个硬件,计算机组装和维修期中考试.doc
  6. ChainOfResponsibilityPattern(23种设计模式之一)
  7. Hibernate 连接数据库,数据库返回数据超过限制报错
  8. Android基于mAppWidget实现手绘地图(九)–如何处理地图对象的touch事件
  9. 面试时,如何巧妙回答跳槽问题
  10. MATLAB寻找数据最大值
  11. 重装系统原来这么简单,最详细的win7安装教程
  12. 翟菜花:从美团配送新品牌发布,看即时配送行业奇点何时到来
  13. 关于python语言、下列说法不正确的是-模拟试卷C【单项选择题】
  14. 文字烫金效果html,ps如何制作烫金效果 PS制作logo烫黄金效果教程
  15. 如何设置行间距和字间距?
  16. 程序员大阳--所有教程、项目、源码导航
  17. Oracle11G完全卸载步骤
  18. 百度网盘搜索工具_2019
  19. 程序猿小白应该注意什么 1
  20. linux ioctl root权限,Linux系统调用设备的ioctl函数

热门文章

  1. VCSA6.7的磁盘扩容与备份、还原
  2. UITextView 打入中文时,输入拼音会调用 textViewDidChange: 回调的解决方案
  3. C语言中exit的简单用法
  4. 高新计算机考试题库安装,计算机信息高新技术考试模块题库教材配套一览表.doc...
  5. PT2272-M4--4键无线遥控器(STM32)
  6. 安装python38_RHEL8 安装 python3.6.9(191023更改,加入Python3.8安装)
  7. 利用网络现有资源 制作 swf动画
  8. 胎压监测系统TPMS
  9. favicon图标修改_7个方便,免费的图标和Favicon编辑器
  10. APP和小程序有什么区别?