问题的引出

关于pytorch中dim的描述个人总是弄的不是很清楚,好多地方存在着疑问,这次在实验过程中需要比较两个高维tensor的相似度,由于需要确定在哪一维进行比较,故去测试了pytorch中关于tensor维度的一些现象。

dim

关于dim许多博客都有比我更加专业的解释,dim具体的解释也不是本文的重点,这里盗用其他博客里的一张图,这张图也是我认为对dim比较好的直观的解释(原文链接),本文的重点在于对高维tensor维度上操作,即不同的操作在不同的维度上进行会有怎样的不同

Cosine_similarity

如果你想对两个tensor比较他们之间的相似度,那么torch.cosine_similarity函数是一个不错的选择,但是在该函数的参数列表中,有一个dim值,在官方文档中,值介绍了这个参数用来指定在哪一维上进行操作,但我在实际使用过程中却对这个概念理解的不好,后来经过不断的实验终于弄懂了dim的含义。

先从二维开始

大部分博客只说明了有关二维的情况,而二维的情况是比较好想的,重点是在高维如三维情况下的tensor,那么这里我们还是从二维开始,先去看一下基本的在维度上的操作
首先直观上我们可以发现,当dim选择在哪一维上操作时,相应的那一维就消失了(这里说的消失,指直观现象,但个人觉得不是特别好理解),

p1 = torch.rand([2,3])
p2 = torch.rand([2,3])
print(p1)
print(p2)
p3 = torch.cosine_similarity(p1,p2,dim=1)
print(p3)
print(p3.shape)

对第一维操作

上述代码作用在第一维上,那么他的结果是怎么样的呢?

他的结果有两个元素(个人认为这里讨论维度容易混乱,不如直接说元素的个数),对应原来的p1的shape我们发现在第一维上操作使得第一维消失了,即[2,3]->[2],这也是大多数博客的解释,但我认为这并没有揭示真正的工作过程,同时如果应用到高维的情况,很容易得到一个令人疑惑的维度。
下面让我们来试着理解一下dim的含义,上述的例子中的图片说的已经比较明显了,在dim=1上操作,实际的含义为在以第一维为单位进行操作,即对每一行进行操作(说法不严谨,但为了方便理解),或者也可以这样进行理解就是固定第一维(即tensor的列),去比较第0维(tensor的行)。

对第0维操作

那么按照该思想,如果按照第0维操作,即对每一列为单位进行操作,那么得到的应该是一个有三个元素(为了方便理解,有不严谨的表述)的结果,分别为对应列之间的相似度:

通过实验验证,我发现确实如此。

三维的情况

在实际应用中,tensor的形状一般是[batch,seq_len,embed]这样三维的形状,那么在三维中对不同的维度操作会有怎样的差别

对第0维操作

p1 = torch.rand([2,2,3])
p2 = torch.rand([2,2,3])
print(p1)
print(p2)
p3 = torch.cosine_similarity(p1,p2,dim=0)
print(p3)
print(p3.shape)

当取dim=0时,注意此时第0维实际上是batch的维度,则固定batch不动,比较后面的[2,3]的元素,那么后面的是怎么比较的呢?依旧是按照第0维,这里的第0维实际上是后面那个[seq_len,embed]的第0维,即对列进行操作,所以其结果为13三个元素,而原来的tensor有两个batch,所以分别比较后就有23个元素,第一行元素是第一个batch中的[2,3]个元素按照第0位求相似的的结果,同理第二行也是

对第一维或第二维操作

这里的过程实际上就和二维的情况一样的,不过需要注意的是二维情况中的第0维对应三维情况的第1维,二维情况的第1维实际上对应三维中的第2维。
不同的是,三维中是一个batch内进行比较,所以,只要在二维操作的基础上加上batch的一维就可以了

