ResNet-18结构

基本结点

代码实现

import torch
import torch.nn as nn
from torch.nn import functional as Fclass RestNetBasicBlock(nn.Module):def __init__(self, in_channels, out_channels, stride):super(RestNetBasicBlock, self).__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)self.bn1 = nn.BatchNorm2d(out_channels)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1)self.bn2 = nn.BatchNorm2d(out_channels)def forward(self, x):output = self.conv1(x)output = F.relu(self.bn1(output))output = self.conv2(output)output = self.bn2(output)return F.relu(x + output)class RestNetDownBlock(nn.Module):def __init__(self, in_channels, out_channels, stride):super(RestNetDownBlock, self).__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride[0], padding=1)self.bn1 = nn.BatchNorm2d(out_channels)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride[1], padding=1)self.bn2 = nn.BatchNorm2d(out_channels)self.extra = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride[0], padding=0),nn.BatchNorm2d(out_channels))def forward(self, x):extra_x = self.extra(x)output = self.conv1(x)out = F.relu(self.bn1(output))out = self.conv2(out)out = self.bn2(out)return F.relu(extra_x + out)class RestNet18(nn.Module):def __init__(self):super(RestNet18, self).__init__()self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)self.bn1 = nn.BatchNorm2d(64)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.layer1 = nn.Sequential(RestNetBasicBlock(64, 64, 1),RestNetBasicBlock(64, 64, 1))self.layer2 = nn.Sequential(RestNetDownBlock(64, 128, [2, 1]),RestNetBasicBlock(128, 128, 1))self.layer3 = nn.Sequential(RestNetDownBlock(128, 256, [2, 1]),RestNetBasicBlock(256, 256, 1))self.layer4 = nn.Sequential(RestNetDownBlock(256, 512, [2, 1]),RestNetBasicBlock(512, 512, 1))self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))self.fc = nn.Linear(512, 10)def forward(self, x):out = self.conv1(x)out = self.layer1(out)out = self.layer2(out)out = self.layer3(out)out = self.layer4(out)out = self.avgpool(out)out = out.reshape(x.shape[0], -1)out = self.fc(out)return out

用来预测CIFAR-10数据集

数据集

官网链接:CIFAR-10 DATASET

测试代码

import torch
from torch import nn, optim
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader
from restnet18.restnet18 import RestNet18#  用CIFAR-10 数据集进行实验def main():batchsz = 128cifar_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])]), download=True)cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)cifar_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])]), download=True)cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)x, label = iter(cifar_train).next()print('x:', x.shape, 'label:', label.shape)device = torch.device('cuda')# model = Lenet5().to(device)model = RestNet18().to(device)criteon = nn.CrossEntropyLoss().to(device)optimizer = optim.Adam(model.parameters(), lr=1e-3)print(model)for epoch in range(1000):model.train()for batchidx, (x, label) in enumerate(cifar_train):# [b, 3, 32, 32]# [b]x, label = x.to(device), label.to(device)logits = model(x)# logits: [b, 10]# label:  [b]# loss: tensor scalarloss = criteon(logits, label)# backpropoptimizer.zero_grad()loss.backward()optimizer.step()print(epoch, 'loss:', loss.item())model.eval()with torch.no_grad():# testtotal_correct = 0total_num = 0for x, label in cifar_test:# [b, 3, 32, 32]# [b]x, label = x.to(device), label.to(device)# [b, 10]logits = model(x)# [b]pred = logits.argmax(dim=1)# [b] vs [b] => scalar tensorcorrect = torch.eq(pred, label).float().sum().item()total_correct += correcttotal_num += x.size(0)# print(correct)acc = total_correct / total_numprint(epoch, 'test acc:', acc)if __name__ == '__main__':main()

运行结果


感觉挺low的,迭代50多次能达到80多的准确率

