pytorch常用函数总结

torch.max(input,dim)

求取指定维度上的最大值,,返回输入张量给定维度上每行的最大值,并同时返回每个最大值的位置索引。比如:

demo.shape

Out[7]: torch.Size([10, 3, 10, 10])

torch.max(demo,1)[0].shape

Out[8]: torch.Size([10, 10, 10])

torch.max(demo,1)[0]这其中的[0]取得就是返回的最大值,torch.max(demo,1)[1]就是返回的最大值对应的位置索引。例子如下:

a

Out[8]:

tensor([[1., 2., 3.],

[4., 5., 6.]])

a.max(1)

Out[9]:

torch.return_types.max(

values=tensor([3., 6.]),

indices=tensor([2, 2]))

class torch.nn.ParameterList(parameters=None)

将submodules保存在一个list中。

ParameterList可以像一般的Python list一样被索引。而且ParameterList中包含的parameters已经被正确的注册,对所有的module method可见。

参数说明:

modules (list, optional) – a list of nn.Parameter

例子:

class MyModule(nn.Module):

def __init__(self):

super(MyModule, self).__init__()

self.params = nn.ParameterList([nn.Parameter(torch.randn(10, 10)) for i in range(10)])

def forward(self, x):

# ModuleList can act as an iterable, or be indexed using ints

for i, p in enumerate(self.params):

