torch的拼接函数_pytorch常用函数总结(持续更新)
pytorch常用函数总结
torch.max(input,dim)
求取指定维度上的最大值,,返回输入张量给定维度上每行的最大值,并同时返回每个最大值的位置索引。比如:
demo.shape
Out[7]: torch.Size([10, 3, 10, 10])
torch.max(demo,1)[0].shape
Out[8]: torch.Size([10, 10, 10])
torch.max(demo,1)[0]这其中的[0]取得就是返回的最大值,torch.max(demo,1)[1]就是返回的最大值对应的位置索引。例子如下:
a
Out[8]:
tensor([[1., 2., 3.],
[4., 5., 6.]])
a.max(1)
Out[9]:
torch.return_types.max(
values=tensor([3., 6.]),
indices=tensor([2, 2]))
class torch.nn.ParameterList(parameters=None)
将submodules保存在一个list中。
ParameterList可以像一般的Python list一样被索引。而且ParameterList中包含的parameters已经被正确的注册,对所有的module method可见。
参数说明:
modules (list, optional) – a list of nn.Parameter
例子:
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.params = nn.ParameterList([nn.Parameter(torch.randn(10, 10)) for i in range(10)])
def forward(self, x):
# ModuleList can act as an iterable, or be indexed using ints
for i, p in enumerate(self.params):
x = self.params[i // 2].mm(x) + p.mm(x)
return x
torch.cat()函数
cat是concatnate的意思:拼接,联系在一起。
先说cat( )的普通用法
如果我们有两个tensor是A和B,想把他们拼接在一起,需要如下操作:
C = torch.cat( (A,B),0 ) #按维数0拼接(竖着拼)
C = torch.cat( (A,B),1 ) #按维数1拼接(横着拼)
相当于将tensor按照指定维度进行拼接,比如A的shape为128*64*32*32,B的shape为 128*32*64*64,那么按照 torch.cat( (A,B),1)拼接的之后的形状为 128*96*64*64。
注意:
两个tensor要想进行拼接,必须保证除了指定拼接的维度以外其他的维度形状必须相同,比如上面的例子,拼接A和B时,A的形状为128*64*32*32,B的形状为128*32*64*64,只有第二个维度的维数数值不同,其他的维度的维数都是相同的,所以拼接时可按维度1进行拼接(注意,维度的下标是从0开始的,比如 A 的形状对应的维度下标为:\(128_0*64_1*32_2*32_3\))
contiguous()函数的使用
contiguous一般与transpose,permute,view搭配使用:使用transpose或permute进行维度变换后,调用contiguous,然后方可使用view对维度进行变形(如:tensor_var.contiguous().view() ),示例如下:
x = torch.Tensor(2,3)
y = x.permute(1,0) # permute:二维tensor的维度变换,此处功能相当于转置transpose
y.view(-1) # 报错,view使用前需调用contiguous()函数
y = x.permute(1,0).contiguous()
y.view(-1) # OK
具体原因有两种说法:
1 transpose、permute等维度变换操作后,tensor在内存中不再是连续存储的,而view操作要求tensor的内存连续存储,所以需要contiguous来返回一个contiguous copy;
2 维度变换后的变量是之前变量的浅拷贝,指向同一区域,即view操作会连带原来的变量一同变形,这是不合法的,所以也会报错;---- 这个解释有部分道理,也即contiguous返回了tensor的深拷贝contiguous copy数据;
tensor.repeat()函数
该函数传入的参数个数不少于tensor的维数,其中每个参数代表的是对该维度重复多少次,也就相当于复制的倍数,结合例子更好理解,如下:
>>> import torch
>>>
>>> a = torch.randn(33, 55)
>>> a.size()
torch.Size([33, 55])
>>>
>>> a.repeat(1, 1).size()
torch.Size([33, 55])
>>>
>>> a.repeat(2,1).size()
torch.Size([66, 55])
>>>
>>> a.repeat(1,2).size()
torch.Size([33, 110])
>>>
>>> a.repeat(1,1,1).size()
torch.Size([1, 33, 55])
>>>
>>> a.repeat(2,1,1).size()
torch.Size([2, 33, 55])
>>>
>>> a.repeat(1,2,1).size()
torch.Size([1, 66, 55])
>>>
>>> a.repeat(1,1,2).size()
torch.Size([1, 33, 110])
>>>
>>> a.repeat(1,1,1,1).size()
torch.Size([1, 1, 33, 55])
>>>
>>> # repeat()的参数的个数,不能少于被操作的张量的维度的个数,
>>> # 下面是一些错误示例
>>> a.repeat(2).size() # 1D < 2D, error
Traceback (most recent call last):
File "", line 1, in
RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor
>>>
>>> b = torch.randn(5,6,7)
>>> b.size() # 3D
torch.Size([5, 6, 7])
>>>
>>> b.repeat(2).size() # 1D < 3D, error
Traceback (most recent call last):
File "", line 1, in
RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor
>>>
>>> b.repeat(2,1).size() # 2D < 3D, error
Traceback (most recent call last):
File "", line 1, in
RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor
>>>
>>> b.repeat(2,1,1).size() # 3D = 3D, okay
torch.Size([10, 6, 7])
>>>
torch.masked_select()函数
a = torch.Tensor([[4,5,7], [3,9,8],[2,3,4]])
b = torch.Tensor([[1,1,0], [0,0,1],[1,0,1]]).type(torch.ByteTensor)
c = torch.masked_select(a,b)
print(c)
用法:torch.masked_select(x, mask),mask必须转化成torch.ByteTensor类型。
torch.sort
torch.sort(input, dim=None, descending=False, out=None) -> (Tensor, LongTensor)
对输入张量input沿着指定维按升序排序。如果不给定dim,则默认为输入的最后一维。如果指定参数descending为True,则按降序排序
返回元组 (sorted_tensor, sorted_indices) , sorted_indices 为原始输入中的下标。
参数:
input (Tensor) – 要对比的张量
dim (int, optional) – 沿着此维排序
descending (bool, optional) – 布尔值,控制升降排序
out (tuple, optional) – 输出张量。必须为ByteTensor或者与第一个参数tensor相同类型。
例子:
>>> x = torch.randn(3, 4)
>>> sorted, indices = torch.sort(x)
>>> sorted
-1.6747 0.0610 0.1190 1.4137
-1.4782 0.7159 1.0341 1.3678
-0.3324 -0.0782 0.3518 0.4763
[torch.FloatTensor of size 3x4]
>>> indices
0 1 3 2
2 1 0 3
3 1 0 2
[torch.LongTensor of size 3x4]
>>> sorted, indices = torch.sort(x, 0)
>>> sorted
-1.6747 -0.0782 -1.4782 -0.3324
0.3518 0.0610 0.4763 0.1190
1.0341 0.7159 1.4137 1.3678
[torch.FloatTensor of size 3x4]
>>> indices
0 2 1 2
2 0 2 0
1 1 0 1
[torch.LongTensor of size 3x4]
torch的拼接函数_pytorch常用函数总结(持续更新)相关推荐
- oracle共享函数,oracle常用函数及示例分享
oracle很多常用的函数如果了解的话可以加速开发,原本想总结下自己工作中使用oracle函数的一些场景,后发现川哥哥的博客总结的很好,为了方便查询函数就转摘过来. 总结的很不错,简单易懂,没什么事就 ...
- PHP 常用函数 - 其他常用函数
PHP 常用函数 PHP 常用函数 - 字符串函数 PHP 常用函数 - 数组函数 PHP 常用函数 - 数学函数 PHP 常用函数 - 目录.文件函数 PHP 常用函数 - 其他常用函数 文章目录 ...
- linux常用指令(持续更新)
linux常用指令(持续更新) 基本访问指令: 直接进入用户的home目录: cd ~ 进入上一个目录: cd - 进入当前目录的上一层目录: cd .. 进入当前目录的上两层目录: cd ../.. ...
- Android常用开发网址(持续更新)
2019独角兽企业重金招聘Python工程师标准>>> Android常用开发网址(持续更新) 环境搭建 android镜像 http://www.androiddevtools.c ...
- 程序员常用英语积累---持续更新
程序员常用英语积累---持续更新: Distribution: 分发 Direction : 方向 Description: 描述 Destination: 目标 Definition : ...
- 工具篇:Git与Github+GitLib常用操作(不定期持续更新)
工具篇:Git与Github+GitLib常用操作(不定期持续更新) 前言: 写这个主要是打算自己用的,里边很多东西都是只要我自己看得懂,但是用了两个星期发现真是越用越简单,越用越好用,私以为得到了学 ...
- python常用函数import_python 常用函数集合
1.常用函数 round() : 四舍五入 参数1:要处理的小数 参数2:可选,如果不加,就是不要小数,如果加,就是保留几位小数 abs() :绝对值函数 max() :列表.字符串,得到最大的元素 ...
- MYSQL天花板函数和地板函数_2020-08-04常用函数
•单行函数语法 –语法: 函数名[(参数1,参数2,-)] –其中的参数可以是以下之一: •变量 •列名 •表达式 •单行函数特征 –单行函数对单行操作 –每行返回一个结果 –有可能返回值与原参数数据 ...
- Hive内置函数与常用函数汇总
目录 Hive内置函数汇总 字符函数(字符串操作) 数学函数 集合函数 类型转换函数 日期函数 条件函数 聚合函数 表生成函数 辅助功能类函数 数据屏蔽函数(从Hive 2.1.0开始) Hive常用 ...
- matlab doc函数,matlab常用函数.doc
matlab常用函数.doc MatLab 常用函数 1. 特殊变量与常数 ans 计算结果的变量名 computer 确定运行的计算机 eps 浮点相对精度 Inf 无穷大 I 虚数单位 name ...
最新文章
- 向量空间和计算机科学与技术,向量空间
- 海信最后的倔强,激光电视最终难逃“过渡产品”的命运?
- openssl升级_CVE20201967: openssl 拒绝服务漏洞通告
- VTK:可视化之ChooseTextColor
- python3爬虫初探(八)requests
- androidid什么时候会变_高瓷绿松石是什么意思?为何绿松石的瓷度要比颜色重要?...
- android viewflipper 动画,Android自定义ViewFlipper实现滚动效果
- 点到线段的距离 计算几何
- Kalman Fuzzy Actor-Critic Learning Automaton Algorithm for the Pursuit-Evasion Differential Game
- Debian 7 源(32/64bit)好用的源
- 面试官:为何Redis使用跳表而非红黑树实现SortedSet?
- 攻城掠地(优先队列)
- 多语言 - 国际化处理 上
- 最短路小结(三种算法+各种常见变种)
- xgboost的使用简析
- DirectX学习笔记(十五):粒子系统实现
- Jasperreport_6.18的吐血记录五之柱形图
- Kafka术语:AR、OSR、ISR、HW和LEO以及之间的关系
- 《电脑音乐制作实战指南:伴奏、录歌、MTV全攻略》——导读
- win10系统想下载win7系统自带的游戏——分享游戏压缩包
热门文章
- 用 WebSocket 实现一个简单的客服聊天系统
- Nginx配置文件(作为Web服务器)
- 类库从自带的配置文件中获取信息(DLL文件 获取 DLL文件自带的配置信息) z...
- [C++] socket - 2 [UDP通信C/S实例]
- SharePoint2013更改网站集端口方法
- Exchange 2010分层通讯薄(HAB)配置指南
- Android引领移动互联网革命的七大理由
- 【SpringBoot_ANNOTATIONS】自动装配 05 @Profile环境搭建
- 那些年使用Android studio遇到的问题
- 【JavaEE】第零章(2020.03.06)模式 表 索引