Pytorch深度强化学习训练Agent走直线

问题提出

最近在学强化学习,想用强化学习的方法做一个机器人的运动路径规划,手头刚好有一个项目,问题具体是这样:机器人处在一个二维环境中,机器人的运动环境(Environment)我把它简化成了一个8x8(如图一)的棋盘格,而机器人就是占据1个格子的“棋子”。机器人刚开始在(0,0)点,机器人每次能采取的动作(Action)是“上、下、左、右”四个动作(当然不能跑出棋盘格外)。我最初想来做路径规划的目的是想让机器人遍历棋盘所有网格,而不重复走以前走过的点,就像扫地,最理想的情况当然是训练出一条“S”字型的扫地路径,这样每个格子都能扫到而且不重复。但在实际强化学习训练的过程样本中,我发现机器人一直在无规律的“踉跄”行走,连直线能没法走成,根本没有向标准“S”型路径收敛的迹象。注意到“S”型路径是由若干直线段组成的,为了让训练结果能更快速收敛到我想要的“S”型路径,也许先训练机器人走直线是一个更有效且关键的步骤。因此我打算先用强化学习来训练机器人走直线。

图一,8x8棋盘格与期望的机器人S形路径,图二,20x20棋盘格
用8x8棋盘格来做训练,机器人很容易碰到墙就停止,产生了一大堆没用的样本。可以试想人学会走路是在一个宽敞的环境,而不是一个狭窄的环境中,因此训练走直线并不一定得使用任务环境8x8棋盘格,完全可以用一个更大的20x20棋盘格(如图二)来训练。状态反馈(State)定义为机器人当前时刻的速度方向:上下左右。奖励函数(Reward)如下:若动作方向与当前速度方向相同,如速度上动作上、速度下动作下,则奖励为10;若动作方向与当前速度方向垂直,如速度上动作左、速度下动作右,则奖励为-5;若动作方向与当前速度方向相反,如速度上动作下、速度左动作右,则奖励为-10。Pytorch代码如下:

