构造一个简单的神经网络,以DQN方式实现小游戏的自动控制
在之前的文章中,我们做了如下工作:
- 如何设计一个类flappy-bird小游戏:【python实战】使用pygame写一个flappy-bird类小游戏 | 设计思路+项目结构+代码详解|新手向
- DFS 算法是怎么回事,我是怎么应用于该小游戏的:【深度优先搜索】一个实例+两张动图彻底理解DFS|DFS与BFS的区别|用DFS自动控制我们的小游戏
- BFS 算法是怎么回事,我是怎么应用于该小游戏的:【广度优先搜索】一个实例+两张动图彻底理解BFS|思路+代码详解|用DFS自动控制我们的小游戏
- 强化学习为什么有用?其基本原理:无需公式或代码,用生活实例谈谈AI自动控制技术“强化学习”算法框架
- 构建一个简单的卷积神经网络,使用DRL框架tianshou匹配DQN算法
构造一个简单的卷积神经网络,实现 DQN
本文涉及的 .py
文件有:
DQN_train/gym_warpper.py
DQN_train/dqn_train2.py
DQN_train/dqn_render2.py
requirements
tianshou
pytorch > 1.40
gym
继续训练与测试
在本项目地址中,你可以使用如下文件对我训练的模型进行测试,或者继续训练。
继续训练该模型
python DQN_train/dqn_train2.py
如图,我已经训练了 53 次(每次10个epoch),输入上述命令,你将开始第 54 次训练,如果不使用任务管理器强制停止,计算机将一直训练下去,并自动保存最新一代的权重。
查看效果
python DQN_train/dqn_render2.py 0
注意参数 0 ,输入 0 代表使用最新的权重。
效果如图:
上图中,可以看到我们的 AI 已经学会了一些“知识”:比如如何前往下一层;它还需要多加练习,以学会如何避开这些小方块构成的障碍。
此外,我保留了一些历史权重。你还可以输入参数:7, 10, 13, 21, 37, 40, 47,查看训练次数较少时,神经网络的表现。
封装交互环境
强化学习算法有效,很大程度上取决于奖励机制设计的是否合理。
事件 | 奖励 |
---|---|
动作后碰撞障碍物、墙壁 | -1 |
动作后无事发生 | 0.1 |
动作后得分 | 1 |
封装代码在 gym_wrapper.py 中,使用类 AmazingBrickEnv2
。
强化学习机制与神经网络的构建
上节中,我们将 2 帧的数据输入到卷积层中,目的是:
- 让卷积层提取出“障碍物边缘”与“玩家位置”;
- 让 2 帧数据反映出“玩家速度”信息。
为了节省计算资源,同时加快训练速度,我们人为地替机器提取这些信息:
- 不再将巨大的 2 帧“图像矩阵”输入到网络中;
- 取而代之的是,输入 2 帧的位置信息;
- 即输入
玩家xy坐标
、左障碍物右上顶点xy坐标
、右障碍物左上顶点xy坐标
、4个障碍方块的左上顶点的xy坐标
(共14个数); - 如此, 2 帧数据共 28 个数字,我们的神经网络输入层只有 28 个神经元,比上一个模型(25600)少了不止一个数量级。
我设计的机制为:
- 每 2 帧进行一次动作决策;
- 状态的描述变量为 2 帧的图像。
线性神经网络的构建
class Net(nn.Module):def __init__(self):super().__init__()self.fc1 = nn.Linear(28, 128)self.fc2 = nn.Linear(128, 256)self.fc3 = nn.Linear(256, 128)self.fc4 = nn.Linear(128, 3)def forward(self, obs, state=None, info={}):if not isinstance(obs, torch.Tensor):obs = torch.tensor(obs, dtype=torch.float)x = F.relu(self.fc1(obs))x = F.relu(self.fc2(x))x = F.relu(self.fc3(x))x = self.fc4(x)return x, state
如上,共四层线性网络。
记录训练的微型框架
为了保存训练好的权重,且在需要时可以暂停并继续训练,我新建了一个.json
文件用于保存训练数据。
dqn2_path = osp.join(path, 'DQN_train/dqn_weights/')if __name__ == '__main__':round = 0try:# 此处 policy 采用 DQN# 具体 DQN 构建方法见下文policy.load_state_dict(torch.load(dqn2_path + 'dqn2.pth'))lines = []with open(dqn2_path + 'dqn2_log.json', "r") as f:for line in f.readlines():cur_dict = json.loads(line)lines.append(cur_dict)log_dict = lines[-1]print(log_dict)round = log_dict['round']del linesexcept FileNotFoundError as identifier:print('\n\nWe shall train a bright new net.\n')passwhile True:round += 1print('\n\nround:{}\n\n'.format(round))result = ts.trainer.offpolicy_trainer(policy, train_collector, test_collector,max_epoch=max_epoch, step_per_epoch=step_per_epoch,collect_per_step=collect_per_step,episode_per_test=30, batch_size=64,train_fn=lambda e: policy.set_eps(0.1 * (max_epoch - e) / round),test_fn=lambda e: policy.set_eps(0.05 * (max_epoch - e) / round), writer=None)print(f'Finished training! Use {result["duration"]}')torch.save(policy.state_dict(), dqn2_path + 'dqn2.pth')policy.load_state_dict(torch.load(dqn2_path + 'dqn2.pth'))log_dict = {}log_dict['round'] = roundlog_dict['last_train_time'] = datetime.datetime.now().strftime('%y-%m-%d %I:%M:%S %p %a')log_dict['result'] = json.dumps(result)with open(dqn2_path + 'dqn2_log.json', "a+") as f:f.write('\n')json.dump(log_dict, f)
DQN
import os.path as osp
import sys
dirname = osp.dirname(__file__)
path = osp.join(dirname, '..')
sys.path.append(path)from amazing_brick.game.wrapped_amazing_brick import GameState
from amazing_brick.game.amazing_brick_utils import CONST
from DQN_train.gym_wrapper import AmazingBrickEnv2import tianshou as ts
import torch, numpy as np
from torch import nn
import torch.nn.functional as F
import json
import datetimetrain_env = AmazingBrickEnv2()
test_env = AmazingBrickEnv2()state_shape = 28
action_shape = 1net = Net()
optim = torch.optim.Adam(net.parameters(), lr=1e-3)'''args for rl'''
estimation_step = 3
max_epoch = 10
step_per_epoch = 300
collect_per_step = 50policy = ts.policy.DQNPolicy(net, optim,discount_factor=0.9, estimation_step=estimation_step,use_target_network=True, target_update_freq=320)train_collector = ts.data.Collector(policy, train_env, ts.data.ReplayBuffer(size=2000))
test_collector = ts.data.Collector(policy, test_env)
如图,采用这种方式训练了 53 个循环(共计 53 * 10 * 300 = 159000 个 step)效果还是一般。
下一节(也是本项目的最后一节),我们将探讨线性网络解决这个控制问题的相对成功的方案。
项目地址:https://github.com/PiperLiu/Amazing-Brick-DFS-and-DRL
构造一个简单的神经网络,以DQN方式实现小游戏的自动控制相关推荐
- 效果良好!构造一个输入速度的神经网络,以DQN方式实现小游戏的自动控制
在之前的文章中,我们做了如下工作: 如何设计一个类flappy-bird小游戏:[python实战]使用pygame写一个flappy-bird类小游戏 | 设计思路+项目结构+代码详解|新手向 DF ...
- 一个简单的神经网络,三种常见的神经网络
BP人工神经网络方法 (一)方法原理人工神经网络是由大量的类似人脑神经元的简单处理单元广泛地相互连接而成的复杂的网络系统.理论和实践表明,在信息处理方面,神经网络方法比传统模式识别方法更具有优势. 人 ...
- 使用 PyTorch 数据读取,JAX 框架来训练一个简单的神经网络
使用 PyTorch 数据读取,JAX 框架来训练一个简单的神经网络 本文例程部分主要参考官方文档. JAX简介 JAX 的前身是 Autograd ,也就是说 JAX 是 Autograd 升级版本 ...
- filter hid_如何构造一个简单的USB过滤驱动程序
本文分三部分来介绍如何构造一个简单的USB过滤驱动程序,包括"基本原理"."程序的实现"."使用INF安装".此文的目的在于希望读者了解基本 ...
- 基于PyTorch,如何构建一个简单的神经网络
本文为 PyTorch 官方教程中:如何构建神经网络.基于 PyTorch 专门构建神经网络的子模块 torch.nn 构建一个简单的神经网络. 完整教程运行 codelab→ https://ope ...
- tensorflow学习笔记二——建立一个简单的神经网络拟合二次函数
tensorflow学习笔记二--建立一个简单的神经网络 2016-09-23 16:04 2973人阅读 评论(2) 收藏 举报 分类: tensorflow(4) 目录(?)[+] 本笔记目的 ...
- 实现一个简单的神经网络
实现一个简单的神经网络 import numpy as npdef tanh(x): return np.tanh(x)def tanh_deriv(x): return 1.0 - np.tanh( ...
- 构造一个简单的操作系统内核,详解进程切换细节
(1)基本功能介绍 如题,本文将介绍如何构造一个简单的操作系统内核(基于内核版本3.9.4 ).它有以下功能: 1:进程的管理 2:进程的初始化 3 : 进程基于时间片的调度 (2)实操步骤 1 安装 ...
- 《Linux内核分析》 第三周 构造一个简单的Linux系统MenuOS
Linux内核分析 第三周 构造一个简单的Linux系统MenuOS 张嘉琪 原创作品转载请注明出处 <Linux内核分析>MOOC课程http://mooc.study.163.com/ ...
最新文章
- 【C/C++ string】之strcpy函数
- 如何让 Hyper-V 和 VMware 虚拟机软件共存?
- 2009年4月计算机网络原理,全国2009年4月高等教育自学考试计算机网络原理
- nowcoderD Xieldy And His Password
- Redis之Redis的事务
- 根据表格中的数据长度自动调整表格宽度DBGrid
- FCKEditor v2.6.3 最新版-ASP.NET 演示程序
- 舵机控制原理/舵机内部电路原理
- vin码识别(车架号识别)的工具
- python 使用公司邮箱发邮件_python3使用腾讯企业邮箱发送邮件的实例
- ICCV 2017 论文解读集锦
- 论QQ如何发大菜狗表情
- 宝塔一键安装部署tipask登录出现错误:The email must be at least 8 characters怎么回事
- 动画对象(lv_anim_t)的应用
- Weakly Guiding Fibers(弱导光纤)
- patindex函数的用法介绍
- 解决win7 中powershell挖矿占用CPU100%
- PyGmae:有限状态机实践(十二)
- Windows磁盘管理概述
- Android使用MediaCodec硬解码播放H264格式视频文件
热门文章
- 【Oracle】详解10053事件
- 【PL/SQL】PL/SQL语言基础
- 解决RStudio(非conda安装)在使用Anaconda中的R环境时,缺失“ libbz2-1.dll ”而不能正常启动问题
- 解决Ajax异步请求中传数组参数,后台无法接收问题
- boost::bind with ros action,ros中SimpleActionServer用boost::bind绑定多个参数
- Makefile:Makefile 使用总结
- Makefile:GCC CFLAGS变量和LDFLAGS变量
- 关于现代计算机的知识,从资本经济到知识经济:现代计算机的知识革命
- CSS 的复合选择器
- 经典手眼标定算法之Tsai-Lenz的OpenCV实现