深度学习stride_深度强化学习成名作——DQN
前言:其实很早之前就想开始写写深度强化学习(Deep reinforcement learning)了,但是一年前DQN没调出来,没好意思写哈哈,最近呢无意中把打砖块游戏Breakout训练到平均分接近40分,最高分随便上50(虽说也不算太好,但好歹也体现了DRL的优势),于是就写写吧~
提到深度强化学习的成名作,很多人可能会觉得是2016年轰动一时的AlphaGo,从大众来看是这样的,但真正让深度强化学习火起来并获得学术界蹭蹭往上涨关注度的,当属Deep Q-learning Network(DQN),最早见于2013年的论文《Playing Atari with Deep Reinforcement Learning》。
2012年,深度学习刚在ImageNet比赛大获全胜,紧接着DeepMind团队就想到把深度网络与强化学习结合起来,思想是基于强化学习领域很早就出现的值函数逼近(function approximation),但是通过深度神经网络这一神奇的工具,巧妙地解决了状态维数爆炸的问题!
怎么解决的呢?让我们走进DQN,一探究竟。
CNN实现Q(s, a)
如果我们以纯数学的角度来看动作值函数
就是胜利。于是用CNN完成这种映射的做法应运而生,先上一幅架构图:
DQN
通过gym模块输出Atari环境的游戏,状态空间都是(210, 160, 3),即210*160的图片大小,3个通道,在输入CNN之前需要通过图像处理二值化并缩小成84*84。由于如果将一张图片作为状态输入信息,很多隐藏信息就会忽略(比如球往哪边飞),于是论文中把连续的4帧图片作为状态输入。所以在pytorch中,CNN的输入就是
- 第一层卷积核8*8,stride=4,输出通道为32,ReLU
- 第二层卷积核4*4,stride=2,输出通道为64,ReLU
- 第三层卷积核3*3,stride=1,输出通道为64,ReLU
- 第三层输出经过flat之后维度为3136,然后第四层连接一个512大小的全连接层
- 第五层为动作空间大小的输出层,Breakout游戏中为4,表示每种动作的概率
搭建CNN的py文件在q_model.py中,并手动随机生成torch.randn(32, 4, 84, 84)向量用于测试网络架构的正确性:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import cv2
import gym
import matplotlib.pyplot as pltclass QNetwork(nn.Module):"""Actor (Policy) Model."""def __init__(self, state_size, action_size, seed):"""Initialize parameters and build model.Params======state_size (int): Dimension of each stateaction_size (int): Dimension of each actionseed (int): Random seed"""super(QNetwork, self).__init__()self.seed = torch.manual_seed(seed)"*** YOUR CODE HERE ***"self.conv = nn.Sequential(nn.Conv2d(state_size[1], 32, kernel_size=8, stride=4),nn.ReLU(),nn.Conv2d(32, 64, kernel_size=4, stride=2),nn.ReLU(),nn.Conv2d(64, 64, kernel_size=3, stride=1),nn.ReLU())self.fc = nn.Sequential(nn.Linear(64*7*7, 512),nn.ReLU(),nn.Linear(512, action_size))def forward(self, state):"""Build a network that maps state -> action values."""conv_out = self.conv(state).view(state.size()[0], -1)return self.fc(conv_out)def pre_process(observation):"""Process (210, 160, 3) picture into (1, 84, 84)"""x_t = cv2.cvtColor(cv2.resize(observation, (84, 84)), cv2.COLOR_BGR2GRAY)ret, x_t = cv2.threshold(x_t, 1, 255, cv2.THRESH_BINARY)return np.reshape(x_t, (1, 84, 84)), x_tdef stack_state(processed_obs):"""Four frames as a state"""return np.stack((processed_obs, processed_obs, processed_obs, processed_obs), axis=0)if __name__ == '__main__':env = gym.make('Breakout-v0')print('State shape: ', env.observation_space.shape)print('Number of actions: ', env.action_space.n)obs = env.reset()x_t, img = pre_process(obs)state = stack_state(img)print(np.shape(state[0]))# plt.imshow(img, cmap='gray')# 用cv2模块显示# cv2.imshow('Breakout', img)# cv2.waitKey(0)state = torch.randn(32, 4, 84, 84) # (batch_size, color_channel, img_height,img_width)state_size = state.size()cnn_model = QNetwork(state_size, action_size=4, seed=1)outputs = cnn_model(state)print(outputs)
完成了解决维数灾难的一步,接下来我们就应该考虑训练的稳定性和效率了,这也是深度学习领域常考虑的问题。而在DQN算法中,作者提出了两大技巧来解决,就是著名的replay buffer和target network,我们一一讨论。
Replay Buffer
replay这个词很形象,在英语中用于影视剧之类的回放;buffer这个词则是计算机里的术语;两个词合起来,形象地体会一下,就是把过去的数据从一个缓存中又拿出来用,这样一用,就比较好地解决了困扰Q-learning算法的样本效率以及相关性问题。从Q-learning的原始公式和算法流程来看,每一次更新Q值的样本都只能用一次,而且在连续获取游戏画面的情景下,状态样本存在极高的相关性。针对这两个问题,如果我们使用一个较大的buffer来储存这些样本,每次随机均匀采样,既能多次使用样本,还能打破样本之间的相关性。
在dqn_agent.py中以ReplayBuffer来实现:
class ReplayBuffer:"""Fixed-size buffer to store experience tuples."""def __init__(self, action_size, buffer_size, batch_size, seed):"""Initialize a ReplayBuffer object.Params======action_size (int): dimension of each actionbuffer_size (int): maximum size of bufferbatch_size (int): size of each training batchseed (int): random seed"""self.action_size = action_sizeself.memory = deque(maxlen=buffer_size) self.batch_size = batch_sizeself.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])self.seed = random.seed(seed)def add(self, state, action, reward, next_state, done):"""Add a new experience to memory."""e = self.experience(state, action, reward, next_state, done)self.memory.append(e)def sample(self):"""Randomly sample a batch of experiences from memory."""experiences = random.sample(self.memory, k=self.batch_size)states = torch.from_numpy(np.stack([e.state for e in experiences if e is not None])).float().to(device)actions = torch.from_numpy(np.vstack([e.action for e in experiences if e is not None])).long().to(device)rewards = torch.from_numpy(np.vstack([e.reward for e in experiences if e is not None])).float().to(device)next_states = torch.from_numpy(np.stack([e.next_state for e in experiences if e is not None])).float().to(device)dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to(device)return (states, actions, rewards, next_states, dones)def __len__(self):"""Return the current size of internal memory."""return len(self.memory)
Target Network
假设真实的动作值函数为
import numpy as np
import random
from collections import namedtuple, dequefrom q_model import QNetworkimport torch
import torch.nn.functional as F
import torch.optim as optimBUFFER_SIZE = int(1e6) # replay buffer size
BATCH_SIZE = 32 # minibatch size
GAMMA = 0.99 # discount factor
TAU = 1e-3 # for soft update of target parameters
LR = 1e-5 # learning rate
UPDATE_EVERY = 4 # how often to update the networkdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)class Agent():"""Interacts with and learns from the environment."""def __init__(self, state_size, action_size, seed):"""Initialize an Agent object.Params======state_size (int): dimension of each stateaction_size (int): dimension of each actionseed (int): random seed"""self.state_size = state_sizeself.action_size = action_sizeself.seed = random.seed(seed)# Q-Networkself.qnetwork_local = QNetwork(state_size, action_size, seed).to(device) # behavior networkself.qnetwork_target = QNetwork(state_size, action_size, seed).to(device) # target networkself.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=LR)# Replay memoryself.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed)# Initialize time step (for updating every UPDATE_EVERY steps)self.t_step = 0def step(self, state, action, reward, next_state, done):# Save experience in replay memoryself.memory.add(state, action, reward, next_state, done)# Learn every UPDATE_EVERY time steps.self.t_step = (self.t_step + 1) % UPDATE_EVERYif self.t_step == 0:# If enough samples are available in memory, get random subset and learnif len(self.memory) > BATCH_SIZE:experiences = self.memory.sample()self.learn(experiences, GAMMA)def act(self, state, eps=0.):"""Returns actions for given state as per current policy.Params======state (array_like): current stateeps (float): epsilon, for epsilon-greedy action selection"""state = torch.from_numpy(state).float().unsqueeze(0).to(device)self.qnetwork_local.eval()with torch.no_grad():action_values = self.qnetwork_local(state)self.qnetwork_local.train()# Epsilon-greedy action selectionif random.random() > eps:return np.argmax(action_values.cpu().data.numpy())else:return random.choice(np.arange(self.action_size))def learn(self, experiences, gamma):"""Update value parameters using given batch of experience tuples.Params======experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples gamma (float): discount factor"""states, actions, rewards, next_states, dones = experiences## TODO: compute and minimize the loss# Get max predicted Q values (for next states) from target modelQ_targets_next = self.qnetwork_target(next_states).detach().max(1)[0].unsqueeze(1)# Compute Q targets for current statesQ_targets = rewards + (gamma * Q_targets_next * (1 - dones))Q_expected = self.qnetwork_local(states).gather(1, actions) # 固定行号,确认列号# Compute lossloss = F.mse_loss(Q_expected, Q_targets)# Minimize the lossself.optimizer.zero_grad()loss.backward()self.optimizer.step()# ------------------- update target network ------------------- #self.soft_update(self.qnetwork_local, self.qnetwork_target, TAU) def soft_update(self, local_model, target_model, tau):"""Soft update model parameters.θ_target = τ*θ_local + (1 - τ)*θ_targetParams======local_model (PyTorch model): weights will be copied fromtarget_model (PyTorch model): weights will be copied totau (float): interpolation parameter """for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):target_param.data.copy_(tau*local_param.data + (1.0-tau)*target_param.data)
DQN算法流程
我们以打砖块游戏Breakout来测试整套算法流程,整个流程的代码在Deep_Q_network.py中,超参数设置在dqn_agent.py中,其中最为重要的超参数设置是replay buffer的大小、迭代次数、学习率和
import gym
import random
import torch
import numpy as np
from collections import deque
from dqn_agent import Agent
import matplotlib.pyplot as plt
import cv2
import timeenv = gym.make('Breakout-v0')
state_size = env.observation_space.shape
action_size = env.action_space.n
print('Original state shape: ', state_size)
print('Number of actions: ', env.action_space.n)agent = Agent((32, 4, 84, 84), action_size, seed=1) # state size (batch_size, 4 frames, img_height, img_width)
TRAIN = False # train or test flag def pre_process(observation):"""Process (210, 160, 3) picture into (1, 84, 84)"""x_t = cv2.cvtColor(cv2.resize(observation, (84, 84)), cv2.COLOR_BGR2GRAY)ret, x_t = cv2.threshold(x_t, 1, 255, cv2.THRESH_BINARY)return x_tdef init_state(processed_obs):return np.stack((processed_obs, processed_obs, processed_obs, processed_obs), axis=0)def dqn(n_episodes=30000, max_t=40000, eps_start=1.0, eps_end=0.01, eps_decay=0.9995):"""Deep Q-Learning.Params======n_episodes (int): maximum number of training episodesmax_t (int): maximum number of timesteps per episode, maximum frameseps_start (float): starting value of epsilon, for epsilon-greedy action selectioneps_end (float): minimum value of epsiloneps_decay (float): multiplicative factor (per episode) for decreasing epsilon"""scores = [] # list containing scores from each episodescores_window = deque(maxlen=100) # last 100 scoreseps = eps_start # initialize epsilonfor i_episode in range(1, n_episodes + 1):obs = env.reset()obs = pre_process(obs)state = init_state(obs)score = 0for t in range(max_t):action = agent.act(state, eps)next_state, reward, done, _ = env.step(action)# last three frames and current frame as the next statenext_state = np.stack((state[1], state[2], state[3], pre_process(next_state)), axis=0)agent.step(state, action, reward, next_state, done)state = next_statescore += rewardif done:breakscores_window.append(score) # save most recent scorescores.append(score) # save most recent scoreeps = max(eps_end, eps_decay * eps) # decrease epsilonprint('tEpsilon now : {:.2f}'.format(eps))print('rEpisode {}tAverage Score: {:.2f}'.format(i_episode, np.mean(scores_window)), end="")if i_episode % 1000 == 0:print('rEpisode {}tAverage Score: {:.2f}'.format(i_episode, np.mean(scores_window)))print('rEpisode {}tThe length of replay buffer now: {}'.format(i_episode, len(agent.memory)))if np.mean(scores_window) >= 50.0:print('nEnvironment solved in {:d} episodes!tAverage Score: {:.2f}'.format(i_episode - 100,np.mean(scores_window)))torch.save(agent.qnetwork_local.state_dict(), 'checkpoint/dqn_checkpoint_solved.pth')breaktorch.save(agent.qnetwork_local.state_dict(), 'checkpoint/dqn_checkpoint_8.pth')return scoresif __name__ == '__main__':if TRAIN:start_time = time.time()scores = dqn()print('COST: {} min'.format((time.time() - start_time)/60))print("Max score:", np.max(scores))# plot the scoresfig = plt.figure()ax = fig.add_subplot(111)plt.plot(np.arange(len(scores)), scores)plt.ylabel('Score')plt.xlabel('Episode #')plt.show()else:# load the weights from fileagent.qnetwork_local.load_state_dict(torch.load('checkpoint/dqn_checkpoint_8.pth'))rewards = []for i in range(10): # episodes, play ten timestotal_reward = 0obs = env.reset()obs = pre_process(obs)state = init_state(obs)for j in range(10000): # frames, in case stuck in one frameaction = agent.act(state)env.render()next_state, reward, done, _ = env.step(action)state = np.stack((state[1], state[2], state[3], pre_process(next_state)), axis=0)total_reward += reward# time.sleep(0.01)if done:rewards.append(total_reward)breakprint("Test rewards are:", *rewards)print("Average reward:", np.mean(rewards))env.close()
在Titan XP显卡运行,分布经过1万次和3万次迭代,统计回报曲线如下两幅图。据图中分析可知,回报曲线波动性较大,但整体趋势是在上升;在三万次游戏中,最高的一次接近350,但这一次结果十分良好的训练并不会对后面造成过多影响,这也是DQN饱受困扰的“灾难性遗忘”问题(catastrophic forgetting)。
最后的最后,上效果视频和GitHub仓库(代码是基于Udacity课程代码改的)!
自动打砖块https://www.zhihu.com/video/1178626355478061056zhengsizuo/DRL_udacitygithub.com
深度学习stride_深度强化学习成名作——DQN相关推荐
- 叶梓老师人工智能培训之强化学习与深度强化学习提纲(强化学习讲师培训)
强化学习与深度强化学习提纲(强化学习讲师培训) 第一天 强化学习 第一课 强化学习综述 1.强化学习要解决的问题 2.强化学习方法的分类 3.强化学习方法的发展趋势 4.环境搭建实验(Gym,Te ...
- AI内训讲师叶梓-强化学习与深度强化学习提纲(强化学习讲师培训)
叶梓老师更多教程资料可点击个人主业查看 第一天 强化学习 第一课 强化学习综述 1.强化学习要解决的问题 2.强化学习方法的分类 3.强化学习方法的发展趋势 4.环境搭建实验(Gym ...
- 机器学习深度学习加强学习_加强强化学习背后的科学
机器学习深度学习加强学习 机器学习 ,强化学习 (Machine Learning, Reinforcement Learning) You're getting bore stuck in lock ...
- 强化学习应用简述---强化学习方向优秀科学家李玉喜博士创作
强化学习 (reinforcement learning) 经过了几十年的研发,在一直稳定发展,最近取得了很多傲人的成果,后面会有越来越好的进展.强化学习广泛应用于科学.工程.艺术等领域. 下面简单列 ...
- 【强化学习知识】强化学习简介
文章目录 前言 1. Q learning 2. Sarsa 3. Deep Q Network(DQN) 4. 总结 前言 强化学习是机器学习中的一大类,它可以让机器学着如何在环境中拿到高分, 表现 ...
- 强化学习q学习求最值_Q学习简介:强化学习
强化学习q学习求最值 by ADL 通过ADL Q学习简介:强化学习 (An introduction to Q-Learning: reinforcement learning) This arti ...
- 《强化学习周刊》第44期:RL-CoSeg、图强化学习、安全强化学习
No.44 智源社区 强化学习组 强 化 学 习 研究 观点 资源 活动 周刊订阅 告诉大家一个好消息,<强化学习周刊>已经开启"订阅功能",以后我们会向您自动推送最 ...
- 学习笔记:强化学习与最优控制(Chapter 2)
Approximation in Value Space 学习笔记:强化学习与最优控制(Chapter 2) Approximation in Value Space 1. 综述 2. 基于Value ...
- 强化学习笔记-01强化学习介绍
本文是博主对<Reinforcement Learning- An introduction>的阅读笔记,不涉及内容的翻译,主要为个人的理解和思考. 1. 强化学习是什么?解决什么样的问题 ...
- 强化学习-什么是强化学习?白话文告诉你!
目录 1.强化学习简介 2.强化学习的概念: 3.马尔可夫决策过程 4.Bellman方程 5.Q-Learning基本原理实例讲解 1.强化学习简介 世石与AlphaGo的这场人机世纪巅峰对决,不但 ...
最新文章
- 各大型邮箱smtp服务器及端口收集:
- 数据链路层差错检测:CRC(循环冗余检验)
- 威纶触摸屏与电脑连接_PLC与这7种设备的连接方式,一看就懂!
- 计算机网络-信道复用技术
- 【Python】文件夹的常用操作
- linux如何查看磁盘剩余空间
- C++中序列化对象并存储到mysql
- Quartz.Net 学习随手记之03 配置文件
- .NET平台开源项目速览(6)FluentValidation验证组件介绍与入门(一)
- Ubuntu下远程访问MySQL数据库
- Java RMI(远程方法调用)入门
- 一条语句引发的思考:装箱和拆箱,空指针的类型转换
- Okhttp之CallServerInterceptor简单分析
- 阶段1 语言基础+高级_1-3-Java语言高级_06-File类与IO流_08 转换流_3_转换流的原理...
- linux 磁盘序列号修改,linux 获取硬盘序列号解决思路
- 【星门跳跃】解题报告
- 为知笔记数据备份方法
- 报错:SLF4J: Failed to load class “org.slf4j.impl.StaticLoggerBinder“.
- HTML练习—东风破
- 查询一列不同值的数据 mysql_怎样查询两个表中同一字段的不同数据值