pytorch实现简单的Resnet网络
笔者也是最近刚学不久的深度学习,也有很多地方不懂,下面给大家使用pytorch实现一个简单的Resnet网络(残差网络),并且训练MNIST数据集.话不多说,直接上代码.
笔者认为最主要的地方就是网络模型,网络模型出来其实基本上就完成了大概了.首先是残差块.之后是残差网络,数据集,训练,测试.完整代码请下拉看最后.
残差块的结构如下所示:
#残差块
class ResidualBlock(nn.Module):def __init__(self,channel):super(ResidualBlock, self).__init__()self.channel=channelself.conv1=nn.Sequential(nn.Conv2d(in_channels=channel,out_channels=channel,kernel_size=3,stride=1,padding=1),nn.BatchNorm2d(channel),nn.ReLU(inplace=True))self.conv2=nn.Sequential(nn.Conv2d(channel,channel,kernel_size=3,stride=1,padding=1),# nn.BatchNorm2d(channel))def forward(self,x):out=self.conv1(x)out=self.conv2(out)out+=xout=F.relu(out)return out
残差块完成之后,就是残差网络,如下图:
#残差网络
class ResNet(nn.Module):def __init__(self):super(ResNet, self).__init__()self.conv1=nn.Sequential(nn.Conv2d(in_channels=1,out_channels=32,kernel_size=5), #(1,28,28)nn.BatchNorm2d(32), #(32,24,24)nn.ReLU(),nn.MaxPool2d(2) #(32,12,12))self.conv2 = nn.Sequential(nn.Conv2d(in_channels=32, out_channels=16, kernel_size=5), #(16,8,8)nn.BatchNorm2d(16),nn.ReLU(),nn.MaxPool2d(2) #(16,4,4))self.reslayer1=ResidualBlock(32)self.reslayer2=ResidualBlock(16)self.fc=nn.Linear(256,10) #这里的输入256是因为16*4*4=256def forward(self,x):out=self.conv1(x)out=self.reslayer1(out)out=self.conv2(out)out=self.reslayer2(out)out=out.view(out.size(0),-1)out=self.fc(out)return out
数据集:
Epoch=3
Batch_Size=50
LR=0.01#训练集
trainData=torchvision.datasets.MNIST(root="/home/sunrui/zqtstudy/卷积网络/ResNetsimple/data",train=True,transform=torchvision.transforms.ToTensor(),download=False)train_loader=Data.DataLoader(dataset=trainData,batch_size=Batch_Size,shuffle=True)
test_data=torchvision.datasets.MNIST(root="/home/sunrui/zqtstudy/卷积网络/ResNetsimple/data",train=False,download=False)test_x = torch.unsqueeze(test_data.data, dim=1).type(torch.FloatTensor)[:5000]/255. # shape from (2000, 28, 28) to (2000, 1, 28, 28), value in range(0,1)
test_y = test_data.targets[:5000]
训练:
#关于训练
def Train(Res):# 损失函数,以及优化器loss_func = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(Res.parameters(), lr=LR)for epoch in range(Epoch):for step,(b_x,b_y)in enumerate(train_loader):output=Res(b_x)loss=loss_func(output,b_y)optimizer.zero_grad()loss.backward()optimizer.step()if(step%50==0):test_output=Res(test_x)pred_y = torch.max(test_output, 1)[1].data.numpy()accuracy = float((pred_y == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0))print('Epoch: ', epoch, '| train loss: %.4f' % loss.data.numpy(), '| test accuracy: %.6f' % accuracy)torch.save(Res, 'res_minist.pkl')print('res finish training')
测试:
# 测试
def Restest():res=torch.load('res_minist.pkl')test_output=res(test_x[:100])prediction=torch.max(test_output,1)[1].data.numpy()print(prediction, 'prediction number')print(test_y[:100].numpy(), 'real number')test_output1 = res(test_x)pred_y1 = torch.max(test_output1, 1)[1].data.numpy()accuracy = float((pred_y1 == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0))print('accuracy', accuracy)
大体就是这样,然后,在一个py文件中全部写出来就OK了,下面是完整代码:
import torch
import torch.nn as nn
import torchvision
from torch.autograd import Variable
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torch.utils.data as DataEpoch=3
Batch_Size=50
LR=0.01#训练集
trainData=torchvision.datasets.MNIST(root="/home/sunrui/zqtstudy/卷积网络/ResNetsimple/data",train=True,transform=torchvision.transforms.ToTensor(),download=False)train_loader=Data.DataLoader(dataset=trainData,batch_size=Batch_Size,shuffle=True)
test_data=torchvision.datasets.MNIST(root="/home/sunrui/zqtstudy/卷积网络/ResNetsimple/data",train=False,download=False)test_x = torch.unsqueeze(test_data.data, dim=1).type(torch.FloatTensor)[:5000]/255. # shape from (2000, 28, 28) to (2000, 1, 28, 28), value in range(0,1)
test_y = test_data.targets[:5000]#残差块
class ResidualBlock(nn.Module):def __init__(self,channel):super(ResidualBlock, self).__init__()self.channel=channelself.conv1=nn.Sequential(nn.Conv2d(in_channels=channel,out_channels=channel,kernel_size=3,stride=1,padding=1),nn.BatchNorm2d(channel),nn.ReLU(inplace=True))self.conv2=nn.Sequential(nn.Conv2d(channel,channel,kernel_size=3,stride=1,padding=1),# nn.BatchNorm2d(channel))def forward(self,x):out=self.conv1(x)out=self.conv2(out)out+=xout=F.relu(out)return out#残差网络
class ResNet(nn.Module):def __init__(self):super(ResNet, self).__init__()self.conv1=nn.Sequential(nn.Conv2d(in_channels=1,out_channels=32,kernel_size=5), #(1,28,28)nn.BatchNorm2d(32), #(32,24,24)nn.ReLU(),nn.MaxPool2d(2) #(32,12,12))self.conv2 = nn.Sequential(nn.Conv2d(in_channels=32, out_channels=16, kernel_size=5), #(16,8,8)nn.BatchNorm2d(16),nn.ReLU(),nn.MaxPool2d(2) #(16,4,4))self.reslayer1=ResidualBlock(32)self.reslayer2=ResidualBlock(16)self.fc=nn.Linear(256,10) #这里的输入256是因为16*4*4=256def forward(self,x):out=self.conv1(x)out=self.reslayer1(out)out=self.conv2(out)out=self.reslayer2(out)out=out.view(out.size(0),-1)out=self.fc(out)return out#关于训练
def Train(Res):# 损失函数,以及优化器loss_func = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(Res.parameters(), lr=LR)for epoch in range(Epoch):for step,(b_x,b_y)in enumerate(train_loader):output=Res(b_x)loss=loss_func(output,b_y)optimizer.zero_grad()loss.backward()optimizer.step()if(step%50==0):test_output=Res(test_x)pred_y = torch.max(test_output, 1)[1].data.numpy()accuracy = float((pred_y == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0))print('Epoch: ', epoch, '| train loss: %.4f' % loss.data.numpy(), '| test accuracy: %.6f' % accuracy)torch.save(Res, 'res_minist.pkl')print('res finish training')# x=torch.randn(16,1,28,28)
# res=ResNet()# 测试
def Restest():res=torch.load('res_minist.pkl')test_output=res(test_x[:100])prediction=torch.max(test_output,1)[1].data.numpy()print(prediction, 'prediction number')print(test_y[:100].numpy(), 'real number')test_output1 = res(test_x)pred_y1 = torch.max(test_output1, 1)[1].data.numpy()accuracy = float((pred_y1 == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0))print('accuracy', accuracy)if __name__=='__main__':# Train(res)Restest()
测试或者训练的话,在 if __name__=='__main__': 部分注释掉相应的部分就好啦.另外训练或者测试的数据大小可以自己选择,毕竟MNIST的训练数据集有6W张来着......笔者自己试了accuracy是在98%99%左右,这个是能够跑通的,跑通了才放上来的.数据集的加载训练测试代码参考,还有简单残差网络的知识请看下面链接,个人觉得对于刚开始的新手还是很有帮助的.
知乎的一位大佬的pytorch搭建CNN用于图像识别
B站刘二大人老师的 pytorch深度学习实践P11课
有问题可以留言,不定期看到会回复~(另外因为是第一次另外我不知道这算原创还是算转载,感觉没有特别合适的.所以我发表出来的时候选的类型是原创,要是并不属于这个类型,请告知我,我会改掉的,谢谢)
pytorch实现简单的Resnet网络相关推荐
- PyTorch实现简单的残差网络
一.实现过程 残差网络(Residual Network)的特点是容易优化,并且能够通过增加相当的深度来提高准确率.其内部的残差块使用了跳跃连接,缓解了在深度神经网络中增加深度带来的梯度消失问题. 本 ...
- 利用Pytorch搭建简单的图像分类模型(之二)---搭建网络
Pytorch搭建网络模型-ResNet 一.ResNet的两个结构 首先来看一下ResNet和一般卷积网络结构上的差异: 图中上面一部分就是ResNet34的网络结构图,下面可以理解为一个含有34层 ...
- ResNet网络详解并使用pytorch搭建模型、并基于迁移学习训练
1.ResNet网络详解 网络中的创新点: (1)超深的网络结构(突破1000层) (2)提出residual模块 (3)使用Batch Normalization加速训练(丢弃dropout) (1 ...
- ResNet网络简单理解与代码
ResNet网络提出的文章是<Deep Residual Learning for Image Recognition> 下载地址:https://arxiv.org/pdf/1512.0 ...
- Pytorch实现Deep Mutual Learning网络
-Model(pytorch版本 参考资料: 信息熵是什么? 交叉熵和相对熵(KL散度), 极大似然估计求loss, softmax多分类 一文搞懂熵.相对熵.交叉熵损失 class torch.nn ...
- 【读点论文】A ConvNet for the 2020s,结合swin transformer的结构设计和训练技巧调整resnet网络,在类似的FLOPs和参数量取得更好一点的效果
A ConvNet for the 2020s Abstract 视觉识别的"咆哮的20年代"始于视觉transformer(ViTs)的问世,它迅速取代ConvNets成为最先进 ...
- ResNet网络详解
ResNet ResNet在2015年由微软实验室提出,斩获当年lmageNet竞赛中分类任务第一名,目标检测第一名.获得coco数据集中目标检测第一名,图像分割第一名. ResNet亮点 1.超深的 ...
- 【Pytorch(七)】基于 PyTorch 实现残差神经网络 ResNet
基于 PyTorch 实现残差神经网络 ResNet 文章目录 基于 PyTorch 实现残差神经网络 ResNet 0. 概述 1. 数据集介绍 1.1 数据集准备 1.2 分析分类难度:CIFAR ...
- ResNet网络的训练和预测
ResNet网络的训练和预测 简介 Introduction 图像分类与CNN 图像分类 是指将图像信息中所反映的不同特征,把不同类别的目标区分开来的图像处理方法,是计算机视觉中其他任务,比如目标检测 ...
最新文章
- c++矩阵作为函数输入变量_C++实现矩阵乘法
- ajax初试,获取数据
- python基础面试都问什么问题_基本 Python 面试问题
- python环境配置-windows版
- 神逸之作:国产快速启动软件神品ALTRun
- ASIC开发流程介绍
- 理财趣事:要想财富滚滚来 先学普京打野猪
- 【GitHub报错】You have not concluded your merge (MERGE_HEAD exists).解决方法
- 大班线描机器人_大班美术lbrack;漂亮的机器人rsqb;活动设计
- linux服务器IP伪造,Linux服务器间同网段IP伪装端口映射
- java 完全解耦_java-完全解耦
- 中国最懒城市,这里的人不想赚钱,只想躺平
- python画老虎_老虎证券量化API Python SDK
- torch norm() Formalize()
- Python pygame(GUI编程)模块最完整教程(1)
- 第二十一天Python之进程
- 声呐学习笔记之概念性理论
- Linux---Ubuntu18.04.03系统安装网易云音乐(解决2.5K屏网易云音乐界面字体过小问题)
- otrs软件_开源ITIL管理工具OTRS简单介绍
- 呼吸衰竭护理查房PPT模板