点击上方↑↑↑“视学算法”关注我

来源:公众号 机器之心 授权

还是熟悉的树莓派!训练 RL agent 打 Atari 不再需要 GPU 集群,这个项目让你在边缘设备上也能进行实时训练。

自从 DeepMind 团队提出 DQN,在 Atari 游戏中表现出超人技巧,已经过去很长一段时间了。在此期间持续有新的方法被提出,不断创造出 Deep RL 领域新 SOTA。然而,目前不论是同策略或异策略强化学习方法(此处仅比较无模型 RL),仍然需要强大的算力予以支撑。即便研究者已将 Atari 游戏的分辨率降低到 84x84,一般情况下仍然需要使用 GPU 进行策略的训练。

如今,来自 Ogma Intelligent Systems Corp. 的研究人员突破了这一限制。他们在稀疏预测性阶层机制(Sparse Predictive Hierarchies)的基础上,提出一种不需要反传机制的策略搜索框架,使得实时在树莓派上训练 Atari 游戏的控制策略成为可能。下图展示了使用该算法在树莓派上进行实时训练的情形。

可以看到,agent 学会了如何正确调整滑块位置来接住小球,并发动进攻的策略。值得注意的是,观测输入为每一时刻产生的图片。

也就是说,该算法做到了在树莓派这样算力较小的边缘设备上,实时学习从像素到策略的映射关系。

研究者开源了他们的 SPH 机制实现代码,并提供了相应 Python API。这是一个结合了动态系统应用数学、计算神经科学以及机器学习的扩展库。他们的方法曾经还被 MIT 科技评论列为「Best of the Physics arXiv」。

项目地址:

https://github.com/ogmacorp/OgmaNeo2

OgmaNeo2

研究者所提出的 SPH 机制不仅在 Pong 中表现良好,在连续策略领域也有不错的表现。下图分别是使用该算法在 OpenAI gym 中 Lunar Lander 环境与 PyBullet 中四足机器人环境的训练结果。

在 Lunar Lander 环境中,训练 1000 代之后,每个 episode 下 agent 取得了平均 100 分左右的 reward。如果训练时间更长(3000 代以上),agent 的平均 reward 甚至能达到 200。在 PyBullet 的 Minitaur 环境中,agent 的训练目标是在其自身能量限制条件下,跑得越快越好。从图中可以看到,经过一段时间训练,这个四足机器人学会了保持身体平衡与快速奔跑(虽然它的步态看起来不是那么地自然)。看起来效果还是很棒的,机器之心也上手测试了一番。

算法框架

OgmaNeo2 用来学习 Pong 控制策略的整体框架如下图所示。图像观测值通过图像编码器输入两层 exponential memory 结构中,计算结果输出到之后的 RL 层产生相应动作策略。

项目实测

在安装 PyOgmaNeo2 之前,我们需要先编译安装其对应的 C++库。将 OgmaNeo2 克隆到本地:

!git clone https://github.com/ogmacorp/OgmaNeo2.git

之后将工作目录切换到 OgmaNeo2 下,并在其中创建一个名为 build 的文件夹,用于存放编译过程产生的文件。

import os
os.chdir('OgmaNeo2')
!mkdir build
os.chdir('build')

接下来我们对 OgmaNeo2 进行编译。这里值得注意的是,我们需要将-DBUILD_SHARED_LIBS=ON 命令传入 cmake 中,这样我们才能在之后的 PyOgmaNeo2 扩展库里使用它。

!cmake .. -DBUILD_SHARED_LIBS=ON
!make
!make install

当 OgmaNeo2 安装成功后,安装 SWIG v3 及 OgmaNeo2 的相应 Python 扩展库:

!apt-get install swig3.0
os.chdir('/content')
!git clone https://github.com/ogmacorp/PyOgmaNeo2
os.chdir('PyOgmaNeo2')
!python3 setup.py install --user

接下来输入 import pyogmaneo,如果没有错误提示就说明已经成功安装了 PyOgmaNeo2。

我们先用一个官方提供的时间序列回归来测试一下,在 notebook 中输入:

import numpy as np
import pyogmaneo
import matplotlib.pyplot as plt
# Set the number of threads
pyogmaneo.ComputeSystem.setNumThreads(4)
# Create the compute system
cs = pyogmaneo.ComputeSystem()
# This defines the resolution of the input encoding - we are using a simple single column that represents a bounded scalar through a one-hot encoding. This value is the number of "bins"
inputColumnSize = 64
# The bounds of the scalar we are encoding (low, high)
bounds = (-1.0, 1.0)
# Define layer descriptors: Parameters of each layer upon creation
lds = []
for i in range(5): # Layers with exponential memoryld = pyogmaneo.LayerDesc()# Set the hidden (encoder) layer size: width x height x columnSizeld.hiddenSize = pyogmaneo.Int3(4, 4, 16)ld.ffRadius = 2 # Sparse coder radius onto visible layersld.pRadius = 2 # Predictor radius onto sparse coder hidden layer (and feed back)ld.ticksPerUpdate = 2 # How many ticks before a layer updates (compared to previous layer) - clock speed for exponential memoryld.temporalHorizon = 2 # Memory horizon of the layer. Must be greater or equal to ticksPerUpdate, usually equal (minimum required)lds.append(ld)
# Create the hierarchy: Provided with input layer sizes (a single column in this case), and input types (a single predicted layer)
h = pyogmaneo.Hierarchy(cs, [ pyogmaneo.Int3(1, 1, inputColumnSize) ], [ pyogmaneo.inputTypePrediction ], lds)
# Present the wave sequence for some timesteps
iters = 2000
for t in range(iters):# The value to encode into the input columnvalueToEncode = np.sin(t * 0.02 * 2.0 * np.pi) * np.sin(t * 0.035 * 2.0 * np.pi + 0.45) # Some wavy linevalueToEncodeBinned = int((valueToEncode - bounds[0]) / (bounds[1] - bounds[0]) * (inputColumnSize - 1) + 0.5)# Step the hierarchy given the inputs (just one here)h.step(cs, [ [ valueToEncodeBinned ] ], True) # True for enabling learning# Print progressif t % 100 == 0:print(t)
# Recall the sequence
ts = [] # Time step
vs = [] # Predicted value
trgs = [] # True value
for t2 in range(300):t = t2 + iters # Continue where previous sequence left off# New, continued value for comparison to what the hierarchy predictsvalueToEncode = np.sin(t * 0.02 * 2.0 * np.pi) * np.sin(t * 0.035 * 2.0 * np.pi + 0.45) # Some wavy line# Bin the value into the column and write into the input buffer. We are simply rounding to the nearest integer location to "bin" the scalar into the columnvalueToEncodeBinned = int((valueToEncode - bounds[0]) / (bounds[1] - bounds[0]) * (inputColumnSize - 1) + 0.5)# Run off of own predictions with learning disabledh.step(cs, [ [ valueToEncodeBinned ] ], False) # Learning disabledpredIndex = h.getPredictionCs(0)[0] # First (only in this case) input layer prediction# Decode value (de-bin)value = predIndex / float(inputColumnSize - 1) * (bounds[1] - bounds[0]) + bounds[0]# Append to plot datats.append(t2)vs.append(value)trgs.append(valueToEncode)# Show predicted valueprint(value)
# Show plot
plt.plot(ts, vs, ts, trgs)

可得到如下结果。图中橙色曲线为真实值,蓝色曲线为预测值。可以看到,该方法以极小的误差拟合了真实曲线。

最后是该项目在 CartPole 任务中的表现。运行!python3 ./examples/CartPole.py,得到如下训练结果。可以看到,其仅用 150 个 episode 左右即解决了 CartPole 任务。

