理论简介

Double Deep Q-Learning Netwok (DQN),基础理论来自于这篇论文。基础理论部分,参考这篇笔记和这篇笔记。下面给出最核心的强化学习公式:
YtDoubleQ=Rt+1+γQ^(St+1,argmaxaQ(St+1,a))Y_{t}^{DoubleQ} = R_{t+1}+\gamma \hat{Q}\left(S_{t+1},\mathop{argmax}_{a}Q\left(S_{t+1},a\right)\right) YtDoubleQ​=Rt+1​+γQ^​(St+1​,argmaxa​Q(St+1​,a))
算法利用了两个结构相同,但是参数不同的神经网络

首先是QQQ网络,这就是DQN中的QQQ网络,是为了用来训练的神经网络。Q^\hat{Q}Q^​网络与QQQ的架构相同,只不过参数是某几步之前的,这是为了计算评估分数使用的。上面的公式含义如下:

  • 利用QQQ选择出St+1S_{t+1}St+1​状态下,分数最大的一步的行动索引
  • 利用Q^\hat{Q}Q^​评估这一步的分数
  • 把这一步的分数,与之前QQQ选择的行动的作比较,注意可能不是同一个行动,然后进行误差反向传播

代码实现

代码基础框架来自于这篇博客。

Agent.py强化学习

import tensorflow as tf
from tensorflow import keras
from collections import deque
import numpy as np
import randomMAX_LEN = 10000
BATCH_SIZE = 64
GAMMA = 0.95
EXPLORATION_DECAY = 0.995
EXPLORATION_MIN = 0.1class Agent(object):def __init__(self, input_space, output_space, lr=0.001, exploration=0.9, update_model_step=10):self._model = keras.Sequential()self._model.add(keras.layers.Dense(input_shape=(input_space,), units=24, activation=tf.nn.relu))self._model.add(keras.layers.Dense(units=24, activation=tf.nn.relu))self._model.add(keras.layers.Dense(units=output_space, activation='linear'))self._model.compile(loss='mse', optimizer=keras.optimizers.Adam(lr))self._replayBuffer = deque(maxlen=MAX_LEN)self._exploration = explorationself._target_model = keras.models.clone_model(self._model)self._update_model_step = update_model_step  # 更新模型需要的最少步数self._cur_step = 0  # 当前使用模型计算的次数def update_target_model(self):self._target_model.set_weights(self._model.get_weights())@propertydef exploration(self):return self._explorationdef add_data(self, state, action, reward, state_next, done):self._replayBuffer.append((state, action, reward, state_next, done))def act(self, state):if np.random.uniform() <= self._exploration:return np.random.randint(0, 2)action = self._model.predict(state)return np.argmax(action[0])def train_from_buffer(self):if len(self._replayBuffer) < BATCH_SIZE:returnbatch = random.sample(self._replayBuffer, BATCH_SIZE)for state, action, reward, state_next, done in batch:new_action = np.argmax(self._model.predict(state_next)[0])q_update = rewardif not done:# 这是DDQN公式q_update = reward + GAMMA * self._target_model.predict(state_next)[0][new_action]# q_update += GAMMA * np.amax(self._model.predict(state_next)[0])  # 注释掉的是DQNq_values = self._model.predict(state)q_values[0][action] = q_updateself._model.fit(state, q_values, verbose=0)self._exploration *= EXPLORATION_DECAYself._exploration = max(EXPLORATION_MIN, self._exploration)

train.py训练

import gym
from Agent import Agent
import numpy as np
import matplotlib.pyplot as pltdef train():env = gym.make("CartPole-v1")input_space = env.observation_space.shape[0]output_space = env.action_space.nprint(input_space, output_space)agent = Agent(input_space, output_space)run = 0x = []y = []while run < 100:run += 1state = env.reset()state = np.reshape(state, [1, -1])step = 0while True:step += 1# env.render()action = agent.act(state)state_next, reward, done, _ = env.step(action)reward = reward if not done else -rewardstate_next = np.reshape(state_next, [1, -1])agent.add_data(state, action, reward, state_next, done)state = state_nextif done:print("Run: " + str(run) + ", exploration: " +str(agent.exploration) + ", score:" + str(step))# 这里是每个episode更新一次,也可以根据实际调整agent.update_target_model()  x.append(run)y.append(step)breakagent.train_from_buffer()plt.plot(x, y)plt.show()if __name__ == "__main__":train()

训练结果

学习率是0.001,100个批次的训练,batch-size是64。每个批次执行一次参数更新。波动属于正常现象,基本能在15步之后取得较大的优势。

