Pytorch之nn.Conv1d学习个人见解

一、官方文档(务必先耐心阅读)

官方文档:点击打开《CONV1D》

二、Conv1d个人见解

Conv1d类构成

  • class torch.nn.Conv1d(in_channels, out_channels, kernel_size,stride=1, padding=0, dilation=1, groups=1, bias=True)
  • in_channels(int)—输入数据的通道数。在文本分类中,即为句子中单个词的词向量的维度。 (word_vector_num)
  • out_channels(int)—输出数据的通道数。设置 N 个输出通道数,就有 N 个1维卷积核。(new word_vector_num)
  • kernel_size(int or tuple) —卷积核的长度,1维卷积中卷积核的实际大小维度是(in_channels,kernel_size),顺序不可互换。
  • stride(int or tuple, optional)—卷积步长。
  • padding (int or tuple, optional)—输入的每一条边补充0的层数。
  • dilation(int or tuple, `optional``)—卷积核元素之间的间距。
  • groups(int, optional)—从输入通道到输出通道的阻塞连接数。
  • bias(bool, optional)—如果bias=True,添加偏置。

具体案例分析

  • 原始数据集说明:6批句子(batch_size),每批句子5个单词(sentence_word_num),每个单词的词向量为3维通道(word_vector_num),数据集的维度表示为 [6,5,3] 。

  • 模型输入数据集说明:在上步原始数据集中进行维度转换,6批句子(batch_size),每个单词的词向量为3维通道(word_vector_num),每批句子5个单词(sentence_word_num),数据集的维度表示为 [6,3,5] 。(注意:为什么需要维度转换呢?因为Conv1d模型的卷积核大小是[输入通道数,卷积核的长],那么数据集和卷积核的点积运算必须维度都一致

  • Conv1d模型参数说明:输入通道数设定为3(数量等同 word_vector_num ),输出通道数设定为8(数量表示new word_vector_num),卷积核的长设定为2。

  • Conv1d模型权重参数(W)维度则根据上步自动生成为 [8,3,2] ,表示 [输出通道数,输入通道数,卷积核的长],又因为卷积核等同表示 [输入通道数,卷积核的长],输出通道数等同表示卷积核的个数,则总而言之,此模型权重参数的维度表示:有8个大小为[3,2]的卷积核去对输入数据做卷积运算

  • 卷积过程中的数据计算说明(非常重要):模型输入数据是一个深度为6长为3宽为5的三维数据,卷积核长为3宽度为2的二维数据,步长默认为1进行移动。先考虑深度为1的情况(可以先暂时不考虑深度这一维进行理解),模型输入数据变成一个长为3宽为5的二维数据,每个卷积核每次完成一次移动后,实现模型输入数据的6个数和这个卷积核的6个数(3*2)进行内积再和,生成1个数。每个卷积核总共需要横向移动四次(见下图动画理解),那么每个卷积核完成卷积后生成数据维度是[1,4],那么8个卷积核完成卷积生成的数据维度是[8,4],若要加上深度这一维就是[1,8,4]。再考虑深度为6的情况,进行卷积后得到的数据是深度为1的情况下的6倍,也就是[6,8,4]。

  • 模型输出数据集说明:6批句子(batch_size),每个单词的词向量为8维通道(new word_vector_num),每批句子4个单词(new sentence_word_num),数据集的维度表示为 [6,8,4] 。

  • 源代码如下:

import torch as t
input = t.randn(6,5,3) # batch_size= 6(sentence_num), sentence_word_num= 5, word_vector_num = 3
print(input)
print(input.shape) # [6,5,3]
input = input.permute(0,2,1) # 维度转换(sentence_word_num <-> word_vector_num)
print(input)
print(input.shape) # [6,3,5]
conv1 = nn.Conv1d(3, 8, 2, bias=False) # in_channels = word_vector_num = 3,out_channels = 8(new word_vector_num), kernel_size = 2
print(conv1.weight.shape) # [8,3,2]
output = conv1(input)
print(output)
print(output.shape) # [6,8,4]
  • 代码运行结果如下:
tensor([[[-1.5697,  1.6189,  0.4521],[-0.9188, -0.5753,  1.4038],[ 1.0623,  0.6014, -0.7945],[-1.0525,  2.0641, -1.8544],[-1.0642, -0.2318,  0.1935]],[[-2.2800, -1.1117, -1.0796],[ 0.2286,  0.6835, -2.6689],[-0.5956,  0.7648,  2.7674],[-0.9383,  0.2043,  1.3341],[-1.0337, -1.4724, -0.9340]],[[-0.9657,  0.2571,  0.6817],[ 0.3036, -1.0275, -0.0496],[ 1.5626,  0.5038, -0.3329],[-0.1654,  1.8341,  0.1949],[-0.1841, -0.1558, -0.1641]],[[-0.2144, -1.3156,  0.8448],[-0.5384,  1.2287,  1.5028],[ 0.2343, -1.0956, -0.5923],[ 0.2661,  1.1084,  0.4200],[-2.7000, -1.0146,  0.2574]],[[-0.2548, -1.6011, -0.8730],[ 0.1237, -0.2313,  0.8306],[ 0.9188,  0.5165,  0.8517],[ 0.0083, -0.4545,  0.9021],[-0.8566, -0.9456,  1.4411]],[[ 0.0890, -0.9539,  0.1321],[-0.8780, -1.2702,  1.9250],[-0.4996, -0.4644, -0.8101],[-2.2298, -0.8780, -0.1641],[ 0.1206,  0.0420, -0.0975]]])
torch.Size([6, 5, 3])
tensor([[[-1.5697, -0.9188,  1.0623, -1.0525, -1.0642],[ 1.6189, -0.5753,  0.6014,  2.0641, -0.2318],[ 0.4521,  1.4038, -0.7945, -1.8544,  0.1935]],[[-2.2800,  0.2286, -0.5956, -0.9383, -1.0337],[-1.1117,  0.6835,  0.7648,  0.2043, -1.4724],[-1.0796, -2.6689,  2.7674,  1.3341, -0.9340]],[[-0.9657,  0.3036,  1.5626, -0.1654, -0.1841],[ 0.2571, -1.0275,  0.5038,  1.8341, -0.1558],[ 0.6817, -0.0496, -0.3329,  0.1949, -0.1641]],[[-0.2144, -0.5384,  0.2343,  0.2661, -2.7000],[-1.3156,  1.2287, -1.0956,  1.1084, -1.0146],[ 0.8448,  1.5028, -0.5923,  0.4200,  0.2574]],[[-0.2548,  0.1237,  0.9188,  0.0083, -0.8566],[-1.6011, -0.2313,  0.5165, -0.4545, -0.9456],[-0.8730,  0.8306,  0.8517,  0.9021,  1.4411]],[[ 0.0890, -0.8780, -0.4996, -2.2298,  0.1206],[-0.9539, -1.2702, -0.4644, -0.8780,  0.0420],[ 0.1321,  1.9250, -0.8101, -0.1641, -0.0975]]])
torch.Size([6, 3, 5])
torch.Size([8, 3, 2])
tensor([[[ 1.8743e-01, -1.4395e-01, -6.9980e-01, -8.2561e-01],[-2.7898e-01, -6.5680e-01,  5.2309e-01,  3.0150e-01],[-1.7926e-01,  1.0438e-01, -1.4334e-01,  2.2036e-01],[ 9.1778e-01,  3.4689e-01,  8.8961e-01,  4.0392e-01],[ 2.5770e-01,  5.3539e-01,  5.1576e-01, -1.7502e-01],[-5.9272e-01, -4.6085e-01,  1.0932e-02, -2.7211e-01],[-1.2418e+00,  4.5105e-01,  1.5149e+00, -7.5503e-01],[ 4.5389e-01, -3.1628e-01,  2.4424e-01, -1.5187e-01]],[[-1.0650e+00, -1.6615e-01,  1.0677e+00,  4.9309e-01],[-8.1073e-01,  1.1998e+00, -5.1610e-01, -8.7283e-01],[ 2.9464e-01, -1.3378e-01, -6.7559e-01, -1.9098e-01],[ 5.6014e-04, -3.3817e-01,  1.5722e+00,  5.0429e-01],[ 7.1028e-01, -1.3099e+00,  9.0939e-01,  9.6488e-01],[ 1.6606e-01, -3.9754e-01, -6.4322e-01,  4.8480e-01],[ 1.2543e+00, -7.9167e-01, -5.4348e-01, -2.5640e-01],[-2.1250e+00,  7.5991e-01,  1.2818e+00, -5.1833e-01]],[[ 4.8963e-02, -3.0574e-01, -2.1625e-01, -4.4589e-01],[-5.3250e-01,  3.3740e-02,  8.2394e-01,  4.8748e-02],[ 1.6242e-01,  3.1454e-01, -1.5465e-01,  2.2231e-01],[-1.6153e-02, -6.8735e-01,  4.7351e-01,  5.9774e-01],[ 2.0333e-01, -3.8176e-01, -2.0578e-01,  1.5212e-01],[-6.1877e-02, -1.3378e-01, -3.8114e-01, -4.3941e-01],[-5.9499e-01,  4.4317e-01,  6.7399e-01, -5.4335e-01],[-3.5491e-01, -2.9921e-01,  1.0920e+00,  4.3913e-01]],[[ 9.3993e-01, -4.9535e-02,  3.9259e-02,  8.4282e-01],[-3.1526e-02, -5.7992e-01,  2.8747e-01, -3.4273e-02],[-7.4271e-01,  2.4287e-01, -1.6298e-01, -6.4197e-01],[ 5.4584e-01,  4.5684e-01, -2.3048e-01,  9.3792e-01],[ 2.0335e-01,  5.2475e-01, -2.9436e-01,  7.0134e-01],[-2.3952e-01, -2.1741e-01, -6.2856e-02,  6.1455e-01],[ 3.9216e-01, -6.6250e-01,  5.9392e-01, -4.2417e-01],[ 5.9883e-01,  7.8288e-02,  6.9463e-04,  5.3361e-01]],[[ 3.7750e-01,  1.7484e-01,  4.7909e-01,  1.1213e+00],[ 4.9472e-02,  2.2069e-02,  1.9605e-01, -1.7306e-01],[-1.5364e-01, -3.4038e-03, -9.3162e-02, -5.0403e-01],[-8.2655e-01,  3.4773e-02,  6.0838e-02,  7.5271e-02],[-4.7433e-01, -1.9094e-01, -1.6035e-01,  8.9366e-02],[ 3.9928e-01, -5.0901e-01, -7.0766e-02,  3.0599e-01],[ 5.0398e-02, -1.3538e-01, -5.4527e-01, -6.1514e-01],[-5.4416e-01,  5.3959e-01,  8.7396e-01,  4.2533e-01]],[[ 1.2261e+00,  8.1240e-01,  5.9319e-01, -1.1802e-01],[-9.5330e-04, -9.8721e-01, -1.7303e-01, -7.0010e-01],[-5.1057e-01, -4.2958e-01, -5.3423e-01, -3.8530e-02],[-4.5270e-01,  4.7178e-01,  1.4625e-01,  7.5624e-02],[-2.9981e-01,  1.0551e+00,  4.4312e-01,  3.2369e-01],[ 5.6614e-01,  3.8799e-01,  9.5110e-01, -1.6010e-01],[-7.5309e-01,  4.6806e-01,  9.6832e-02,  5.8812e-02],[ 2.0502e-01, -5.2707e-01, -6.2798e-01, -1.0742e+00]]],grad_fn=<SqueezeBackward1>)
torch.Size([6, 8, 4])

三、Conv1d和Conv2d的联系和区别

  • 两者关于批次的理解是一样的:也就是按照有多少组数据进行理解,比如上面的案例是6批数据,也就是6组数据。
  • 输入通道数理解不同:Conv1d的通道数是指词向量的维度,Conv2d的通道数是指颜色通道比如:黑白图的通道数是1和RGB彩色图的通道数为3或者设置更多的颜色通道数。
  • 卷积核大小不同:Conv1d的卷积核是[输入通道数,卷积核的长],Conv2d的卷积核是[输入通道数,卷积核的长,卷积核的宽]。
  • 卷积核移动路线不同:Conv1d的卷积核只能横向移动,Conv2d的卷积核可以横向纵向移动。
  • 输出通道数理解相同,都是指卷积核的个数,也是新的输入通道数。
  • 对比理解可参考一个Conv2d案例:点击打开《图像相关层之卷积锐化图片示例》文章

Pytorch之nn.Conv1d学习个人见解相关推荐

  1. 【pytorch】nn.conv1d的使用

    官方文档在这里. conv1d具体不做介绍了,本篇只做pytorch的API使用介绍. torch.nn.Conv1d(in_channels, out_channels, kernel_size, ...

  2. PyTorch中的nn.Conv1d与nn.Conv2d

    本文主要介绍PyTorch中的nn.Conv1d和nn.Conv2d方法,并给出相应代码示例,加深理解. 一维卷积nn.Conv1d 一般来说,一维卷积nn.Conv1d用于文本数据,只对宽度进行卷积 ...

  3. 深度学习框架PyTorch一书的学习-第四章-神经网络工具箱nn

    参考https://github.com/chenyuntc/pytorch-book/tree/v1.0 希望大家直接到上面的网址去查看代码,下面是本人的笔记 本章介绍的nn模块是构建与autogr ...

  4. 【Pytorch】torch.nn.Conv1d()理解与使用

    官方文档:https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html?highlight=nn%20conv1d#torch.nn.C ...

  5. Pytorch中nn.Module和nn.Sequencial的简单学习

    文章目录 前言 1.Python 类 2.nn.Module 和 nn.Sequential 2.1 nn.Module 2.1.1 torch.nn.Module类 2.1.2 nn.Sequent ...

  6. (pytorch-深度学习)使用pytorch框架nn.RNN实现循环神经网络

    使用pytorch框架nn.RNN实现循环神经网络 首先,读取周杰伦专辑歌词数据集. import time import math import numpy as np import torch f ...

  7. pytorch nn.Conv1d

    一维卷积nn.Conv1d一般用于文本数据 1.应用 import torch import torch.nn as nnx = torch.randn(1, 1, 32) # batch, chan ...

  8. pytorch笔记(四)nn.Conv1d、nn.Conv2d、nn.Conv3d

    概念: nn.Conv1d:常用在文本 (B,C,L) (batch,channel,sequence_len) (批数量,通道数,句子长度) nn.Conv2d:常用在图像 (B,C,H,W) (b ...

  9. Pytorch中nn.Conv2d的用法

    官网链接: nn.Conv2d     Applies a 2D convolution over an input signal composed of several input planes. ...

最新文章

  1. 华为鸿蒙概念机990,华为5G概念新机:鸿蒙OS系统+麒麟990+石墨烯 安卓机皇来势汹汹...
  2. hub-spock-ospf,nbma
  3. 个人思考与研究:道德经(二)
  4. jdk下载:各历史版本下载地址
  5. 1042 字符统计 (20分)——16行代码满分
  6. java.util.regex_java.util.regex.PatternSyntaxException:索引附近的...
  7. themleft模板库_Thymeleaf模板引擎常用总结
  8. @MySQL的存储引擎
  9. linxu其他用户登录mysql_Linux系统的MySQL用户如何开启远程登录权限
  10. mysql事务操作_mysql的事务操作
  11. hasset java_java HashSet的使用
  12. amr转换成mp3 java_java将amr文件转换为MP3格式(windowslinux均可使用,亲测)
  13. mysql语句大全(2)
  14. 这一次,让你彻底明白接口及抽象类
  15. 给实践者的算法学习指南
  16. Android:android2.3电话接听
  17. 系统辨识总论(System Identification)
  18. cdrx4自动排版步骤_coreldraw自动排版
  19. 数学建模层次分析法例题及答案_数学建模之层次分析法
  20. 整车EMC正向开发及仿真

热门文章

  1. 管理者和企业如何做好员工管理?
  2. Flume-day03_进阶案例
  3. 1-SII--SharedPreferences完美封装
  4. Kinect v2和Intel RealSense D435的三维重建对比
  5. RISC-V 指令格式
  6. eth_clockgen.v
  7. 成考计算机科学与技术考试科目,计算机科学与技术本科自考有哪些科目
  8. A hybrid method of exponential smoothing and recurrent
  9. shiro登录验证原理
  10. sap crm button_SAP携Intelligent RPA 2.0 参加中国流程自动化产业峰会