效果良好!构造一个输入速度的神经网络,以DQN方式实现小游戏的自动控制
在之前的文章中,我们做了如下工作:
- 如何设计一个类flappy-bird小游戏:【python实战】使用pygame写一个flappy-bird类小游戏 | 设计思路+项目结构+代码详解|新手向
- DFS 算法是怎么回事,我是怎么应用于该小游戏的:【深度优先搜索】一个实例+两张动图彻底理解DFS|DFS与BFS的区别|用DFS自动控制我们的小游戏
- BFS 算法是怎么回事,我是怎么应用于该小游戏的:【广度优先搜索】一个实例+两张动图彻底理解BFS|思路+代码详解|用DFS自动控制我们的小游戏
- 强化学习为什么有用?其基本原理:无需公式或代码,用生活实例谈谈AI自动控制技术“强化学习”算法框架
- 构建一个简单的卷积神经网络,使用DRL框架tianshou匹配DQN算法
- 构造一个简单的神经网络,以DQN方式实现小游戏的自动控制
构造一个输入速度的神经网络,实现 DQN
本文涉及的 .py
文件有:
DQN_train/gym_warpper.py
DQN_train/dqn_train3.py
DQN_train/dqn_render3.py
requirements
tianshou
pytorch > 1.40
gym
继续训练与测试
在本项目地址中,你可以使用如下文件对我训练的模型进行测试,或者继续训练。
继续训练该模型
python DQN_train/dqn_train3.py
我已经训练了 40 次(每次5个epoch),输入上述命令,你将开始第 41 次训练,如果不使用任务管理器强制停止,计算机将一直训练下去,并自动保存每一代的权重。
查看效果
python DQN_train/dqn_render3.py 3
注意参数 3 ,输入 3 代表使用训练 3 次后的权重。
效果如图:
我保留了该模型的所有历史权重。你还可以输入参数:1-40,查看历代神经网络的表现。如果你继续训练了模型,你可以输入更大的参数,如 41 。
输入 10 则代表使用训练 10 次后的权重:
python DQN_train/dqn_render3.py 25
效果如图:
输入 30 则代表使用训练 30 次后的权重:
python DQN_train/dqn_render3.py 30
效果如图:
封装交互环境
上一个模型的效果并不好,这个模型的表现却很值得称道。我对这个模型做出了什么改进呢?
事件 | 奖励 |
---|---|
动作后碰撞障碍物、墙壁 | -1 |
动作后无事发生 | 0.0001 |
动作后得分 | 1 |
在第一层滞留过久(超过500步) | -10 |
可以看出,我将动作后无事发生
的奖励从 0.1 降低到了 -1 ,是为了:
- 突出
动作后得分
这项的奖励; - 如此,智能体第一次得分后,会很“欣喜地发现”上升一层的快乐远远大于在第一层苟命的快乐。
此外,如果智能体在第一层滞留过久
,也是会受到 -10 的惩罚的:
- 这是为了告诉智能体,在第一层过久是不被鼓励的;
- 因为状态是链式的,因此最后的惩罚会回溯分配到之前的“苟命”策略上。
封装代码在 gym_wrapper.py 中,使用类 AmazingBrickEnv3
。
强化学习机制与神经网络的构建
上节中,我们将 2 帧的数据输入到线性层中,效果并不理想。我进一步帮助机器提取了信息
,并且预处理了数据
:
- 不再将巨大的 2 帧数据输入到网络中;
- 取而代之的是,当前状态的速度向量
(velx, vely)
; - 再加上
玩家xy坐标
、左障碍物右上顶点xy坐标
、右障碍物左上顶点xy坐标
、4个障碍方块的左上顶点的xy坐标
(共14个数); - 如此,输入层只有 16 个神经网即可,且每 1 帧做一次决策。
我还放慢了 epsilon (探索概率)的收敛速度,让智能体更多地去探索动作,不局限在局部最优解中。
此外,我对输入数据进行了归一化处理比如,玩家的坐标 x, y 分别除以了屏幕的 宽、高。从结果和训练所需的代数更少来看,我认为这对于机器学习有极大的帮助。
线性神经网络的构建
class Net(nn.Module):def __init__(self):super().__init__()self.fc1 = nn.Linear(16, 128)self.fc2 = nn.Linear(128, 256)self.fc3 = nn.Linear(256, 128)self.fc4 = nn.Linear(128, 3)def forward(self, obs, state=None, info={}):if not isinstance(obs, torch.Tensor):obs = torch.tensor(obs, dtype=torch.float)x = F.relu(self.fc1(obs))x = F.relu(self.fc2(x))x = F.relu(self.fc3(x))x = self.fc4(x)return x, state
如上,共四层线性网络。
记录训练的微型框架
为了保存训练好的权重,且在需要时可以暂停并继续训练,我新建了一个.json
文件用于保存训练数据。
dqn2_path = osp.join(path, 'DQN_train/dqn_weights/')if __name__ == '__main__':try:with open(dqn3_path + 'dqn3_log.json', 'r') as f:jlist = json.load(f)log_dict = jlist[-1]round = log_dict['round']policy.load_state_dict(torch.load(dqn3_path + 'dqn3round_' + str(int(round)) + '.pth'))del jlistexcept FileNotFoundError as identifier:print('\n\nWe shall train a bright new net.\n')# 第一次训练时,新建一个 .json 文件# 声明一个列表# 以后每次写入 json 文件,向列表新增一个字典对象with open(dqn3_path + 'dqn3_log.json', 'a+') as f:f.write('[]')round = 0while True:round += 1print('\n\nround:{}\n\n'.format(round))result = ts.trainer.offpolicy_trainer(policy, train_collector, test_collector,max_epoch=max_epoch, step_per_epoch=step_per_epoch,collect_per_step=collect_per_step,episode_per_test=30, batch_size=64,# 如下,每新一轮训练才更新 epsilontrain_fn=lambda e: policy.set_eps(0.1 / round),test_fn=lambda e: policy.set_eps(0.05 / round), writer=None)print(f'Finished training! Use {result["duration"]}')torch.save(policy.state_dict(), dqn3_path + 'dqn3round_' + str(int(round)) + '.pth')policy.load_state_dict(torch.load(dqn3_path + 'dqn3round_' + str(int(round)) + '.pth'))log_dict = {}log_dict['round'] = roundlog_dict['last_train_time'] = datetime.datetime.now().strftime('%y-%m-%d %I:%M:%S %p %a')log_dict['best_reward'] = result['best_reward']with open(dqn3_path + 'dqn3_log.json', 'r') as f:"""dqn3_log.json should be inited as []"""jlist = json.load(f)jlist.append(log_dict)with open(dqn3_path + 'dqn3_log.json', 'w') as f:json.dump(jlist, f)del jlist
DQN
import os.path as osp
import sys
dirname = osp.dirname(__file__)
path = osp.join(dirname, '..')
sys.path.append(path)from amazing_brick.game.wrapped_amazing_brick import GameState
from amazing_brick.game.amazing_brick_utils import CONST
from DQN_train.gym_wrapper import AmazingBrickEnv3import tianshou as ts
import torch, numpy as np
from torch import nn
import torch.nn.functional as F
import json
import datetimetrain_env = AmazingBrickEnv3()
test_env = AmazingBrickEnv3()state_shape = 16
action_shape = 1net = Net()
optim = torch.optim.Adam(net.parameters(), lr=1e-3)'''args for rl'''
estimation_step = 3
max_epoch = 5
step_per_epoch = 300
collect_per_step = 50policy = ts.policy.DQNPolicy(net, optim,discount_factor=0.9, estimation_step=estimation_step,use_target_network=True, target_update_freq=320)train_collector = ts.data.Collector(policy, train_env, ts.data.ReplayBuffer(size=2000))
test_collector = ts.data.Collector(policy, test_env)
采用这种方式获得了不错的效果,在第 40 代训练后(共 40 * 5 * 300 = 6000 个 step),智能体已经能走 10 层左右。
相信继续的迭代会获得更好的成绩。
项目地址:https://github.com/PiperLiu/Amazing-Brick-DFS-and-DRL
本项目的说明文件到此结束。感谢你的阅读,欢迎提交更好的方案与意见!
效果良好!构造一个输入速度的神经网络,以DQN方式实现小游戏的自动控制相关推荐
- 构造一个简单的神经网络,以DQN方式实现小游戏的自动控制
在之前的文章中,我们做了如下工作: 如何设计一个类flappy-bird小游戏:[python实战]使用pygame写一个flappy-bird类小游戏 | 设计思路+项目结构+代码详解|新手向 DF ...
- 【Python游戏】基于pygame实现的一个Dino Rush 恐龙宝贝冲冲冲的小游戏 | 附源码
前言 halo,包子们晚上好 很久没有更新啦,主要是小编这边最近有点小忙 今天给大家整一个Dino Rush 恐龙宝贝冲冲冲的小游戏 还是一个比较记经典的小游戏,还记这可谷歌浏览器上没有网也能打发时间 ...
- 速度挑战 - 2小时完成HTML5拼图小游戏
概述 我用lufylegend.js开发了第一个HTML5小游戏--拼图游戏,还写了篇博文来炫耀一下:HTML5小游戏<智力大拼图>发布,挑战你的思维风暴. 详细 代码下载:http:// ...
- 【日常练习】一个可以试玩五次的猜数小游戏
练习要求: 请设计一个猜数字小游戏,可以试玩5次.试玩结束之后,给出提示:游戏试玩结束,请付费. 小游戏代码 import java.util.Scanner;public class GuessNu ...
- 主观不可见 一个非常有创意的动作解谜Flash小游戏
http://www.newgrounds.com/portal/view/480006 http://www.deadwhale.com/play.php?game=874 今天玩到了一个超级有意思 ...
- c#输入三个数选出最大的_C#写一个输入三个整数,按大到小顺序输出的小程序...
满意答案 pf48154968 2013.07.05 采纳率:52% 等级:12 已帮助:9321人 int[] arr = new int[3]; for (int i = 0; i < ...
- 实现用java做一个简易版《羊了个羊》小游戏(附源代码)
该项目是跟着这个b站视频一步一步写出来的,初学java有些地方我看不是很明白,但是讲解很仔细,大家可以看原视频,我没有添加背景音乐和背景图片,做出来的效果也勉勉强强. 代码已经上传到github上了, ...
- c语言编程游戏界面,震惊!!!一个关于c语言图形化界面编程的小游戏-Go语言中文社区...
关于C语言的图形化界面编程 第一个小程序<飞翔的小鸟> 效果图 本人也是小白,大家轻点喷!!!! 下面是源码 作者: @追风 #include #include #include #inc ...
- 震惊!!!一个关于c语言图形化界面编程的小游戏
关于C语言的图形化界面编程 第一个小程序<飞翔的小鸟> 效果图 本人也是小白,大家轻点喷!!!! 下面是源码 作者: @追风#include<graphics.h> #incl ...
最新文章
- MyGeneration学习笔记(1) : 使用MyGeneration生成存储过程和数据访问层代码
- 按属性对自定义对象的ArrayList进行排序
- Ubuntu 14 配置Android Studio的快捷启动方式
- Android窗口管理服务WindowManagerService切换Activity窗口(App Transition)的过程分析
- python print error 空_python笔记37:10分钟掌握异常处理,再也不担心程序挂了
- 微信小程序测试的策略和注意事项
- 用虚拟网卡(softether)共享局域网资源
- 调整jvm参数_JVM源码分析之MetaspaceSize和MaxMetaspaceSize的区别
- ios uitableview 积累
- 博客样式-bbsmax4风格V0.2
- Firefox常用扩展
- 公众号对接电影 输入电影名字即可
- 基础知识 字节、KB、MB、GB 之间的换算关系
- 关于人工智能的天马行空
- mysql数据库的超级管理员名称_MySQL数据库的超级管理员名称是______
- 产品与服务最大的卖点,可能是销售最大的坑!
- java实现验证邮箱有效性
- wav 转换到 flac
- percona-tool文档说明(5)- 复制类
- 【网络】计算机网络重点知识总结
热门文章
- 浙江计算机三级考试单片机试题,历年浙江省计算机三级单片机
- PlSqlDev中执行INSERT SQL语句包含符号导致数据异常
- 深入理解前端跨域问题的解决方案——前端面试
- Oracle存储过程的异常处理
- junit单元测试不通过报documentationPluginsBootstrapper相关异常
- 移动硬盘新建选项消失、不能新建文件夹和文件的解决方案
- 打开 VMware Workstation 14 Pro 中的虚拟机出现 “此主机支持 Intel VT-x,但 Intel VT-x 处于禁用状态” 解决方法
- win11开始菜单如何分组 Windows11开始菜单进行分组的设置方法
- linux用c++获取mac地址,网卡地址,网口地址,网卡序号ip地址,不使用 ioctl(sock, SIOCGIFCONF, ifc)获取网络接口名称,这个接口有时会返回-1获取不到,换方法获取
- activiti流程变量