PyTorch实现ResNet18相关推荐

  1. 【pytorch】ResNet18、ResNet20、ResNet34、ResNet50网络结构与实现

    文章目录 ResNet主体 BasicBlock ResNet18 ResNet34 ResNet20 Bottleneck Block ResNet50 ResNet到底解决了什么问题 选取经典的早 ...

  2. pytorch——用resnet18做动物多分类问题(含可视化结果)

    代码以及数据集下载:https://github.com/duchp/python-all/tree/master/CV%20code/动物多分类项目 一.任务介绍 纲分类问题,预测该动物是属于哺乳纲 ...

  3. Pytorch之模型微调(Finetune)——用Resnet18进行蚂蚁蜜蜂二分类为例

    Pytorch之模型微调(Finetune)--手写数字集为例 文章目录 Pytorch之模型微调(Finetune)--手写数字集为例 前言 一.Transfer Learning and Mode ...

  4. PyTorch 源码解读之 torch.serialization torch.hub

    作者 | 123456 来源 | OpenMMLab 编辑 | 极市平台 导读 本文解读基于PyTorch 1.7版本,对torch.serialization.torch.save和torch.hu ...

  5. Pytorch基于卷积神经网络的猫狗识别

    实验环境 Pytorch 1.4.0 conda 4.7.12 Jupyter Notebook 6.0.1 Python 3.7 数据集介绍 实验采用的猫和狗的图片来自 Kaggle 竞赛的一个赛题 ...

  6. 通过深度残差网络ResNet进行图像分类(pytorch网络多网络集成配置)

    通过深度残差网络进行图像分类(pytorch网络多网络集成配置) 简介 本项目通过配置文件修改,实现pytorch的ResNet18, ResNet34, ResNet50, ResNet101, R ...

  7. 使用PyTorch搭建ResNet50网络

    ResNet18的搭建请移步:使用PyTorch搭建ResNet18网络并使用CIFAR10数据集训练测试 ResNet34的搭建请移步:使用PyTorch搭建ResNet34网络 ResNet101 ...

  8. 量化交易有因子动物园 深度学习里有模型动物园(ModelZoo)又叫模型市场基于深度学习的增量学习,迁移学习等技术发展而来【调研】

    前言 随着迁移模型的概念流行起来,就像快乐会传染样,自然语言处理,计算机视觉,生成模型,强化学习,非监监督学习,语音识别 这几个领域内部产生了大量的可复用可迁移学习的基础模型,领域之间的方法也在互相学 ...

  9. ResNet 18 的结构解读

    现在很多网络结构都是一个命名+数字,比如(ResNet18),数字代表的是网络的深度,也就是说ResNet18 网络就是18层的吗?其实这里的18指定的是带有权重的 18层,包括卷积层和全连接层,不包 ...

  10. Pytorch模型量化实践并以ResNet18模型量化为例(附代码)

    更多.更及时内容欢迎微信公众号:小窗幽记机器学习 围观,后续会进一步整理模型推理加速和部署方面的相关内容. 文章目录 量化基础知识 映射函数 量化参数 校准(Calibration) Affine和S ...

最新文章

  1. 相机标定 matlab opencv ROS三种方法标定步骤(2)
  2. matlab中提供滤波器的种类有,求MATLAB巴特沃思低通滤波器程序
  3. 生成0到1之间随机数的C代码
  4. 一个RSS阅读器的源码,不敢独享!
  5. 一个大神的文章如何在没有OPENCV的电脑上跑你的程序调用的opencv程序
  6. C++中容器的使用(二)
  7. 阿里云主要产品及功能介绍,阿里云产品分为6大分类:云计算基础/安全/大数据/人工智能/企业应用/物联网...
  8. 利用DiskGenius对电脑磁盘系统数据迁移
  9. 图像的灰度化和二值化
  10. VirtualBox虚拟机安装Windows XP
  11. 小猪短租网requests库使用
  12. Comsumer的一些解释
  13. CMAP是否可以使用RemoveAll()释放内存?CMAP如何释放内存?
  14. opencv 摄像头捕获的图像保存为avi视频 代码解析
  15. bios设置raid启动模式Linux,华硕主板BIOS里哪项是开启RAID方式?
  16. 特征工程(七):图像特征提取和深度学习
  17. 编译错误:errno undeclared的解决
  18. 提高扫地机器人避障能力,景联文科技提供专业数据采集服务
  19. “白条”黑产追踪:京东账号遭大量泄露,黑市价高至每个千元
  20. UG NX 曲面造型方法

热门文章

  1. 数据分类分级指南分级方法
  2. PowerPCB转Protel 99的详细教程
  3. 在网页中实现透明flash的代码
  4. mw150um 驱动程序win10_水星MW150UM 1.0无线网卡驱动
  5. 量子计算机采用量子力学原理,量子力学原理及其应用.docx
  6. 批处理 bat for 详解
  7. 批处理FOR 中的Delims和Tokens总结
  8. SPDY与http2
  9. Android中TextToSpeech的使用
  10. 常见运维监控系统的技术选型