看到一段代码

def split_prediction(tensor, shape, split):tensor = F.interpolate(tensor, size=shape, mode='bilinear', align_corners=False).squeeze()heatmaps, rooms, icons = torch.split(tensor, split, 0)icons = F.softmax(icons, 0)rooms = F.softmax(rooms, 0)heatmaps = heatmaps.data.numpy()icons = icons.data.numpy()rooms = rooms.data.numpy()return heatmaps, rooms, icons
import torch
import torch.nn.functional as Finput = torch.randn(3,4)
print(input)
tensor([[-0.5526, -0.0194, 2.1469, -0.2567],[-0.3337, -0.9229, 0.0376, -0.0801],[ 1.4721, 0.1181, -2.6214, 1.7721]])b = F.softmax(input,dim=0) # 按列SoftMax,列和为1
print(b)
tensor([[0.1018, 0.3918, 0.8851, 0.1021],[0.1268, 0.1587, 0.1074, 0.1218],[0.7714, 0.4495, 0.0075, 0.7762]])c = F.softmax(input,dim=1)  # 按行SoftMax,行和为1
print(c)
tensor([[0.0529, 0.0901, 0.7860, 0.0710],[0.2329, 0.1292, 0.3377, 0.3002],[0.3810, 0.0984, 0.0064, 0.5143]])d = torch.max(input,dim=0)  # 按列取max,
print(d)
torch.return_types.max(
values=tensor([1.4721, 0.1181, 2.1469, 1.7721]),
indices=tensor([2, 2, 0, 2]))e = torch.max(input,dim=1)  # 按行取max,
print(e)
torch.return_types.max(
values=tensor([2.1469, 0.0376, 1.7721]),
indices=tensor([2, 2, 3]))

下面看看三维tensor解释例子:

函数softmax输出的是所给矩阵的概率分布;

b输出的是在dim=0维上的概率分布,b[0][5][6]+b[1][5][6]+b[2][5][6]=1