Pytorch关于高维tensor的dim上操作的理解--以cosine_similarity的dim参数为例相关推荐

  1. Pytorch List Tensor转Tensor,,reshape拼接等操作

    Pytorch List Tensor转Tensor,reshape拼接等操作 持续更新一些常用的Tensor操作,比如List,Numpy,Tensor之间的转换,Tensor的拼接,维度的变换等操 ...

  2. pytorch中常用对张量shape的操作

    常用shape操作 目录 文章目录 常用shape操作 目录 1. 增加/删除维度 > 删除: torch.tensor.squeeze(dim) 举例 > 增加: torch.tenso ...

  3. Lesson 16.5 在Pytorch中实现卷积网络(上):卷积核、输入通道与特征图在PyTorch中实现卷积网络(中):步长与填充

    卷积神经网络是使用卷积层的一组神经网络.在一个成熟的CNN中,往往会涉及到卷积层.池化层.线性层(全连接层)以及各类激活函数.因此,在构筑卷积网络时,需从整体全部层的需求来进行考虑. 1 二维卷积层n ...

  4. Pytorch 其它有关Tensor的话题,GPU,向量化

    3.1.4 其它有关Tensor的话题 这部分的内容不好专门划分一小节,但是笔者认为仍值得读者注意,故而将其放在这一小节. GPU/CPU tensor可以很随意的在gpu/cpu上传输.使用tens ...

  5. 【深度学习理论】一文搞透pytorch中的tensor、autograd、反向传播和计算图

    转载:https://zhuanlan.zhihu.com/p/145353262 前言 本文的主要目标: 一遍搞懂反向传播的底层原理,以及其在深度学习框架pytorch中的实现机制.当然一遍搞不定两 ...

  6. pytorch 实现张量tensor,图片,CPU,GPU,数组等的转换

    pytorch 实现张量tensor,图片,CPU,GPU,数组等的转换 1, 创建pytorch 的Tensor张量: torch.rand((3,224,224)) #创建随机值的三维张量,大小为 ...

  7. pytorch 创建张量tensor

    pytorch 创建张量tensor 先看下面一张图 通过上图有了一个直观了解后,我们开始尝试创建一下. 先创建一个标量和一个向量 a = torch.tensor([1]) #标量 print(a) ...

  8. PyTorch Variable与Tensor 【详解】

    Variable 与 Tensor tensor 是 PyTorch 中的完美组件,高效的数据格式,但是构建神经网络还远远不够,我们需要能够构建计算图的 tensor,这就是 Variable.Var ...

  9. python输入参数改变图形_Python基于Tensor FLow的图像处理操作详解

    本文实例讲述了Python基于Tensor FLow的图像处理操作.分享给大家供大家参考,具体如下: 在对图像进行深度学习时,有时可能图片的数量不足,或者希望网络进行更多的学习,这时可以对现有的图片数 ...

  10. python图像处理教程_Python基于Tensor FLow的图像处理操作详解

    本文实例讲述了Python基于Tensor FLow的图像处理操作.分享给大家供大家参考,具体如下: 在对图像进行深度学习时,有时可能图片的数量不足,或者希望网络进行更多的学习,这时可以对现有的图片数 ...

最新文章

  1. 干货|TensorFlow开发环境搭建(Ubuntu16.04+GPU+TensorFlow源码编译)
  2. Survey | 多任务学习综述
  3. altium Designer丝印显示汉字,更换字体,数码管风格,镂空效果
  4. nginx源码编译和集群及高可用
  5. laravel的一个简单文件博客项目katana的使用
  6. go基本语法:channel未关闭遍历结束后会报错deadlock
  7. NEFU394 素数价值
  8. CF785E Anton and Permutation
  9. 前端学习(1485):restful接口规则
  10. 使用计算机时 正确的关机顺序是( ),《计算机应用基础》半期考试卷
  11. mysql-proxy 2进制版本安装
  12. C++ 常用基础概念
  13. c语言程序设计的实验仪器和设备,C语言程序设计实验.doc
  14. 超分辨率分析(一)--传统方案综述
  15. 跨域小结(为什么form表单提交没有跨域问题,但ajax提交有跨域问题)
  16. NUC980 DIY项目大挑战 - EtherCAT实现
  17. ubuntu16安装Times New Roma字体 / WPS 安装Times New Roma字体
  18. python查询mysql decimal报错_【2020Python修炼记】MySQL之 表相关操作
  19. java基础知识总结,javaweb参考资料大全
  20. JavaScript面试题整理汇总

热门文章

  1. JavaWeb 登陆界面
  2. c语言大于一小于10,C语言首先输入一个大于2且小于10的整数
  3. 解决:Elasticsearch failed to map source
  4. 微信小程序 java社区快递柜取件管理系统python php
  5. 【渗透测试基础-4】资产收集之nmap扫描
  6. 输入若干个字符串,查找其中的最大字母,在该字母后面插入字符串“(max)”
  7. jQuery基础之正则表达式及表单验证
  8. 图benchmark
  9. Python报错集合篇7-KeyError: 1
  10. Livezilla on Linux 安装配置教程