CNN的理论部分可见机器学习笔记:CNN卷积神经网络_刘文巾的博客-CSDN博客

1 导入库

import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt

2 超参数定义

EPOCH=1
BATCH_SIZE=50
LR=0.001
DOWNLOAD_MNIST=False
#如果已经事先下载好了 mnist数据,那么DOWNLOAD_MNIST就是False,否则就是True

3 加载数据

train_data=torchvision.datasets.MNIST(
root='./mnist/',
#从这个路径找mnist数据/下载mnist数据到这个路径下
train=True,
#这时候数据是训练集(是训练集还是测试集对dropout等会有影响)
transform=torchvision.transforms.ToTensor()
#将mnist数据集中的数据类型转换为Tensor形式,
download=DOWNLOAD_MNIST)

4 数据集信息及可视化

print(train_data.data.shape)
#torch.Size([60000, 28, 28])
#60000条数据,每条数据是28*28的像素点train_data.targets.shape
#torch.Size([60000])
#训练集数据的标签,每条数据对应一个标签,代表这张图片是哪个数字#可视化
plt.imshow(train_data.data[1])
plt.title("{}".format(train_data.targets.data[1]))

5 dataloader生成

#生成dataloader
train_loader=Data.DataLoader(dataset=train_data,batch_size=BATCH_SIZE,shuffle=True
)

6 CNN模型定义

