学习笔记,仅供参考,有错必纠


文章目录

  • 理论
    • 卷积神经网络CNN
      • 局部感受野和权值共享
      • 卷积计算
      • 池化Pooling
      • Padding
    • LeNET-5
  • 代码
    • 初始设置
    • 导包
    • 载入数据
    • 模型

理论

卷积神经网络CNN

卷积神经网络是近年发展起来,并广泛应用于图像处理,NLP等领域的一种多层神经网络。

局部感受野和权值共享

CNN通过局部感受野和权值共享减少了神经网络需要训练的参数个数,从而解决了传统BP权值太多,计算量太大,需要大量样本进行训练的问题.

卷积计算

卷积核也叫滤波器,不同的卷积核 对 同样的图片做卷积之后会提取出不同的信息. 以下图的卷积核为例,我们可以对示例Image进行卷积操作.


需要注意的是,卷积核里的参数不是人为设定的,而是算法优化得到的.

池化Pooling

Pooling常用的三种方式:

  • max-pooling
  • mean-pooling
  • stochastic pooling

Padding

  • SAME PADDING

给平面外部补0,卷积窗口采样后可能会得到一个跟原来大小相同的平面.

  • VALID PADDING

不会超出平面外部,卷积窗口采样后得到一个比原来平面小的平面。

LeNET-5

LeNET-5是最早的卷积神经网络之一. 下图为LeNET-5的网络结构.

我们可以看到通过对第3层进行卷积后,第4层得到了16幅图. 那么第4层的16幅图是如何计算的呢,操作如下图所示.

代码

初始设置

# 支持多行输出
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = 'all' #默认为'last'

导包

# 导入常用的包
import numpy as np
from torch import nn,optim
from torch.autograd import Variable
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch

载入数据

# 载入数据
train_dataset = datasets.MNIST(root = './data/', # 载入的数据存放的位置train = True, # 载入训练集数据transform = transforms.ToTensor(), # 将载入进来的数据变成Tensordownload = True) # 是否下载数据
test_dataset = datasets.MNIST(root = './data/', # 载入的数据存放的位置train = False, # 载入测试集数据transform = transforms.ToTensor(), # 将载入进来的数据变成Tensordownload = True) # 是否下载数据
# 批次大小
batch_size = 64# 装载训练集
train_loader = DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True)# 装载训练集
test_loader = DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=True)

模型

这里我们使用具有多层网络结构的模型,并加入Dropout操作.

# 定义网络结构
class Net(nn.Module):def __init__(self):super(Net, self).__init__()# 定义卷积层和池化# in_channels:int, 因为是黑白图片,所以输入通道设置为1,如果为彩色图像则这里为3# out_channels:int, 这里的输出通道数也为生成的特征图的数量,这里我们设置为32# kernel_size:int, 卷积核大小,我们设置为5# stride=1, 步长我们设置为1# padding=0, 我们设置padding为2,也就是在图片的外围补2圈0,这里我们要按照自己的需求自己计算# 如果想要卷积后的大小和原始图像大小相同,则卷积核大小为3*3则填充1圈0,5*5填充2圈,7*7填充3圈.# 因为卷积不是非线性操作,所以我们在卷积后增加非线性激活函数nn.ReLU()# 在卷积后,我们增加一个2*2的池化操作self.conv1 = nn.Sequential(nn.Conv2d(1, 32, 5, 1, 2), nn.ReLU(), nn.MaxPool2d(2, 2))# 再定义一个卷积和池化self.conv2 = nn.Sequential(nn.Conv2d(32, 64, 5, 1, 2), nn.ReLU(), nn.MaxPool2d(2, 2))# 全连接# 全连接的输入为64个大小为(7*7)的特征图# 输出为1000self.fc1 = nn.Sequential(nn.Linear(64*7*7, 1000), nn.Dropout(p = 0.4), nn.ReLU())# 全连接self.fc2 = nn.Sequential(nn.Linear(1000, 10),nn.Softmax(dim = 1))def forward(self, x):# ([64, 1, 28, 28])# 卷积要求的数据格式就是4维的([图片数量, 图片通道数, 图片维度1, 图片维度2])x = self.conv1(x)x = self.conv2(x)# 进入全连接层时,需要reshape# ([64, 64, 7, 7]) -> ([64, 64*7*7])x = x.view(x.size()[0], -1)x = self.fc1(x)x = self.fc2(x)return x
LR = 0.0003
# 定义模型
model = Net()
# 定义代价函数为交叉熵代价函数
mse_loss = nn.CrossEntropyLoss()
# 定义优化器Adam
optimizer = optim.Adam(model.parameters(), LR)

在自定义训练和测试函数中,我们分别增加两个方法,model.train()model.eval() ,这model.train()方法可以使训练集中的Dropout在训练模型时发挥作用,而model.eval()则可以使模型在测试过程中不工作.