'''
训练环境
'''
import torch  #使用Pytorch
import time
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import tkinter as tk  #使用仿真环境tkinter
UNIT = 40   # 像素
WP_H = 20  # 棋盘格高
WP_W = 20  # 棋盘格宽
direction= 4 #速度方向
class Run(tk.Tk, object):def __init__(self):super(Run, self).__init__()self.action_space = ['u', 'd', 'l', 'r']   #上下左右self.n_actions = len(self.action_space)self.n_features = directionself.title('Run')self.geometry('{0}x{1}'.format(WP_H * UNIT, WP_H * UNIT))self._build_WP()def _build_WP(self):self.canvas = tk.Canvas(self, bg='white',height=WP_H * UNIT,width=WP_W * UNIT)# create gridsfor c in range(0, WP_W * UNIT, UNIT):x0, y0, x1, y1 = c, 0, c, WP_H * UNITself.canvas.create_line(x0, y0, x1, y1)for r in range(0, WP_H * UNIT, UNIT):x0, y0, x1, y1 = 0, r, WP_W * UNIT, rself.canvas.create_line(x0, y0, x1, y1)# create originorigin = np.array([20, 20])# create red rectself.rect = self.canvas.create_rectangle(origin[0] - 15, origin[1] - 15,origin[0] + 15, origin[1] + 15,fill='red')# pack allself.canvas.pack()self.acculR=0   #累积奖赏self.ToolCord=np.array([0,0])self.speed = torch.zeros(1,4)def reset(self):self.update()time.sleep(0.5)self.canvas.delete(self.rect)origin = np.array([20, 20])self.acculR=0self.speed = torch.zeros(4)self.speed[np.random.randint(0, 4, 1)]=1        #初始速度是随机的self.ToolCord = np.random.randint(0, WP_H, 2)   #初始出发点是随机的self.rect = self.canvas.create_rectangle(self.ToolCord[0] * UNIT + origin[0] - 15, self.ToolCord[1] * UNIT + origin[1] - 15,self.ToolCord[0] * UNIT + origin[0] + 15, self.ToolCord[1] * UNIT + origin[1] + 15,fill='red')s=self.speed.flatten()return s  #reset返回初始状态def step(self, action):outside = 0base_action = np.array([0, 0])if action == 0:   # upif self.speed[0]==1:reward=10elif self.speed[1]==1:reward=-10elif self.speed[2]== 1 or self.speed[3]== 1:reward = -5self.speed = torch.zeros(4)if self.ToolCord[1] > 0:base_action[1] -= UNITself.speed[0] = 1self.ToolCord = np.add(self.ToolCord, (0, -1))elif self.ToolCord[1] == 0:base_action[1] = base_action[1]outside = 1self.speed[0] = 1elif action == 1:   # downif self.speed[1]==1:reward=10elif self.speed[0]==1:reward=-10elif self.speed[2]==1 or self.speed[3]==1:reward = -5self.speed = torch.zeros(4)if self.ToolCord[1] < (WP_H - 1):base_action[1] += UNITself.speed[1] = 1self.ToolCord = np.add(self.ToolCord, (0, 1))elif self.ToolCord[1] == (WP_H - 1):base_action[1] = base_action[1]outside = 1self.speed[1] = 1elif action == 2:   # rightif self.speed[2]==1:reward=10elif self.speed[3]==1:reward=-10elif self.speed[0]==1 or self.speed[1]==1:reward = -5self.speed = torch.zeros(4)if self.ToolCord[0] < (WP_W - 1):base_action[0] += UNITself.speed[2] = 1self.ToolCord = np.add(self.ToolCord, (1, 0))elif self.ToolCord[0] == (WP_W - 1):base_action[0] = base_action[0]outside = 1self.speed[2] = 1elif action == 3:   # leftif self.speed[3]==1:reward=10elif self.speed[2]==1:reward = -10elif self.speed[0]==1 or self.speed[1]==1:reward = -5self.speed = torch.zeros(4)if self.ToolCord[0] > 0:base_action[0] -= UNITself.ToolCord = np.add(self.ToolCord, (-1, 0))self.speed[3] = 1elif self.ToolCord[0] == 0:base_action[0] = base_action[0]outside = 1self.speed[3] = 1self.canvas.move(self.rect, base_action[0], base_action[1])  # move agents_ = self.speed.flatten()#end flagself.acculR += rewardif self.acculR < -25 or outside==1:done = Trueelse:done = Falsereturn s_, reward, done, outsidedef render(self):time.sleep(0.1)self.update()def update():for t in range(2):s = env.reset()while True:env.render()a = 1s_, r, done, ifoutside = env.step(a)print(s_)if done:break
if __name__ == '__main__':env = Run()env.after(100, update)env.mainloop()
"""
学习过程
"""import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import Learn_straight_env
import loc2glb_env
# Hyper Parametersenv=Learn_straight_env.Run()
N_ACTIONS = env.n_actions
N_STATES = env.n_features
MEMORY_CAPACITY = 100 #记忆容量
BATCH_SIZE = 20
# ENV_A_SHAPE = 0 if isinstance(env.action_space.sample(), int) else env.action_space.sample().shape# Deep Q Network off-policy
class Net(nn.Module):def __init__(self, ):super(Net, self).__init__()self.fc1 = nn.Linear(N_STATES, 4)self.fc1.weight.data.normal_(0, 0.1)   # initializationself.out = nn.Linear(4, N_ACTIONS)self.out.weight.data.normal_(0, 0.1)   # initializationdef forward(self, x):x = self.fc1(x)  #网络前向传播x = F.relu(x)actions_value = self.out(x)return actions_valueclass DQN(object):def __init__(self):self.eval_net, self.target_net = Net(), Net()self.learn_step_counter = 0                                     # for target updatingself.LR = 0.1  # learning rateself.memory_counter = 0                                         # for storing memoryself.memory = np.zeros((MEMORY_CAPACITY, N_STATES * 2 + 2))     # initialize memoryself.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=self.LR)self.loss_func = nn.MSELoss()self.EPSILON = 0.5  # greedy policyself.GAMMA = 0.9  # reward discountself.TARGET_REPLACE_ITER = 1  # target update frequencydef choose_action(self, x):x = torch.unsqueeze(torch.FloatTensor(x), 0)# input only one sampleif np.random.uniform() < self.EPSILON:   # greedyactions_value = self.eval_net.forward(x)action = torch.max(actions_value, 1)[1].data.numpy()# action = action[0] if ENV_A_SHAPE == 0 else action.reshape(ENV_A_SHAPE)  # return the argmax indexelse:   # randomaction = np.random.randint(0, N_ACTIONS)# action = action if ENV_A_SHAPE == 0 else action.reshape(ENV_A_SHAPE)return actiondef store_transition(self, s, a, r, s_):transition = np.hstack((s, [a, r], s_))# replace the old memory with new memoryindex = self.memory_counter % MEMORY_CAPACITYself.memory[index, :] = transitionself.memory_counter += 1def learn(self):# target parameter updateif self.learn_step_counter % self.TARGET_REPLACE_ITER == 0:self.target_net.load_state_dict(self.eval_net.state_dict())self.learn_step_counter += 1# sample batch transitionssample_index = np.random.choice(MEMORY_CAPACITY, BATCH_SIZE)b_memory = self.memory[sample_index, :]b_s = torch.FloatTensor(b_memory[:, :N_STATES])b_a = torch.LongTensor(b_memory[:, N_STATES:N_STATES+1].astype(int))b_r = torch.FloatTensor(b_memory[:, N_STATES+1:N_STATES+2])b_s_ = torch.FloatTensor(b_memory[:, -N_STATES:])# q_eval w.r.t the action in experienceq_eval = self.eval_net(b_s).gather(1, b_a)  # shape (batch, 1)q_next = self.target_net(b_s_).detach()     # detach from graph, don't backpropagateq_target = b_r + self.GAMMA * q_next.max(1)[0].view(BATCH_SIZE, 1)   # shape (batch, 1)loss = self.loss_func(q_eval, q_target)self.optimizer.zero_grad()loss.backward()self.optimizer.step()print('\nCollecting experience...')
EPI=200
dqn=DQN()
for i_episode in range(EPI):s = env.reset()ep_r = 0dqn.EPSILON=0.6+0.3*i_episode/EPI  #学习参数的调整dqn.LR=0.1-0.09*i_episode/EPIdqn.GAMMA=0.5dqn.optimizer = torch.optim.Adam(dqn.eval_net.parameters(), lr=dqn.LR)dqn.TARGET_REPLACE_ITER=1+int(10*i_episode/EPI)while True:env.render()a = dqn.choose_action(s)# take actions_, r, done, outside = env.step(a)dqn.store_transition(s, a, r, s_)ep_r += rif dqn.memory_counter > MEMORY_CAPACITY:dqn.learn()if done:print('Ep: ', i_episode,'reward',ep_r)if done:print(s_)breaks = s_# train eval network

