1、为什么要写这篇blog

因为最近在使用pytorch复现关于图像处理的深度学习论文时,需要求4维张量与4维张量(Batch,Channel,sizeA,sizeB)的余弦相似度和欧氏距离。
余弦相似度比较好解决,网上对于多维tensors的欧式距离求解就让人一言难尽了,直接使用torch.pairwise_distance根本解决不了问题,为此,我写下这篇blog,让自己好好梳理一遍原理,也希望能帮到和我遇到一样问题的小伙伴们。

2、函数介绍

以下,我将使用python中的pytorch框架进行讲解和展示。
(1)余弦相似度:torch.cosine_similarity
(2)欧氏距离:对输入数据进行改写,迎合torch.pairwise_distance函数的实现。

3、结果(代码)实现

3.1、余弦相似度

这是我在了解torch.cosine_similarity函数时使用的示例,输入的维度是[1,3,2,2],输出维度是[1,2,2]。

input1 = torch.tensor([[[[1, 2], [3, 4]],[[1, 2], [3, 4]],[[5, 6], [7, 8]]]], dtype=torch.float)
print(input1.shape)
input2 = torch.tensor([[[[5, 6], [7, 8]],[[5, 6], [7, 8]],[[1, 2], [3, 4]]]], dtype=torch.float)
output = torch.cosine_similarity(input1, input2)
print(output)
print(torch.cosine_similarity(torch.tensor([1,1,5], dtype=torch.float),\
torch.tensor([5,5,1], dtype=torch.float),dim=0))

最后一行代码就是其输出的原理(应该很好理解),结果展示:

3.2、欧氏距离

这是我在了解torch.pairwise_distance函数时使用的示例,输入的维度是[2,2,3,4],输出维度是[2,3,4],少了的维度(2)是channel所在的维度。

a = torch.randint(1,8,(2,2,3,4))
# print(a)
a = a.transpose(1,3)
print(a.shape)
# print(a)
b = torch.randint(1,4,(2,2,3,4))
# print(b)
b = b.transpose(1,3)
# print(b)c = torch.pairwise_distance(a,b)
# print(c)
print(c.shape)
c = c.transpose(1,2)
# print(c)
print(c.shape)

因为torch.pairwise_distance函数,会对最后一维进行展开,所以应该先把张量维度重构为(Batch,sizeA(B),sizeB(A),Channel),再进行计算即可。维度展示如下所示,

4、结语

要先理解其中机理才能更好地实现自己的需求啊,兄弟们,不能一味地直接调用函数呢!

PYTORCH学习(3):多维tensors求余弦相似度和欧氏距离相关推荐

  1. Pytorch学习 - Task6 PyTorch常见的损失函数和优化器使用

    Pytorch学习 - Task6 PyTorch常见的损失函数和优化器使用 官方参考链接 1. 损失函数 (1)BCELoss 二分类 计算公式 小例子: (2) BCEWithLogitsLoss ...

  2. Pytorch学习笔记总结

    往期Pytorch学习笔记总结: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 Pytorch系列目录: PyTorch学习笔记( ...

  3. PyTorch学习笔记2:nn.Module、优化器、模型的保存和加载、TensorBoard

    文章目录 一.nn.Module 1.1 nn.Module的调用 1.2 线性回归的实现 二.损失函数 三.优化器 3.1.1 SGD优化器 3.1.2 Adagrad优化器 3.2 分层学习率 3 ...

  4. 【Pytorch学习笔记2】Pytorch的主要组成模块

    个人笔记,仅用于个人学习与总结 感谢DataWhale开源组织提供的优秀的开源Pytorch学习文档:原文档链接 本文目录 1. Pytorch的主要组成模块 1.1 完成深度学习的必要部分 1.2 ...

  5. 1.pytorch 学习笔记--Getting stared

    pytorch 学习笔记–Getting stared 1.什么是pytorch Pytorch 是一个基于Python的科学计算包,主要面向以下人群: 替代numpy以使用GPU做计算加速 一个深度 ...

  6. PyTorch学习笔记(二)——回归

    PyTorch学习笔记(二)--回归 本文主要是用PyTorch来实现一个简单的回归任务. 编辑器:spyder 1.引入相应的包及生成伪数据 import torch import torch.nn ...

  7. Pytorch学习 - Task5 PyTorch卷积层原理和使用

    Pytorch学习 - Task5 PyTorch卷积层原理和使用 1. 卷积层 (1)介绍 (torch.nn下的) 1) class torch.nn.Conv1d() 一维卷积层 2) clas ...

  8. Pytorch学习-torch.max()和min()深度解析

    Pytorch学习-torch.max和min深度解析 max的使用 min同理 dim参数理解 二维张量使用max() 三维张量使用max() max的使用 min同理 参考链接: 参考链接: 对于 ...

  9. Pytorch学习- 小型知识点汇总 unsqueeze()/squeeze() 和 .max() 等等

    1. unsqueeze(input, dim, out=None)函数 - 升维作用 参考链接 在指定的地方上增加一个维度 0(-2) [行扩展]: 表示在张量最外层增加一个中括号变成第一维 1(- ...

最新文章

  1. python对文件的操作_python的 随手记----对文件的操作
  2. 命令行里给个注释,AI就能自动生成代码
  3. Can't access RabbitMQ web management interface after fresh install
  4. PHP中使用redis 执行lua脚本
  5. unity怎么做水面_防水博士小课堂 | 什么是背水面防水? 背水面防水施工到底该怎么做?...
  6. Knowledge Graph Alignment Network with Gated Multi-Hop Neighborhood Aggregation-学习笔记
  7. UE3 贴图支持及设置
  8. vrrp协议_虚拟路由冗余协议VRRP原理介绍
  9. HttpClient发送get,post接口请求
  10. 国内遥感卫星资源综述
  11. 基于单片机的电子万年历的设计
  12. 用python统计字母个数_如何用python统计字符串中字母个数?
  13. mpp格式文件怎么打开,mpp进度计划
  14. linux双网卡透明网桥,两种网桥透明网桥和源路由选择网桥
  15. 硬件-电子基础元器件(一)电阻
  16. UVA - 10105 Polynomial Coefficients
  17. 9 9简单的数独游戏python_如何使用tkinter gui python创建一个9*9的数独生成器?
  18. 学法语的你伤不起之吐槽各种语言
  19. 论文解读:ChangeFormer | A TRANSFORMER-BASED SIAMESE NETWORK FOR CHANGE DETECTION
  20. 【软件质量】软件复杂性

热门文章

  1. p2p开户银行审核模块功能实现
  2. 独立思考有四个层次-知识体系
  3. jenkins自定义邮件发送Editable
  4. arduino与风向传感器的接线_小白如何开始学习Arduino?
  5. 浅谈智慧灯杆的配套产业
  6. 1小时撸一套原型,你敢信?
  7. 国盟一年三度的研讨会面向广大的CISA考生
  8. STM32中断笔记——关于NVIC的两个问题
  9. 虚拟机 ubuntu 无法进入界面
  10. (2012.01.12-2012.04.01)八十二天的学习小记