笔者也是最近刚学不久的深度学习,也有很多地方不懂,下面给大家使用pytorch实现一个简单的Resnet网络(残差网络),并且训练MNIST数据集.话不多说,直接上代码.

  笔者认为最主要的地方就是网络模型,网络模型出来其实基本上就完成了大概了.首先是残差块.之后是残差网络,数据集,训练,测试.完整代码请下拉看最后.

残差块的结构如下所示:

#残差块
class ResidualBlock(nn.Module):def __init__(self,channel):super(ResidualBlock, self).__init__()self.channel=channelself.conv1=nn.Sequential(nn.Conv2d(in_channels=channel,out_channels=channel,kernel_size=3,stride=1,padding=1),nn.BatchNorm2d(channel),nn.ReLU(inplace=True))self.conv2=nn.Sequential(nn.Conv2d(channel,channel,kernel_size=3,stride=1,padding=1),# nn.BatchNorm2d(channel))def forward(self,x):out=self.conv1(x)out=self.conv2(out)out+=xout=F.relu(out)return out

残差块完成之后,就是残差网络,如下图:


#残差网络
class ResNet(nn.Module):def __init__(self):super(ResNet, self).__init__()self.conv1=nn.Sequential(nn.Conv2d(in_channels=1,out_channels=32,kernel_size=5), #(1,28,28)nn.BatchNorm2d(32),                                     #(32,24,24)nn.ReLU(),nn.MaxPool2d(2)                                         #(32,12,12))self.conv2 = nn.Sequential(nn.Conv2d(in_channels=32, out_channels=16, kernel_size=5), #(16,8,8)nn.BatchNorm2d(16),nn.ReLU(),nn.MaxPool2d(2)                                           #(16,4,4))self.reslayer1=ResidualBlock(32)self.reslayer2=ResidualBlock(16)self.fc=nn.Linear(256,10)              #这里的输入256是因为16*4*4=256def forward(self,x):out=self.conv1(x)out=self.reslayer1(out)out=self.conv2(out)out=self.reslayer2(out)out=out.view(out.size(0),-1)out=self.fc(out)return  out

数据集:

Epoch=3
Batch_Size=50
LR=0.01#训练集
trainData=torchvision.datasets.MNIST(root="/home/sunrui/zqtstudy/卷积网络/ResNetsimple/data",train=True,transform=torchvision.transforms.ToTensor(),download=False)train_loader=Data.DataLoader(dataset=trainData,batch_size=Batch_Size,shuffle=True)
test_data=torchvision.datasets.MNIST(root="/home/sunrui/zqtstudy/卷积网络/ResNetsimple/data",train=False,download=False)test_x = torch.unsqueeze(test_data.data, dim=1).type(torch.FloatTensor)[:5000]/255. # shape from (2000, 28, 28) to (2000, 1, 28, 28), value in range(0,1)
test_y = test_data.targets[:5000]

训练:


#关于训练
def Train(Res):# 损失函数,以及优化器loss_func = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(Res.parameters(), lr=LR)for epoch in range(Epoch):for step,(b_x,b_y)in enumerate(train_loader):output=Res(b_x)loss=loss_func(output,b_y)optimizer.zero_grad()loss.backward()optimizer.step()if(step%50==0):test_output=Res(test_x)pred_y = torch.max(test_output, 1)[1].data.numpy()accuracy = float((pred_y == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0))print('Epoch: ', epoch, '| train loss: %.4f' % loss.data.numpy(), '| test accuracy: %.6f' % accuracy)torch.save(Res, 'res_minist.pkl')print('res finish training')

测试:


# 测试
def Restest():res=torch.load('res_minist.pkl')test_output=res(test_x[:100])prediction=torch.max(test_output,1)[1].data.numpy()print(prediction, 'prediction number')print(test_y[:100].numpy(), 'real number')test_output1 = res(test_x)pred_y1 = torch.max(test_output1, 1)[1].data.numpy()accuracy = float((pred_y1 == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0))print('accuracy', accuracy)

大体就是这样,然后,在一个py文件中全部写出来就OK了,下面是完整代码:

import torch
import torch.nn as nn
import torchvision
from torch.autograd import Variable
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torch.utils.data as DataEpoch=3
Batch_Size=50
LR=0.01#训练集
trainData=torchvision.datasets.MNIST(root="/home/sunrui/zqtstudy/卷积网络/ResNetsimple/data",train=True,transform=torchvision.transforms.ToTensor(),download=False)train_loader=Data.DataLoader(dataset=trainData,batch_size=Batch_Size,shuffle=True)
test_data=torchvision.datasets.MNIST(root="/home/sunrui/zqtstudy/卷积网络/ResNetsimple/data",train=False,download=False)test_x = torch.unsqueeze(test_data.data, dim=1).type(torch.FloatTensor)[:5000]/255. # shape from (2000, 28, 28) to (2000, 1, 28, 28), value in range(0,1)
test_y = test_data.targets[:5000]#残差块
class ResidualBlock(nn.Module):def __init__(self,channel):super(ResidualBlock, self).__init__()self.channel=channelself.conv1=nn.Sequential(nn.Conv2d(in_channels=channel,out_channels=channel,kernel_size=3,stride=1,padding=1),nn.BatchNorm2d(channel),nn.ReLU(inplace=True))self.conv2=nn.Sequential(nn.Conv2d(channel,channel,kernel_size=3,stride=1,padding=1),# nn.BatchNorm2d(channel))def forward(self,x):out=self.conv1(x)out=self.conv2(out)out+=xout=F.relu(out)return out#残差网络
class ResNet(nn.Module):def __init__(self):super(ResNet, self).__init__()self.conv1=nn.Sequential(nn.Conv2d(in_channels=1,out_channels=32,kernel_size=5), #(1,28,28)nn.BatchNorm2d(32),                                     #(32,24,24)nn.ReLU(),nn.MaxPool2d(2)                                         #(32,12,12))self.conv2 = nn.Sequential(nn.Conv2d(in_channels=32, out_channels=16, kernel_size=5), #(16,8,8)nn.BatchNorm2d(16),nn.ReLU(),nn.MaxPool2d(2)                                           #(16,4,4))self.reslayer1=ResidualBlock(32)self.reslayer2=ResidualBlock(16)self.fc=nn.Linear(256,10)              #这里的输入256是因为16*4*4=256def forward(self,x):out=self.conv1(x)out=self.reslayer1(out)out=self.conv2(out)out=self.reslayer2(out)out=out.view(out.size(0),-1)out=self.fc(out)return  out#关于训练
def Train(Res):# 损失函数,以及优化器loss_func = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(Res.parameters(), lr=LR)for epoch in range(Epoch):for step,(b_x,b_y)in enumerate(train_loader):output=Res(b_x)loss=loss_func(output,b_y)optimizer.zero_grad()loss.backward()optimizer.step()if(step%50==0):test_output=Res(test_x)pred_y = torch.max(test_output, 1)[1].data.numpy()accuracy = float((pred_y == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0))print('Epoch: ', epoch, '| train loss: %.4f' % loss.data.numpy(), '| test accuracy: %.6f' % accuracy)torch.save(Res, 'res_minist.pkl')print('res finish training')# x=torch.randn(16,1,28,28)
# res=ResNet()# 测试
def Restest():res=torch.load('res_minist.pkl')test_output=res(test_x[:100])prediction=torch.max(test_output,1)[1].data.numpy()print(prediction, 'prediction number')print(test_y[:100].numpy(), 'real number')test_output1 = res(test_x)pred_y1 = torch.max(test_output1, 1)[1].data.numpy()accuracy = float((pred_y1 == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0))print('accuracy', accuracy)if __name__=='__main__':# Train(res)Restest()

测试或者训练的话,在  if __name__=='__main__':  部分注释掉相应的部分就好啦.另外训练或者测试的数据大小可以自己选择,毕竟MNIST的训练数据集有6W张来着......笔者自己试了accuracy是在98%99%左右,这个是能够跑通的,跑通了才放上来的.数据集的加载训练测试代码参考,还有简单残差网络的知识请看下面链接,个人觉得对于刚开始的新手还是很有帮助的.

知乎的一位大佬的pytorch搭建CNN用于图像识别

B站刘二大人老师的 pytorch深度学习实践P11课

有问题可以留言,不定期看到会回复~(另外因为是第一次另外我不知道这算原创还是算转载,感觉没有特别合适的.所以我发表出来的时候选的类型是原创,要是并不属于这个类型,请告知我,我会改掉的,谢谢)

pytorch实现简单的Resnet网络相关推荐

  1. PyTorch实现简单的残差网络

    一.实现过程 残差网络(Residual Network)的特点是容易优化,并且能够通过增加相当的深度来提高准确率.其内部的残差块使用了跳跃连接,缓解了在深度神经网络中增加深度带来的梯度消失问题. 本 ...

  2. 利用Pytorch搭建简单的图像分类模型(之二)---搭建网络

    Pytorch搭建网络模型-ResNet 一.ResNet的两个结构 首先来看一下ResNet和一般卷积网络结构上的差异: 图中上面一部分就是ResNet34的网络结构图,下面可以理解为一个含有34层 ...

  3. ResNet网络详解并使用pytorch搭建模型、并基于迁移学习训练

    1.ResNet网络详解 网络中的创新点: (1)超深的网络结构(突破1000层) (2)提出residual模块 (3)使用Batch Normalization加速训练(丢弃dropout) (1 ...

  4. ResNet网络简单理解与代码

    ResNet网络提出的文章是<Deep Residual Learning for Image Recognition> 下载地址:https://arxiv.org/pdf/1512.0 ...

  5. Pytorch实现Deep Mutual Learning网络

    -Model(pytorch版本 参考资料: 信息熵是什么? 交叉熵和相对熵(KL散度), 极大似然估计求loss, softmax多分类 一文搞懂熵.相对熵.交叉熵损失 class torch.nn ...

  6. 【读点论文】A ConvNet for the 2020s,结合swin transformer的结构设计和训练技巧调整resnet网络,在类似的FLOPs和参数量取得更好一点的效果

    A ConvNet for the 2020s Abstract 视觉识别的"咆哮的20年代"始于视觉transformer(ViTs)的问世,它迅速取代ConvNets成为最先进 ...

  7. ResNet网络详解

    ResNet ResNet在2015年由微软实验室提出,斩获当年lmageNet竞赛中分类任务第一名,目标检测第一名.获得coco数据集中目标检测第一名,图像分割第一名. ResNet亮点 1.超深的 ...

  8. 【Pytorch(七)】基于 PyTorch 实现残差神经网络 ResNet

    基于 PyTorch 实现残差神经网络 ResNet 文章目录 基于 PyTorch 实现残差神经网络 ResNet 0. 概述 1. 数据集介绍 1.1 数据集准备 1.2 分析分类难度:CIFAR ...

  9. ResNet网络的训练和预测

    ResNet网络的训练和预测 简介 Introduction 图像分类与CNN 图像分类 是指将图像信息中所反映的不同特征,把不同类别的目标区分开来的图像处理方法,是计算机视觉中其他任务,比如目标检测 ...

最新文章

  1. c++矩阵作为函数输入变量_C++实现矩阵乘法
  2. ajax初试,获取数据
  3. python基础面试都问什么问题_基本 Python 面试问题
  4. python环境配置-windows版
  5. 神逸之作:国产快速启动软件神品ALTRun
  6. ASIC开发流程介绍
  7. 理财趣事:要想财富滚滚来 先学普京打野猪
  8. 【GitHub报错】You have not concluded your merge (MERGE_HEAD exists).解决方法
  9. 大班线描机器人_大班美术lbrack;漂亮的机器人rsqb;活动设计
  10. linux服务器IP伪造,Linux服务器间同网段IP伪装端口映射
  11. java 完全解耦_java-完全解耦
  12. 中国最懒城市,这里的人不想赚钱,只想躺平
  13. python画老虎_老虎证券量化API Python SDK
  14. torch norm() Formalize()
  15. Python pygame(GUI编程)模块最完整教程(1)
  16. 第二十一天Python之进程
  17. 声呐学习笔记之概念性理论
  18. Linux---Ubuntu18.04.03系统安装网易云音乐(解决2.5K屏网易云音乐界面字体过小问题)
  19. otrs软件_开源ITIL管理工具OTRS简单介绍
  20. 呼吸衰竭护理查房PPT模板

热门文章

  1. RK3288下如何实现虚拟摄像头。
  2. 微信小程序 动态添加view组件
  3. 潜在通路分析软件QA/SCAT V2.0用户使用说明书(目录)
  4. PyTorch入门到实战自然语言处理及计算机视觉01为什么选择Pytorch
  5. 力士乐变频器调试软件RDwin11V09
  6. 分享一款最新抽奖网站源码
  7. 转: C#的25个基础概念
  8. 如何更换天籁车钥匙电池
  9. golang调用c文件
  10. 基于双层蚂蚁算法和区域优化的机器人导航新算法 翻译+总结