Double Deep Q-Learning Netwok的理解与实现相关推荐

  1. Deep Q learning: DQN及其改进

    Deep Q Learning Generalization Deep Reinforcement Learning 使用深度神经网络来表示 价值函数 策略 模型 使用随机梯度下降(SGD)优化los ...

  2. CNNs and Deep Q Learning

    前面的一篇博文介绍了函数价值近似,是以简单的线性函数来做的,这篇博文介绍使用深度神经网络来做函数近似,也就是Deep RL.这篇博文前半部分介绍DNN.CNN,熟悉这些的读者可以跳过,直接看后半部分的 ...

  3. Deep Q Learning伪代码分析及翻译

    伪代码 代码翻译及分析 初始化记忆体D中的记忆N 初始化随机权重θaction值的函数Q(Q估计) 初始化权重θ-=θ target-action值的函数^Q(Q现实) 循环:初始化第一个场景s1=x ...

  4. 零基础10分钟运行DQN图文教程 Playing Flappy Bird Using Deep Reinforcement Learning (Based on Deep Q Learning DQN

    文件下载 链接:http://pan.baidu.com/s/1jH9ItTW  密码:0pmq 文件列表 Anaconda3-4.2.0-Windows-x86_64.exe  (python3.5 ...

  5. Deep Reinforcement Learning: Pong from Pixels翻译和简单理解

    原文链接: http://karpathy.github.io/2016/05/31/rl/ 文章目录 原文链接: 前言 Policy-Gradient结构流程图 Deep Reinforcement ...

  6. 【阅读笔记】Falsification of Cyber-Physical Systems Using Deep Reinforcement Learning

    FM2018 Falsification of Cyber-Physical Systems Using Deep Reinforcement Learning (International Symp ...

  7. 论文记载: Deep Reinforcement Learning for Traffic LightControl in Vehicular Networks

    强化学习论文记载 论文名: Deep Reinforcement Learning for Traffic LightControl in Vehicular Networks ( 车辆网络交通信号灯 ...

  8. Deep Reinforcement Learning: Pong from Pixels

    这是一篇迟来很久的关于增强学习(Reinforcement Learning, RL)博文.增强学习最近非常火!你一定有所了解,现在的计算机能不但能够被全自动地训练去玩儿ATARI(译注:一种游戏机) ...

  9. Deep Reinforcement Learning超简单入门项目 Pytorch实现接水果游戏AI

    学习过传统的监督和无监督学习方法后,我们现在已经可以自行开发机器学习系统来解决一些实际问题了.我们能实现一些事件的预测,一些模式的分类,还有数据的聚类等项目.但是这些好像和我们心目中的人工智能仍有差距 ...

  10. 深度强化学习 Deep Reinforcement Learning 学习整理

    这学期的一门机器学习课程中突发奇想,既然卷积神经网络可以识别一副图片,解决分类问题,那如果用神经网络去控制'自动驾驶',在一个虚拟的环境中不停的给网络输入车周围环境的图片,让它去选择前后左右中的一个操 ...

最新文章

  1. Git命令比较两个分支commit 差异
  2. SQLite基本操作
  3. c语言i o编程,C 语言输入输出 (I/O)
  4. CSS学习17之动画
  5. 算法竞赛入门经典(第二版) | 例题5-1 大理石在哪 (普适查找)(UVa10474,Where is the Marble?)
  6. python3.6入门到高阶(全栈) day015 初识面向对象
  7. ie6 7下 relative absolute无法冲破的等级问题解决办法
  8. php header 文件大小,php获取远程文件大小及信息的函数(head_php
  9. 机器学习实战(Machine Learning in Action)学习笔记————06.k-均值聚类算法(kMeans)学习笔记...
  10. 原生的强大DOM选择器querySelector - querySelector和querySelectorAll
  11. php设置http请求头信息和响应头信息
  12. opencv进行5种图像变化:
  13. .NET Hacks Tips
  14. CCF NOI1142 质数
  15. 【三维路径规划】基于matlab A_star算法无人机三维路径规划【含Matlab源码 003期】
  16. 平面设计中的网格系统pdf_一本好书 | 排版圣经:设计中的网格系统
  17. 终身教职让美国研究型大学称霸世界,却把中国「青椒」卷怕了!
  18. Web服务器群集——LVS-DR+Keepalived高可用集群
  19. linux文件分隔符
  20. Python写接口api

热门文章

  1. Python从list删除元素
  2. Day1:360培训学习重点笔记(7.13)
  3. 【实用】Putty常见错误汇总
  4. Python3.x中数据随机重排基本方法
  5. Python3.x的print()输出问题
  6. 常用类 (二) ----- Math类
  7. 敏捷开发“松结对编程”系列之七:问题集之一
  8. 【编程珠玑】第五章 编程小事
  9. 敏捷开发免费管理工具——火星人预览之七:自定义字段
  10. 信息传递(luogu 2661)