Cifar10与ResNet18实战、lenet5、resnet(学习笔记)
1.44.Cifar10与ResNet18实战
Pytorch工程中建立pytorch,在pytorch里面创建lenet5.py、main.py、resnet.py。
1.44.1.lenet5.py
# -*- coding: UTF-8 -*-import torch
from torch import nnclass Lenet5(nn.Module):"""for cifar10 dataset."""def __init__(self):super(Lenet5, self).__init__()self.conv_unit = nn.Sequential(# x: [b, 3, 32, 32] => [b, 16, ]nn.Conv2d(3, 16, kernel_size=5, stride=1, padding=0),nn.MaxPool2d(kernel_size=2, stride=2, padding=0),nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=0),nn.MaxPool2d(kernel_size=2, stride=2, padding=0),)# flatten# fc unitself.fc_unit = nn.Sequential(nn.Linear(32 * 5 * 5, 32),nn.ReLU(),# nn.Linear(120, 84),# nn.ReLU(),nn.Linear(32, 10))# [b, 3, 32, 32]tmp = torch.randn(2, 3, 32, 32)out = self.conv_unit(tmp)# [b, 16, 5, 5]print('conv out:', out.shape)# # use Cross Entropy Loss# self.criteon = nn.CrossEntropyLoss()def forward(self, x):""":param x: [b, 3, 32, 32]:return:"""batchsz = x.size(0)# [b, 3, 32, 32] => [b, 16, 5, 5]x = self.conv_unit(x)# [b, 16, 5, 5] => [b, 16*5*5]x = x.view(batchsz, 32 * 5 * 5)# [b, 16*5*5] => [b, 10]logits = self.fc_unit(x)# # [b, 10]# pred = F.softmax(logits, dim=1)# loss = self.criteon(logits, y)return logitsdef main():net = Lenet5()tmp = torch.randn(2, 3, 32, 32)out = net(tmp)print('lenet out:', out.shape)if __name__ == '__main__':main()
1.44.2.Resnet.py
# -*- coding: UTF-8 -*-import torch
from torch import nn
from torch.nn import functional as Fclass ResBlk(nn.Module):"""resnet block"""def __init__(self, ch_in, ch_out, stride=1):""":param ch_in::param ch_out:"""super(ResBlk, self).__init__()# we add stride support for resbok, which is distinct from tutorials.self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1)self.bn1 = nn.BatchNorm2d(ch_out)self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)self.bn2 = nn.BatchNorm2d(ch_out)self.extra = nn.Sequential()if ch_out != ch_in:# [b, ch_in, h, w] => [b, ch_out, h, w]self.extra = nn.Sequential(nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride),nn.BatchNorm2d(ch_out))def forward(self, x):""":param x: [b, ch, h, w]:return:"""out = F.relu(self.bn1(self.conv1(x)))out = self.bn2(self.conv2(out))# short cut.# extra module: [b, ch_in, h, w] => [b, ch_out, h, w]# element-wise add:out = self.extra(x) + outout = F.relu(out)return outclass ResNet18(nn.Module):def __init__(self):super(ResNet18, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, stride=3, padding=0),nn.BatchNorm2d(64))# followed 4 blocks# [b, 64, h, w] => [b, 128, h ,w]self.blk1 = ResBlk(64, 128, stride=2)# [b, 128, h, w] => [b, 256, h, w]self.blk2 = ResBlk(128, 256, stride=2)# # [b, 256, h, w] => [b, 512, h, w]self.blk3 = ResBlk(256, 512, stride=2)# # [b, 512, h, w] => [b, 1024, h, w]self.blk4 = ResBlk(512, 512, stride=2)self.outlayer = nn.Linear(512 * 1 * 1, 10)def forward(self, x):""":param x::return:"""x = F.relu(self.conv1(x))# [b, 64, h, w] => [b, 1024, h, w]x = self.blk1(x)x = self.blk2(x)x = self.blk3(x)x = self.blk4(x)# print('after conv:', x.shape) #[b, 512, 2, 2]# [b, 512, h, w] => [b, 512, 1, 1]x = F.adaptive_avg_pool2d(x, [1, 1])# print('after pool:', x.shape)x = x.view(x.size(0), -1)x = self.outlayer(x)return xdef main():blk = ResBlk(64, 128, stride=4)tmp = torch.randn(2, 64, 32, 32)out = blk(tmp)print('block:', out.shape)x = torch.randn(2, 3, 32, 32)model = ResNet18()out = model(x)print('resnet:', out.shape)if __name__ == '__main__':main()
1.44.3.main.py
# -*- coding: UTF-8 -*-import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torch import nn, optimfrom pytorch.lenet5 import Lenet5
from pytorch.resnet import ResNet18def main():batchsz = 128cifar_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])]), download=True)cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)cifar_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])]), download=True)cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)x, label = iter(cifar_train).next()print('x:', x.shape, 'label:', label.shape)device = torch.device('cuda')# model = Lenet5().to(device)model = ResNet18().to(device)criteon = nn.CrossEntropyLoss().to(device)optimizer = optim.Adam(model.parameters(), lr=1e-3)print(model)for epoch in range(1000):model.train()for batchidx, (x, label) in enumerate(cifar_train):# [b, 3, 32, 32]# [b]x, label = x.to(device), label.to(device)logits = model(x)# logits: [b, 10]# label: [b]# loss: tensor scalarloss = criteon(logits, label)# backpropoptimizer.zero_grad()loss.backward()optimizer.step()print(epoch, 'loss:', loss.item())model.eval()with torch.no_grad():# testtotal_correct = 0total_num = 0for x, label in cifar_test:# [b, 3, 32, 32]# [b]x, label = x.to(device), label.to(device)# [b, 10]logits = model(x)# [b]pred = logits.argmax(dim=1)# [b] vs [b] => scalar tensorcorrect = torch.eq(pred, label).float().sum().item()total_correct += correcttotal_num += x.size(0)# print(correct)acc = total_correct / total_numprint(epoch, 'test acc:', acc)if __name__ == '__main__':main()
Cifar10与ResNet18实战、lenet5、resnet(学习笔记)相关推荐
- 1.5 Hello, world! 解剖 -JSF实战 -hxzon -jsf学习笔记
为什么80%的码农都做不了架构师?>>> 1.5 Hello, world! 解剖 -JSF实战 -hxzon -jsf学习笔记 既然已经对JSF能够解决什么问题有了初步理解, ...
- python3《机器学习实战系列》学习笔记----3.2 决策树实战
前言 一.ID3算法构造决策树 1.1 背景 1.2 信息增益计算 1.3 递归生成决策树 二.使用Matplotlib注解绘制树形图 2.1 Matplotlib注解 2.2 构造注解树 三.测试和 ...
- 《崔庆才Python3网络爬虫开发实战教程》学习笔记(5):将爬虫爬取到的数据存储到TXT,Word,Excel,Json等文件中
本篇博文是自己在学习崔庆才的<Python3网络爬虫开发实战教程>的学习笔记系列,此套教程共5章,加起来共有34节课,内容非常详细丰富!如果你也要这套视频教程的话,关注我公众号[小众技术] ...
- MySQL实战45讲学习笔记
文章目录 MySQL实战45讲-学习笔记 01 基础架构:一条SQL查询语句是如何执行的? mysql逻辑架构 连接器 查询缓存 分析器 优化器 执行器 02 日志系统:一条SQL更新语句如何执行 r ...
- 《机器学习实战》kNN学习笔记
<机器学习实战>kNN学习笔记 文章目录 <机器学习实战>kNN学习笔记 概述 优缺点 k-近邻算法的一般流程 简单案例kNN.py 在约会网站上使用k-近邻算法 归一化特征值 ...
- 【1】机器学习实战peter Harrington——学习笔记
机器学习实战peter Harrington--学习笔记 综述 数据挖掘十大算法 本书结构 一.机器学习基础 1.1 机器学习 1.2 关键术语 1.3 机器学习主要任务 1.4 如何选择合适的算法 ...
- vn.py全实战进阶课程学习笔记(零)
目录 写在前面 MySQL数据库配置 安装mysq 创建数据库 vnpy数据库配置 rqdata数据服务配置 申请rqdata试用权限 vnpy参数配置 simnow仿真环境配置 准备账号 接口登录 ...
- 《姜承尧的MySQL实战宝典》学习笔记
<姜承尧的MySQL实战宝典>学习笔记 1 表结构设计 1.1 数字类型 1.1.1 整形类型 1.1.2 浮点类型和高精度型 1.1.3 实战--整型类型与自增设计 1.1.4 实战-- ...
- 从零开始带你成为MySQL实战优化高手学习笔记(一)
重复是有必要的. 很多新入职的小朋友可能和现在的我一样,对数据库的了解仅仅停留在建库建表增删改查这些操作,日常工作也都是用封装好的代码,别说底层原理了,数据库和系统之间是如何工作都不是很懂. 长此以往 ...
最新文章
- socket sock inet_sock 等关系
- 身为Java程序员,这些开源工具你一定要学会!
- VTK:可视化之StreamLines
- Linux 系统应用编程——进程间通信(上)
- 为什么不能睁一只眼闭一只眼_自媒体人上哪里找非常多的原创短视频素材?我为什么一定要你做原创?...
- 如何对待基金评审负面意见?
- ExtJS6 Grid的日期编辑栏位处理
- Comet:基于HTTP长连接的“服务器推”技术
- kubeflow fairing详解
- 不要迷失在技术的海洋中(转)
- Origin绘制热重TG和微分热重DTG曲线
- C语言删除字符串中的单词
- 浅析GPU通信技术(上)-GPUDirect P2P
- [题解] 洛谷 P3603 雪辉
- Jupyter lab add kernel Python+Julia+R 【jupyter Notebook 切换Python环境】and【在jupyter Notebook中安装第三方库】
- GTX 770 (GK 104)
- 常见的两种python编译器的安装
- Linux下安装配置Cobra教程
- 7.3 进程管理之暂停、归档和策略
- 高级文秘、高级行政助理职业化训练
热门文章
- centos 关机命令_Docker 常用命令速查手册
- C语言简单计算器考虑优先级,利用你现有的c语言知识 设计开发一个简易计算器,可进行加、减、乘、除、求余运算。...
- VTK:数据结构比较用法实战
- JavaScript实现返回数字的二进制表示中使用的位数bitLength算法(附完整源码)
- wxWidgets:wxHyperlinkCtrl类用法
- 在并发中练习 Boost.Multiprecision多线程环境相关的测试程序
- boost::maximum_weighted_matching用法的测试程序
- boost::geometry::dot_product用法的测试程序
- boost::is_output_streamable用法的测试程序
- ITK:过滤图像而没有复制其数据