强化学习DQN算法实战之CartPole
简介
这篇笔记主要是记录了Deep Q-Learning Network的开发过程。开发环境是:Ubuntu18.04 、tensorflow-gpu 1.13.1 和 OpenAI gym
其中,这篇笔记记录了深度学习的开发环境。安装完成后,在虚拟环境执行pip install gym
安装界面环境。
强化学习的一个困难的地方,在于数据收集和环境描述。而 OpenAI的gym给我们提供了一个非常强大的虚拟环境,这样我们就可以专注于算法本身的开发了。
这篇笔记主要参考了:
- 论文: https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf
- 博客:https://towardsdatascience.com/cartpole-introduction-to-reinforcement-learning-ed0eb5b58288
环境描述
基本环境可以参考:https://gym.openai.com/envs/CartPole-v1/
学习的目标是使得木棍在小车上树立的时间尽量长。action的选择只有向左或者是向右。环境会自动给出给出反馈,每一步后的得分,下一个局面的描述的状态,是否是结束。环境状态被gym自动封装成一个np.array
,可以通过有关的API获取信息。
在这个例子中,环境的描述是一个4维的向量,我们不必管这4维向量的意义,只需要知道有这个描述即可(当然,如果你感兴趣,可以深究)。每个环境,gym都封装了一分数reward。而且,如果是结束状态,gym会给出描述符。这些在下面的代码中会有说明。
算法介绍和说明
先给出基本算法描述,算法来自上面的参考连接:
这是一个最基本的Off-Policy借助Replay-Buffer和神经网络实现的算法。上面的ϕ\phiϕ,是表示一个连贯的输入,因为上述的算法是输入了一系列的图片。不过在这个例子中,可以把ϕ\phiϕ理解成仅仅输入当前的局面,即st=ϕ(st)s_t=\phi(s_t)st=ϕ(st)。之后会有exploration
的操作,这是为了随机的选取那些评估分数比较低,但是可能会有较好表现的行动。Q(s,a)Q(s,a)Q(s,a)表示一个Q-function,它的作用是给状态sss下的每个行动aaa一个评估分数。实际操作中,QQQ是一个神经网络,每个状态作为神经网络的输入,神经网络的输出是所有的行动aaa的评估分数。
算法给出了yiy_iyi的计算法则。对神经网络进行BP的时候,就根据这个公式来即可。每次从buffer中选取一个批次的数据,执行随机梯度下降SGD算法,即可进行修正。
建议仔细阅读原文。
一个说明点:在Deep-Q-Learning Nerwork (DQN)中,最后的输出层的激活函数都是线性的,而且损失函数是Mean Square Error (MSE)。即下面代码中提到的linear
,原因参考自这篇博客:
因为DQN需要计算的是采取某个行动后的评估值,所以最好是线性输出;而且使用MSE可以可以描述两个数据之间的差距,所以MSE是最佳的选择。这就处理一个回归问题。而对于分类问题,我们需要的是处理非线性的,因为最终的结果需要把数据进行概率的划分,一般使用softmax等的分类函数处理。对于DQN的内部隐藏层,一般采取Relu函数进行非线性映射,以增加函数可以表示的状态空间。
还有一个说明的地方,在损失函数中,我们实际需要反向传播的,只有QDN选择的那个动作与评估之间的误差,而其他动作的误差不需要传播。
先给出MSE公式:
MSE=1N∑i=1N(Yi−Y^i)2MSE = \frac{1}{N}\sum_{i=1}^N \left(Y_i-\hat{Y}_i\right)^2 MSE=N1i=1∑N(Yi−Y^i)2
举个例子:Q(s,a)Q(s,a)Q(s,a)输出[1.0, 0.8],记为q_values
。那么,这个表示动作0的分数是1.0,动作1的分数是0.8。这样来看,agent肯定会采取动作0。假设现在经过计算的该动作的分数是0.7,那么误差的绝对值是0.3,这是需要进行反向传播的。那么,问题是怎样进行反向传播呢?采取的技巧是这样的:新给出一个向量q_update
,等于Q(s,a)Q(s,a)Q(s,a)的输出[1.0, 0.8],因为我们之前存储了action,即知道是采取行动0,那么更新q_update
参数为[0.7, 0.8]。
因为反向传递使用MSE,计算方式是:
Err=12[(0.7−1.0)2+(0.8−0.8)2]=12(0.7−1.0)2Err= \frac{1}{2}\left[\left(0.7-1.0\right)^2+\left(0.8-0.8\right)^2\right]=\frac{1}{2}(0.7-1.0)^2 Err=21[(0.7−1.0)2+(0.8−0.8)2]=21(0.7−1.0)2
这个方式很巧妙,只标出了具体行动的误差,其他方式在MSE计算中,都成为0了。可以参考代码好好理解。
代码实例
代码借助深度学习框架tensorflow进行实现。在1.13.1以及将来要发布的版本中,继承了keras接口。个人的看法是,尽量使用高阶API,开发效率高,不易出错,代码简单易懂。
这个例子中,使用了一个全连接神经网络。输入层是4维向量,表示当前状态。两个隐藏层,都包含24个神经元。输出层是2维的,表示向左或者向右采取行动。
强化学习算法Agent.py介绍
import tensorflow as tf
from tensorflow import keras
from collections import deque
import numpy as np
import randomMAX_LEN = 2000
BATCH_SIZE = 64
GAMMA = 0.95
EXPLORATION_DECAY = 0.995
EXPLORATION_MIN = 0.1class Agent(object):def __init__(self, input_space, output_space, lr=0.001, exploration=0.9):self._model = keras.Sequential()self._model.add(keras.layers.Dense(input_shape=(input_space,), units=24, activation=tf.nn.relu))self._model.add(keras.layers.Dense(units=24, activation=tf.nn.relu))# 注意这里输出层的激活函数是线性的!!!self._model.add(keras.layers.Dense(units=output_space, activation='linear'))self._model.compile(loss='mse', optimizer=keras.optimizers.Adam(lr))self._replayBuffer = deque(maxlen=MAX_LEN) # replay buffer,最大200的容量self._exploration = exploration@propertydef exploration(self):return self._explorationdef add_data(self, state, action, reward, state_next, done):self._replayBuffer.append((state, action, reward, state_next, done))def act(self, state):if np.random.uniform() <= self._exploration: # 随机走出一步return np.random.randint(0, 2)action = self._model.predict(state) # 使用神经网络评估的选择return np.argmax(action[0])def train_from_buffer(self):if len(self._replayBuffer) < BATCH_SIZE:returnbatch = random.sample(self._replayBuffer, BATCH_SIZE) # 随机选取一个批次的数据for state, action, reward, state_next, done in batch:if done: # 对应论文中的分数更新q_update = rewardelse:q_update = reward + GAMMA * np.amax(self._model.predict(state_next)[0])q_values = self._model.predict(state) # 先赋值,为了减去不相关的行动得分q_values[0][action] = q_update # 把采取了的行动的分数更新,那么只有这项在MSE中有效果self._model.fit(state, q_values, verbose=0) # SGD训练模型self._exploration *= EXPLORATION_DECAYself._exploration = max(EXPLORATION_MIN, self._exploration)
train.py训练
import gym
from Agent import Agent
import numpy as np
import matplotlib.pyplot as pltdef train():env = gym.make("CartPole-v1")input_space = env.observation_space.shape[0]output_space = env.action_space.nprint(input_space, output_space)agent = Agent(input_space, output_space)run = 0x = []y = []while run < 100:run += 1state = env.reset()state = np.reshape(state, [1, -1])step = 0while True:step += 1 # 步数越多,相当于站立的时间越长,比较容易理解。# env.render()action = agent.act(state)state_next, reward, done, _ = env.step(action)reward = reward if not done else -reward # 棍子倒了,分数肯定是负数了state_next = np.reshape(state_next, [1, -1])agent.add_data(state, action, reward, state_next, done)state = state_nextif done:print("Run: " + str(run) + ", exploration: " +str(agent.exploration) + ", score:" + str(step))x.append(run)y.append(step)breakagent.train_from_buffer() # 每次都要执行训练plt.plot(x, y)plt.show()if __name__ == "__main__":train()
训练结果
首先声明一个地方,强化学习不同于监督学习,曲线是下降的。RL的曲线会波动的很厉害,不过如果模型好的话,大体上会是上升的。这里训练,我测试了多种激活函数。每个进行100局的测试。硬件条件限制,结果肯定是不太准确的,只能看一下大体趋势。
初始情况 batch-size到32,buffer-size到10000。然后把学习速率设置到0.001,看一下效果:
很明显的上升趋势,出现波动属于正常现象,而且20步之后平均分数能到150/。。所以是之前的学习速率设定的太高了。。。。。
在上述的基础上,使用batch-size=64,在此测试结果,得到下面的图像:
平均值更好一些,而且收敛速度相对较快
强化学习DQN算法实战之CartPole相关推荐
- 用强化学习DQN算法玩合成大西瓜游戏!(提供Keras版本和Paddlepaddle版本)
本文禁止转载,违者必究! 用强化学习玩合成大西瓜 代码地址:https://github.com/Sharpiless/play-daxigua-using-Reinforcement-Learnin ...
- 深度强化学习——DQN算法原理
DQN算法原理 一.DQN算法是什么 二.DQN训练过程 三.经验回放 (Experience Replay) 四.目标网络(Target Network) 1.自举(Bootstrapping) 2 ...
- 深度强化学习-DQN算法原理与代码
DQN算法是DeepMind团队提出的一种深度强化学习算法,在许多电动游戏中达到人类玩家甚至超越人类玩家的水准,本文就带领大家了解一下这个算法,论文和代码的链接见下方. 论文:Human-level ...
- 强化学习—— TD算法(Sarsa算法+Q-learning算法)
强化学习-- TD算法(Sarsa算法+Q-learning算法) 1. Sarsa算法 1.1 TD Target 1.2 表格形式的Sarsa算法 1.3 神经网络形式的Sarsa算法 2. Q- ...
- 深度强化学习-DDPG算法原理和实现
全文共3077个字,8张图,预计阅读时间15分钟. 基于值的强化学习算法的基本思想是根据当前的状态,计算采取每个动作的价值,然后根据价值贪心的选择动作.如果我们省略中间的步骤,即直接根据当前的状态来选 ...
- 强化学习常用算法总结
强化学习常用算法总结 本文为2020年6月参加的百度PaddlePaddle强化学习训练营总结 1. 表格型方法:Sarsa和Q-Learning算法 State-action-reward-stat ...
- 深度强化学习主流算法介绍(二):DPG系列
之前的文章可以看这里 深度强化学习主流算法介绍(一):DQN系列 相关论文在这里 开始介绍DPG之前,先回顾下DQN系列 DQN直接训练一个Q Network 去估计每个离散动作的Q值,使用时选择Q值 ...
- 强化学习DQN(Deep Q-Learning)、DDQN(Double DQN)
强化学习DQN(Deep Q-Learning).DDQN(Double DQN) _学习记录-有错误感谢指出 Deep Q-Learning 的主要目的在于最小化以下目标函数: J ( ω ) = ...
- 强化学习经典算法笔记(十四):双延迟深度确定性策略梯度算法TD3的PyTorch实现
强化学习经典算法笔记(十四):双延迟深度确定性策略梯度算法TD3的PyTorch实现 TD3算法简介 TD3是Twin Delayed Deep Deterministic policy gradie ...
最新文章
- R语言dplyr包对数据进行超前或者之后处理(lead、lag)实战
- 维塔与 Magic Leap 的MR游戏发布概念片
- ASP.NET MVC2+MSSQL+Godaddy
- vs2008【断点无效】解决方法
- 【GAN的应用】基于对抗学习的图像美学增强方法
- 修改左侧导航显示样式(转载自Sunmoonfire's artistic matrix)
- 551. Student Attendance Record I 从字符串判断学生考勤
- 初一数学计算机教案,初一数学教案
- 计算机专业使用的工具,电子投标工具使用手册计算机软件及应用it计算机专业资料.doc...
- 永磁同步电机转子磁链_永磁同步电机转子初始位置检测、增量式光电编码器对位调零思路解析...
- tensorflow 张量
- C#万年历dll插件
- ET7.0+HybridCLR(huatuo)热更教程
- 计算机的超级登录用户名和密码,登录到windows用户名和密码
- Python提取视频帧图片
- 主播必备超萌代打猫咪,超人气全键盘版资源下载~
- 微信公众号自定义菜单如何添加emoji表情图标?
- Java之CompletableFuture异步、组合计算基本用法
- arduino dht11 传感器实现
- java用wasd_涨知识:游戏默认WASD原来是这么来的
热门文章
- 剑指 Offer 43. 1~n 整数中 1 出现的次数
- pytorch学习笔记(二十五):VGG
- 《南溪的目标检测学习笔记》——目标检测的评价指标(mAP)
- 【干货】如何删除“自豪地采用WordPress“
- 元素增删事件DOMNodeInserted和DOMNodeRemoved
- hihocoder第196周
- LeetCode 415. 字符串相加 (逢十进一模版字符处理)
- 敏捷开发生态系统系列之一:序言及需求管理生态(客户价值导向-可工作软件-响应变化)...
- 敏捷开发般若敏捷系列之六:如何推广敏捷(下)(以无我之心,行无住之法)...
- 【编程珠玑】内联函数和宏