1. 1x1的卷积核的作用

  1. 在width和height不变的基础上改变通道的数量
  2. 减少计算量

2. GoogLeNet中Inception Module的实现

2.1 Inception块的代码实现

import torch
import torch.nn.functional as Fclass InceptinA(torch.nn.Module):def __init__(self,channels):super(InceptinA, self).__init__()self.branch_pool = torch.nn.Conv2d(channels,24,kernel_size=1)self.branch1x1 = torch.nn.Conv2d(channels,16,kernel_size=1)self.branch5x5_1 = torch.nn.Conv2d(channels,16,kernel_size=1)self.branch5x5_2 = torch.nn.Conv2d(16,24,kernel_size=5,padding=2)#使用了5x5的卷积核,为保证w和h不变,使用padding=2self.branch3x3_1 = torch.nn.Conv2d(channels,16,kernel_size=1)self.branch3x3_2 = torch.nn.Conv2d(16,24,kernel_size=3,padding=1)self.branch3x3_3 = torch.nn.Conv2d(24,24,kernel_size=3,padding=1)def forward(self,x):branch_pool = F.avg_pool2d(x,kernel_size=3,padding=1,stride=1) #本来默认stride就是1branch_pool = self.branch_pool(branch_pool)branch1x1 = self.branch1x1(x)branch5x5 = self.branch5x5_2(self.branch5x5_1(x))branch3x3 = self.branch3x3_3(self.branch3x3_2(self.branch3x3_1(x)))outputs = [branch_pool,branch1x1,branch5x5,branch3x3]return torch.cat(outputs,dim=1) #BxCxWxH,dim=1按照通道数进行拼接

2.2 使用模块构建卷积网络训练Minist数据集

2.3  整体代码实现

import torch
from Inception import InceptinA
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from torchvision import datasets,transforms#追踪日志
writer = SummaryWriter(log_dir='../LEDR')#准备数据集
trans = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,),(0.3801,))])
train_set = datasets.MNIST(root='E:\learn_pytorch\LE',train=True,transform=trans,download=True)
test_set = datasets.MNIST(root='E:\learn_pytorch\LE',train=False,transform=trans,download=True)#下载数据集
train_data = DataLoader(dataset=train_set,batch_size=64,shuffle=True)
test_data = DataLoader(dataset=test_set,batch_size=64,shuffle=False)#构建模型
class Net(torch.nn.Module):def __init__(self):super(Net, self).__init__()self.conv_1 = torch.nn.Conv2d(1,10,kernel_size=5)#输出变成 10x24x24self.conv_2 = torch.nn.Conv2d(88,20,kernel_size=5)# 输出变成 20x12x12self.mp = torch.nn.MaxPool2d(2)self.incept1 = InceptinA(channels=10)self.incept2 = InceptinA(channels=20)self.fc = torch.nn.Linear(1408,10)def forward(self,x):x = F.relu(self.mp(self.conv_1(x)))# 输出为 10x12x12x = self.incept1(x) #输出是88x12x12x = F.relu(self.mp(self.conv_2(x)))# 输出是 20x4x4x = self.incept2(x) #输出是 88x4x4x = x.view(-1,1408)x = self.fc(x)return x#实例化模型
huihui = Net()#定义损失函数和优化函数
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(params=huihui.parameters(),lr=0.01,momentum=0.5)#开始训练
def train(epoch):run_loss = 0.0for batch_id , data in enumerate(train_data,0):inputs , targets = dataoutputs = huihui(inputs)loss = criterion(outputs, targets)#归零,反馈,更新optimizer.zero_grad()loss.backward()optimizer.step()run_loss += loss.item()if batch_id % 300 == 299:print("[%d,%d] loss:%.3f" %(epoch+1,batch_id+1,run_loss/300))run_loss = 0.0def test():total = 0correct = 0with torch.no_grad():for data in test_data:inputs , labels = dataoutputs = huihui(inputs)_,predict = torch.max(outputs,dim=1)total += labels.size(0)correct += (predict==labels).sum().item()writer.add_scalar("The Accuracy1",correct/total,epoch)print('[Accuracy] %d %%' % (100*correct/total))if __name__ == '__main__':for epoch in range(10):train(epoch)test()writer.close()

2.4 结果展示(正确率还是98%)

D:\Anaconda3\envs\pytorch\python.exe E:/learn_pytorch/LE/Inception_model.py
[1,300] loss:0.961
[1,600] loss:0.207
[1,900] loss:0.143
[Accuracy] 96 %
[2,300] loss:0.115
[2,600] loss:0.095
[2,900] loss:0.098
[Accuracy] 97 %
[3,300] loss:0.083
[3,600] loss:0.081
[3,900] loss:0.071
[Accuracy] 98 %
[4,300] loss:0.068
[4,600] loss:0.066
[4,900] loss:0.069
[Accuracy] 98 %
[5,300] loss:0.063
[5,600] loss:0.055
[5,900] loss:0.054
[Accuracy] 98 %
[6,300] loss:0.054
[6,600] loss:0.053
[6,900] loss:0.050
[Accuracy] 98 %
[7,300] loss:0.047
[7,600] loss:0.050
[7,900] loss:0.048
[Accuracy] 98 %
[8,300] loss:0.043
[8,600] loss:0.041
[8,900] loss:0.050
[Accuracy] 98 %
[9,300] loss:0.041
[9,600] loss:0.040
[9,900] loss:0.043
[Accuracy] 98 %
[10,300] loss:0.037
[10,600] loss:0.038
[10,900] loss:0.040
[Accuracy] 98 %

