Pytorch之nn.Conv1d学习个人见解
Pytorch之nn.Conv1d学习个人见解
一、官方文档(务必先耐心阅读)
官方文档:点击打开《CONV1D》
二、Conv1d个人见解
Conv1d类构成
- class torch.nn.Conv1d(in_channels, out_channels, kernel_size,stride=1, padding=0, dilation=1, groups=1, bias=True)
- in_channels(int)—输入数据的通道数。在文本分类中,即为句子中单个词的词向量的维度。 (word_vector_num)
- out_channels(int)—输出数据的通道数。设置 N 个输出通道数,就有 N 个1维卷积核。(new word_vector_num)
- kernel_size(int or tuple) —卷积核的长度,1维卷积中卷积核的实际大小维度是(in_channels,kernel_size),顺序不可互换。
- stride(int or tuple, optional)—卷积步长。
- padding (int or tuple, optional)—输入的每一条边补充0的层数。
- dilation(int or tuple, `optional``)—卷积核元素之间的间距。
- groups(int, optional)—从输入通道到输出通道的阻塞连接数。
- bias(bool, optional)—如果bias=True,添加偏置。
具体案例分析
原始数据集说明:6批句子(batch_size),每批句子5个单词(sentence_word_num),每个单词的词向量为3维通道(word_vector_num),数据集的维度表示为 [6,5,3] 。
模型输入数据集说明:在上步原始数据集中进行维度转换,6批句子(batch_size),每个单词的词向量为3维通道(word_vector_num),每批句子5个单词(sentence_word_num),数据集的维度表示为 [6,3,5] 。(注意:为什么需要维度转换呢?因为Conv1d模型的卷积核大小是[输入通道数,卷积核的长],那么数据集和卷积核的点积运算必须维度都一致)
Conv1d模型参数说明:输入通道数设定为3(数量等同 word_vector_num ),输出通道数设定为8(数量表示new word_vector_num),卷积核的长设定为2。
Conv1d模型权重参数(W)维度则根据上步自动生成为 [8,3,2] ,表示 [输出通道数,输入通道数,卷积核的长],又因为卷积核等同表示 [输入通道数,卷积核的长],输出通道数等同表示卷积核的个数,则总而言之,此模型权重参数的维度表示:有8个大小为[3,2]的卷积核去对输入数据做卷积运算。
卷积过程中的数据计算说明(非常重要):模型输入数据是一个深度为6长为3宽为5的三维数据,卷积核长为3宽度为2的二维数据,步长默认为1进行移动。先考虑深度为1的情况(可以先暂时不考虑深度这一维进行理解),模型输入数据变成一个长为3宽为5的二维数据,每个卷积核每次完成一次移动后,实现模型输入数据的6个数和这个卷积核的6个数(3*2)进行内积再和,生成1个数。每个卷积核总共需要横向移动四次(见下图动画理解),那么每个卷积核完成卷积后生成数据维度是[1,4],那么8个卷积核完成卷积生成的数据维度是[8,4],若要加上深度这一维就是[1,8,4]。再考虑深度为6的情况,进行卷积后得到的数据是深度为1的情况下的6倍,也就是[6,8,4]。
模型输出数据集说明:6批句子(batch_size),每个单词的词向量为8维通道(new word_vector_num),每批句子4个单词(new sentence_word_num),数据集的维度表示为 [6,8,4] 。
源代码如下:
import torch as t
input = t.randn(6,5,3) # batch_size= 6(sentence_num), sentence_word_num= 5, word_vector_num = 3
print(input)
print(input.shape) # [6,5,3]
input = input.permute(0,2,1) # 维度转换(sentence_word_num <-> word_vector_num)
print(input)
print(input.shape) # [6,3,5]
conv1 = nn.Conv1d(3, 8, 2, bias=False) # in_channels = word_vector_num = 3,out_channels = 8(new word_vector_num), kernel_size = 2
print(conv1.weight.shape) # [8,3,2]
output = conv1(input)
print(output)
print(output.shape) # [6,8,4]
- 代码运行结果如下:
tensor([[[-1.5697, 1.6189, 0.4521],[-0.9188, -0.5753, 1.4038],[ 1.0623, 0.6014, -0.7945],[-1.0525, 2.0641, -1.8544],[-1.0642, -0.2318, 0.1935]],[[-2.2800, -1.1117, -1.0796],[ 0.2286, 0.6835, -2.6689],[-0.5956, 0.7648, 2.7674],[-0.9383, 0.2043, 1.3341],[-1.0337, -1.4724, -0.9340]],[[-0.9657, 0.2571, 0.6817],[ 0.3036, -1.0275, -0.0496],[ 1.5626, 0.5038, -0.3329],[-0.1654, 1.8341, 0.1949],[-0.1841, -0.1558, -0.1641]],[[-0.2144, -1.3156, 0.8448],[-0.5384, 1.2287, 1.5028],[ 0.2343, -1.0956, -0.5923],[ 0.2661, 1.1084, 0.4200],[-2.7000, -1.0146, 0.2574]],[[-0.2548, -1.6011, -0.8730],[ 0.1237, -0.2313, 0.8306],[ 0.9188, 0.5165, 0.8517],[ 0.0083, -0.4545, 0.9021],[-0.8566, -0.9456, 1.4411]],[[ 0.0890, -0.9539, 0.1321],[-0.8780, -1.2702, 1.9250],[-0.4996, -0.4644, -0.8101],[-2.2298, -0.8780, -0.1641],[ 0.1206, 0.0420, -0.0975]]])
torch.Size([6, 5, 3])
tensor([[[-1.5697, -0.9188, 1.0623, -1.0525, -1.0642],[ 1.6189, -0.5753, 0.6014, 2.0641, -0.2318],[ 0.4521, 1.4038, -0.7945, -1.8544, 0.1935]],[[-2.2800, 0.2286, -0.5956, -0.9383, -1.0337],[-1.1117, 0.6835, 0.7648, 0.2043, -1.4724],[-1.0796, -2.6689, 2.7674, 1.3341, -0.9340]],[[-0.9657, 0.3036, 1.5626, -0.1654, -0.1841],[ 0.2571, -1.0275, 0.5038, 1.8341, -0.1558],[ 0.6817, -0.0496, -0.3329, 0.1949, -0.1641]],[[-0.2144, -0.5384, 0.2343, 0.2661, -2.7000],[-1.3156, 1.2287, -1.0956, 1.1084, -1.0146],[ 0.8448, 1.5028, -0.5923, 0.4200, 0.2574]],[[-0.2548, 0.1237, 0.9188, 0.0083, -0.8566],[-1.6011, -0.2313, 0.5165, -0.4545, -0.9456],[-0.8730, 0.8306, 0.8517, 0.9021, 1.4411]],[[ 0.0890, -0.8780, -0.4996, -2.2298, 0.1206],[-0.9539, -1.2702, -0.4644, -0.8780, 0.0420],[ 0.1321, 1.9250, -0.8101, -0.1641, -0.0975]]])
torch.Size([6, 3, 5])
torch.Size([8, 3, 2])
tensor([[[ 1.8743e-01, -1.4395e-01, -6.9980e-01, -8.2561e-01],[-2.7898e-01, -6.5680e-01, 5.2309e-01, 3.0150e-01],[-1.7926e-01, 1.0438e-01, -1.4334e-01, 2.2036e-01],[ 9.1778e-01, 3.4689e-01, 8.8961e-01, 4.0392e-01],[ 2.5770e-01, 5.3539e-01, 5.1576e-01, -1.7502e-01],[-5.9272e-01, -4.6085e-01, 1.0932e-02, -2.7211e-01],[-1.2418e+00, 4.5105e-01, 1.5149e+00, -7.5503e-01],[ 4.5389e-01, -3.1628e-01, 2.4424e-01, -1.5187e-01]],[[-1.0650e+00, -1.6615e-01, 1.0677e+00, 4.9309e-01],[-8.1073e-01, 1.1998e+00, -5.1610e-01, -8.7283e-01],[ 2.9464e-01, -1.3378e-01, -6.7559e-01, -1.9098e-01],[ 5.6014e-04, -3.3817e-01, 1.5722e+00, 5.0429e-01],[ 7.1028e-01, -1.3099e+00, 9.0939e-01, 9.6488e-01],[ 1.6606e-01, -3.9754e-01, -6.4322e-01, 4.8480e-01],[ 1.2543e+00, -7.9167e-01, -5.4348e-01, -2.5640e-01],[-2.1250e+00, 7.5991e-01, 1.2818e+00, -5.1833e-01]],[[ 4.8963e-02, -3.0574e-01, -2.1625e-01, -4.4589e-01],[-5.3250e-01, 3.3740e-02, 8.2394e-01, 4.8748e-02],[ 1.6242e-01, 3.1454e-01, -1.5465e-01, 2.2231e-01],[-1.6153e-02, -6.8735e-01, 4.7351e-01, 5.9774e-01],[ 2.0333e-01, -3.8176e-01, -2.0578e-01, 1.5212e-01],[-6.1877e-02, -1.3378e-01, -3.8114e-01, -4.3941e-01],[-5.9499e-01, 4.4317e-01, 6.7399e-01, -5.4335e-01],[-3.5491e-01, -2.9921e-01, 1.0920e+00, 4.3913e-01]],[[ 9.3993e-01, -4.9535e-02, 3.9259e-02, 8.4282e-01],[-3.1526e-02, -5.7992e-01, 2.8747e-01, -3.4273e-02],[-7.4271e-01, 2.4287e-01, -1.6298e-01, -6.4197e-01],[ 5.4584e-01, 4.5684e-01, -2.3048e-01, 9.3792e-01],[ 2.0335e-01, 5.2475e-01, -2.9436e-01, 7.0134e-01],[-2.3952e-01, -2.1741e-01, -6.2856e-02, 6.1455e-01],[ 3.9216e-01, -6.6250e-01, 5.9392e-01, -4.2417e-01],[ 5.9883e-01, 7.8288e-02, 6.9463e-04, 5.3361e-01]],[[ 3.7750e-01, 1.7484e-01, 4.7909e-01, 1.1213e+00],[ 4.9472e-02, 2.2069e-02, 1.9605e-01, -1.7306e-01],[-1.5364e-01, -3.4038e-03, -9.3162e-02, -5.0403e-01],[-8.2655e-01, 3.4773e-02, 6.0838e-02, 7.5271e-02],[-4.7433e-01, -1.9094e-01, -1.6035e-01, 8.9366e-02],[ 3.9928e-01, -5.0901e-01, -7.0766e-02, 3.0599e-01],[ 5.0398e-02, -1.3538e-01, -5.4527e-01, -6.1514e-01],[-5.4416e-01, 5.3959e-01, 8.7396e-01, 4.2533e-01]],[[ 1.2261e+00, 8.1240e-01, 5.9319e-01, -1.1802e-01],[-9.5330e-04, -9.8721e-01, -1.7303e-01, -7.0010e-01],[-5.1057e-01, -4.2958e-01, -5.3423e-01, -3.8530e-02],[-4.5270e-01, 4.7178e-01, 1.4625e-01, 7.5624e-02],[-2.9981e-01, 1.0551e+00, 4.4312e-01, 3.2369e-01],[ 5.6614e-01, 3.8799e-01, 9.5110e-01, -1.6010e-01],[-7.5309e-01, 4.6806e-01, 9.6832e-02, 5.8812e-02],[ 2.0502e-01, -5.2707e-01, -6.2798e-01, -1.0742e+00]]],grad_fn=<SqueezeBackward1>)
torch.Size([6, 8, 4])
三、Conv1d和Conv2d的联系和区别
- 两者关于批次的理解是一样的:也就是按照有多少组数据进行理解,比如上面的案例是6批数据,也就是6组数据。
- 输入通道数理解不同:Conv1d的通道数是指词向量的维度,Conv2d的通道数是指颜色通道比如:黑白图的通道数是1和RGB彩色图的通道数为3或者设置更多的颜色通道数。
- 卷积核大小不同:Conv1d的卷积核是[输入通道数,卷积核的长],Conv2d的卷积核是[输入通道数,卷积核的长,卷积核的宽]。
- 卷积核移动路线不同:Conv1d的卷积核只能横向移动,Conv2d的卷积核可以横向纵向移动。
- 输出通道数理解相同,都是指卷积核的个数,也是新的输入通道数。
- 对比理解可参考一个Conv2d案例:点击打开《图像相关层之卷积锐化图片示例》文章
Pytorch之nn.Conv1d学习个人见解相关推荐
- 【pytorch】nn.conv1d的使用
官方文档在这里. conv1d具体不做介绍了,本篇只做pytorch的API使用介绍. torch.nn.Conv1d(in_channels, out_channels, kernel_size, ...
- PyTorch中的nn.Conv1d与nn.Conv2d
本文主要介绍PyTorch中的nn.Conv1d和nn.Conv2d方法,并给出相应代码示例,加深理解. 一维卷积nn.Conv1d 一般来说,一维卷积nn.Conv1d用于文本数据,只对宽度进行卷积 ...
- 深度学习框架PyTorch一书的学习-第四章-神经网络工具箱nn
参考https://github.com/chenyuntc/pytorch-book/tree/v1.0 希望大家直接到上面的网址去查看代码,下面是本人的笔记 本章介绍的nn模块是构建与autogr ...
- 【Pytorch】torch.nn.Conv1d()理解与使用
官方文档:https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html?highlight=nn%20conv1d#torch.nn.C ...
- Pytorch中nn.Module和nn.Sequencial的简单学习
文章目录 前言 1.Python 类 2.nn.Module 和 nn.Sequential 2.1 nn.Module 2.1.1 torch.nn.Module类 2.1.2 nn.Sequent ...
- (pytorch-深度学习)使用pytorch框架nn.RNN实现循环神经网络
使用pytorch框架nn.RNN实现循环神经网络 首先,读取周杰伦专辑歌词数据集. import time import math import numpy as np import torch f ...
- pytorch nn.Conv1d
一维卷积nn.Conv1d一般用于文本数据 1.应用 import torch import torch.nn as nnx = torch.randn(1, 1, 32) # batch, chan ...
- pytorch笔记(四)nn.Conv1d、nn.Conv2d、nn.Conv3d
概念: nn.Conv1d:常用在文本 (B,C,L) (batch,channel,sequence_len) (批数量,通道数,句子长度) nn.Conv2d:常用在图像 (B,C,H,W) (b ...
- Pytorch中nn.Conv2d的用法
官网链接: nn.Conv2d Applies a 2D convolution over an input signal composed of several input planes. ...
最新文章
- 华为鸿蒙概念机990,华为5G概念新机:鸿蒙OS系统+麒麟990+石墨烯 安卓机皇来势汹汹...
- hub-spock-ospf,nbma
- 个人思考与研究:道德经(二)
- jdk下载:各历史版本下载地址
- 1042 字符统计 (20分)——16行代码满分
- java.util.regex_java.util.regex.PatternSyntaxException:索引附近的...
- themleft模板库_Thymeleaf模板引擎常用总结
- @MySQL的存储引擎
- linxu其他用户登录mysql_Linux系统的MySQL用户如何开启远程登录权限
- mysql事务操作_mysql的事务操作
- hasset java_java HashSet的使用
- amr转换成mp3 java_java将amr文件转换为MP3格式(windowslinux均可使用,亲测)
- mysql语句大全(2)
- 这一次,让你彻底明白接口及抽象类
- 给实践者的算法学习指南
- Android:android2.3电话接听
- 系统辨识总论(System Identification)
- cdrx4自动排版步骤_coreldraw自动排版
- 数学建模层次分析法例题及答案_数学建模之层次分析法
- 整车EMC正向开发及仿真
热门文章
- 管理者和企业如何做好员工管理?
- Flume-day03_进阶案例
- 1-SII--SharedPreferences完美封装
- Kinect v2和Intel RealSense D435的三维重建对比
- RISC-V 指令格式
- eth_clockgen.v
- 成考计算机科学与技术考试科目,计算机科学与技术本科自考有哪些科目
- A hybrid method of exponential smoothing and recurrent
- shiro登录验证原理
- sap crm button_SAP携Intelligent RPA 2.0 参加中国流程自动化产业峰会