本人正在学习 贪心科技高性能神经网络与AI芯片应用研修课程,在此做学习笔记,欢迎一起交流学习,共同进步

一、序言

   

本文承接第一部分,基于对卷积神经网络网络组成的认识,开始学习如何去使用卷积神经网络进行对应的训练。模型评估作为优化部分,我们将放在第三个部分中再好好讲他的作用以及意义~

   训练的基本流程主要是数据集引入、训练及参数设置、验证及反馈这三个步骤,我们现在分三个步骤来认识一下这个训练的基本流程。

   PS:我更新真是快啊~

## 二、训练流程

## 1、数据集引入

   本文根据对应的实验要求,主要采用的是Pytorch中自带的MNIST数据集。MNIST数据集由于比较基础,历年来都是被各种玩坏的主要对象~

   引入数据集的时候主要需要注意的是预处理的一个操作,在这里主要用的是ToTensor和Normalize两个函数进行归一化处理。其实也不一定需要Normalize这个函数,因为训练其实都是可以进行的。

   但是这里需要注意一下,因为导入数据集的时候操作是固定的。所以为了保证这个操作固定,就最好是用Compose把他们固定起来,不然在后续操作中可能就会添麻烦。

   如果你在做自己的手写图像识别,并且老是正确率比较低,那么一定注意一下这几个点。

   第一个是图像的前后的前后处理的时候是不一样的,很容易直接用自己的图像直接拿去识别了,但是因为之前训练集中的都是经过Compose结合后的组合处理后的图像。但是你直接拿去处理的图像是没有经过处理的,输入到模型中的和此前的格式是不一样的。

   第二个就是因为你手写的时候,导出的文件无论是png还是jpg,他们基本都是彩色图片。(是的,哪怕你看到的都是黑色,但他们本身还都是彩色图片)这个时候可以使用​transforms.Grayscale​函数先将你的图像灰度处理,不然在用Normalize的时候还是会带来问题。由于预处理不同,所以你在前后训练的素材和你最后手写的素材不是一个格式,难免会导致你的准确率很低。预处理函数的设置是后面新增自定义素材时的必要保障。

   (2020.11.22补充:识别率和笔触的关系较大,可以参考训练集中图像的大小和笔触进行书写;在一定程度上,黑底白字比起白底黑字来说,准确率更高——by绚佬)

   关于transforms中包含有多少函数,有什么对应的作用,可以参考:​​

2、构建网络

   我们在第一部分的基础上,我们再重新定义一个网络,这里我们分别定义一个全连接层网络,再定义一个三层卷积神经网络。也借此复习一下网络定义的相关注意事项。

