在之前的文章中,我们做了如下工作:

  • 如何设计一个类flappy-bird小游戏:【python实战】使用pygame写一个flappy-bird类小游戏 | 设计思路+项目结构+代码详解|新手向
  • DFS 算法是怎么回事,我是怎么应用于该小游戏的:【深度优先搜索】一个实例+两张动图彻底理解DFS|DFS与BFS的区别|用DFS自动控制我们的小游戏
  • BFS 算法是怎么回事,我是怎么应用于该小游戏的:【广度优先搜索】一个实例+两张动图彻底理解BFS|思路+代码详解|用DFS自动控制我们的小游戏
  • 强化学习为什么有用?其基本原理:无需公式或代码,用生活实例谈谈AI自动控制技术“强化学习”算法框架

本节开始,我们将讨论如何用深度强化学习实现小游戏的自动控制。

构造一个简单的卷积神经网络,实现 DQN

本文涉及的 .py 文件有:

DQN_train/gym_warpper.py
DQN_train/dqn_train.py

requirements

tianshou
pytorch > 1.40
gym
openCV

封装交互环境

强化学习算法有效,很大程度上取决于奖励机制设计的是否合理。

事件 奖励
动作后碰撞障碍物、墙壁 -1
动作后无事发生 0.1
动作后得分 1

封装代码在 gym_wrapper.py 中,使用类 AmazingBrickEnv

强化学习机制与神经网络的构建

我设计的机制为:

  • 每 2 帧进行一次动作决策;
  • 状态的描述变量为 2 帧的图像。

对于每帧的图像处理如下。

# 首先把图像转换成 RGB 矩阵
pygame.surfarray.array3d(pygame.display.get_surface())
# 使用 openCV 将 RGB 矩阵矩阵转换成 100*100 的灰度0-1矩阵
x_t = cv2.cvtColor(cv2.resize(obs, (100, 100)), cv2.COLOR_BGR2GRAY)

最后使用 np.stack() 将两帧数据合并,我们就得到了一个 2 通道的图像矩阵数据。

卷积神经网络的构建

class Net(nn.Module):def __init__(self):super().__init__()# nn.Conv2d(通道数, 输出通道数, 卷积核大小, 步长)self.conv1 = nn.Conv2d(2, 32, 8, 4, padding=1)self.conv2 = nn.Conv2d(32, 64, 4, 2, padding=1)self.conv3 = nn.Conv2d(64, 64, 3, 1, padding=1)self.fc1 = nn.Linear(64, 64)self.fc2 = nn.Linear(64, 3)def forward(self, obs, state=None, info={}):if not isinstance(obs, torch.Tensor):obs = torch.tensor(obs, dtype=torch.float)# turn NHWC to NCHWobs = obs.permute(0, 3, 1, 2)x = F.max_pool2d(F.relu(self.conv1(obs)), 2)x = F.max_pool2d(F.relu(self.conv2(x)), 2)x = F.max_pool2d(F.relu(self.conv3(x)), 2)x = x.view(-1, 64)x = F.relu(self.fc1(x))x = self.fc2(x)return x, state

神经网络解构如上述代码。

卷积训练过程如上图右。

DQN 构建

import os.path as osp
import sys
dirname = osp.dirname(__file__)
path = osp.join(dirname, '..')
sys.path.append(path)from amazing_brick.game.wrapped_amazing_brick import GameState
from amazing_brick.game.amazing_brick_utils import CONST
from DQN_train.gym_wrapper import AmazingBrickEnv# 使用了清华开源深度强化学习框架
import tianshou as ts
import torch, numpy as np
from torch import nn
import torch.nn.functional as Ftrain_env = AmazingBrickEnv(fps=1000)
test_env = AmazingBrickEnv(fps=1000)state_shape = (80, 80, 4)
action_shape = 1net = Net()
optim = torch.optim.Adam(net.parameters(), lr=1e-3)policy = ts.policy.DQNPolicy(net, optim,discount_factor=0.9, estimation_step=3,use_target_network=True, target_update_freq=320)train_collector = ts.data.Collector(policy, train_env, ts.data.ReplayBuffer(size=200))
test_collector = ts.data.Collector(policy, test_env)result = ts.trainer.offpolicy_trainer(policy, train_collector, test_collector,max_epoch=10, step_per_epoch=1000, collect_per_step=10,episode_per_test=100, batch_size=64,train_fn=lambda e: policy.set_eps(0.1),test_fn=lambda e: policy.set_eps(0.05), writer=None)
print(f'Finished training! Use {result["duration"]}')

由于我还没有开始系统学习 NNs ,不了解 CNNs ,因此不是很信任自己建立的这个网络,没有投入资源与时间训练。

下两节(也是本项目的最后两节),我们将探讨线性网络解决这个控制问题的,其中将涉及到简单的建模与奖励机制设计讨论,会很有趣。

项目地址:https://github.com/PiperLiu/Amazing-Brick-DFS-and-DRL