a=torch.rand(3,16,20)
b=F.softmax(a,dim=0)
c=F.softmax(a,dim=1)
d=F.softmax(a,dim=2)In [1]: import torch as t
In [2]: import torch.nn.functional as F
In [4]: a=t.Tensor(3,4,5)
In [5]: b=F.softmax(a,dim=0)
In [6]: c=F.softmax(a,dim=1)
In [7]: d=F.softmax(a,dim=2)In [8]: a
Out[8]:
tensor([[[-0.1581, 0.0000, 0.0000, 0.0000, -0.0344],[ 0.0000, -0.0344, 0.0000, -0.0344, 0.0000],[-0.0344, 0.0000, -0.0344, 0.0000, -0.0344],[ 0.0000, -0.0344, 0.0000, -0.0344, 0.0000]],[[-0.0344, 0.0000, -0.0344, 0.0000, -0.0344],[ 0.0000, -0.0344, 0.0000, -0.0344, 0.0000],[-0.0344, 0.0000, -0.0344, 0.0000, -0.0344],[ 0.0000, -0.0344, 0.0000, -0.0344, 0.0000]],[[-0.0344, 0.0000, -0.0344, 0.0000, -0.0344],[ 0.0000, -0.0344, 0.0000, -0.0344, 0.0000],[-0.0344, 0.0000, -0.0344, 0.0000, -0.0344],[ 0.0000, -0.0344, 0.0000, -0.0344, 0.0000]]])In [9]: b
Out[9]: tensor([[[0.3064, 0.3333, 0.3410, 0.3333, 0.3333],[0.3333, 0.3333, 0.3333, 0.3333, 0.3333],[0.3333, 0.3333, 0.3333, 0.3333, 0.3333],[0.3333, 0.3333, 0.3333, 0.3333, 0.3333]],[[0.3468, 0.3333, 0.3295, 0.3333, 0.3333],[0.3333, 0.3333, 0.3333, 0.3333, 0.3333],[0.3333, 0.3333, 0.3333, 0.3333, 0.3333],[0.3333, 0.3333, 0.3333, 0.3333, 0.3333]],[[0.3468, 0.3333, 0.3295, 0.3333, 0.3333],[0.3333, 0.3333, 0.3333, 0.3333, 0.3333],[0.3333, 0.3333, 0.3333, 0.3333, 0.3333],[0.3333, 0.3333, 0.3333, 0.3333, 0.3333]]])In [10]: b.sum()
Out[10]: tensor(20.0000)In [11]: b[0][0][0]+b[1][0][0]+b[2][0][0]
Out[11]: tensor(1.0000)In [12]: c.sum()
Out[12]: tensor(15.)In [13]: c
Out[13]:
tensor([[[0.2235, 0.2543, 0.2521, 0.2543, 0.2457],[0.2618, 0.2457, 0.2521, 0.2457, 0.2543],[0.2529, 0.2543, 0.2436, 0.2543, 0.2457],[0.2618, 0.2457, 0.2521, 0.2457, 0.2543]],[[0.2457, 0.2543, 0.2457, 0.2543, 0.2457],[0.2543, 0.2457, 0.2543, 0.2457, 0.2543],[0.2457, 0.2543, 0.2457, 0.2543, 0.2457],[0.2543, 0.2457, 0.2543, 0.2457, 0.2543]],[[0.2457, 0.2543, 0.2457, 0.2543, 0.2457],[0.2543, 0.2457, 0.2543, 0.2457, 0.2543],[0.2457, 0.2543, 0.2457, 0.2543, 0.2457],[0.2543, 0.2457, 0.2543, 0.2457, 0.2543]]])In [14]: n=t.rand(3,4)In [15]: n
Out[15]: tensor([[0.2769, 0.3475, 0.8914, 0.6845],[0.9251, 0.3976, 0.8690, 0.4510],[0.8249, 0.1157, 0.3075, 0.3799]])In [16]: m=t.argmax(n,dim=0)In [17]: m
Out[17]: tensor([1, 1, 0, 0])In [18]: p=t.argmax(n,dim=1)In [19]: p
Out[19]: tensor([2, 0, 0])In [20]: d.sum()
Out[20]: tensor(12.0000)In [22]: d
Out[22]: tensor([[[0.1771, 0.2075, 0.2075, 0.2075, 0.2005],[0.2027, 0.1959, 0.2027, 0.1959, 0.2027],[0.1972, 0.2041, 0.1972, 0.2041, 0.1972],[0.2027, 0.1959, 0.2027, 0.1959, 0.2027]],[[0.1972, 0.2041, 0.1972, 0.2041, 0.1972],[0.2027, 0.1959, 0.2027, 0.1959, 0.2027],[0.1972, 0.2041, 0.1972, 0.2041, 0.1972],[0.2027, 0.1959, 0.2027, 0.1959, 0.2027]],[[0.1972, 0.2041, 0.1972, 0.2041, 0.1972],[0.2027, 0.1959, 0.2027, 0.1959, 0.2027],[0.1972, 0.2041, 0.1972, 0.2041, 0.1972],[0.2027, 0.1959, 0.2027, 0.1959, 0.2027]]])In [23]: d[0][0].sum()
Out[23]: tensor(1.)

softmax的参数很迷

看一下

Pytorch中torch.nn.Softmax的dim参数使用含义

涉及到多维tensor时,对softmax的参数dim总是很迷,下面用一个例子说明

import torch.nn as nnm = nn.Softmax(dim=0)n = nn.Softmax(dim=1)k = nn.Softmax(dim=2)input = torch.randn(2, 2, 3)print(input)print(m(input))print(n(input))print(k(input))

输出:

inputtensor([[[ 0.5450, -0.6264, 1.0446],
[ 0.6324, 1.9069, 0.7158]],[[ 1.0092, 0.2421, -0.8928],
[ 0.0344, 0.9723, 0.4328]]])

dim=0

tensor([[[0.3860, 0.2956, 0.8741],
[0.6452, 0.7180, 0.5703]],[[0.6140, 0.7044, 0.1259],
[0.3548, 0.2820, 0.4297]]])

