Pytorch落地实践

  • 2:pytorch深度强化学习落地:以打乒乓小游戏为例
  • 一、需求分析
  • 二、动作空间设计
  • 三、状态空间设计
  • 四、回报函数设计
  • 五、算法选择
  • 六、训练调试
  • 总结

2:pytorch深度强化学习落地:以打乒乓小游戏为例

前几天买的新书《深度强化学习落地指南》,今天收到了,迫不及待的阅读起来。作者将深度强化学习应用落地分为七步,1、需求分析;2、动作空间设计;3、状态空间设计;4、汇报函数设计;5、算法选择;6、训练调试;7、性能冲刺。读罢,深受启发,总想找个例子实践一下。

思来想去,记得小时候在掌机上玩过一个打乒乓球的小游戏:从上面以一个角度落下来一个小球,遇到屏幕边界后会反弹,底部有一个由方格组成的球拍,玩家移动球拍,接住球得分,接不住则游戏结束。于是,我就想,能不能设计一个神经网络,通过训练使这个智能体学会玩这个游戏?查了一下,网上的素材,尤其是算法方面的素材非常多,典型的如:
Deep Reinforcement Learning超简单入门项目 Pytorch实现接水果游戏AI
实现起来难度不大,于是决定按照《深度强化学习落地指南》这本书的指导,一步一步的实践利用pytorch实现深度强化学习的过程。

一、需求分析

一问是不是:打乒乓的小游戏,是一个典型的单智能体和环境交互的强化学习问题。乒乓球的状态随时间按一定规律运动,玩家每个时间间隔可以做出一个动作,每个动作有左、右两个选择。二问值不值:这个一个运算量要求不大的小例子,以学习和实践为主,主要是体验强化学习、深度学习落地的过程。三问能不能:场景固定、数据廉价。四问边界在哪里:这样一个简单的问题,不存在模块划分的问题,主要决策也只有控制动作这一个问题。

二、动作空间设计

动作空间在事实上决定了任何算法所能达到的性能上限:对于打乒乓这个小游戏而言,动作就两个,移动底下的球拍,朝左运动或者朝右运动。这是一个离散的取值,分别定义为0和1。

三、状态空间设计

状态信息代表了Agent所感知的环境信息及其动态变化:对于打乒乓这个小游戏而言,完整的状态信息就是以像素为代表的整个方格。在这里,我们采用留空式空间编码,可以参考用10*8的一个矩阵来表示,前面9层代表球落的空间,最后一层用3个连续的1代表球拍的位置。如下图所示:

四、回报函数设计

回报信号是人与算法沟通的桥梁:对于打乒乓这个小游戏,回报就是要让底下的三个球拍接住落下来的小球,接住得1分,接不住扣1分。

读到这里,我们就可以开始动手编程了。结合前面写的面向对象编程,我们可以设计一个PingpongEnv类,在这个类里面,我们要首先定义reset方法,用来初始化状态空间和球拍的位置。接着,定义个step方法,用来定义当输入一个动作之后,对状态空间的变化。然后,通过判断球落地还是落到拍子上,确定本次移动的回报值。

class PingpongEnv:def __init__(self):self.reset()def reset(self):game = np.zeros((10,8))game[9,3:6] = 1.self.state = gameself.t = 0self.done = Falsereturn self.state[np.newaxis,:].copy()def step(self, action):reward = 0.game = self.stateif self.done:print('Call step after env is done')if self.t==200:self.done = Truereturn game[np.newaxis,:].copy(),10,self.done# 根据action移动盘子if action  == 0 and game[9][0]!=1:game[9][0:7] = game[9][1:8].copy()#向左移动game[9][7] = 0elif action == 1 and game[9][7]!=1:game[9][1:8] = game[9][0:7].copy()#向右移动game[9][0] = 0# 判断球落地还是落到拍子上if 1 in game[8]:ball = np.where(game[8]==1)if game[9][ball] != 1:reward = -1.self.done = Trueelse:reward = 1.game[8][ball] = 0.#game[1:9] = game[0:8].copy()game[0:9] = 0if self.t%9==0:self.idx = random.randint(a = 0, b = 7)#随机生成位置self.idirect = random.randint(0,1) #随机生成方位 game[0][self.idx] = 1.              else:if self.idirect == 0 and self.idx != 0:  #方位向左,非最左边self.idx -= 1elif self.idirect == 0 and self.idx == 0: #方位向左,最左边self.idx += 1self.idirect = 1               #之后的方向变为右边elif self.idirect == 1 and self.idx != 7: #方位向右,非最右边self.idx += 1elif self.idirect == 1 and self.idx == 7: #方位向右,最右边self.idx -= 1self.idirect = 0game[self.t%9][self.idx] =1.self.t += 1return game[np.newaxis,:].copy(),reward,self.done

五、算法选择

