(刘二大人)PyTorch深度学习实践-卷积网络(Advance)
1. 1x1的卷积核的作用
- 在width和height不变的基础上改变通道的数量
- 减少计算量
2. GoogLeNet中Inception Module的实现
2.1 Inception块的代码实现
import torch
import torch.nn.functional as Fclass InceptinA(torch.nn.Module):def __init__(self,channels):super(InceptinA, self).__init__()self.branch_pool = torch.nn.Conv2d(channels,24,kernel_size=1)self.branch1x1 = torch.nn.Conv2d(channels,16,kernel_size=1)self.branch5x5_1 = torch.nn.Conv2d(channels,16,kernel_size=1)self.branch5x5_2 = torch.nn.Conv2d(16,24,kernel_size=5,padding=2)#使用了5x5的卷积核,为保证w和h不变,使用padding=2self.branch3x3_1 = torch.nn.Conv2d(channels,16,kernel_size=1)self.branch3x3_2 = torch.nn.Conv2d(16,24,kernel_size=3,padding=1)self.branch3x3_3 = torch.nn.Conv2d(24,24,kernel_size=3,padding=1)def forward(self,x):branch_pool = F.avg_pool2d(x,kernel_size=3,padding=1,stride=1) #本来默认stride就是1branch_pool = self.branch_pool(branch_pool)branch1x1 = self.branch1x1(x)branch5x5 = self.branch5x5_2(self.branch5x5_1(x))branch3x3 = self.branch3x3_3(self.branch3x3_2(self.branch3x3_1(x)))outputs = [branch_pool,branch1x1,branch5x5,branch3x3]return torch.cat(outputs,dim=1) #BxCxWxH,dim=1按照通道数进行拼接
2.2 使用模块构建卷积网络训练Minist数据集
2.3 整体代码实现
import torch
from Inception import InceptinA
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from torchvision import datasets,transforms#追踪日志
writer = SummaryWriter(log_dir='../LEDR')#准备数据集
trans = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,),(0.3801,))])
train_set = datasets.MNIST(root='E:\learn_pytorch\LE',train=True,transform=trans,download=True)
test_set = datasets.MNIST(root='E:\learn_pytorch\LE',train=False,transform=trans,download=True)#下载数据集
train_data = DataLoader(dataset=train_set,batch_size=64,shuffle=True)
test_data = DataLoader(dataset=test_set,batch_size=64,shuffle=False)#构建模型
class Net(torch.nn.Module):def __init__(self):super(Net, self).__init__()self.conv_1 = torch.nn.Conv2d(1,10,kernel_size=5)#输出变成 10x24x24self.conv_2 = torch.nn.Conv2d(88,20,kernel_size=5)# 输出变成 20x12x12self.mp = torch.nn.MaxPool2d(2)self.incept1 = InceptinA(channels=10)self.incept2 = InceptinA(channels=20)self.fc = torch.nn.Linear(1408,10)def forward(self,x):x = F.relu(self.mp(self.conv_1(x)))# 输出为 10x12x12x = self.incept1(x) #输出是88x12x12x = F.relu(self.mp(self.conv_2(x)))# 输出是 20x4x4x = self.incept2(x) #输出是 88x4x4x = x.view(-1,1408)x = self.fc(x)return x#实例化模型
huihui = Net()#定义损失函数和优化函数
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(params=huihui.parameters(),lr=0.01,momentum=0.5)#开始训练
def train(epoch):run_loss = 0.0for batch_id , data in enumerate(train_data,0):inputs , targets = dataoutputs = huihui(inputs)loss = criterion(outputs, targets)#归零,反馈,更新optimizer.zero_grad()loss.backward()optimizer.step()run_loss += loss.item()if batch_id % 300 == 299:print("[%d,%d] loss:%.3f" %(epoch+1,batch_id+1,run_loss/300))run_loss = 0.0def test():total = 0correct = 0with torch.no_grad():for data in test_data:inputs , labels = dataoutputs = huihui(inputs)_,predict = torch.max(outputs,dim=1)total += labels.size(0)correct += (predict==labels).sum().item()writer.add_scalar("The Accuracy1",correct/total,epoch)print('[Accuracy] %d %%' % (100*correct/total))if __name__ == '__main__':for epoch in range(10):train(epoch)test()writer.close()
2.4 结果展示(正确率还是98%)
D:\Anaconda3\envs\pytorch\python.exe E:/learn_pytorch/LE/Inception_model.py
[1,300] loss:0.961
[1,600] loss:0.207
[1,900] loss:0.143
[Accuracy] 96 %
[2,300] loss:0.115
[2,600] loss:0.095
[2,900] loss:0.098
[Accuracy] 97 %
[3,300] loss:0.083
[3,600] loss:0.081
[3,900] loss:0.071
[Accuracy] 98 %
[4,300] loss:0.068
[4,600] loss:0.066
[4,900] loss:0.069
[Accuracy] 98 %
[5,300] loss:0.063
[5,600] loss:0.055
[5,900] loss:0.054
[Accuracy] 98 %
[6,300] loss:0.054
[6,600] loss:0.053
[6,900] loss:0.050
[Accuracy] 98 %
[7,300] loss:0.047
[7,600] loss:0.050
[7,900] loss:0.048
[Accuracy] 98 %
[8,300] loss:0.043
[8,600] loss:0.041
[8,900] loss:0.050
[Accuracy] 98 %
[9,300] loss:0.041
[9,600] loss:0.040
[9,900] loss:0.043
[Accuracy] 98 %
[10,300] loss:0.037
[10,600] loss:0.038
[10,900] loss:0.040
[Accuracy] 98 %Process finished with exit code 0
2.5 图像展示
(刘二大人)PyTorch深度学习实践-卷积网络(Advance)相关推荐
- 【刘二大人 - PyTorch深度学习实践】学习随手记(一)
目录 1. Overview 1.Human Intelligence 2.Machine Learning 3.How to develop learning system? 4.Tradition ...
- 刘二大人 PyTorch深度学习实践 笔记 P6 逻辑斯蒂回归
刘二大人 PyTorch深度学习实践 笔记 P6 逻辑斯蒂回归 P6 逻辑斯蒂回归 1.torchversion 提供的数据集 2.基本概念 3.代码实现 P6 逻辑斯蒂回归 1.torchversi ...
- PyTorch 深度学习实践 GPU版本B站 刘二大人第11讲卷积神经网络(高级篇)GPU版本
第11讲 卷积神经网络(高级篇) GPU版本源代码 原理是基于B站 刘二大人 :传送门PyTorch深度学习实践--卷积神经网络(高级篇) 这篇基于博主错错莫:传送门 深度学习实践 第11讲博文 仅在 ...
- 【刘二大人】PyTorch深度学习实践
文章目录 一.overview 1 机器学习 二.Linear_Model(线性模型) 1 例子引入 三.Gradient_Descent(梯度下降法) 1 梯度下降 2 梯度下降与随机梯度下降(SG ...
- 【Pytorch深度学习实践】B站up刘二大人课程笔记——目录与索引(已完结)
从有代码的课程开始讨论 [Pytorch深度学习实践]B站up刘二大人之LinearModel -代码理解与实现(1/9) [Pytorch深度学习实践]B站up刘二大人之 Gradient Desc ...
- 【Pytorch深度学习实践】B站up刘二大人之 Gradient Descend-代码理解与实现(2/9)
开篇几句题外话: 以往的代码,都是随便看看就过去了,没有这样较真过,以至于看了很久的深度学习和Python,都没有能够形成编程能力: 这次算是废寝忘食的深入进去了,踏实地把每一个代码都理解透,包括其中 ...
- 【Pytorch深度学习实践】B站up刘二大人之BasicCNN Advanced CNN -代码理解与实现(9/9)
这是刘二大人系列课程笔记的 最后一个笔记了,介绍的是 BasicCNN 和 AdvancedCNN ,我做图像,所以后面的RNN我可能暂时不会花时间去了解了: 写在前面: 本节把基础个高级CNN放在一 ...
- 笔记|(b站)刘二大人:pytorch深度学习实践(代码详细笔记,适合零基础)
pytorch深度学习实践 笔记中的代码是根据b站刘二大人的课程所做的笔记,代码每一行都有注释方便理解,可以配套刘二大人视频一同使用. 用PyTorch实现线性回归 # 1.算预测值 # 2.算los ...
- 【Pytorch深度学习实践】B站up刘二大人之SoftmaxClassifier-代码理解与实现(8/9)
这是刘二大人系列课程笔记的倒数第二个博客了,介绍的是多分类器的原理和代码实现,下一个笔记就是basicCNN和advancedCNN了: 写在前面: 这节课的内容,主要是两个部分的修改: 一是数据集: ...
最新文章
- 在阿里干了五年,面试个小公司挂了…
- VTK:小部件之TexturedButtonWidget
- 异常检测机器学习_使用机器学习检测异常
- 计算机网络实验vc6实现串口通信,用vc的串口通信实验报告.docx
- CCF202104-1 灰度直方图
- 构建安全的计算机网络报告,计算机网络与安全实践设计报告 矿大资料.doc
- 95-190-642-源码-窗口操作符-EvictingWindowOperator
- spring boot2.0.4集成druid,用jmeter并发测试工具调用接口,druid查看监控的结果
- 四阶龙格库塔方法求解一次常微分方程组
- 车企Tier1的日子不好过
- 决定使用JBPM3、JBPM4、Drools Folw 还是等待JBPM5?
- C++播放wav音乐和音效
- 哈马德国际机场在全球最佳机场评选中排名第一;合肥君悦酒店浪漫呈现“悦-七夕”限定晚宴 | 全球旅报...
- 支付宝面试:什么是序列化,怎么序列化,为什么序列化,反序列化会遇到什么问题,如何解决?...
- 质量检验中那些不为人所知的事儿
- android关机闹钟设计思路
- python模块相互引用_python导入模块交叉引用的方法
- 2014年计算机考研,2014年计算机考研大纲原文
- 【转】关于linux中wps出现系统字体缺失的解决方法
- U盘文件被病毒破坏的常见迹象和数据恢复方法