训练环境

使用Movan写的机械臂环境:https://github.com/MorvanZhou/Reinforcement-learning-with-tensorflow/blob/master/experiments/Robot_arm/arm_env.py

这个环境真的挺有意思的,主要可以和用户交互,真真切切感受到训练后智能体的聪明程度。

提醒:python不要用3.8的,可能会和他的环境不兼容,我用的3.6的。

这个环境主要采用pyglet包写的,详见Movan的教程:Movan教你如何从0写强化学习环境(机械臂)

采用算法

既然是连续动作那就无脑PPO算法吧,当然,PPO最大的劣势就是训练慢,每一个episode都要重新收集buffer,既然这样,那就采用多进程吧,关于多进程的实现方法见我的博客:DPPO实现
不用多进程在GPU上跑CPU很难全负荷运行,多进程可以最大化CPU负荷:

Rewad Curve

这个环境是严格按照gym的接口写的,因此兼容规范的RL代码。

交互展示训练效果

在机械臂的环境下加入如下代码:

if __name__ == "__main__":env = ArmEnv()o = env.reset()sys.path.append('./PPO/multi_processing_ppo')from PPO.multi_processing_ppo.PPOModel import *net = GlobalNet(env.state_dim,env.action_dim)net.act.load_state_dict(torch.load('./PPO/TrainedModel/act.pkl'))while 1:env.render()a = net.act(torch.tensor(o, dtype=torch.float32, device="cpu")).detach().numpy()o2,r,d,_ = env.step(a)o = o2

大体就是加载训练好的模型参数后写一个while死循环不断执行策略。

运行后鼠标指哪小方块跑哪,小方块代表机械臂需要移动到的位置:


这个模型还可以选择训练难度,我选择的是hard,训练了几分钟就有如上的效果了,有兴趣可以用各种算法来挑战上面这个环境。

环境代码

github有些时候很难上,我就把环境代码贴上来:

#!/usr/bin/python
# -*- coding: utf-8 -*-
"""
Environment for Robot Arm.
You can customize this script in a way you want.
View more on [莫烦Python] : https://morvanzhou.github.io/tutorials/
Requirement:
pyglet >= 1.2.4
numpy >= 1.12.1
"""
import numpy as np
import pygletpyglet.clock.set_fps_limit(10000)class ArmEnv(object):action_bound = [-1, 1]action_dim = 2state_dim = 7dt = .1  # refresh ratearm1l = 100arm2l = 100viewer = Noneviewer_xy = (400, 400)get_point = Falsemouse_in = np.array([False])point_l = 15grab_counter = 0def __init__(self, mode='easy'):# node1 (l, d_rad, x, y),# node2 (l, d_rad, x, y)self.mode = modeself.arm_info = np.zeros((2, 4))self.arm_info[0, 0] = self.arm1lself.arm_info[1, 0] = self.arm2lself.point_info = np.array([250, 303])self.point_info_init = self.point_info.copy()self.center_coord = np.array(self.viewer_xy)/2def step(self, action):# action = (node1 angular v, node2 angular v)action = np.clip(action, *self.action_bound)self.arm_info[:, 1] += action * self.dtself.arm_info[:, 1] %= np.pi * 2arm1rad = self.arm_info[0, 1]arm2rad = self.arm_info[1, 1]arm1dx_dy = np.array([self.arm_info[0, 0] * np.cos(arm1rad), self.arm_info[0, 0] * np.sin(arm1rad)])arm2dx_dy = np.array([self.arm_info[1, 0] * np.cos(arm2rad), self.arm_info[1, 0] * np.sin(arm2rad)])self.arm_info[0, 2:4] = self.center_coord + arm1dx_dy  # (x1, y1)self.arm_info[1, 2:4] = self.arm_info[0, 2:4] + arm2dx_dy  # (x2, y2)s, arm2_distance = self._get_state()r = self._r_func(arm2_distance)return s, r, self.get_point, Nonedef reset(self):self.get_point = Falseself.grab_counter = 0if self.mode == 'hard':pxy = np.clip(np.random.rand(2) * self.viewer_xy[0], 100, 300)self.point_info[:] = pxyelse:arm1rad, arm2rad = np.random.rand(2) * np.pi * 2self.arm_info[0, 1] = arm1radself.arm_info[1, 1] = arm2radarm1dx_dy = np.array([self.arm_info[0, 0] * np.cos(arm1rad), self.arm_info[0, 0] * np.sin(arm1rad)])arm2dx_dy = np.array([self.arm_info[1, 0] * np.cos(arm2rad), self.arm_info[1, 0] * np.sin(arm2rad)])self.arm_info[0, 2:4] = self.center_coord + arm1dx_dy  # (x1, y1)self.arm_info[1, 2:4] = self.arm_info[0, 2:4] + arm2dx_dy  # (x2, y2)self.point_info[:] = self.point_info_initreturn self._get_state()[0]def render(self):if self.viewer is None:self.viewer = Viewer(*self.viewer_xy, self.arm_info, self.point_info, self.point_l, self.mouse_in)self.viewer.render()def sample_action(self):return np.random.uniform(*self.action_bound, size=self.action_dim)def set_fps(self, fps=30):pyglet.clock.set_fps_limit(fps)def _get_state(self):# return the distance (dx, dy) between arm finger point with blue pointarm_end = self.arm_info[:, 2:4]t_arms = np.ravel(arm_end - self.point_info)center_dis = (self.center_coord - self.point_info)/200in_point = 1 if self.grab_counter > 0 else 0return np.hstack([in_point, t_arms/200, center_dis,# arm1_distance_p, arm1_distance_b,]), t_arms[-2:]def _r_func(self, distance):t = 50abs_distance = np.sqrt(np.sum(np.square(distance)))r = -abs_distance/200if abs_distance < self.point_l and (not self.get_point):r += 1.self.grab_counter += 1if self.grab_counter > t:r += 10.self.get_point = Trueelif abs_distance > self.point_l:self.grab_counter = 0self.get_point = Falsereturn rclass Viewer(pyglet.window.Window):color = {'background': [1]*3 + [1]}fps_display = pyglet.clock.ClockDisplay()bar_thc = 5def __init__(self, width, height, arm_info, point_info, point_l, mouse_in):super(Viewer, self).__init__(width, height, resizable=False, caption='Arm', vsync=False)  # vsync=False to not use the monitor FPSself.set_location(x=80, y=10)pyglet.gl.glClearColor(*self.color['background'])self.arm_info = arm_infoself.point_info = point_infoself.mouse_in = mouse_inself.point_l = point_lself.center_coord = np.array((min(width, height)/2, ) * 2)self.batch = pyglet.graphics.Batch()arm1_box, arm2_box, point_box = [0]*8, [0]*8, [0]*8c1, c2, c3 = (249, 86, 86)*4, (86, 109, 249)*4, (249, 39, 65)*4self.point = self.batch.add(4, pyglet.gl.GL_QUADS, None, ('v2f', point_box), ('c3B', c2))self.arm1 = self.batch.add(4, pyglet.gl.GL_QUADS, None, ('v2f', arm1_box), ('c3B', c1))self.arm2 = self.batch.add(4, pyglet.gl.GL_QUADS, None, ('v2f', arm2_box), ('c3B', c1))def render(self):pyglet.clock.tick()self._update_arm()self.switch_to()self.dispatch_events()self.dispatch_event('on_draw')self.flip()def on_draw(self):self.clear()self.batch.draw()# self.fps_display.draw()def _update_arm(self):point_l = self.point_lpoint_box = (self.point_info[0] - point_l, self.point_info[1] - point_l,self.point_info[0] + point_l, self.point_info[1] - point_l,self.point_info[0] + point_l, self.point_info[1] + point_l,self.point_info[0] - point_l, self.point_info[1] + point_l)self.point.vertices = point_boxarm1_coord = (*self.center_coord, *(self.arm_info[0, 2:4]))  # (x0, y0, x1, y1)arm2_coord = (*(self.arm_info[0, 2:4]), *(self.arm_info[1, 2:4]))  # (x1, y1, x2, y2)arm1_thick_rad = np.pi / 2 - self.arm_info[0, 1]x01, y01 = arm1_coord[0] - np.cos(arm1_thick_rad) * self.bar_thc, arm1_coord[1] + np.sin(arm1_thick_rad) * self.bar_thcx02, y02 = arm1_coord[0] + np.cos(arm1_thick_rad) * self.bar_thc, arm1_coord[1] - np.sin(arm1_thick_rad) * self.bar_thcx11, y11 = arm1_coord[2] + np.cos(arm1_thick_rad) * self.bar_thc, arm1_coord[3] - np.sin(arm1_thick_rad) * self.bar_thcx12, y12 = arm1_coord[2] - np.cos(arm1_thick_rad) * self.bar_thc, arm1_coord[3] + np.sin(arm1_thick_rad) * self.bar_thcarm1_box = (x01, y01, x02, y02, x11, y11, x12, y12)arm2_thick_rad = np.pi / 2 - self.arm_info[1, 1]x11_, y11_ = arm2_coord[0] + np.cos(arm2_thick_rad) * self.bar_thc, arm2_coord[1] - np.sin(arm2_thick_rad) * self.bar_thcx12_, y12_ = arm2_coord[0] - np.cos(arm2_thick_rad) * self.bar_thc, arm2_coord[1] + np.sin(arm2_thick_rad) * self.bar_thcx21, y21 = arm2_coord[2] - np.cos(arm2_thick_rad) * self.bar_thc, arm2_coord[3] + np.sin(arm2_thick_rad) * self.bar_thcx22, y22 = arm2_coord[2] + np.cos(arm2_thick_rad) * self.bar_thc, arm2_coord[3] - np.sin(arm2_thick_rad) * self.bar_thcarm2_box = (x11_, y11_, x12_, y12_, x21, y21, x22, y22)self.arm1.vertices = arm1_boxself.arm2.vertices = arm2_boxdef on_key_press(self, symbol, modifiers):if symbol == pyglet.window.key.UP:self.arm_info[0, 1] += .1print(self.arm_info[:, 2:4] - self.point_info)elif symbol == pyglet.window.key.DOWN:self.arm_info[0, 1] -= .1print(self.arm_info[:, 2:4] - self.point_info)elif symbol == pyglet.window.key.LEFT:self.arm_info[1, 1] += .1print(self.arm_info[:, 2:4] - self.point_info)elif symbol == pyglet.window.key.RIGHT:self.arm_info[1, 1] -= .1print(self.arm_info[:, 2:4] - self.point_info)elif symbol == pyglet.window.key.Q:pyglet.clock.set_fps_limit(1000)elif symbol == pyglet.window.key.A:pyglet.clock.set_fps_limit(30)def on_mouse_motion(self, x, y, dx, dy):self.point_info[:] = [x, y]def on_mouse_enter(self, x, y):self.mouse_in[0] = Truedef on_mouse_leave(self, x, y):self.mouse_in[0] = False