def train():model.train()for i,data in enumerate(train_loader):# 获得一个批次的数据和标签inputs, labels = data# 获得模型预测结果(64,10)out = model(inputs)# 计算loss,交叉熵代价函数out(batch,C), labels(batch)loss = mse_loss(out, labels)# 梯度清0optimizer.zero_grad()# 计算梯度loss.backward()# 修改权值optimizer.step()def test():model.eval()# 计算训练集准确率correct = 0for i,data in enumerate(train_loader):# 获得一个批次的数据和标签inputs, labels = data# 获得模型预测结果(64,10)out = model(inputs)# 获得最大值,以及最大值所在的位置_, predicted = torch.max(out, 1)# 预测正确的数量correct += (predicted == labels).sum()print("Train acc:{0}".format(correct.item()/len(train_dataset)))# 计算测试集准确率correct = 0for i,data in enumerate(test_loader):# 获得一个批次的数据和标签inputs, labels = data# 获得模型预测结果(64,10)out = model(inputs)# 获得最大值,以及最大值所在的位置_, predicted = torch.max(out, 1)# 预测正确的数量correct += (predicted == labels).sum()print("Test acc:{0}".format(correct.item()/len(test_dataset)))
for epoch in range(5):print('epoch:',epoch)train()test()
epoch: 0
Train acc:0.9728166666666667
Test acc:0.9755
epoch: 1
Train acc:0.9827666666666667
Test acc:0.983
epoch: 2
Train acc:0.9863
Test acc:0.9863
epoch: 3
Train acc:0.98665
Test acc:0.9842
epoch: 4
Train acc:0.99075
Test acc:0.9896

PyTorch基础(part7)--CNN相关推荐

  1. PyTorch学习笔记(四):PyTorch基础实战

    PyTorch实战:以FashionMNIST时装分类为例: 往期学习资料推荐: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 本 ...

  2. 基于pytorch使用实现CNN 如何使用pytorch构建CNN卷积神经网络

    基于pytorch使用实现CNN 如何使用pytorch构建CNN卷积神经网络 所用工具 文件结构: 数据: 代码: 结果: 改进思路 拓展 本文是一个基于pytorch使用CNN在生物信息学上进行位 ...

  3. 深入浅出Pytorch:02 PyTorch基础知识

    深入浅出Pytorch 02 PyTorch基础知识 内容属性:深度学习(实践)专题 航路开辟者:李嘉骐.牛志康.刘洋.陈安东 领航员:叶志雄 航海士:李嘉骐.牛志康.刘洋.陈安东 开源内容:http ...

  4. 第02章 PyTorch基础知识

    文章目录 第02章 Pytorch基础知识 2.1 张量 2.2 自动求导 2.3 并行计算简介 2.3.1 为什么要做并行计算 2.3.2 CUDA是个啥 2.3.3 做并行的方法 补充:通过股票数 ...

  5. 深度学习之Pytorch基础教程!

    ↑↑↑关注后"星标"Datawhale 每日干货 & 每月组队学习,不错过 Datawhale干货 作者:李祖贤,Datawhale高校群成员,深圳大学 随着深度学习的发展 ...

  6. python cnn_使用python中pytorch库实现cnn对mnist的识别

    使用python中pytorch库实现cnn对mnist的识别 1 环境:Anaconda3 64bit https://www.anaconda.com/download/ 2 环境:pycharm ...

  7. 【深度学习】基础知识--CNN:图像分类(上)

    作者信息: 华校专,曾任阿里巴巴资深算法工程师.智易科技首席算法研究员,现任腾讯高级研究员,<Python 大战机器学习>的作者. 编者按: 算法工程师必备系列更新啦!继上次推出了算法工程 ...

  8. 【深度学习】深度学习之Pytorch基础教程!

    作者:李祖贤,Datawhale高校群成员,深圳大学 随着深度学习的发展,深度学习框架开始大量的出现.尤其是近两年,Google.Facebook.Microsoft等巨头都围绕深度学习重点投资了一系 ...

  9. PyTorch基础(part5)--交叉熵

    学习笔记,仅供参考,有错必纠 文章目录 原理 代码 初始设置 导包 载入数据 模型 原理 交叉熵(Cross-Entropy) Loss=−(t∗ln⁡y+(1−t)ln⁡(1−y))Loss =-( ...

最新文章

  1. 语义分割--ParseNet: Looking Wider to See Better
  2. 示波器测485串口波特率的使用方法
  3. Linux排障必备命令
  4. Kubernetes 稳定性保障手册:洞察+预案
  5. python315题的漫漫通关之路
  6. 穷举 百文百鸡
  7. 数组——寄包柜(洛谷 P3613)
  8. KeyboardEvent keyMap
  9. php获取当天 天气预报,PHP获取当天和72小时天气预报,并生成接口
  10. 11、进入保护模式-V
  11. leetcode每日一练(第一天)
  12. 无损音乐ape格式怎么转为ogg格式
  13. Java基础-运算符
  14. 尤雨溪 6 月 4 日的 Vue 技术分享
  15. websocket 给服务端发送太长数据处理(The decoded text message was too big for the output buffer and the endpoint )
  16. ITRON入门之实时操作系统的特点
  17. ios html背景音乐,iOS音频篇:使用AVPlayer播放网络音乐
  18. 分析linux启动内核源码
  19. javascript教程完整版,JavaScript视频教程
  20. 360极速浏览器内核切换设置

热门文章

  1. firefox安装adobe flash插件
  2. History of pruning algorithm development and python implementation(finished)
  3. 机器学习(二十三)——Beam Search, NLP机器翻译常用评价度量, 模型驱动 vs 数据驱动
  4. oracle动态采样超时,解决 ORACLE 11.2 动态采样导致的性能问题
  5. 2020年408真题_2020年408真题和参考解析
  6. 平方环法_2019环法挑战赛加速诸暨“运动之城”建设 推动“体育+旅游”新热潮...
  7. Chrome的console
  8. 为何2018年中国自然灾害损失大幅下降?官方回应
  9. 18.12.04 有品面试小记
  10. 《Orange’s 一个操作系统的实现》1.搭建操作系统开发环境