不需要借助GPU的力量,用树莓派也能实时训练agent玩Atari相关推荐

  1. 借助资本的力量,雷军仅花10年时间成为中国第九大富豪

    据媒体报道指随着金山云的上市,雷军投资的上市公司已有十多家,由此他的财富猛增,已成为中国第九大富豪,回顾他的创富史,看看他是如何借助资本的力量成就了自己的财富梦想. 雷军在金山公司工作了15年时间,成 ...

  2. APP刚上线没有用户,如何借助渠道的力量推广

    一.什么是APP推广的渠道. APP想要获得用户,需要借助渠道的力量.以下是几个主流推广渠道: 途径一:应用商店促销. AppStore促销是目前APP推广的主要渠道之一,也是最大用户下载渠道.重点包 ...

  3. 用树莓派做一个实时垃圾分类器|超实用!!

    此开源项目由树莓派爱好者基地人工智能部门谢远伦.任剑杰.沈超,共同协作完成.在此感谢各位成员的付出与努力.正是有各位的付出,树莓派生态才能越来越丰富! 代码仓库 1.码云Gitee:https://g ...

  4. 采用keras深度学习框架搭建卷积神经网络模型实现垃圾分类,基于树莓派上进行实时视频流的垃圾识别源代码

    一.项目概述 简介:该垃圾分类项目主要在于对各种垃圾进行所属归类,本次项目采用keras深度学习框架搭建卷积神经网络模型实现图像分类,最终移植在树莓派上进行实时视频流的垃圾识别. 前期:主要考虑PC端 ...

  5. TVM:在树莓派上部署预训练的模型

    TVM:在树莓派上部署预训练的模型 之前我们已经介绍如何通过Python接口(AutoTVM)来编译和优化模型.本文将介绍如何在远程(如本例中的树莓派)上部署预训练的模型. 在设备上构建 TVM Ru ...

  6. 树莓派 摄像头 VLC实时监控

    这两天在捣鼓树莓派摄像头通过电脑实时监控,有一枚官方鱼眼摄像头,本来是打算实现在任何网络中都可以直接访问,看网上很多教程都是VLC,于是就按照教程来,后来发现VLC只能在局域网中,那就先局域网吧. 网 ...

  7. 12.树莓派mjpg-streamer实现实时监控(树莓派摄像头的安装)

    树莓派mjpg-streamer实现实时监控 准备工作 树莓派扩容 使用raspi-config扩容(推荐) 安装依赖库 安装git及git源码 编译安装mjpeg 使能摄像头 启动摄像头(验证) 参 ...

  8. 品牌软文营销借助故事的力量打动用户

    故事型文案在引起读者情感共鸣以及促进用户产生购买行动上起着重要的作用.这是因为随着经济水平的提升,人们不再满足于简单的物质需要,更加追求情感层次的需求.其实,说白了就是当人们不再为生存问题困扰时,就会 ...

  9. 借助 GPU 和容器支持,在 Amazon Robomaker 中运行任何高保真模拟

    点击上方[凌云驭势 重塑未来] 一起共赴年度科技盛宴! 本博客引用了 Amazon RoboMaker 集成式开发环境(IDE),这是一项已弃用的功能.要继续阅读这篇博客,请使用 Amazon Clo ...

最新文章

  1. 【转】ibatis的简介与初步搭建应用
  2. python语法syntaxerror怎么修改-Python 语法错误
  3. 0 开场白元素项类的设计
  4. WPF 获取鼠标屏幕位置、窗口位置、控件位置
  5. SMOTE/SMOTEEN 处理不平衡数据集
  6. [Leedcode][JAVA][第22题括号生成][DFS][BFS][动态规划]
  7. 数据库基础知识——MySQL服务的启动和停止
  8. php 获取远程大文件上传,PHP 获取远程文件大小的3种解决方法
  9. PyTorch系列入门到精通——张量操作线性回归
  10. [ An Ac a Day ^_^ ] CodeForces 680A Bear and Five Cards
  11. centos 7安装zabbix 3.0
  12. python和nodejs哪个写爬虫好_PythonNodejs 哪个比较适合写爬虫
  13. vc中format用法以及c++中Format用法
  14. Java语言程序设计(基础篇)课后答案
  15. 网关支付、银联代扣通道、快捷支付、银行卡支付分别是怎么样进行支付的?...
  16. 最新凌风云支付系统网站源码全解无后门V4.1.1版本
  17. Python-----函数详解(上篇)(附小项目实战)
  18. 使用CS发送钓鱼邮件
  19. 请教统计对应表字段为空的字段数
  20. 苹果16g不够用怎么办_孩子不够自信怎么办?父母学会用这4个方法,孩子长大更优秀自信...

热门文章

  1. Matlab与线性代数 -- 矩阵的转置
  2. java ee不能运行_Java9+移除 Java EE,导致我的 groovy 脚本无法运行
  3. 真香!Vision Transformer 快速实现 Mnist 识别
  4. 机器模拟共情,情感AI正踏足诸多行业
  5. 算法鼻祖高德纳,82 岁仍在写《计算机程序设计的艺术》
  6. 无需训练RNN或生成模型,我写了一个AI来讲故事
  7. 一览群智胡健:在中国完全照搬Palantir模式,这不现实
  8. 技术不错的程序员,为何面试却“屡战屡败”
  9. 网易有道周枫:AI正带来革命性变化,但在线教育的核心是内容
  10. 贾跃亭晒FF 91新图,“生态化反”到底凉没凉?