CNN中的卷积、1x1卷积及在pytorch中的验证
参考网址:
http://www.cnblogs.com/darkknightzh/p/9017854.html
https://pytorch.org/docs/stable/nn.html?highlight=conv2d#torch.nn.Conv2d
https://www.cnblogs.com/chuantingSDU/p/8120065.html
https://blog.csdn.net/chaolei3/article/details/79374563
1x1卷积
https://blog.csdn.net/u014114990/article/details/50767786
https://www.quora.com/How-are-1x1-convolutions-used-for-dimensionality-reduction
https://www.reddit.com/r/MachineLearning/comments/3oln72/1x1_convolutions_why_use_them/?st=is9xc9jn&sh=7b774d4d
理解错误的地方敬请谅解。
- 卷积
才发现一直理解错了CNN中的卷积操作。
假设输入输出大小不变,输入是NCinHW,输出是NCoHW。其中N为batchsize。卷积核的大小是kk。实际上共有CinCo个kk的卷积核,总共的参数是CinkkCo(无bias)或者Cinkk*Co+Co(有bias)。
pytorch中给出了conv2d的计算公式
(https://pytorch.org/docs/stable/nn.html?highlight=conv2d#torch.nn.Conv2d):
out(Ni,Coj)=bias(Coj)+∑k=0Cin−1weight(Coj,k)∗input(Ni,k)out({{N}_{i}},C{{o}_{j}})=bias(C{{o}_{j}})+\sum\limits_{k=0}^{Cin-1}{weight(C{{o}_{j}},k)*input({{N}_{i}},k)}out(Ni,Coj)=bias(Coj)+k=0∑Cin−1weight(Coj,k)∗input(Ni,k)
其中weight即为卷积核,上式中输出的batch中的第Ni个特征图的第Coj个特征,即为输入的第Ni个特征图的第k个特征,和第Coj个卷积核中的第k个核进行卷积(cross-correlation)。
如下图所示,对于某个输入特征图,其某局域分别于Co个卷积核进行卷积,得到对应的特征Coi,而后将这些特征拼接起来,得到最终的特征图。实际上每个卷积核都是kkCin的大小
经过上面的卷积,就可以将输入的不同的通道的信息融合了(权重不同,类似于加权融合)。
如果输出Co数量大于输入Cin数量,输出特征数量就多于输入特征。否则输出就少于输入特征数量。
回到顶部(go to top)
2. 1*1卷积
上面的卷积理解了,1*1卷积就好理解了。
11主要用于降维或者升维(看Cin和Co哪个更大),其核大小为11。
实际上卷积核的数量为Cin11Co=CinCo(无bias)或者Cin*Co+Co(有bias)。
计算时,通道方向上每个卷积核将输入按照通道进行加权,得到对应的输出特征,之后将这些特征拼接起来,即可得到最终的特征图。
回到顶部(go to top)
3. pytorch中的验证
代码:
from __future__ import print_function
from __future__ import divisionimport torch.nn as nn
import numpy as npclass testNet(nn.Module):def __init__(self):super(testNet, self).__init__()self.conv1 = nn.Conv2d(in_channels=3, out_channels=10, kernel_size=5, stride=1, padding=1, bias=True)def forward(self, x):x = self.conv1(x)return xdef get_total_params(model):model_parameters = filter(lambda p: p.requires_grad, model.parameters())num_params = sum([np.prod(p.size()) for p in model_parameters])return num_paramsdef main():net = testNet()print(get_total_params(net))if __name__ == '__main__':main()
上面代码中get_total_params用于得到模型总共的参数。
当kernel_size=5,bias=True时,参数共计760个:355*10+10=760。
当kernel_size=5,bias=False时,参数共计750个:355*10=750。
当kernel_size=1,bias=True时,参数共计40个:311*10+10=40。
当kernel_size=1,bias=False时,参数共计30个:311*10=30。
CNN中的卷积、1x1卷积及在pytorch中的验证相关推荐
- CNN 常用网络结构解析 1x1 卷积运算 示意图
AlexNet 网络结构: VGG : conv3x3.conv5x5.conv7x7.conv9x9和conv11x11,在224x224x3的RGB图上(设置pad=1,stride=4,outp ...
- Lesson 16.5 在Pytorch中实现卷积网络(上):卷积核、输入通道与特征图在PyTorch中实现卷积网络(中):步长与填充
卷积神经网络是使用卷积层的一组神经网络.在一个成熟的CNN中,往往会涉及到卷积层.池化层.线性层(全连接层)以及各类激活函数.因此,在构筑卷积网络时,需从整体全部层的需求来进行考虑. 1 二维卷积层n ...
- pytorch中的卷积操作详解
首先说下pytorch中的Tensor通道排列顺序是:[batch, channel, height, width] 我们常用的卷积(Conv2d)在pytorch中对应的函数是: torch.nn. ...
- 深度学习/联邦学习笔记(六)卷积神经及相关案例+pytorch
深度学习/联邦学习笔记(六) 卷积神经及相关案例+pytorch 卷积神经网络不同于一般的全连接神经网络,卷积神经网络是一个3D容量的神经元,即神经元是以三个维度来排列的:宽度.高度和深度 卷积神经网 ...
- Pytorch中的Conv1d()和Conv2d()函数
文章目录 一.Pytorch中的Conv1d()函数 二.Pytorch中的Conv2d()函数 三.Pytorch中的MaxPool1d()函数 四.pytorch中的MaxPool2d()函数 参 ...
- 循环神经网络实现文本情感分类之Pytorch中LSTM和GRU模块使用
循环神经网络实现文本情感分类之Pytorch中LSTM和GRU模块使用 1. Pytorch中LSTM和GRU模块使用 1.1 LSTM介绍 LSTM和GRU都是由torch.nn提供 通过观察文档, ...
- PyTorch中的C++扩展
今天要聊聊用 PyTorch 进行 C++ 扩展. 在正式开始前,我们需要了解 PyTorch 如何自定义module.这其中,最常见的就是在 python 中继承torch.nn.Module,用 ...
- Lesson 15.2 学习率调度在PyTorch中的实现方法
Lesson 15.2 学习率调度在PyTorch中的实现方法 学习率调度作为模型优化的重要方法,也集成在了PyTorch的optim模块中.我们可以通过下述代码将学习率调度模块进行导入. fro ...
- Pytorch 学习(7):Pytorch中的Non-linear Activations (非线性层)实现
Pytorch 学习(7):Pytorch中的Non-linear Activations (非线性层)实现 Pytorch中的Non-linear Activations (非线性层)包括以下激活函 ...
- PyTorch学习笔记(15) ——PyTorch中的contiguous
本文转载自栩风在知乎上的文章<PyTorch中的contiguous>.我觉得很好,特此转载. 0. 前言 本文讲解了pytorch中contiguous的含义.定义.实现,以及conti ...
最新文章
- 格式工厂软件处理视频
- ABAP 7.53 中的ABAP SQL(原Open SQL)新特性
- AQS理解之三,由刚才写的锁转变成一个公平锁
- MTFlexbox自动化埋点探索
- iOS学习笔记---oc语言第三天
- datetime 索引_MySQL 性能优化:MySQL 中的隐式转换造成的索引失效
- Visual Studio中Debug和Release的区别
- C# 连接SQL 连接字符串
- 使用yarn运行react项目指令_Jenkins | 使用yarn构建前端项目
- 如何使用Super Vectorizer在 Mac 上将 PDF 转换为 SVG 矢量?
- 'gbk' codec can't decode byte 0x9d in position 7674: illegal multibyte sequence
- 卖家如何利用关键词进行SEO优化以提高排名?
- 软件测试 — 面试题
- 明解C语言(入门篇)第十章
- Linux:Redis搭建集群
- 1521 一维战舰(区间)
- 消费新品周报 | AWE海尔推出无尘洗衣机;卡西欧F1红牛车队合作新款运动手表...
- 记一次智能灯泡的破解
- 转贴: 辞职日记----记录31岁的程序员跳槽心态
- CD网站用户消费行为的分析报告
热门文章
- 阿里云服务器部署GeoServer以及跨域处理
- ArcMAP TIN与栅格DEM的坡度坡向对比分析
- ENVI-IDL基础学习(1)
- java中enum怎么用_java 中enum的使用方法详解
- 一种边播边下的播放策略
- 工作两年和研究生两年(专业硕士)有什么差异?
- 小红帽linux操作教程_linux入门教程 Redhat使用指南
- 保护计算机系统与数据有什么方法,电脑数据保护方法 看完保你不后悔
- js 正则匹配邮箱_比较正宗的验证邮箱的正则表达式js代码详解
- html 文本框 p,Javascript实现HTML表单form多个HttpPost请求