1.44.Cifar10与ResNet18实战

Pytorch工程中建立pytorch,在pytorch里面创建lenet5.py、main.py、resnet.py。

1.44.1.lenet5.py

# -*- coding: UTF-8 -*-import torch
from torch import nnclass Lenet5(nn.Module):"""for cifar10 dataset."""def __init__(self):super(Lenet5, self).__init__()self.conv_unit = nn.Sequential(# x: [b, 3, 32, 32] => [b, 16, ]nn.Conv2d(3, 16, kernel_size=5, stride=1, padding=0),nn.MaxPool2d(kernel_size=2, stride=2, padding=0),nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=0),nn.MaxPool2d(kernel_size=2, stride=2, padding=0),)# flatten# fc unitself.fc_unit = nn.Sequential(nn.Linear(32 * 5 * 5, 32),nn.ReLU(),# nn.Linear(120, 84),# nn.ReLU(),nn.Linear(32, 10))# [b, 3, 32, 32]tmp = torch.randn(2, 3, 32, 32)out = self.conv_unit(tmp)# [b, 16, 5, 5]print('conv out:', out.shape)# # use Cross Entropy Loss# self.criteon = nn.CrossEntropyLoss()def forward(self, x):""":param x: [b, 3, 32, 32]:return:"""batchsz = x.size(0)# [b, 3, 32, 32] => [b, 16, 5, 5]x = self.conv_unit(x)# [b, 16, 5, 5] => [b, 16*5*5]x = x.view(batchsz, 32 * 5 * 5)# [b, 16*5*5] => [b, 10]logits = self.fc_unit(x)# # [b, 10]# pred = F.softmax(logits, dim=1)# loss = self.criteon(logits, y)return logitsdef main():net = Lenet5()tmp = torch.randn(2, 3, 32, 32)out = net(tmp)print('lenet out:', out.shape)if __name__ == '__main__':main()

1.44.2.Resnet.py

# -*- coding: UTF-8 -*-import torch
from torch import nn
from torch.nn import functional as Fclass ResBlk(nn.Module):"""resnet block"""def __init__(self, ch_in, ch_out, stride=1):""":param ch_in::param ch_out:"""super(ResBlk, self).__init__()# we add stride support for resbok, which is distinct from tutorials.self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1)self.bn1 = nn.BatchNorm2d(ch_out)self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)self.bn2 = nn.BatchNorm2d(ch_out)self.extra = nn.Sequential()if ch_out != ch_in:# [b, ch_in, h, w] => [b, ch_out, h, w]self.extra = nn.Sequential(nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride),nn.BatchNorm2d(ch_out))def forward(self, x):""":param x: [b, ch, h, w]:return:"""out = F.relu(self.bn1(self.conv1(x)))out = self.bn2(self.conv2(out))# short cut.# extra module: [b, ch_in, h, w] => [b, ch_out, h, w]# element-wise add:out = self.extra(x) + outout = F.relu(out)return outclass ResNet18(nn.Module):def __init__(self):super(ResNet18, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, stride=3, padding=0),nn.BatchNorm2d(64))# followed 4 blocks# [b, 64, h, w] => [b, 128, h ,w]self.blk1 = ResBlk(64, 128, stride=2)# [b, 128, h, w] => [b, 256, h, w]self.blk2 = ResBlk(128, 256, stride=2)# # [b, 256, h, w] => [b, 512, h, w]self.blk3 = ResBlk(256, 512, stride=2)# # [b, 512, h, w] => [b, 1024, h, w]self.blk4 = ResBlk(512, 512, stride=2)self.outlayer = nn.Linear(512 * 1 * 1, 10)def forward(self, x):""":param x::return:"""x = F.relu(self.conv1(x))# [b, 64, h, w] => [b, 1024, h, w]x = self.blk1(x)x = self.blk2(x)x = self.blk3(x)x = self.blk4(x)# print('after conv:', x.shape) #[b, 512, 2, 2]# [b, 512, h, w] => [b, 512, 1, 1]x = F.adaptive_avg_pool2d(x, [1, 1])# print('after pool:', x.shape)x = x.view(x.size(0), -1)x = self.outlayer(x)return xdef main():blk = ResBlk(64, 128, stride=4)tmp = torch.randn(2, 64, 32, 32)out = blk(tmp)print('block:', out.shape)x = torch.randn(2, 3, 32, 32)model = ResNet18()out = model(x)print('resnet:', out.shape)if __name__ == '__main__':main()

1.44.3.main.py

# -*- coding: UTF-8 -*-import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torch import nn, optimfrom pytorch.lenet5 import Lenet5
from pytorch.resnet import ResNet18def main():batchsz = 128cifar_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])]), download=True)cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)cifar_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])]), download=True)cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)x, label = iter(cifar_train).next()print('x:', x.shape, 'label:', label.shape)device = torch.device('cuda')# model = Lenet5().to(device)model = ResNet18().to(device)criteon = nn.CrossEntropyLoss().to(device)optimizer = optim.Adam(model.parameters(), lr=1e-3)print(model)for epoch in range(1000):model.train()for batchidx, (x, label) in enumerate(cifar_train):# [b, 3, 32, 32]# [b]x, label = x.to(device), label.to(device)logits = model(x)# logits: [b, 10]# label:  [b]# loss: tensor scalarloss = criteon(logits, label)# backpropoptimizer.zero_grad()loss.backward()optimizer.step()print(epoch, 'loss:', loss.item())model.eval()with torch.no_grad():# testtotal_correct = 0total_num = 0for x, label in cifar_test:# [b, 3, 32, 32]# [b]x, label = x.to(device), label.to(device)# [b, 10]logits = model(x)# [b]pred = logits.argmax(dim=1)# [b] vs [b] => scalar tensorcorrect = torch.eq(pred, label).float().sum().item()total_correct += correcttotal_num += x.size(0)# print(correct)acc = total_correct / total_numprint(epoch, 'test acc:', acc)if __name__ == '__main__':main()