dim=0时,在第0维上sum=1,即:

[0][0][0]+[1][0][0]=0.3860+0.6140=1
[0][0][1]+[1][0][1]=0.2956+0.7044=1
… …

dim=1

tensor([[[0.4782, 0.0736, 0.5815],
[0.5218, 0.9264, 0.4185]],[[0.7261, 0.3251, 0.2099],
[0.2739, 0.6749, 0.7901]]])

dim=1时,在第1维上sum=1,即:

[0][0][0]+[0][1][0]=0.4782+0.5218=1
[0][0][1]+[0][1][1]=0.0736+0.9264=1
… …

dim=2

tensor([[[0.3381, 0.1048, 0.5572],
[0.1766, 0.6315, 0.1919]],[[0.6197, 0.2878, 0.0925],
[0.1983, 0.5065, 0.2953]]])

dim=2时,在第2维上sum=1,即:

[0][0][0]+[0][0][1]+[0][0][2]=0.3381+0.1048+0.5572=1.0001(四舍五入问题)
[0][1][0]+[0][1][1]+[0][1][2]=0.1766+0.6315+0.1919=1
… …

补充知识:多分类问题torch.nn.Softmax的使用

为什么谈论这个问题呢?是因为我在工作的过程中遇到了语义分割预测输出特征图个数为16,也就是所谓的16分类问题。

因为每个通道的像素的值的大小代表了像素属于该通道的类的大小,为了在一张图上用不同的颜色显示出来,我不得不学习了torch.nn.Softmax的使用。

首先看一个简答的例子,倘若输出为(3, 4, 4),也就是3张4x4的特征图。

import torch
img = torch.rand((3,4,4))
print(img)tensor([[[0.0413, 0.8728, 0.8926, 0.0693],[0.4072, 0.0302, 0.9248, 0.6676],[0.4699, 0.9197, 0.3333, 0.4809],[0.3877, 0.7673, 0.6132, 0.5203]],[[0.4940, 0.7996, 0.5513, 0.8016],[0.1157, 0.8323, 0.9944, 0.2127],[0.3055, 0.4343, 0.8123, 0.3184],[0.8246, 0.6731, 0.3229, 0.1730]],[[0.0661, 0.1905, 0.4490, 0.7484],[0.4013, 0.1468, 0.2145, 0.8838],[0.0083, 0.5029, 0.0141, 0.8998],[0.8673, 0.2308, 0.8808, 0.0532]]])

我们可以看到共三张特征图,每张特征图上对应的值越大,说明属于该特征图对应类的概率越大。

import torch.nn as nn
sogtmax = nn.Softmax(dim=0)
img = sogtmax(img)
print(img)tensor([[[0.2780, 0.4107, 0.4251, 0.1979],[0.3648, 0.2297, 0.3901, 0.3477],[0.4035, 0.4396, 0.2993, 0.2967],[0.2402, 0.4008, 0.3273, 0.4285]],[[0.4371, 0.3817, 0.3022, 0.4117],[0.2726, 0.5122, 0.4182, 0.2206],[0.3423, 0.2706, 0.4832, 0.2522],[0.3718, 0.3648, 0.2449, 0.3028]],[[0.2849, 0.2076, 0.2728, 0.3904],[0.3627, 0.2581, 0.1917, 0.4317],[0.2543, 0.2898, 0.2175, 0.4511],[0.3880, 0.2344, 0.4278, 0.2686]]])

可以看到,上面的代码对每张特征图对应位置的像素值进行Softmax函数处理, 图中对应通道(维度)位置加和=1。

我们看到Softmax函数会对原特征图每个像素的值在对应维度(这里dim=0,也就是第一维)上进行计算,将其处理到0~1之间,并且大小固定不变。

print(torch.max(img,0))torch.return_types.max(
values=tensor([[0.4371, 0.4107, 0.4251, 0.4117],[0.3648, 0.5122, 0.4182, 0.4317],[0.4035, 0.4396, 0.4832, 0.4511],[0.3880, 0.4008, 0.4278, 0.4285]]),
indices=tensor([[1, 0, 0, 1],[0, 1, 1, 2],[0, 0, 1, 2],[2, 0, 2, 0]]))