Process finished with exit code 0

2.5 图像展示

(刘二大人)PyTorch深度学习实践-卷积网络(Advance)相关推荐

  1. 【刘二大人 - PyTorch深度学习实践】学习随手记(一)

    目录 1. Overview 1.Human Intelligence 2.Machine Learning 3.How to develop learning system? 4.Tradition ...

  2. 刘二大人 PyTorch深度学习实践 笔记 P6 逻辑斯蒂回归

    刘二大人 PyTorch深度学习实践 笔记 P6 逻辑斯蒂回归 P6 逻辑斯蒂回归 1.torchversion 提供的数据集 2.基本概念 3.代码实现 P6 逻辑斯蒂回归 1.torchversi ...

  3. PyTorch 深度学习实践 GPU版本B站 刘二大人第11讲卷积神经网络(高级篇)GPU版本

    第11讲 卷积神经网络(高级篇) GPU版本源代码 原理是基于B站 刘二大人 :传送门PyTorch深度学习实践--卷积神经网络(高级篇) 这篇基于博主错错莫:传送门 深度学习实践 第11讲博文 仅在 ...

  4. 【刘二大人】PyTorch深度学习实践

    文章目录 一.overview 1 机器学习 二.Linear_Model(线性模型) 1 例子引入 三.Gradient_Descent(梯度下降法) 1 梯度下降 2 梯度下降与随机梯度下降(SG ...

  5. 【Pytorch深度学习实践】B站up刘二大人课程笔记——目录与索引(已完结)

    从有代码的课程开始讨论 [Pytorch深度学习实践]B站up刘二大人之LinearModel -代码理解与实现(1/9) [Pytorch深度学习实践]B站up刘二大人之 Gradient Desc ...

  6. 【Pytorch深度学习实践】B站up刘二大人之 Gradient Descend-代码理解与实现(2/9)

    开篇几句题外话: 以往的代码,都是随便看看就过去了,没有这样较真过,以至于看了很久的深度学习和Python,都没有能够形成编程能力: 这次算是废寝忘食的深入进去了,踏实地把每一个代码都理解透,包括其中 ...

  7. 【Pytorch深度学习实践】B站up刘二大人之BasicCNN Advanced CNN -代码理解与实现(9/9)

    这是刘二大人系列课程笔记的 最后一个笔记了,介绍的是 BasicCNN 和 AdvancedCNN ,我做图像,所以后面的RNN我可能暂时不会花时间去了解了: 写在前面: 本节把基础个高级CNN放在一 ...

  8. 笔记|(b站)刘二大人:pytorch深度学习实践(代码详细笔记,适合零基础)

    pytorch深度学习实践 笔记中的代码是根据b站刘二大人的课程所做的笔记,代码每一行都有注释方便理解,可以配套刘二大人视频一同使用. 用PyTorch实现线性回归 # 1.算预测值 # 2.算los ...

  9. 【Pytorch深度学习实践】B站up刘二大人之SoftmaxClassifier-代码理解与实现(8/9)

    这是刘二大人系列课程笔记的倒数第二个博客了,介绍的是多分类器的原理和代码实现,下一个笔记就是basicCNN和advancedCNN了: 写在前面: 这节课的内容,主要是两个部分的修改: 一是数据集: ...

最新文章

  1. 在阿里干了五年,面试个小公司挂了…
  2. VTK:小部件之TexturedButtonWidget
  3. 异常检测机器学习_使用机器学习检测异常
  4. 计算机网络实验vc6实现串口通信,用vc的串口通信实验报告.docx
  5. CCF202104-1 灰度直方图
  6. 构建安全的计算机网络报告,计算机网络与安全实践设计报告 矿大资料.doc
  7. 95-190-642-源码-窗口操作符-EvictingWindowOperator
  8. spring boot2.0.4集成druid,用jmeter并发测试工具调用接口,druid查看监控的结果
  9. 四阶龙格库塔方法求解一次常微分方程组
  10. 车企Tier1的日子不好过
  11. 决定使用JBPM3、JBPM4、Drools Folw 还是等待JBPM5?
  12. C++播放wav音乐和音效
  13. 哈马德国际机场在全球最佳机场评选中排名第一;合肥君悦酒店浪漫呈现“悦-七夕”限定晚宴 | 全球旅报...
  14. 支付宝面试:什么是序列化,怎么序列化,为什么序列化,反序列化会遇到什么问题,如何解决?...
  15. 质量检验中那些不为人所知的事儿
  16. android关机闹钟设计思路
  17. python模块相互引用_python导入模块交叉引用的方法
  18. 2014年计算机考研,2014年计算机考研大纲原文
  19. 【转】关于linux中wps出现系统字体缺失的解决方法
  20. U盘文件被病毒破坏的常见迹象和数据恢复方法

热门文章

  1. Cadence Allegro过电阻电容的XNET等长图文视频演示
  2. Casual inference 综述框架
  3. Adlik在深度学习异构计算上的实践
  4. 视频教程-嵌入式读图基础-智能硬件
  5. 如何判断一个男人将来是穷还是富?
  6. 最新红旗linux系统,红旗Linux10系统下载
  7. 关于Java你不知道的那些事之等等与equals的区别
  8. 使用navicat新建sqlite数据库
  9. 蔡氏电路混沌同步Multisim实现
  10. Redis网站热搜关键词加载实践,建议收藏