Cifar10与ResNet18实战、lenet5、resnet(学习笔记)相关推荐

  1. 1.5 Hello, world! 解剖 -JSF实战 -hxzon -jsf学习笔记

    为什么80%的码农都做不了架构师?>>>    1.5 Hello, world! 解剖 -JSF实战 -hxzon -jsf学习笔记 既然已经对JSF能够解决什么问题有了初步理解, ...

  2. python3《机器学习实战系列》学习笔记----3.2 决策树实战

    前言 一.ID3算法构造决策树 1.1 背景 1.2 信息增益计算 1.3 递归生成决策树 二.使用Matplotlib注解绘制树形图 2.1 Matplotlib注解 2.2 构造注解树 三.测试和 ...

  3. 《崔庆才Python3网络爬虫开发实战教程》学习笔记(5):将爬虫爬取到的数据存储到TXT,Word,Excel,Json等文件中

    本篇博文是自己在学习崔庆才的<Python3网络爬虫开发实战教程>的学习笔记系列,此套教程共5章,加起来共有34节课,内容非常详细丰富!如果你也要这套视频教程的话,关注我公众号[小众技术] ...

  4. MySQL实战45讲学习笔记

    文章目录 MySQL实战45讲-学习笔记 01 基础架构:一条SQL查询语句是如何执行的? mysql逻辑架构 连接器 查询缓存 分析器 优化器 执行器 02 日志系统:一条SQL更新语句如何执行 r ...

  5. 《机器学习实战》kNN学习笔记

    <机器学习实战>kNN学习笔记 文章目录 <机器学习实战>kNN学习笔记 概述 优缺点 k-近邻算法的一般流程 简单案例kNN.py 在约会网站上使用k-近邻算法 归一化特征值 ...

  6. 【1】机器学习实战peter Harrington——学习笔记

    机器学习实战peter Harrington--学习笔记 综述 数据挖掘十大算法 本书结构 一.机器学习基础 1.1 机器学习 1.2 关键术语 1.3 机器学习主要任务 1.4 如何选择合适的算法 ...

  7. vn.py全实战进阶课程学习笔记(零)

    目录 写在前面 MySQL数据库配置 安装mysq 创建数据库 vnpy数据库配置 rqdata数据服务配置 申请rqdata试用权限 vnpy参数配置 simnow仿真环境配置 准备账号 接口登录 ...

  8. 《姜承尧的MySQL实战宝典》学习笔记

    <姜承尧的MySQL实战宝典>学习笔记 1 表结构设计 1.1 数字类型 1.1.1 整形类型 1.1.2 浮点类型和高精度型 1.1.3 实战--整型类型与自增设计 1.1.4 实战-- ...

  9. 从零开始带你成为MySQL实战优化高手学习笔记(一)

    重复是有必要的. 很多新入职的小朋友可能和现在的我一样,对数据库的了解仅仅停留在建库建表增删改查这些操作,日常工作也都是用封装好的代码,别说底层原理了,数据库和系统之间是如何工作都不是很懂. 长此以往 ...

最新文章

  1. socket sock inet_sock 等关系
  2. 身为Java程序员,这些开源工具你一定要学会!
  3. VTK:可视化之StreamLines
  4. Linux 系统应用编程——进程间通信(上)
  5. 为什么不能睁一只眼闭一只眼_自媒体人上哪里找非常多的原创短视频素材?我为什么一定要你做原创?...
  6. 如何对待基金评审负面意见?
  7. ExtJS6 Grid的日期编辑栏位处理
  8. Comet:基于HTTP长连接的“服务器推”技术
  9. kubeflow fairing详解
  10. 不要迷失在技术的海洋中(转)
  11. Origin绘制热重TG和微分热重DTG曲线
  12. C语言删除字符串中的单词
  13. 浅析GPU通信技术(上)-GPUDirect P2P
  14. [题解] 洛谷 P3603 雪辉
  15. Jupyter lab add kernel Python+Julia+R 【jupyter Notebook 切换Python环境】and【在jupyter Notebook中安装第三方库】
  16. GTX 770 (GK 104)
  17. 常见的两种python编译器的安装
  18. Linux下安装配置Cobra教程
  19. 7.3 进程管理之暂停、归档和策略
  20. 高级文秘、高级行政助理职业化训练

热门文章

  1. centos 关机命令_Docker 常用命令速查手册
  2. C语言简单计算器考虑优先级,利用你现有的c语言知识 设计开发一个简易计算器,可进行加、减、乘、除、求余运算。...
  3. VTK:数据结构比较用法实战
  4. JavaScript实现返回数字的二进制表示中使用的位数bitLength算法(附完整源码)
  5. wxWidgets:wxHyperlinkCtrl类用法
  6. 在并发中练习 Boost.Multiprecision多线程环境相关的测试程序
  7. boost::maximum_weighted_matching用法的测试程序
  8. boost::geometry::dot_product用法的测试程序
  9. boost::is_output_streamable用法的测试程序
  10. ITK:过滤图像而没有复制其数据