可以看到这里3x4x4变成了1x4x4,而且对应位置上的值为像素对应每个通道上的最大值,并且indices是对应的分类。

清楚理解了上面的流程,那么我们就容易处理了。

看具体案例,这里输出output的大小为:16x416x416.

output = torch.tensor(output)sm = nn.Softmax(dim=0)
output = sm(output)mask = torch.max(output,0).indices.numpy()# 因为要转化为RGB彩色图,所以增加一维
rgb_img = np.zeros((output.shape[1], output.shape[2], 3))
for i in range(len(mask)):for j in range(len(mask[0])):if mask[i][j] == 0:rgb_img[i][j][0] = 255rgb_img[i][j][1] = 255rgb_img[i][j][2] = 255if mask[i][j] == 1:rgb_img[i][j][0] = 255rgb_img[i][j][1] = 180rgb_img[i][j][2] = 0if mask[i][j] == 2:rgb_img[i][j][0] = 255rgb_img[i][j][1] = 180rgb_img[i][j][2] = 180if mask[i][j] == 3:rgb_img[i][j][0] = 255rgb_img[i][j][1] = 180rgb_img[i][j][2] = 255if mask[i][j] == 4:rgb_img[i][j][0] = 255rgb_img[i][j][1] = 255rgb_img[i][j][2] = 180if mask[i][j] == 5:rgb_img[i][j][0] = 255rgb_img[i][j][1] = 255rgb_img[i][j][2] = 0if mask[i][j] == 6:rgb_img[i][j][0] = 255rgb_img[i][j][1] = 0rgb_img[i][j][2] = 180if mask[i][j] == 7:rgb_img[i][j][0] = 255rgb_img[i][j][1] = 0rgb_img[i][j][2] = 255if mask[i][j] == 8:rgb_img[i][j][0] = 255rgb_img[i][j][1] = 0rgb_img[i][j][2] = 0if mask[i][j] == 9:rgb_img[i][j][0] = 180rgb_img[i][j][1] = 0rgb_img[i][j][2] = 0if mask[i][j] == 10:rgb_img[i][j][0] = 180rgb_img[i][j][1] = 255rgb_img[i][j][2] = 255if mask[i][j] == 11:rgb_img[i][j][0] = 180rgb_img[i][j][1] = 0rgb_img[i][j][2] = 180if mask[i][j] == 12:rgb_img[i][j][0] = 180rgb_img[i][j][1] = 0rgb_img[i][j][2] = 255if mask[i][j] == 13:rgb_img[i][j][0] = 180rgb_img[i][j][1] = 255rgb_img[i][j][2] = 180if mask[i][j] == 14:rgb_img[i][j][0] = 0rgb_img[i][j][1] = 180rgb_img[i][j][2] = 255if mask[i][j] == 15:rgb_img[i][j][0] = 0rgb_img[i][j][1] = 0rgb_img[i][j][2] = 0cv2.imwrite('output.jpg', rgb_img)

