官方文档在这里。

conv1d具体不做介绍了,本篇只做pytorch的API使用介绍.

torch.nn.Conv1d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode=‘zeros’, device=None, dtype=None)

计算公式

输入张量的Shape一般为(N,Cin,L)(N,C_{in}, L)(N,Cin​,L),其中N为batch_size,一般也可用B代替;
CinC_{in}Cin​输入张量倒数第二维,表示Channel的数量;
L是输入信号序列的长度;

输出张量的shape一般为(N,Cout,Lout)(N,C_{out},L_{out})(N,Cout​,Lout​), 下一节会介绍怎么计算来的。

公式截图如下,

星号表示卷积运算,torch的运算这里使用的是cross-correlation,即互相关运算

运算过程

假设输入的序列为 [1,5,7,3,2,1,6,9], 卷积核为[2,4,6,1,3]

那么对应nn.conv1d初始化就是 in_channel=1, out_channel=1, stride=1, padding=0

计算过程中,每计算一次(kernel_size),移动一步(stride),
这个过程就是上图公式中的weight(Cout,k)∗input(Ni,k)weight(C_{out},k) * input(N_i,k)weight(Cout​,k)∗input(Ni​,k),其中k是out_channel的,
如果out_channel个数大于1,则这样的过程会有多次。