构建一个简单的卷积神经网络,使用DRL框架tianshou匹配DQN算法相关推荐

  1. atm取款机的简单程序代码_LeNet:一个简单的卷积神经网络PyTorch实现

    前两篇文章分别介绍了卷积层和池化层,卷积和池化是卷积神经网络必备的两大基础.本文我们将介绍一个早期用来识别手写数字图像的卷积神经网络:LeNet[1].LeNet名字来源于论文的第一作者Yann Le ...

  2. 神经网络学习笔记2.2 ——用Matlab写一个简单的卷积神将网络图像分类器

    配套视频讲解 10分钟学会matlab实现cnn图像分类_哔哩哔哩_bilibili 10分钟学会matlab实现cnn图像分类 整体代码 链接:https://pan.baidu.com/s/1bt ...

  3. ubuntu16.04 简单的卷积神经网络 cpu和gpu训练时间对比

    我的电脑配置: cpu:i5-4200H gpu:gtx 950M 昨天测试了训练一般的神经网络使用cpu和gpu各自的速度,使用gpu比使用cpu大概能节省42%的时间,当时我以为这么个程度已经很不 ...

  4. 基于PyTorch,如何构建一个简单的神经网络

    本文为 PyTorch 官方教程中:如何构建神经网络.基于 PyTorch 专门构建神经网络的子模块 torch.nn 构建一个简单的神经网络. 完整教程运行 codelab→ https://ope ...

  5. 简单的卷积神经网络,实现手写英文字母识别

    简单的卷积神经网络,实现手写英文字母识别 1 搭建Python运行环境(建议用Anaconda),自学Python程序设计 安装Tensorflow.再安装Pycharm等环境.(也可用Pytorch ...

  6. 简单的卷积神经网络编程,卷积神经网络算法代码

    关于AlphaGo的一些错误说法 最近看了一些关于alphago围棋对弈的一些人工智能的文章,尤其是美国人工智能方面教授的文章,发现此前媒体宣传的东西几乎都是错的,都是夸大了alpha狗.我做了一个阅 ...

  7. 快速构建一个简单的对话+问答AI (上)

    文章目录 前言 part0 资源准备 基本功能 语料 停用词 问答 闲聊语料 获取 part01句的表达 表达 one-hot编码 词嵌入 大致原理 实现 简单版 复杂版 如何训练 转换后的形状 pa ...

  8. 3.2 实战项目二(手工分析错误、错误标签及其修正、快速地构建一个简单的系统(快速原型模型)、训练集与验证集-来源不一致的情况(异源问题)、迁移学习、多任务学习、端到端学习)

    手工分析错误 手工分析错误的大多数是什么 猫猫识别,准确率90%,想提升,就继续猛加材料,猛调优?     --应该先做错误分析,再调优! 把识别出错的100张拿出来, 如果发现50%是"把 ...

  9. Pytorch:手把手教你搭建简单的卷积神经网络(CNN),实现MNIST数据集分类任务

    关于一些代码里的解释,可以看我上一篇发布的文章,里面有很详细的介绍!!! 可以依次把下面的代码段合在一起运行,也可以通过jupyter notebook分次运行 第一步:基本库的导入 import n ...

最新文章

  1. 斗争程序猿(三十八)——历史朝代大学(两)——我与数据库的故事
  2. 60. Permutation Sequence
  3. “中序表达式”转“后续表达式”
  4. 从传递函数到差分方程的转换
  5. Web前端开发CSS基础(2)
  6. authinfo.php,【nginxphp】后台权限认证方式
  7. centos7.5 安装配置supervisor管理python进程(也就是服务)
  8. 4.8_adapter_结构型模式:适配器模式
  9. 终于可以和 QQ 彻底说再见了!
  10. UmiJs(v3.x版本)
  11. HDU - 6437
  12. 2021年全球与中国木材采伐设备行业市场规模及发展前景分析
  13. 国内顶尖网页游戏制作人和主策划名单(转)
  14. php tp gii,TP电商项目:使用GII制作品牌管理
  15. 学习java的第5天
  16. AI绘画初体验(6pen平台)
  17. 清华软院、清华计科、南大计算机、中科院自动化所夏令营保研过程贴
  18. 头牌知产介绍减肥药商标注册属于哪一类?
  19. python用pandas读取excel_Python 中利用Pandas处理复杂的Excel数据
  20. 腾讯云数据库联手宇信科技发布联合方案,全面助力金融科技安全可控

热门文章

  1. 关于postgre中的pg_hba.conf 文件
  2. 遥感导论网课_甘肃农业大学2019年地理信息科学专业专升本招生 专业课考试大纲...
  3. 安装slide后Powerpoint 不自动退出的解决方案
  4. 监控mysql主从同步状态是否异常
  5. 从Python字符串中剥离字母数字字符以外的所有内容
  6. 如何在回调中访问正确的“ this”?
  7. win11 c4d如何安装 Windows11安装c4d的步骤方法
  8. AutopilotSim2驾驶模拟器使用
  9. (Activiti6.0.0)SpringProcessEngineConfiguration配置bean时属性注入不了,问题已经找到
  10. 可见的轮廓线用虚线绘制_CAD制图初学入门教程:CAD软件中如何绘制轴测图?