pytorch中torch.max和F.softmax函数的维度解释相关推荐

  1. Pytorch中torch.nn.Softmax的dim参数含义

    自己搞了一晚上终于搞明白了,下文说的很透彻,做个记录,方便以后翻阅 Pytorch中torch.nn.Softmax的dim参数含义

  2. pytorch中torch.optim的介绍

    pytorch中torch.optim的介绍 这是torch自带的一个优化器,里面自带了求导,更新等操作.开门见山直接讲怎么使用: 常用的引入: import torch.optim as optim ...

  3. Pytorch学习-torch.max()和min()深度解析

    Pytorch学习-torch.max和min深度解析 max的使用 min同理 dim参数理解 二维张量使用max() 三维张量使用max() max的使用 min同理 参考链接: 参考链接: 对于 ...

  4. PyTorch 中 torch.optim优化器的使用

    一.优化器基本使用方法 建立优化器实例 循环: 清空梯度 向前传播 计算Loss 反向传播 更新参数 示例: from torch import optim input = ..... optimiz ...

  5. PyTorch中torch.norm函数详解

    torch.norm() 是 PyTorch 中的一个函数,用于计算输入张量沿指定维度的范数.具体而言,当给定一个输入张量 x 和一个整数 p 时,torch.norm(x, p) 将返回输入张量 x ...

  6. pytorch 之 torch.max() 和 torch.min() 记录

    两个函数用法相同,此处就介绍max函数. 1.torch.max(data),不指定维度,返回data的最大值. 2.torch.max(data,dim),返回data中指定维度的最大值. 3.to ...

  7. 【Pytorch神经网络理论篇】 08 Softmax函数(处理分类问题)

    1.1 Softmax函数简介 oftmax函数本质也为激活函数,主要用于多分类问题,且要求分类互斥,分类器最后的输出单元需要Softmax 函数进行数值处理. Tip:在搭建网络模型的时候,需要用S ...

  8. python中numpy函数fft_如何在PyTorch中正确使用Numpy的FFT函数?

    我最近被介绍给Pythorch,开始浏览图书馆的文档和教程. 在"使用numpy和scipy创建扩展"教程中( http://pytorch.org/tutorials/advan ...

  9. Pytorch中tensor.expand()和tensor.expand_as()函数

    Pytorch中tensor.expand函数 Tensor.expand()函数详解 Tensor.expand_as()函数 Tensor.expand()函数详解 函数语法: # 官方解释: D ...

  10. F.softmax函数dim解读

    F.softmax(score, dim=1) dim=1就是对score矩阵中 所有第1维下标不同,其他维下标均相同的元素进行操作(softmax) 比如a[0][8][15]和a[7][8][15 ...

最新文章

  1. 多级反馈队列调度算法原理
  2. 小明分享|WiFi协议迭代历程
  3. cdgb调试linux崩溃程序
  4. Elasticsearch 技术分析(七): Elasticsearch 的性能优化
  5. html5跟html4有什么区别,Html5和Html4的区别
  6. 安卓项目中的R.java文件丢失如何解决
  7. mysql将查到的数据删除_MySQL基本SQL语句之数据插入、删除数据和更新数据 | 旺旺知识库...
  8. [.NET] 在Windows系统中搭建基于.NET的iPhone应用程序虚机开发环境
  9. Codeforces D546:Soldier and Number Game
  10. 爬取奇迹秀工具箱里面的文本和软件网盘链接
  11. python 全国内地中高风险地区数量查询与可视化(分省)
  12. 交通灯控制系统C语言代码,《C语言代码-交通灯控制器》.doc
  13. 怎么彻底删除users下的文件夹_什么工具可以有效清理C:\Users\用户名\AppData目录下的文件?...
  14. 使用python爬取电子书_怎样用python3爬取电子书网站所有下载链接
  15. C# vs2019 智能提示中文突然变成英文
  16. 实验九 使用异步方式实现文件读\写
  17. Python实现GWO智能灰狼优化算法优化支持向量机回归模型(svr算法)项目实战
  18. 第四课 尚硅谷Scala语言学习-面向对象
  19. 统计软件与数据分析—Lesson2
  20. 彻底清楚搞懂toRef和toRefs是什么,也许你知道toRef和toRefs,一直有点蒙蔽,一直没搞懂它,看完这篇文章你彻底清楚

热门文章

  1. ArcMAP 用不同颜色区分地类
  2. 二级分类php代码,php smarty 二级分类代码和模版循环例子
  3. Python基础语法-05-装饰器
  4. Android 滑动冲突问题的简单解决思路
  5. hibernate一对多双向关联中怎么配置list
  6. 专治数仓疑难杂症!美团点评 Flink 实时数仓应用经验分享
  7. 短视频秒播优化实践(一)
  8. 因为一个YYYY-MM-dd的Bug,我被老板骂的狗血淋头!
  9. findViewById中NullPointerException的错误
  10. asp oracle数据库开发 adodb,asp怎么连接oracle数据库