入参

  • in_channels
    Number of channels in the input image
    输入的Channel数,对应的是输入数据的倒数第二维
  • out_channels
    Number of channels produced by the convolution
    输出的Channel数,对应输出数据的倒数第二维度
  • kernel_size
    Size of the convolving kernel
    即卷积核长度
    它可以是一个数字也可以是一个tuple(但是conv1d下,tuple是否有意义?
  • stride
    Stride of the convolution. Default: 1
    卷积核步长
  • padding
    Padding added to both sides of the input. Default: 0
  • padding_mode
    ‘zeros’, ‘reflect’, ‘replicate’ or ‘circular’. Default: ‘zeros’
  • dilation
    Spacing between kernel elements. Default: 1
  • groups
    Number of blocked connections from input channels to output channels. Default: 1
  • bias
    If True, adds a learnable bias to the output. Default: True
    是否需要bias

输出结果的shape计算

代码举例

1

net = nn.Conv1d(in_channels=1, out_channels=8, kernel_size=4, stride=1, bias=False)
x = torch.linspace(1,10,10).view(1,1,10)y = net(x)
print(y.shape)

计算结果

torch.Size([1, 8, 7])

2. 修改stride

net = nn.Conv1d(in_channels=1, out_channels=8, kernel_size=4, stride=2, bias=False)
x = torch.linspace(1,10,10).view(1,1,10)y = net(x)
print(y.shape)

计算结果

torch.Size([1, 8, 4])

3 添加padding

net = nn.Conv1d(in_channels=1, out_channels=8, kernel_size=4, stride=1, padding=1, padding_mode='zeros',bias=False)
x = torch.linspace(1,10,10).view(1,1,10)y = net(x)
print(y.shape)

计算结果

torch.Size([1, 8, 5])

4 修改dilation

net = nn.Conv1d(in_channels=1, out_channels=8, kernel_size=4, stride=2, dilation=2,bias=False)
x = torch.linspace(1,10,10).view(1,1,10)y = net(x)
print(y.shape)

计算结果

torch.Size([1, 8, 2])

【pytorch】nn.conv1d的使用相关推荐

  1. pytorch nn.Conv1d

    一维卷积nn.Conv1d一般用于文本数据 1.应用 import torch import torch.nn as nnx = torch.randn(1, 1, 32) # batch, chan ...

  2. PyTorch中的nn.Conv1d与nn.Conv2d

    本文主要介绍PyTorch中的nn.Conv1d和nn.Conv2d方法,并给出相应代码示例,加深理解. 一维卷积nn.Conv1d 一般来说,一维卷积nn.Conv1d用于文本数据,只对宽度进行卷积 ...

  3. Pytorch之nn.Conv1d学习个人见解

    Pytorch之nn.Conv1d学习个人见解 一.官方文档(务必先耐心阅读) 官方文档:点击打开<CONV1D> 二.Conv1d个人见解 Conv1d类构成 class torch.n ...

  4. 【Pytorch】torch.nn.Conv1d()理解与使用

    官方文档:https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html?highlight=nn%20conv1d#torch.nn.C ...

  5. pytorch笔记(四)nn.Conv1d、nn.Conv2d、nn.Conv3d

    概念: nn.Conv1d:常用在文本 (B,C,L) (batch,channel,sequence_len) (批数量,通道数,句子长度) nn.Conv2d:常用在图像 (B,C,H,W) (b ...

  6. pytorch —— nn网络层 - 卷积层

    目录 1.1d/2d/3d卷积 2.卷积-nn.Conv1d() 2.1 Conv1d的参数说明 2.2 例子说明 3.卷积-nn.Conv2d() 3.1 深入了解卷积层的参数 4.转置卷积-nn. ...

  7. 根据PyTorch学习CONV1D

    新手刚学习卷积,不知道理解的有没有问题,如果有问题劳烦大家指出. 1. PyTorch中的torch.nn.Conv1d()函数 官方文档链接 torch.nn.Conv1d(in_channels, ...

  8. NLP中的卷积操作详解(torch.nn.Conv1d)

    NLP领域中,由于自然文本是一维的,通常使用一维卷积即可达到要求. 在实际应用中,经embedding层处理后的数据格式一般为(batch_size, word_embeddings_dim, max ...

  9. pytorch nn.Embedding

    pytorch nn.Embedding class torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_n ...

最新文章

  1. ubuntu 10.0.4安装小企鹅(Fcitx)输入法
  2. 【Linux】——常见的rc的含义
  3. Hibernate框架--学习笔记(上):hibernate项目的搭建和常用接口方法、对象的使用
  4. 中石油训练赛 - DNA(字符串哈希)
  5. 第三次学JAVA再学不好就吃翔(part18)--数组操作
  6. php socket 实现ftp,用socket实现FTP教程
  7. gdal1.6linux编译,VS2015下编译64位GDAL总结
  8. c#中,如何获取日期型字段里的年、月、日?
  9. ITUT-T recommendations G.168 标准回声模型
  10. webpack打包处理字体文件
  11. C语言--小学生计算机辅助教学系统
  12. 小结一篇-(秀我工作一年)
  13. MATLAB 绘制堆叠柱状图
  14. 【U8】登录账套显示“账套XXX年度XXXX是以前版本的数据,请使用系统管理升级”
  15. 2018年全国计算机一级考试大纲,2018年全国计算机等级考试一级Photoshop考试大纲...
  16. 举个栗子!Tableau 技巧(139):突出显示文本表的行或列
  17. 关于\xEF\xBB\xBF的介绍
  18. Android 超级玛丽跳跃动画,Doodle Mario Jump 涂鸦马里奥跳跃
  19. python数据分析及可视化(一)课程介绍以及统计学的应用、介绍、分类、基本概念及描述性统计
  20. (九)Fabric2.0 通道实践-更新通道配置(修改区块交易数量)

热门文章

  1. java方法重载和重写在jvm_重载和重写在jvm运行中的区别(一)
  2. BLASTN format=6
  3. 对硕士而言,编制和稳定究竟有多重要?
  4. 微生物组数据库: 一站式环境基因组学数据云平台更新啦!
  5. 文件批量重命名的技术,你值得拥有
  6. 水稻微生物组时间序列分析
  7. 小米云能同步到华为手机上吗_有没有小米还没涉足的产业?对标百度网盘,小米云盘即将上线...
  8. R语言使用ggplot2包使用geom_dotplot函数绘制分组点图(双分类变量分组可视化)实战(dot plot)
  9. 深度学习时间序列预测:LSTM算法构建时间序列单变量模型预测空气质量(PM2.5)+代码实战
  10. R语言关系操作符:>、<=、!=、>=、==、