利用深度强化学习训练机械臂环境相关推荐

  1. 【强化学习与机器人控制论文 1】基于深度强化学习的机械臂避障

    基于深度强化学习的机械臂避障 1. 引言 2. 论文解读 2.1 背景 2.2 将NAF算法用在机器人避障中 3. 总结 1. 引言 本文介绍一篇2018年发表在 European Control C ...

  2. 基于深度强化学习训练《街头霸王·二:冠军特别版》通关关底 BOSS -智能 AI 代理项目上手

    文章目录 SFighterAI项目简介 实现软件环境 项目文件结构 运行指南 环境配置 验证及调整gym环境: gym-retro 游戏文件夹 错误提示及解决 Could not initialize ...

  3. 【经验】深度强化学习训练与调参技巧

    来源:知乎(https://zhuanlan.zhihu.com/p/482656367) 作者:岳小飞 天下苦 RL 久矣,其中最苦的地方莫过于训练和调参了,人人欲"调"之而后快 ...

  4. 利用AI强化学习训练50级比卡超单挑70级超梦!

    强化学习(Reinforcement Learning, RL),是机器学习的范式和方法论之一,用于描述和解决智能体(agent)在与环境的交互过程中通过学习策略以达成回报最大化或实现特定目标的问题. ...

  5. 强化学习UR机械臂仿真环境搭建(一) - 为UR3机械臂添加robotiq ft300力传感器

    为UR3机械臂添加robotiq ft300力传感器 ```建议参考这篇```,[ur机械臂 + robotiq gripper + robotiq ft sensor + gazebo + 连接真实 ...

  6. 写一个强化学习训练的gym环境

    需求 要用强化学习(Reinforcement Learning)算法解决问题,需要百千万次的训练,真实环境一般不允许这么多次训练(时间太长.试错代价太大),需要开发仿真环境.OpenAI的gym环境 ...

  7. 参加Matlab与AI讲座:使用深度强化学习训练走路机器人观后感

    时间:2023年4月12日,周三,天气晴 地址:大连理工大学研教楼303 前言:Matlab其实有很多功能,我们所用的只是最基础最简单的部分,例如矩阵计算,画图等等. 随着强化学习的发展,matlab ...

  8. 深度强化学习(资源篇)(更新于2020.11.22)

    理论 1种策略就能控制多类模型,华人大二学生提出RL泛化方法,LeCun认可转发 | ICML 2020 AlphaGo原来是这样运行的,一文详解多智能体强化学习的基础和应用 [DeepMind总结] ...

  9. 【医疗人工智能论文】使用深度强化学习的腹腔镜机器人辅助训练

    Article 作者:Xiaoyu Tan , Chin-Boon Chng, Ye Su, Kah-Bin Lim, and Chee-Kong Chui 文献题目:Robot-Assisted T ...

最新文章

  1. AMD依然yes!官宣锐龙5000系列CPU,单核性能首次超越英特尔,苏妈:最好的游戏CPU!...
  2. (转)OpenCV版本的摄像机标定
  3. 如果通过当前元素知道父元素、同级元素
  4. linux ftp做yum源,在RedHat5下架设yum源服务器(FTP)
  5. springMVC 前台向后台传数组
  6. Android笔记 Application对象的使用-数据传递以及内存泄漏问题
  7. Eclipse中看java源代码
  8. Python中表示偶数_蒙特卡洛模拟(Python)深入教程
  9. An HTML5 presentation builder — Read more
  10. 19什么情况下会帮助他人
  11. python考试有什么用_Python有什么用?2020年学习Python的10个理由
  12. PHP获取服务器端的相关信息
  13. 绕过waf mysql爆库_iwebsec刷题记录-SQL注入漏洞
  14. 打包jar文件后的spring部署及hibernate自动建表经验总结
  15. 支持tls的tcp服务器,TCP+TLS
  16. centos 日志切割_CentOS Linux使用logrotate分割管理日志
  17. MegaWizard Plug-in Manager产生的目录结构及关键文件
  18. Zephyr单元测试框架:ztest/twister的使用和介绍
  19. CF235C-Cyclical Quest
  20. 看看别人怎么学习的。

热门文章

  1. 图像修复(Image Restoration)
  2. 解决更换电池引发的乐视2手机(lex620)不进系统问题
  3. 【Java入门】--键盘输入月份,控制台返回对应英文月份。
  4. Scrapy爬虫框架,爬取小说网的所有小说
  5. 【项目总结】论文复现与改进:一般选择模型的产品组合优化算法(Research@收益管理)
  6. c语言语法要素,第6章DSP_C语言程序设计要素.ppt
  7. 龙迅LT8612UX 是一款 HDMI 至 HDMIVGA 转换器
  8. EPICS记录参考--sub-Array记录(subArray)
  9. LPC1768 IAR环境下使用完整64K内存的方法
  10. 云服务器oa系统,oa系统放到云服务器云服务器