class CNN(nn.Module):def __init__(self):super(CNN,self).__init__()self.conv1=nn.Sequential(nn.Conv2d(in_channels=1,
#输入shape (1,28,28)out_channels=16,
#输出shape(16,28,28),16也是卷积核的数量kernel_size=5,stride=1,padding=2),
#如果想要conv2d出来的图片长宽没有变化,那么当stride=1的时候,padding=(kernel_size-1)/2nn.ReLU(),nn.MaxPool2d(kernel_size=2)#在2*2空间里面下采样,输出shape(16,14,14))self.conv2=nn.Sequential(nn.Conv2d(in_channels=16,
#输入shape (16,14,14)out_channels=32,
#输出shape(32,14,14)kernel_size=5,stride=1,padding=2),
#输出shape(32,7,7),nn.ReLU(),nn.MaxPool2d(kernel_size=2))self.fc=nn.Linear(32*7*7,10)
#输出一个十维的东西,表示我每个数字可能性的权重def forward(self,x):x=self.conv1(x)x=self.conv2(x)x=x.view(x.shape[0],-1)x=self.fc(x)return xcnn=CNN()
print(cnn)
'''
CNN((conv1): Sequential((0): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))(1): ReLU()(2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))(conv2): Sequential((0): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))(1): ReLU()(2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))(fc): Linear(in_features=1568, out_features=10, bias=True)
)
'''

7 定义优化函数和损失函数

optimizer=torch.optim.SGD(cnn.parameters(),lr=LR)loss_func=torch.nn.CrossEntropyLoss()
#损失函数定义为交叉熵loss_his=[]

8 训练模型

for epoch in range(EPOCH):for step,(b_x,b_y) in enumerate(train_loader):output=cnn(b_x)loss=loss_func(output,b_y)loss_his.append(loss)optimizer.zero_grad()
#清除上一次参数更新的残余梯度loss.backward()
#损失函数后向传播optimizer.step()
#参数更新

9 损失函数可视化

plt.figure(figsize=(10,5))
plt.plot(loss_his)

10 结果验证

tmp=train_data.data[3]
print(tmp.shape)
#torch.Size([28, 28])plt.imshow(tmp)print(train_data.targets[3])
#1tmp=tmp.reshape(1,1,28,28)
#reshape一下,这样可以送入模型中cnn(tmp.type(torch.FloatTensor))
#type那一部分必须要,否则报错“#RuntimeError: expected scalar type Byte but found Float”#tensor([[-74.7790, 140.8805,  54.7678,   2.0222, -28.6432, -62.3598, -20.4240,
#         -44.9795, 110.6059, -37.1277]], grad_fn=<AddmmBackward>)torch.max(cnn(tmp.type(torch.FloatTensor)),axis=1)[1]
#tensor([1])

pytorch笔记:搭建简易CNN相关推荐

  1. 【 卷积神经网络CNN 数学原理分析与源码详解 深度学习 Pytorch笔记 B站刘二大人(9/10)】

    卷积神经网络CNN 数学原理分析与源码详解 深度学习 Pytorch笔记 B站刘二大人(9/10) 本章主要进行卷积神经网络的相关数学原理和pytorch的对应模块进行推导分析 代码也是通过demo实 ...

  2. 【卷积神经网络CNN 实战案例 GoogleNet 实现手写数字识别 源码详解 深度学习 Pytorch笔记 B站刘二大人 (9.5/10)】

    卷积神经网络CNN 实战案例 GoogleNet 实现手写数字识别 源码详解 深度学习 Pytorch笔记 B站刘二大人 (9.5/10) 在上一章已经完成了卷积神经网络的结构分析,并通过各个模块理解 ...

  3. 李沐《动手学深度学习》第二版 pytorch笔记1 环境搭建

    李沐<动手学深度学习>第二版pytorch笔记1 搭建环境 文章目录 李沐<动手学深度学习>第二版pytorch笔记1 搭建环境 此时尚有耐心 虚拟环境搭建 创建虚拟环境 查看 ...

  4. pytorch 笔记:torchsummary

    作用:打印神经网络的结构 以pytorch笔记:搭建简易CNN_UQI-LIUWJ的博客-CSDN博客 中搭建的CNN为例 import torch from torchsummary import ...

  5. pytorch 笔记:使用Tune 进行调参

    自动进行调参,我们以pytorch笔记:搭建简易CNN_UQI-LIUWJ的博客-CSDN博客的代码为基础,进行output_channel和learning rate的调参 1 导入库 from f ...

  6. pytorch笔记:policy gradient

    本文参考了 策略梯度PG( Policy Gradient) 的pytorch代码实现示例 cart-pole游戏_李莹斌XJTU的博客-CSDN博客_策略梯度pytorch 在其基础上添加了注释和自 ...

  7. 【深度学习】Keras vs PyTorch vs Caffe:CNN实现对比

    作者 | PRUDHVI VARMA 编译 | VK 来源 | Analytics Indiamag 在当今世界,人工智能已被大多数商业运作所应用,而且由于先进的深度学习框架,它非常容易部署.这些深度 ...

  8. Pytorch笔记:风格迁移

    Pytorch笔记:风格迁移 训练模型:风格迁移网络+VGG16网络 生成网络:风格迁移网络 代码如下(根据陈云<深度学习框架:Pytorch入门与实践>的代码改动) main.py im ...

  9. Pytorch快速搭建Alexnet实现手写英文字母识别+PyQt实现鼠标绘图

    Pytorch快速搭建Alexnet实现手写英文字母识别+PyQt实现鼠标绘图 前言 一.案例要求 二.训练数据准备 1.下载手写英文字母数据集 2.构建自己的数据集 三.AlexNet实现 1.Al ...

最新文章

  1. mvc:annotation-driven/
  2. Windows域的管理
  3. Vue学习(入门实例、常用指令)-学习笔记
  4. mongoose日期 时间 范围查询
  5. 路由复用器--gorilla/mux
  6. TypeError: Unexpected keyword argument passed to optimizer: learning_rate解决方法
  7. html css 表头,css固定表格表头(各浏览器通用)
  8. [LibTorch Win] 各版本 LibTorch 下载
  9. asp.net gridview 模板列 弹出窗口编辑_如何使用极速PDF编辑器的注释工具?
  10. Linux 基础——权限管理命令chown、chgrp
  11. VS 中PageLayout 属性设置
  12. 青岛高新职业学校计算机专业,“把灯光调亮 让我最后再好好看看你们” 青岛高新职业学校举行2021届毕业典礼...
  13. 浪涌保护器ant120_菲尼克斯浪涌保护器
  14. 正则——16进制颜色
  15. 测试睡眠的软件是,MOORING睡眠检测
  16. Creator+微信小游戏:(3)微信openID获取(https、wss问题)
  17. C Sharp编写缓和曲线计算应用程序
  18. 机器人系统数学建模(现代控制理论1)
  19. 谁再说“游戏没用”,就拿这个回怼他!
  20. 安装使用Animate动画库【Animate.css下载安装教程】

热门文章

  1. zabbix如何监控WEB应用性能
  2. ubuntu下的第一个脚本file.sh
  3. docker深入1-导入导出images和container的方式
  4. win2003 shutdown命令
  5. Tasklets 机制浅析
  6. PAT甲级1101 Quick Sort:[C++题解]DP、快速排序划分个数、快排
  7. php流调签名,微信接口签名及调用流程详解 - 黎明互联-官方博客 - 黎明互联 - 区块链培训,PHP培训,IT培训,职业技能培训,追求极致!改变您的职业生涯!...
  8. xampp php源码的路径,php – XAMPP中的根路径
  9. mysql设计经纬度表_MySQL经纬度表设置
  10. MySQL中的视图操作