最后训练出的机器人可以走直线了,虽然我用的网络隐层只有一层,神经元数也特别少,但也总算是用强化学习完成了机器人自主走直线这一功能。

强化学习训练Agent走直线相关推荐

  1. 田渊栋的2021年终总结:多读历史!历史就是一个大规模强化学习训练集

      视学算法报道   作者:田渊栋 编辑:好困 LRS [新智元导读]田渊栋博士最近又在知乎上发表了他的2021年度总结,成果包括10篇Paper和1部长篇小说及续集.文章中还提到一些研究心得和反思, ...

  2. 在Unity环境中使用强化学习训练Donkey Car(转译)

    在Unity环境中使用强化学习训练Donkey Car 1.Introduction 简介 2. Train Donkey Car with Reinforcement Learning 使用强化学习 ...

  3. 利用AI强化学习训练50级比卡超单挑70级超梦!

    强化学习(Reinforcement Learning, RL),是机器学习的范式和方法论之一,用于描述和解决智能体(agent)在与环境的交互过程中通过学习策略以达成回报最大化或实现特定目标的问题. ...

  4. 基于深度强化学习训练《街头霸王·二:冠军特别版》通关关底 BOSS -智能 AI 代理项目上手

    文章目录 SFighterAI项目简介 实现软件环境 项目文件结构 运行指南 环境配置 验证及调整gym环境: gym-retro 游戏文件夹 错误提示及解决 Could not initialize ...

  5. 【经验】深度强化学习训练与调参技巧

    来源:知乎(https://zhuanlan.zhihu.com/p/482656367) 作者:岳小飞 天下苦 RL 久矣,其中最苦的地方莫过于训练和调参了,人人欲"调"之而后快 ...

  6. 上海交大开源训练框架,支持大规模基于种群多智能体强化学习训练

    机器之心专栏 作者:上海交大和UCL多智能体强化学习研究团队 基于种群的多智能体深度强化学习(PB-MARL)方法在星际争霸.王者荣耀等游戏AI上已经得到成功验证,MALib 则是首个专门面向 PB- ...

  7. 谷歌造了个虚拟足球场,让AI像打FIFA一样做强化学习训练丨开源有API

    郭一璞 发自 苏州街  量子位 报道 | 公众号 QbitAI 除了下棋.雅达利游戏和星际,AI终于把"魔爪"伸向了粉丝众多的体育竞技活动: 足球. 今天,谷歌开源了足球模拟环境G ...

  8. 深度学习8-加速强化学习训练的方法

    # 2022.6.2 rl-9 ### 加速强化学习训练的方法 ▪  使用第8章的Pong环境,并试图尽可能快地解决它. ▪  使用完全相同的硬件,逐步解决Pong问题并将速度提升3.5倍. ▪  讨 ...

  9. MedicalGPT:基于LLaMA-13B的中英医疗问答模型(LoRA)、实现包括二次预训练、有监督微调、奖励建模、强化学习训练[LLM:含Ziya-LLaMA]。

    项目设计集合(人工智能方向):助力新人快速实战掌握技能.自主完成项目设计升级,提升自身的硬实力(不仅限NLP.知识图谱.计算机视觉等领域):汇总有意义的项目设计集合,助力新人快速实战掌握技能,助力用户 ...

最新文章

  1. 为什么 Python被Google选为TensorFlow的开发语言呢?使用 Python比C++语言进行机器学习有什么优势?
  2. Java开发必会的反编译知识
  3. iis6 配置python CGI
  4. 【Excel】使用VLOOKUP+IF实现多列条件匹配查询
  5. VB讲课笔记02:VB程序开发环境
  6. 随想录(程序员怎么用英文查资料)
  7. html¥符号代码是什么,html怎么特殊符号赋
  8. 网吧服务器维护工具,某某网吧专用维护工具(网吧维护管理助手)V5.1 最新版
  9. python读坐标像素_python如何读取像素值
  10. tp php websocket教程,tp6 websocket方法详解
  11. SpringCloud Day05---服务网关(Gateway)
  12. 厦门大学904数据结构与机器学习资料与辅导
  13. ‘VBE6EXT.OLB’ 不能被加载
  14. LSVGlobal Mapper应用----影像裁剪
  15. 王者荣耀服务器维护9月27,王者荣耀9月27日更新维护公告 修复夏洛特技能bug等...
  16. 微信小程序盲盒系统源码 附带教程
  17. ES与传统数据库,为什么用ES?
  18. expdp impdp
  19. 通俗的解释一下什么是 RPC 框架?
  20. 蒙特梭利素材-【彩色圆柱体1】蒙氏教具 三段卡 蒙氏素材

热门文章

  1. python儿童编程培训
  2. 使用vuepress搭建一个完全免费的个人网站
  3. 网上书城—登录、书籍管理
  4. Excel教程之什么是好的仪表板工具
  5. 兔子繁殖为例 c语言,用斐波那契数列解答兔子的繁殖
  6. PTA求100以内的素数
  7. CPU 基本工作原理和概念
  8. 无线路由器经常掉线断网的可能的原因
  9. GBase 8a管理集群gcware的日志-vote leader、flower、candidate部分
  10. thinkphp3.2 微信 Native扫码支付功能