### (1)四层卷积神经网络

   我们在第一部分的基础上,我们再重新定义一个网络,这里我们分别定义一个全连接层网络,再定义一个三层卷积神经网络。也借此复习一下网络定义的相关注意事项。

   在定义的时候我们只需要注意几个点,一个是我们在定义的时候,务必保证我们的每一个Linear之间存在着输入输出通道对应的关系要相对应。第一个Linear函数的输入需要符合 ​深度x高度x宽度​ 的相关信息。

   其实这里还有几个没有解决的问题:Linear函数的数量该如何确定,他们数目会不会影响训练效果;log_softmax函数对于整体效果影响有多大等~(如果之后解决了我再写上去(嗯!

### (2)两层全连接层网络

   同卷积神经网络不太一样的是,全连接层网络中就只含有Linear映射。从我们此前的文字,我们可以知道:全连接层是不含Conv2d、relu这些函数的,它的组成仅是简单的Linear映射而已。所以我们定义全连接网络如下:

   该网络包含的参数有三个,第一个是输入图像的大小,第二个是中间层,最后一个是输出。很明显,输入的大小就是28*28,并不需要我们再做过多的设计,输出也是十通道输出,所以也是固定的。中间层则是根据自己的需求进行定义的。

## 3、模型训练

   我们在第一部分的基础上,我们再重新定义一个网络,这里我们分别定义一个全连接层网络,再定义一个三层卷积神经网络。也借此复习一下网络定义的相关注意事项。这一部分,也可以参考链接:​​进一步了解一下~博主写的也是真的好

   首先,按照国际惯例,我们先用一个流程图来展示一下每一次训练过程。

Created with Raphaël 2.2.0

开始将训练集输入到模型进行训练对结果采用交叉熵巡视计算模型误差,并将预测结果提取出来预测结果等于实际标签成功预测数n += 1反向传播,更新参数输出准确率和实验误差yes

   如果是想要利用已经有的参数进行多次训练,还可以使用如下语句。

torch.save(model.state_dict(), ‘./params.pth’)

   为了加深对于整段代码的理解,我们可以先了解一下其中比较重要但是又不太常见的几个语句块和函数: ​_ , pred = torch.max(out, 1)​:这句话需要先了解torch.max的用法,不太熟悉的可以参考​​先看一下。torch.max的定义格式为:

out = torch.max(input, dim)

   输入为input以及一个dim。dim指的是维度,0代表索引每列的最大值,1代表索引每行的最大值。他的输出为最大值以及其索引。在这里的作用就是,在多分类问题的类别取概率最大的类别。

   对于我们而言,经过模型输出后,我们需要的是结果的第二列,也就是预测值。所以用 ​_ , pred​ 就可以只存下pred。除了这种方式以外,也可以用如下语句表示同样的意思:

pred = torch.max(out, 1)[1] ​torch.cuda.is_available()​:看你的电脑的GPU是否可以被PyTorch调用​item()​:得到一个元素张量里面的元素值,常用于将一个零维张量转换成浮点数。​optimizer.zero_grad()​:遍历模型的所有参数,将上一次的梯度记录被清空。​loss.backward()​:进行误差反向传播。​optimizer.step()​:执行一次优化步骤,通过梯度下降法来更新参数的值。以上三个函数均为反向传播当中的必要函数,详细可以参考链接​​进一步了解,这三个函数之间是相辅相成的。

4、模型评估

   模型评估大体上的效果和步骤同模型训练一致,只需要将部分代码进行替换即可~这里就不贴代码了,就将评估当成是基于以上的又一次训练即可。

贪心高性能神经网络与AI芯片~学习笔记总计1相关推荐

  1. 贪心高性能神经网络与AI芯片应用

    ⽬前深度学习理论研究主要的⼀部分围绕在使⽤常微分⽅程,随机微分⽅程,偏微分⽅程,动⼒学系统,等来进⾏,例如neural ODE.可以关注鄂维南组,等的⼯作,他们都有傲⼈的title,与应⽤数学背景,在 ...

  2. 2022年薪百万赛道:高性能神经网络与AI芯片应用

    随着大数据的发展,计算机芯片算力的提升,人工智能近两年迎来了新一轮的爆发.而人工智能实现超级算力的核心就是AI芯片.AI芯片也被称为人工智能加速器,即专门用于处理人工智能应用中的大量计算任务的模块. ...

  3. 三种256MB SPIFLASH的高性能模式和软复位学习笔记

    三种256MB SPIFLASH的高性能模式和软复位学习笔记 WINBONE CONTINUE READ MODE The Fast Read Dual I/O The Fast Read Quad ...

  4. 深入浅出图神经网络|GNN原理解析☄学习笔记(四)表示学习

    深入浅出图神经网络|GNN原理解析☄学习笔记(四)表示学习 文章目录 深入浅出图神经网络|GNN原理解析☄学习笔记(四)表示学习 表示学习 表示学习的意义 离散表示与分布式表示 端到端学习 基于重构损 ...

  5. 图神经网络-图与图学习笔记-1

    图神经网络-图与图学习 笔记-1 目录 一. 图是什么? 图的基本表示方法 计算图的每个节点的度 计算边的数量 可视化 二. 如何存储图? 存储为边列表 使用邻接矩阵 使用邻接列表 三. 图的类型和性 ...

  6. 阿里云趣味视觉AI训练营学习笔记Day 5

    阿里云趣味视觉AI训练营学习笔记Day 5 学习目标 学习内容 前言 一.创建人像卡通化应用 二.应用配置 三.后端服务开发部署 四.小程序前端开发 阿里云高校计划,陪伴两千多所高校在校生云上实践.云 ...

  7. 新唐芯片学习笔记——概要

    ##新唐芯片学习笔记--概要 特性 编号信息列表与管脚定义 NuMicroNUC029 命名规则 特性 ARM®Cortex®-M0 内核 – 运行频率可达50MHz – 一个 24位系统定时器 – ...

  8. 新唐芯片学习笔记——UART

    新唐芯片学习笔记--UART 概述 NuMicro®NUC029 提供2个通用异步收/发器(UART)通道,UART支持普通速度UART,并支持流控制.UART控制器对从外设收到的数据执行串到并的转换 ...

  9. 新唐芯片学习笔记——GPIO

    新唐芯片学习笔记--GPIO 概述 NuMicro®NUC029 最多有40个通用I/O引脚,这些引脚和其它功能共享.40个引脚分为6个端口,分别命名为P0, P1, P2, P3, P4和P5,每个 ...

  10. 新唐芯片学习笔记——ADC

    新唐芯片学习笔记--ADC 概述 NuMicro®NUC029xAN 包含一个12位逐次逼近型模数转换器(SAR A/D转换器) ,包含8个输入通道:NuMicro®NUC029FAE 包含一个10位 ...

最新文章

  1. mysql中文无法显示
  2. 岳阳机器人餐厅在哪_从机器人咖啡看未来餐饮行业大方向,如何才能活下去?...
  3. IdentityServer4 之Client Credentials走起来
  4. 【数据库】Window环境安装MySQL Server 5.7.21
  5. 抓包分析360浏览器和360搜索配对使用的安全性-WEB服务端分析
  6. web前端入门到实战:CSS3两大实用属性,以及网页制作技巧
  7. MPC(模型预测控制)之二(路径规划)
  8. 最常使用的css 工具_使用这些非常有用CSS工具更快地实施网站设计
  9. python对于一元线性回归模型_Python|机器学习-一元线性回归
  10. Java程序员面试宝典(第4版)
  11. matlab一维haar信号塔式分解,matlab小波分解与重构
  12. 80286 与 80386,实模式与保护模式切换编程
  13. 第09篇:Spring处理Mybatis事务
  14. Foxmail设置雅虎邮箱的方法
  15. findfont: Font family [‘sans-serif‘] not found解决方法
  16. centos下espeak文本转语音的代码实现
  17. python网页自动填写_Windows下使用python3 + selenium实现网页自动填表功能
  18. Linux LVS 负载均衡群集
  19. 提问的智慧(smart questions)
  20. php开源 rss订阅_5个开源RSS feed阅读器

热门文章

  1. 基于stm32的智能家居语音控制系统
  2. 马里兰帕克分校计算机科学,马里兰大学帕克分校管理信息系统(MIS)专业详解...
  3. homelede软路由设置方法_软路由LEDE折腾overlay分区扩容之路
  4. 中国四大运营商2G/3G/4G/5G工作频率以及网络制式
  5. python绘制折线图显示点数据_Python_散点图与折线图绘制
  6. HOG特征,LBP特征,Haar特征(图像特征提取)
  7. android投屏功能开发,Android PC投屏功能实现的示例代码
  8. google license key格式不对
  9. 如何利用Python开发一款快手加抖音自动刷视频脚本!
  10. IDEA工具-鼠标滚轮调整字体大小