Deep Reinforcement Learning超简单入门项目 Pytorch实现接水果游戏AI
学习过传统的监督和无监督学习方法后,我们现在已经可以自行开发机器学习系统来解决一些实际问题了。我们能实现一些事件的预测,一些模式的分类,还有数据的聚类等项目。但是这些好像和我们心目中的人工智能仍有差距,我们可能会认为,人工智能是能理解人类语言,模仿人类行为,并做到人类难以完成的工作的机器。所谓KNN、决策树分类器,好像只是代替人类进行一些简单的工作。
但今天,我们似乎在强化学习的领域找到了通往真正的人工智能的大门。经过强化学习训练的AI,似乎已经可以做到人类做不到的事情。chat bot可以生成逼真的语言,GAN可以进行艺术创作,甚至有些AI可以在星际争霸上打赢人类玩家。这些工作都与强化学习分不开关系。
我们之前实现的学习井字棋强化学习似乎已经能在完全没有训练资料的情况下学会下棋,甚至可以在不犯任何失误的时候找到人类的破绽。虽然井字棋这个游戏非常简单,对20000不到的状态空间,我们可以直接在博弈树中搜索出最好决策,但是这依然体现了强化学习的潜力。
强化学习与深度学习
在之前的学习过程中,我们学习了表格型的Q-Learning,在表格型的Q-Learning方法的学习过程中,我们逐渐会形成一张表格。在许多简单的问题中,这种表格型的Q-Learning方法是比较实用的,但是当我们所处理的问题具有较大的状态集合动作集时,这种表格型的方法就显得十分的低效了。此时我们需要一种新的模型方法来处理这种问题,所以出现了结合了神经网络的Q-Learning方法,Deep Q-Learning(DQN),通过在探索的过程中训练网络,最后所达到的目标就是将当前状态输入,得到的输出就是对应它的动作值函数,也即f(s)=Q(s,a)f(s)=Q(s,a)f(s)=Q(s,a) ,这个f就是训练的网络.
Deep Q learning
我们前面学习井字棋使用的学习方法是让机器对弈,产生一条情节(episode)链,然后从后向前遍历序列并更新Q值。每个状态的更新公式都决定于该步action的奖赏和后续一个状态的Q值。
Q(x,a)=(1−α)Q(x,a)+α(R(x′,a′)+γQ(x′,a′))Q(x,a) = (1-\alpha)Q(x,a)+\alpha (R(x',a') + \gamma Q(x',a')) Q(x,a)=(1−α)Q(x,a)+α(R(x′,a′)+γQ(x′,a′))
使用神经网络来做deep的end to end Q学习时,这个问题就不是直接修改表,而是让模型做一个回归。回归使用的误差函数是MSE均方误差。就是把上面的公式计算出来的新Q值当作回归目标,计算网络输出和它的均方误差,然后用梯度方法更新一下就好了。
想让DQL顺利跑起来,还需要一些其他的工程技巧。
首先,我们在使用deep network拟合函数时,我们都是假定数据是独立同分布的,并在一个有一定规模的数据集上运行梯度下降优化网络。但是如果我们使用之前的方法,每进行一个对弈就用这个情节链去训练一次。先不说想让网络适应这个情节需要几次迭代更新,就算我们用很多次更新去让网络适应了这个经验,神经网络的特性也常常会出现在样本数过少时的过拟合。而且下一次我们再用这个网络去采样时,就会出现完全不同分布的采样轨迹,这个问题的性质是无法保证收敛的。
为此,2013年最早的DQL论文提出的方法是用两个network,一个叫evaluate一个叫target。我们使用target网络去对弈多次,采样出多条探索轨迹形成数据集(这个探索次数大小自行调整),再用这个数据集在evaluate网络上训练。这样的数据集是独立同分布的,网络不容易过拟合,而且训练震荡发散可能性降低,更加稳定。训练得差不多之后,我们再把evaluate网络的参数拷贝给target网络,然后进入下一个epoch。
接水果
我们设计一个简单的接水果游戏来体现DQL与AI的工作逻辑,也许你觉得让AI学习井字棋太简单了,那么现在我们真正让AI来学习打电玩。
接水果游戏的游戏界面为10x8,最上方每8个时间单位在随机位置出现水果,每个时间单位后水果下降一格,每个时间单位AI或玩家可以控制最下方的3宽度的盘子向左或向右移动一格。
AI控制游戏的方式是通过向前看一步,评估当前时刻进行任何操作获得的预期奖励,并做出最好决策。为了避免AI在探索游戏世界的过程中只探索当前模型给出的最优路径,导致泛化性不足;我们也会设置一些比较小的噪声,按照指数分布去选择优先级低一些的操作。
接水果游戏实现
首先导入绘图库和numpy、torch等矩阵运算库和深度学习库。
import gym
import numpy as np
import random
from collections import deque
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
然后我们自己写一个environment类实现接水果游戏的逻辑。它和Gym中的env一样,有几个重要的接口。一个是reset初始化环境,另一个是step更新环境。
class FruitEnv:def __init__(self):self.reset()def reset(self):game = np.zeros((10,8))game[9,3:6] = 1.self.state = gameself.t = 0self.done = Falsereturn self.state[np.newaxis,:].copy()def step(self, action):reward = 0.game = self.stateif self.done:print('Call step after env is done')if self.t==200:self.done = Truereturn game[np.newaxis,:].copy(),10,self.done# 根据action移动盘子if action==0 and game[9][0]!=1:game[9][0:7] = game[9][1:8].copy()game[9][7] = 0elif action==1 and game[9][7]!=1:game[9][1:8] = game[9][0:7].copy()game[9][0] = 0# 判断果子落地还是落到盘子上if 1 in game[8]:fruit = np.where(game[8]==1)if game[9][fruit] != 1:reward = -1.self.done = Trueelse:reward = 1.game[8][fruit] = 0.game[1:9] = game[0:8].copy()game[0] = 0if self.t%8==0:idx = random.randint(a = 0, b = 7)game[0][idx] = 1.self.t += 1return game[np.newaxis,:].copy(),reward,self.done
模型
我们设计DQN的网络和Agent类,网络使用简单的卷积网络,为了不丢失边缘特征,使用两层3x3,1步长,1填充的卷积层。后接池化层和全连接层用于回归。事实上如果我们神经网络的输入可以不是一张图片,也可以是连续的n个时间单位的n张图片,这样的设计允许网络学习到时序的信息(比如在自动驾驶中,AI通过接收连续的几个帧就能分析出当前车辆的运行速度),但我们这个游戏是很符合MDP假设的游戏,一般来说只看一帧就可以做出决策,所以这里并没有使用这种输入。
Agent用于epsilon贪心决策的sample函数,以及一个按照Q-learning逻辑训练的learn函数。
class DQN(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(1, 8, 3, padding=1)self.conv2 = nn.Conv2d(8, 16, 3, padding=1)# 16*10*8self.maxpool = nn.MaxPool2d(2,2)# 16*5*4self.fc = nn.Sequential(nn.Linear(16*5*4, 32),nn.ReLU(),nn.Linear(32, 3),)def forward(self, x):x = F.relu(self.conv1(x))x = F.relu(self.conv2(x))x = self.maxpool(x)x = x.view(x.size(0), -1)x = self.fc(x)return xclass DQNAgent():def __init__(self, network, eps, gamma, lr):self.network = networkself.eps = epsself.gamma = gammaself.optimizer = optim.Adam(self.network.parameters(), lr=lr)def learn(self, batch):s0, a0, r1, s1, done = zip(*batch)n = len(s0)s0 = torch.FloatTensor(s0)s1 = torch.FloatTensor(s1)r1 = torch.FloatTensor(r1)a0 = torch.LongTensor(a0)done = torch.BoolTensor(done)increment = self.gamma * torch.max(self.network(s1).detach(), dim=1)[0]y_true = r1+incrementy_pred = self.network(s0)[range(n),a0]loss = F.mse_loss(y_pred, y_true)self.optimizer.zero_grad()loss.backward()self.optimizer.step()return loss.item()def sample(self, state):'''epsilon探索选择下一个action'''state = state[np.newaxis,:]action_value = self.network(torch.FloatTensor(state))if random.random()<self.eps:return random.randint(0, 2)else:max_action = torch.argmax(action_value,dim=1)return max_action.item()
然后我们就可以不断采样让模型自我训练,直到它学会怎样玩游戏。如果AI已经学会了玩游戏,他将会达到最大时间限制(200)
class Net(nn.Module):def __init__(self):super(Net, self).__init__()# 输入:nx1x10x8张量self.conv1 = nn.Conv2d(1, 8, 3, padding=1)self.conv2 = nn.Conv2d(8, 16, 3, padding=1)# nx16x10x8张量self.maxpool = nn.MaxPool2d(2,2)# nx16x5x4张量self.fc = nn.Sequential(nn.Linear(16*5*4, 32),nn.ReLU(),nn.Linear(32, 1),)def forward(self, x):x = F.relu(self.conv1(x))x = F.relu(self.conv2(x))x = self.maxpool(x)x = x.view(x.size(0), -1)x = self.fc(x)return x
训练
模型训练使用随机梯度下降,损失函数使用均方误差。均方误差很容易理解,因为我们的问题本质上是回归问题,MSE是很好的回归用损失函数。
我们使用上面的方法,训练500个episode,每个episode我们让模型玩一局游戏(进行完整的一轮采样)
gamma = 0.9
eps_high = 0.9
eps_low = 0.1
num_episodes = 500
LR = 0.001
batch_size = 5
decay = 200
net = DQN()
agent = DQNAgent(net,1,gamma,LR)replay_lab = deque(maxlen=5000)state = env.reset()for episode in range(num_episodes):agent.eps = eps_low + (eps_high-eps_low) *\(np.exp(-1.0 * episode/decay))s0 = env.reset()while True:a0 = agent.sample(s0)s1, r1, done = env.step(a0)replay_lab.append((s0.copy(),a0,r1,s1.copy(),done))if done:breaks0 = s1if replay_lab.__len__()>=batch_size:batch = random.sample(replay_lab,k=batch_size)loss = agent.learn(batch)if (episode+1)%50==0:print("Episode: %d, loss: %.3f"%(episode+1,loss))score = evaluate(agent,env,10)print("Score: %.1f"%(score))Episode: 50, loss: 0.009
Score: -0.8
Episode: 100, loss: 0.024
Score: -0.1
Episode: 150, loss: 0.042
Score: -0.1
Episode: 200, loss: 0.170
Score: 34.0
Episode: 250, loss: 0.011
Score: 34.0
Episode: 300, loss: 0.011
Score: 34.0
Episode: 350, loss: 0.005
Score: 34.0
Episode: 400, loss: 0.071
Score: 6.1
Episode: 450, loss: 0.027
Score: 34.0
Episode: 500, loss: 0.005
Score: 34.0
训练到收敛后,执行下面的代码就可以看到AI成功地在玩接水果游戏玩了200帧。
from matplotlib.colors import ListedColormapcmap_light = ListedColormap(['white','red'])
from IPython import displayagent.eps = 0
s = env.reset()
while True:a = agent.sample(s)s, r, done = env.step(a)if done:breakimg = s.squeeze()plt.imshow(img, cmap=cmap_light)plt.show()display.clear_output(wait=True)
总结
这个微型项目的初衷是简单的入门Deep Reinforcement learning,因为网络上的强化学习入门项目无不是对硬件要求很高,需要训练个几百万帧才能收敛的那种游戏。更不用说每一帧都是100x100以上的大小,没有硬件支持可能很难快速看到结果。这里设计了超轻量级的游戏,游戏逻辑简单,每一帧也只是10x8的大小。
尽管如此,cpu上要把这个游戏玩到收敛也需要几百次的采样,约10k帧的训练,花费2分钟以上。所以,在深度学习时代算力还是最重要的资源。
Deep Reinforcement Learning超简单入门项目 Pytorch实现接水果游戏AI相关推荐
- Deep Reinforcement Learning: Pong from Pixels翻译和简单理解
原文链接: http://karpathy.github.io/2016/05/31/rl/ 文章目录 原文链接: 前言 Policy-Gradient结构流程图 Deep Reinforcement ...
- 深度强化学习:入门(Deep Reinforcement Learning: Scratching the surface)
原文链接:https://blog.csdn.net/qq_32690999/article/details/78594220 本博客是对学习李宏毅教授在youtube上传的课程视频<Deep ...
- 深度强化学习—— 译 Deep Reinforcement Learning(part 0: 目录、简介、背景)
深度强化学习--概述 翻译说明 综述 1 简介 2 背景 2.1 人工智能 2.2 机器学习 2.3 深度学习 2.4 强化学习 2.4.1 Problem Setup 2.4.2 值函数 2.4.3 ...
- 基于深度强化学习的车道线检测和定位(Deep reinforcement learning based lane detection and localization) 论文解读+代码复现
之前读过这篇论文,导师说要复现,这里记录一下.废话不多说,再重读一下论文. 注:非一字一句翻译.个人理解,一定偏颇. 基于深度强化学习的车道检测和定位 官方源码下载:https://github.co ...
- 基于vue-cli、elementUI的Vue超简单入门小例子
基于vue-cli.elementUI的Vue超简单入门小例子 这个例子还是比较简单的,独立完成后,能大概知道vue是干嘛的,可以写个todoList的小例子. 开始写例子之前,先对环境的部署做点简单 ...
- 利用Deep Reinforcement Learning训练王者荣耀超强AI
Mastering Complex Control in MOBA Games with Deep Reinforcement Learning (一)知识背景 (二)系统架构 (三)算法结构 3.1 ...
- 【强化学习】Playing Atari with Deep Reinforcement Learning (2013)
Playing Atari with Deep Reinforcement Learning (2013) 这篇文章提出了第一个可以直接用强化学习成功学习控制policies的深度学习模型. 输入是r ...
- 【论文翻译】Playing Atari with Deep Reinforcement Learning
摘要:我们第一个提出了"利用强化学习从高维输入中直接学习控制策略"的深度学习模型.该模型是一个卷积神经网络,经过Q-learning训练,输入为原始像素,输出为:"用来估 ...
- 论文笔记(十六):Learning to Walk in Minutes Using Massively Parallel Deep Reinforcement Learning
Learning to Walk in Minutes Using Massively Parallel Deep Reinforcement Learning 文章概括 摘要 1 介绍 2 大规模并 ...
最新文章
- Entity Framework 4.3 中的新特性
- boostrap-table export 导出监听
- Knative Serving 进阶: Knative Serving SDK
- 日常生活小技巧 -- 文件对比工具 Beyond Compare
- springboot+shiro框架中上传到服务器的图片不能查看,访问404
- python爬去百度搜索结果_python实现提取百度搜索结果的方法
- 用vue实现模态框组件
- Mac系统下Homebrew的安装和使用Homebrew安装python
- 前端学习(77):css中常见margin塌陷问题之解决办法
- PHP实现一个轻量级容器
- 数据分析用这样的可视化报表,秒杀Excel,再也不怕被说low
- SDOI2010 代码拍卖会
- 怎样用sql语句复制表table1到表table2的同时复制主键
- 39-java 输入输出总结
- 计算机应用专业毕业感言,大学毕业感言一句话
- 商汤科技推出SenseCore AI大装置,打造物理世界的搜索引擎
- UIAutomatorViewer初体验
- 自定义联系人快速索引栏
- 张氏矢量化骨骼化细化算法
- Java实现拼图小游戏(7)—— 计步功能及菜单业务的实现