x = self.params[i // 2].mm(x) + p.mm(x)

return x

torch.cat()函数

cat是concatnate的意思:拼接,联系在一起。

先说cat( )的普通用法

如果我们有两个tensor是A和B,想把他们拼接在一起,需要如下操作:

C = torch.cat( (A,B),0 ) #按维数0拼接(竖着拼)

C = torch.cat( (A,B),1 ) #按维数1拼接(横着拼)

相当于将tensor按照指定维度进行拼接,比如A的shape为128*64*32*32,B的shape为 128*32*64*64,那么按照 torch.cat( (A,B),1)拼接的之后的形状为 128*96*64*64。

注意:

两个tensor要想进行拼接,必须保证除了指定拼接的维度以外其他的维度形状必须相同,比如上面的例子,拼接A和B时,A的形状为128*64*32*32,B的形状为128*32*64*64,只有第二个维度的维数数值不同,其他的维度的维数都是相同的,所以拼接时可按维度1进行拼接(注意,维度的下标是从0开始的,比如 A 的形状对应的维度下标为:\(128_0*64_1*32_2*32_3\))

contiguous()函数的使用

contiguous一般与transpose,permute,view搭配使用:使用transpose或permute进行维度变换后,调用contiguous,然后方可使用view对维度进行变形(如:tensor_var.contiguous().view() ),示例如下:

x = torch.Tensor(2,3)

y = x.permute(1,0) # permute:二维tensor的维度变换,此处功能相当于转置transpose

y.view(-1) # 报错,view使用前需调用contiguous()函数

y = x.permute(1,0).contiguous()

y.view(-1) # OK

具体原因有两种说法:

1 transpose、permute等维度变换操作后,tensor在内存中不再是连续存储的,而view操作要求tensor的内存连续存储,所以需要contiguous来返回一个contiguous copy;

2 维度变换后的变量是之前变量的浅拷贝,指向同一区域,即view操作会连带原来的变量一同变形,这是不合法的,所以也会报错;---- 这个解释有部分道理,也即contiguous返回了tensor的深拷贝contiguous copy数据;

tensor.repeat()函数

该函数传入的参数个数不少于tensor的维数,其中每个参数代表的是对该维度重复多少次,也就相当于复制的倍数,结合例子更好理解,如下:

>>> import torch

>>>

>>> a = torch.randn(33, 55)

>>> a.size()

torch.Size([33, 55])

>>>

>>> a.repeat(1, 1).size()

torch.Size([33, 55])

>>>

>>> a.repeat(2,1).size()

torch.Size([66, 55])

>>>

>>> a.repeat(1,2).size()

torch.Size([33, 110])

>>>

>>> a.repeat(1,1,1).size()

torch.Size([1, 33, 55])

>>>

>>> a.repeat(2,1,1).size()

torch.Size([2, 33, 55])

>>>

>>> a.repeat(1,2,1).size()

torch.Size([1, 66, 55])

>>>

>>> a.repeat(1,1,2).size()

torch.Size([1, 33, 110])

>>>

>>> a.repeat(1,1,1,1).size()

torch.Size([1, 1, 33, 55])

>>>

>>> # repeat()的参数的个数,不能少于被操作的张量的维度的个数,

>>> # 下面是一些错误示例

>>> a.repeat(2).size() # 1D < 2D, error

Traceback (most recent call last):

File "", line 1, in

RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor

>>>

>>> b = torch.randn(5,6,7)

>>> b.size() # 3D

torch.Size([5, 6, 7])

>>>

>>> b.repeat(2).size() # 1D < 3D, error

Traceback (most recent call last):

File "", line 1, in

RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor

>>>

>>> b.repeat(2,1).size() # 2D < 3D, error

Traceback (most recent call last):

File "", line 1, in

RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor

>>>

>>> b.repeat(2,1,1).size() # 3D = 3D, okay

torch.Size([10, 6, 7])

>>>

torch.masked_select()函数

a = torch.Tensor([[4,5,7], [3,9,8],[2,3,4]])

b = torch.Tensor([[1,1,0], [0,0,1],[1,0,1]]).type(torch.ByteTensor)

c = torch.masked_select(a,b)

print(c)

用法:torch.masked_select(x, mask),mask必须转化成torch.ByteTensor类型。

torch.sort

torch.sort(input, dim=None, descending=False, out=None) -> (Tensor, LongTensor)

对输入张量input沿着指定维按升序排序。如果不给定dim,则默认为输入的最后一维。如果指定参数descending为True,则按降序排序

返回元组 (sorted_tensor, sorted_indices) , sorted_indices 为原始输入中的下标。

参数:

input (Tensor) – 要对比的张量

dim (int, optional) – 沿着此维排序

descending (bool, optional) – 布尔值,控制升降排序

out (tuple, optional) – 输出张量。必须为ByteTensor或者与第一个参数tensor相同类型。

例子:

>>> x = torch.randn(3, 4)

>>> sorted, indices = torch.sort(x)

>>> sorted

-1.6747 0.0610 0.1190 1.4137

-1.4782 0.7159 1.0341 1.3678

-0.3324 -0.0782 0.3518 0.4763

[torch.FloatTensor of size 3x4]

>>> indices

0 1 3 2

2 1 0 3

3 1 0 2

[torch.LongTensor of size 3x4]

>>> sorted, indices = torch.sort(x, 0)

>>> sorted

-1.6747 -0.0782 -1.4782 -0.3324

0.3518 0.0610 0.4763 0.1190

1.0341 0.7159 1.4137 1.3678

[torch.FloatTensor of size 3x4]

>>> indices

0 2 1 2

2 0 2 0

1 1 0 1

[torch.LongTensor of size 3x4]

torch的拼接函数_pytorch常用函数总结(持续更新)相关推荐

  1. oracle共享函数,oracle常用函数及示例分享

    oracle很多常用的函数如果了解的话可以加速开发,原本想总结下自己工作中使用oracle函数的一些场景,后发现川哥哥的博客总结的很好,为了方便查询函数就转摘过来. 总结的很不错,简单易懂,没什么事就 ...

  2. PHP 常用函数 - 其他常用函数

    PHP 常用函数 PHP 常用函数 - 字符串函数 PHP 常用函数 - 数组函数 PHP 常用函数 - 数学函数 PHP 常用函数 - 目录.文件函数 PHP 常用函数 - 其他常用函数 文章目录 ...

  3. linux常用指令(持续更新)

    linux常用指令(持续更新) 基本访问指令: 直接进入用户的home目录: cd ~ 进入上一个目录: cd - 进入当前目录的上一层目录: cd .. 进入当前目录的上两层目录: cd ../.. ...

  4. Android常用开发网址(持续更新)

    2019独角兽企业重金招聘Python工程师标准>>> Android常用开发网址(持续更新) 环境搭建 android镜像 http://www.androiddevtools.c ...

  5. 程序员常用英语积累---持续更新

    程序员常用英语积累---持续更新: Distribution: 分发 Direction    : 方向 Description: 描述 Destination: 目标 Definition   : ...

  6. 工具篇:Git与Github+GitLib常用操作(不定期持续更新)

    工具篇:Git与Github+GitLib常用操作(不定期持续更新) 前言: 写这个主要是打算自己用的,里边很多东西都是只要我自己看得懂,但是用了两个星期发现真是越用越简单,越用越好用,私以为得到了学 ...

  7. python常用函数import_python 常用函数集合

    1.常用函数 round() :  四舍五入 参数1:要处理的小数 参数2:可选,如果不加,就是不要小数,如果加,就是保留几位小数 abs() :绝对值函数 max() :列表.字符串,得到最大的元素 ...

  8. MYSQL天花板函数和地板函数_2020-08-04常用函数

    •单行函数语法 –语法: 函数名[(参数1,参数2,-)] –其中的参数可以是以下之一: •变量 •列名 •表达式 •单行函数特征 –单行函数对单行操作 –每行返回一个结果 –有可能返回值与原参数数据 ...

  9. Hive内置函数与常用函数汇总

    目录 Hive内置函数汇总 字符函数(字符串操作) 数学函数 集合函数 类型转换函数 日期函数 条件函数 聚合函数 表生成函数 辅助功能类函数 数据屏蔽函数(从Hive 2.1.0开始) Hive常用 ...

  10. matlab doc函数,matlab常用函数.doc

    matlab常用函数.doc MatLab 常用函数 1. 特殊变量与常数 ans 计算结果的变量名 computer 确定运行的计算机 eps 浮点相对精度 Inf 无穷大 I 虚数单位 name ...

最新文章

  1. 向量空间和计算机科学与技术,向量空间
  2. 海信最后的倔强,激光电视最终难逃“过渡产品”的命运?
  3. openssl升级_CVE20201967: openssl 拒绝服务漏洞通告
  4. VTK:可视化之ChooseTextColor
  5. python3爬虫初探(八)requests
  6. androidid什么时候会变_高瓷绿松石是什么意思?为何绿松石的瓷度要比颜色重要?...
  7. android viewflipper 动画,Android自定义ViewFlipper实现滚动效果
  8. 点到线段的距离 计算几何
  9. Kalman Fuzzy Actor-Critic Learning Automaton Algorithm for the Pursuit-Evasion Differential Game
  10. Debian 7 源(32/64bit)好用的源
  11. 面试官:为何Redis使用跳表而非红黑树实现SortedSet?
  12. 攻城掠地(优先队列)
  13. 多语言 - 国际化处理 上
  14. 最短路小结(三种算法+各种常见变种)
  15. xgboost的使用简析
  16. DirectX学习笔记(十五):粒子系统实现
  17. Jasperreport_6.18的吐血记录五之柱形图
  18. Kafka术语:AR、OSR、ISR、HW和LEO以及之间的关系
  19. 《电脑音乐制作实战指南:伴奏、录歌、MTV全攻略》——导读
  20. win10系统想下载win7系统自带的游戏——分享游戏压缩包

热门文章

  1. 用 WebSocket 实现一个简单的客服聊天系统
  2. Nginx配置文件(作为Web服务器)
  3. 类库从自带的配置文件中获取信息(DLL文件 获取 DLL文件自带的配置信息) z...
  4. [C++] socket - 2 [UDP通信C/S实例]
  5. SharePoint2013更改网站集端口方法
  6. Exchange 2010分层通讯薄(HAB)配置指南
  7. Android引领移动互联网革命的七大理由
  8. 【SpringBoot_ANNOTATIONS】自动装配 05 @Profile环境搭建
  9. 那些年使用Android studio遇到的问题
  10. 【JavaEE】第零章(2020.03.06)模式 表 索引