明确任务需求并初步完成问题定义之后,就可以为相关任务选择合适的深度强化学习算法了。:对于算法落地而言,还需要根据任务自身的特点从DRL本源出发进行由浅入深、粗中有细的筛选和迭代。
在本案例中,直接选用DQN算法即可,具体代码如下,其主要包括两个类,一个DQN,定义网络的层数和形状;一个DQNAgent,定义学习和采样方法。

class DQN(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(1, 8, 3, padding=1)self.conv2 = nn.Conv2d(8, 16, 3, padding=1)# 16*10*8self.maxpool = nn.MaxPool2d(2,2)# 16*5*4self.fc = nn.Sequential(nn.Linear(16*5*4, 32),nn.ReLU(),nn.Linear(32, 3),)def forward(self, x):x = F.relu(self.conv1(x))x = F.relu(self.conv2(x))x = self.maxpool(x)x = x.view(x.size(0), -1)x = self.fc(x)return xclass DQNAgent():def __init__(self, network, eps, gamma, lr):self.network = networkself.eps = epsself.gamma = gammaself.optimizer = optim.Adam(self.network.parameters(), lr=lr)def learn(self, batch):s0, a0, r1, s1, done = zip(*batch)n = len(s0)s0 = torch.FloatTensor(s0)s1 = torch.FloatTensor(s1)r1 = torch.FloatTensor(r1)a0 = torch.LongTensor(a0)done = torch.BoolTensor(done)increment = self.gamma * torch.max(self.network(s1).detach(), dim=1)[0]y_true = r1+incrementy_pred = self.network(s0)[range(n),a0]loss = F.mse_loss(y_pred, y_true)self.optimizer.zero_grad()loss.backward()self.optimizer.step()return loss.item()def sample(self, state):'''epsilon探索选择下一个action'''state = state[np.newaxis,:]action_value = self.network(torch.FloatTensor(state))if random.random()<self.eps:return random.randint(0, 2)else:max_action = torch.argmax(action_value,dim=1)return max_action.item()

在这之后,还需要定义个评估函数,用来评判训练效果:

def evaluate(agent,env,times):eval_reward = []for i in range(times):obs = env.reset()episode_reward = 0while True:action = agent.sample(obs)#选取最优动作obs,reward,isOver = env.step(action)episode_reward += rewardif isOver:breakeval_reward.append(episode_reward)return np.mean(eval_reward)

六、训练调试

具体说来,就是设置不同的训练参数,尝试是否能够收敛,在DRL落地实践中,这种第一训练“心里没底”的忐忑体验无论对入门小白还是学术明星都是公平的。我们这个算例倒是不存在这个问题,直接进行训练即可。

if __name__ == "__main__":gamma = 0.9eps_high = 0.9eps_low = 0.1num_episodes = 500LR = 0.001batch_size = 5decay = 200net = DQN()agent = DQNAgent(net,1,gamma,LR)replay_lab = deque(maxlen=5000)env = PingpongEnv()state = env.reset()for episode in range(num_episodes):agent.eps = eps_low + (eps_high-eps_low) *\(np.exp(-1.0 * episode/decay))s0 = env.reset()while True:a0 = agent.sample(s0)s1, r1, done = env.step(a0)replay_lab.append((s0.copy(),a0,r1,s1.copy(),done))if done:breaks0 = s1if replay_lab.__len__() >= batch_size:batch = random.sample(replay_lab,k=batch_size)loss = agent.learn(batch)if (episode+1)%50==0:print("Episode: %d, loss: %.3f"%(episode+1,loss))score = evaluate(agent,env,10)print("Score: %.1f"%(score))

训练完成后,利用以下代码,对训练的Agent运行效果进行观察:


import matplotlib.pyplot as plt
import torch
from matplotlib.colors import ListedColormap
from train import PingpongEnv,DQNAgent,DQN cmap_light = ListedColormap(['white','red'])
from IPython import displayagent.eps = 0
env = PingpongEnv()
s = env.reset()
while True:a = agent.sample(s)s, r, done = env.step(a)if done:breakimg = s.squeeze()plt.imshow(img, cmap=cmap_light)plt.show()display.clear_output(wait=True)

由此,我们便可以观察到AI玩游戏的效果。

总结

通过这个打乒乓小游戏,结合深度强化学习落地指南,算是对如何将强化学习算法用到一个具体的事务上有了一个初步的认识。感觉pytorch确实是一个非常容易上手的神经网络建模工具。随着后面对《深度学习》认识的深入,将逐步细化理论方面的细节内容。

2:pytorch深度强化学习落地:以打乒乓小游戏为例相关推荐

  1. 【赠书】深度强化学习落地指南,来自一线工程师的经验!

    ‍‍ 今天要给大家介绍的书是深度强化学习落地指南,本书是海康威视研究院任算法专家工作总结,对强化学习落地实践中的工程经验和相关方法进行了系统归纳. 本书内容 本书一共分为7章,包括强化学习的需求分析和 ...

  2. 深度强化学习落地指南:弥合DRL算法原理和落地实践之间的断层 | 文末送书

    魏宁 著 电子工业出版社-博文视点 2021-08-01 ISBN: 9787121416446 定价: 109.00 元 新书推荐 ????今日福利 |关于本书| 本书从工业界一线算法工作者的视角, ...

  3. 深度强化学习落地指南总结(二)-动作空间设计

    本系列是对<深度强化学习落地指南>全书的总结,这本书是我市面上看过对深度 强化学习落地讲的最好的一本书,大大拓宽了自己对RL落地思考的维度,形成了强化学习落地分析的一套完整框架,本文内容基 ...

  4. 深度强化学习落地方法论(8)——新书推荐《深度强化学习落地指南》

    知乎原文链接 文章目录 记一次成功的Exploration DRL落地中的"武德"问题 本书的创作理念 关于强化学习 结语 大家好,已经很久没有更新这个专栏了,希望当初关注它的知友 ...

  5. 深度强化学习制作森林冰火人游戏AI(一)下载游戏

    概述 首先先把游戏环境搭建起来 下载游戏 这部分的介绍可以看python 从4399获取小游戏,我就不重新介绍一遍了 import os import requests# 基础url host_url ...

  6. 深度强化学习制作森林冰火人游戏AI(四)获取窗口部分界面

    概述 这篇主要讲述如何用python获取森林冰火人窗口部分界面 在获取部分界面的图片之后通过图片识别/分类来判断当前游戏所属的状态 前篇:深度强化学习制作森林冰火人游戏AI(三)向游戏输出键盘控制信息 ...

  7. 深度强化学习制作森林冰火人游戏AI(五)识别游戏状态

    深度强化学习制作森林冰火人游戏AI(五)识别游戏状态 概述 游戏状态切换图 游戏状态识别原理 界面区域选择 保存界面 识别方法 识别游戏状态 概述 通过图片识别来对分析游戏当前状态 前篇:深度强化学习 ...

  8. 深度强化学习制作森林冰火人游戏AI(三)向游戏输出键盘控制信息

    概述 本文讲如何通过python发送键盘控制命令控制游戏 前篇:深度强化学习制作森林冰火人游戏AI(二)获取游戏屏幕 后篇:深度强化学习制作森林冰火人游戏AI(四)获取窗口部分界面 获取窗口句柄 窗口 ...

  9. 深度强化学习制作森林冰火人游戏AI(二)获取游戏屏幕

    概述 前篇:深度强化学习制作森林冰火人游戏AI(一)下载游戏 后篇:深度强化学习制作森林冰火人游戏AI(三)向游戏输出键盘控制信息 游戏有了,接下来是程序的输入了 获取窗口名称 windows里面的所 ...

最新文章

  1. 测试思想 什么是软件测试(摘录)
  2. 粒子物理学有了新的基础数学理论
  3. [LeetCode] 5. Longest Palindromic Substring
  4. 自言自语(三)--部分中文字体
  5. [云炬创业基础笔记]第七张创业团队测试3
  6. 快速判断一个数是否是4的幂次方,若是,并判断出来是多少次方! .
  7. 拥有一亿会员的爱奇艺如何搭建大数据实时分析平台
  8. ElasticSearch(笔记)
  9. JavaScript(js)的replace问题的解决
  10. vue实现数字“滚动式增加”效果 【插件化封装】
  11. asp sql ip地址排序_SQL必知必会读书笔记,30分钟入门SQL!
  12. 不敢穷,不敢病,不敢死……我们是独生子女
  13. mysql求数据库平均成绩视图_MySQL数据库视图
  14. shell编程四剑客之 grep
  15. Android studio 六大基本布局详解
  16. 各种绩效考核方法的区别
  17. 解决nrm不能使用问题
  18. r语言把两个折线图图像放到一个图里_图像目标检测算法总结(从R-CNN到YOLO v3)...
  19. 20条职场潜规则!小心那些城府很深的人(建议收藏)
  20. html中怎么把h标签左移,基础标签--h、p、a、hr、br、img、base

热门文章

  1. 博世BOSCH EDI项目案例
  2. 计算机科学与技术专业自考自学资料
  3. 没有了 main 函数,程序还能跑吗?
  4. 论文笔记(五)【DENSITY ESTIMATION USING REAL NVP】
  5. 【K3S 一】部署K3S集群(单Master)
  6. STM32F103你学不会系列(十七)电容触摸按键实现
  7. NaN == NaN , NaN === NaN 为啥是false?
  8. 微信小程序倒计时,购物车,向左滑删除 左拉删除
  9. 从零开始,搭建Windows 10+Ubuntu 18.04双系统及Anaconda3+CUDA10.1+cuDNN7.6+Tensorflow2.1等开发环境
  10. django 基本user列子