参考网址:

http://www.cnblogs.com/darkknightzh/p/9017854.html

https://pytorch.org/docs/stable/nn.html?highlight=conv2d#torch.nn.Conv2d

https://www.cnblogs.com/chuantingSDU/p/8120065.html

https://blog.csdn.net/chaolei3/article/details/79374563

1x1卷积

https://blog.csdn.net/u014114990/article/details/50767786

https://www.quora.com/How-are-1x1-convolutions-used-for-dimensionality-reduction

https://www.reddit.com/r/MachineLearning/comments/3oln72/1x1_convolutions_why_use_them/?st=is9xc9jn&sh=7b774d4d

理解错误的地方敬请谅解。

  1. 卷积

才发现一直理解错了CNN中的卷积操作。

假设输入输出大小不变,输入是NCinHW,输出是NCoHW。其中N为batchsize。卷积核的大小是kk。实际上共有CinCo个kk的卷积核,总共的参数是CinkkCo(无bias)或者Cinkk*Co+Co(有bias)。

pytorch中给出了conv2d的计算公式

(https://pytorch.org/docs/stable/nn.html?highlight=conv2d#torch.nn.Conv2d):

out(Ni,Coj)=bias(Coj)+∑k=0Cin−1weight(Coj,k)∗input(Ni,k)out({{N}_{i}},C{{o}_{j}})=bias(C{{o}_{j}})+\sum\limits_{k=0}^{Cin-1}{weight(C{{o}_{j}},k)*input({{N}_{i}},k)}out(Ni​,Coj​)=bias(Coj​)+k=0∑Cin−1​weight(Coj​,k)∗input(Ni​,k)

其中weight即为卷积核,上式中输出的batch中的第Ni个特征图的第Coj个特征,即为输入的第Ni个特征图的第k个特征,和第Coj个卷积核中的第k个核进行卷积(cross-correlation)。

如下图所示,对于某个输入特征图,其某局域分别于Co个卷积核进行卷积,得到对应的特征Coi,而后将这些特征拼接起来,得到最终的特征图。实际上每个卷积核都是kkCin的大小


经过上面的卷积,就可以将输入的不同的通道的信息融合了(权重不同,类似于加权融合)。

如果输出Co数量大于输入Cin数量,输出特征数量就多于输入特征。否则输出就少于输入特征数量。
回到顶部(go to top)
2. 1*1卷积

上面的卷积理解了,1*1卷积就好理解了。

11主要用于降维或者升维(看Cin和Co哪个更大),其核大小为11。

实际上卷积核的数量为Cin11Co=CinCo(无bias)或者Cin*Co+Co(有bias)。

计算时,通道方向上每个卷积核将输入按照通道进行加权,得到对应的输出特征,之后将这些特征拼接起来,即可得到最终的特征图。
回到顶部(go to top)
3. pytorch中的验证

代码:

from __future__ import print_function
from __future__ import divisionimport torch.nn as nn
import numpy as npclass testNet(nn.Module):def __init__(self):super(testNet, self).__init__()self.conv1 = nn.Conv2d(in_channels=3, out_channels=10, kernel_size=5, stride=1, padding=1, bias=True)def forward(self, x):x = self.conv1(x)return xdef get_total_params(model):model_parameters = filter(lambda p: p.requires_grad, model.parameters())num_params = sum([np.prod(p.size()) for p in model_parameters])return num_paramsdef main():net = testNet()print(get_total_params(net))if __name__ == '__main__':main()

上面代码中get_total_params用于得到模型总共的参数。

当kernel_size=5,bias=True时,参数共计760个:355*10+10=760。

当kernel_size=5,bias=False时,参数共计750个:355*10=750。

当kernel_size=1,bias=True时,参数共计40个:311*10+10=40。

当kernel_size=1,bias=False时,参数共计30个:311*10=30。

CNN中的卷积、1x1卷积及在pytorch中的验证相关推荐

  1. CNN 常用网络结构解析 1x1 卷积运算 示意图

    AlexNet 网络结构: VGG : conv3x3.conv5x5.conv7x7.conv9x9和conv11x11,在224x224x3的RGB图上(设置pad=1,stride=4,outp ...

  2. Lesson 16.5 在Pytorch中实现卷积网络(上):卷积核、输入通道与特征图在PyTorch中实现卷积网络(中):步长与填充

    卷积神经网络是使用卷积层的一组神经网络.在一个成熟的CNN中,往往会涉及到卷积层.池化层.线性层(全连接层)以及各类激活函数.因此,在构筑卷积网络时,需从整体全部层的需求来进行考虑. 1 二维卷积层n ...

  3. pytorch中的卷积操作详解

    首先说下pytorch中的Tensor通道排列顺序是:[batch, channel, height, width] 我们常用的卷积(Conv2d)在pytorch中对应的函数是: torch.nn. ...

  4. 深度学习/联邦学习笔记(六)卷积神经及相关案例+pytorch

    深度学习/联邦学习笔记(六) 卷积神经及相关案例+pytorch 卷积神经网络不同于一般的全连接神经网络,卷积神经网络是一个3D容量的神经元,即神经元是以三个维度来排列的:宽度.高度和深度 卷积神经网 ...

  5. Pytorch中的Conv1d()和Conv2d()函数

    文章目录 一.Pytorch中的Conv1d()函数 二.Pytorch中的Conv2d()函数 三.Pytorch中的MaxPool1d()函数 四.pytorch中的MaxPool2d()函数 参 ...

  6. 循环神经网络实现文本情感分类之Pytorch中LSTM和GRU模块使用

    循环神经网络实现文本情感分类之Pytorch中LSTM和GRU模块使用 1. Pytorch中LSTM和GRU模块使用 1.1 LSTM介绍 LSTM和GRU都是由torch.nn提供 通过观察文档, ...

  7. PyTorch中的C++扩展

    今天要聊聊用 PyTorch 进行 C++ 扩展. 在正式开始前,我们需要了解 PyTorch 如何自定义module.这其中,最常见的就是在 python 中继承torch.nn.Module,用 ...

  8. Lesson 15.2 学习率调度在PyTorch中的实现方法

    Lesson 15.2 学习率调度在PyTorch中的实现方法   学习率调度作为模型优化的重要方法,也集成在了PyTorch的optim模块中.我们可以通过下述代码将学习率调度模块进行导入. fro ...

  9. Pytorch 学习(7):Pytorch中的Non-linear Activations (非线性层)实现

    Pytorch 学习(7):Pytorch中的Non-linear Activations (非线性层)实现 Pytorch中的Non-linear Activations (非线性层)包括以下激活函 ...

  10. PyTorch学习笔记(15) ——PyTorch中的contiguous

    本文转载自栩风在知乎上的文章<PyTorch中的contiguous>.我觉得很好,特此转载. 0. 前言 本文讲解了pytorch中contiguous的含义.定义.实现,以及conti ...

最新文章

  1. 格式工厂软件处理视频
  2. ABAP 7.53 中的ABAP SQL(原Open SQL)新特性
  3. AQS理解之三,由刚才写的锁转变成一个公平锁
  4. MTFlexbox自动化埋点探索
  5. iOS学习笔记---oc语言第三天
  6. datetime 索引_MySQL 性能优化:MySQL 中的隐式转换造成的索引失效
  7. Visual Studio中Debug和Release的区别
  8. C# 连接SQL 连接字符串
  9. 使用yarn运行react项目指令_Jenkins | 使用yarn构建前端项目
  10. 如何使用Super Vectorizer在 Mac 上将 PDF 转换为 SVG 矢量?
  11. 'gbk' codec can't decode byte 0x9d in position 7674: illegal multibyte sequence
  12. 卖家如何利用关键词进行SEO优化以提高排名?
  13. 软件测试 — 面试题
  14. 明解C语言(入门篇)第十章
  15. Linux:Redis搭建集群
  16. 1521 一维战舰(区间)
  17. 消费新品周报 | AWE海尔推出无尘洗衣机;卡西欧F1红牛车队合作新款运动手表...
  18. 记一次智能灯泡的破解
  19. 转贴: 辞职日记----记录31岁的程序员跳槽心态
  20. CD网站用户消费行为的分析报告

热门文章

  1. 阿里云服务器部署GeoServer以及跨域处理
  2. ArcMAP TIN与栅格DEM的坡度坡向对比分析
  3. ENVI-IDL基础学习(1)
  4. java中enum怎么用_java 中enum的使用方法详解
  5. 一种边播边下的播放策略
  6. 工作两年和研究生两年(专业硕士)有什么差异?
  7. 小红帽linux操作教程_linux入门教程 Redhat使用指南
  8. 保护计算机系统与数据有什么方法,电脑数据保护方法 看完保你不后悔
  9. js 正则匹配邮箱_比较正宗的验证邮箱的正则表达式js代码详解
  10. html 文本框 p,Javascript实现HTML表单form多个HttpPost请求