torch.nn.flatten

torch.nn.flatten是一个类,作用为将连续的几个维度展平成一个tensor(将一些维度合并)

  • 参数为合并开始的维度,合并结束的维度(维度就是索引,从 0 开始)

    • 开始维度默认为 1。因为其被用在神经网络中,输入为一批数据,第 0 维为batch(输入数据的个数),通常要把一个数据拉成一维,而不是将一批数据拉为一维。所以torch.nn.Flatten()默认从第一维开始平坦化。
    • 结束维度默认为 -1,也就是一直合并到最后一维
  • 默认参数情况

    x = torch.ones(2, 2, 2, 2)F = torch.nn.Flatten()
    y = F(x)
    print(y)
    print(y.shape)
    >>tensor([[1., 1., 1., 1., 1., 1., 1., 1.],[1., 1., 1., 1., 1., 1., 1., 1.]])
    >>torch.Size([2, 8])
    
  • 输入一个参数情况:该参数为合并开始的维度

    x = torch.ones(2, 2, 2, 2)F = torch.nn.Flatten(2)
    y = F(x)
    print(y)
    print(y.shape)
    >>tensor([[[1., 1., 1., 1.],[1., 1., 1., 1.]],[[1., 1., 1., 1.],[1., 1., 1., 1.]]])
    >>torch.Size([2, 2, 4])
    
  • 输入两个参数情况:第一个参数代表合并开始维度,第二个参数代表合并结束维度(合并范围包含开始维度和结束维度)

    x = torch.ones(2, 2, 2, 2)F = torch.nn.Flatten(1, 2)
    y = F(x)
    print(y)
    print(y.shape)
    >>tensor([[[1., 1.],[1., 1.],[1., 1.],[1., 1.]],[[1., 1.],[1., 1.],[1., 1.],[1., 1.]]])
    >>torch.Size([2, 4, 2])
    

torch.flatten

作用与 torch.nn.flatten 类似,都是用于展平 tensor 的,只是 torch.flatten 是 function 而不是类,其默认开始维度为第 0 维

t = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
print(t.shape)
>>torch.Size([2, 2, 2])print(torch.flatten(t))
>>tensor([1, 2, 3, 4, 5, 6, 7, 8])print(torch.flatten(t, 1))
>>tensor([[1, 2, 3, 4],[5, 6, 7, 8]])print(torch.flatten(t, 0, 1).shape)
>>torch.Size([4, 2])

若输入是 0 维 tensor,则输出的是一维 tensor

t = torch.tensor(1)
print("before flatten:")
print(t)
print(t.shape)
>>before flatten:tensor(1)torch.Size([])print("\n")
print("after flatten:")
print(torch.flatten(t))
print(torch.flatten(t).shape)
>>after flatten:tensor([1])torch.Size([1])

torch.flatten与torch.nn.flatten相关推荐

  1. pytorch中的reshape()、view()、nn.flatten()和flatten()

    在使用pytorch定义神经网络结构时,经常会看到类似如下的.view() / flatten()用法,这里对其用法做出讲解与演示. torch.reshape用法 reshape()可以由torch ...

  2. 【PyTorch】 torch.flatten()与nn.Flatten()的区别

    问题 torch.flatten()与nn.Flatten()都可以实现展开Tensor,那么二者的区别是什么呢? 方法 经过查阅相关资料,发现二者主要区别有: (1) 默认的dim不同,torch. ...

  3. nn.Flatten()函数详解及示例

    torch.nn.Flatten(start_dim=1, end_dim=- 1) 作用:将连续的维度范围展平为张量. 经常在nn.Sequential()中出现,一般写在某个神经网络模型之后,用于 ...

  4. torch学习二之nn.Convolution

    torch学习二之nn.Convolution nn.Conv1d 函数参数 输入数据维度转换 关于kernel nn.Conv2D nn.Conv1d 一维卷积通常用于处理文本数据 函数参数 首先看 ...

  5. PyTorch 笔记(08)— Tensor 比较运算(torch.gt、lt、ge、le、eq、ne、torch.topk、torch.sort、torch.max、torch.min)

    1. 常用函数 比较函数中有一些是逐元素比较,操作类似逐元素操作,还有一些类似归并操作,常用的比较函数如下表所示. 表中第一行的比较操作已经实现了运算符重载,因此可以使用 a>=b,a>b ...

  6. torch.Tensor和torch.tensor的区别

    torch.Tensor和torch.tensor的区别 2019-06-10 16:34:48 Vic_Hao 阅读数 4058更多 分类专栏: Pytorch 在Pytorch中,Tensor和t ...

  7. pytorch torch.Tensor.new_ones()(返回一个与size大小相同的用1填充的张量。 默认返回的Tensor具有与此张量相同的torch.dtype和torch.device)

    from https://pytorch.org/docs/1.1.0/tensors.html?highlight=new_ones#torch.Tensor.new_ones new_ones(s ...

  8. pytorch拼接函数:torch.stack()和torch.cat()--详解及例子

    原文链接: https://blog.csdn.net/xinjieyuan/article/details/105205326 https://blog.csdn.net/xinjieyuan/ar ...

  9. torch.unsqueeze()和torch.unsqueeze()

    参考:torch.squeeze() 和torch.unsqueeze()用法的通俗解释 import torch x = torch.tensor([[1, 2, 3],[1, 2, 3],[1, ...

最新文章

  1. mysql 判断日期是否在某范围内_判断时间是否在某个区间内
  2. css之命名规范 BEM
  3. C语言试题十五之编写函数void function(int x,int pp[],int *n),求出能整除x且不是偶数的各整数,并按从小到大的顺序放在pp所指的数组中,这些除数的个数通过形参n返回
  4. tasker运行java_Tasker 打开桌面快捷方式(以微信公众号为例)[No Root]
  5. 我的第一个python web开发框架(23)——代码版本控制管理与接口文档
  6. 教你如何用 Python 三行代码做动图!
  7. 在PostgreSQL中创建数据库的副本
  8. 改变win7登陆时的界面
  9. 高通9008驱动_安卓手机高通9008模式下如何救砖
  10. 一个程序员父亲的呼吁:不要教你的孩子从小学编程!
  11. SpringBootJ2EE相关介绍
  12. 计算机考证要考PS吗
  13. Android 集成百度地图服务和驾车导航jar包冲突、驾车导航引入armeabi-v7a平台
  14. HTML与CSS--------p标签
  15. MySQL 中文字段排序问题(根据中文拼音排序)
  16. 暴雪battle注册账户不转到中国
  17. 入职阿里一周年,我能谈点什么 | 可惜主语不是我~
  18. JS-节点的属性 获取各种节点(全)
  19. [绍棠_Swift] SwiftyJSON的使用详解(附样例,用于JSON数据处理)
  20. 什么是CDN什么是高防CDN

热门文章

  1. 黑客瞄准里约奥运会,多种手法可能让你中招
  2. openstack restful api 使用
  3. poj 2404 Jogging Trails
  4. python 爬虫保存封面_Python爬虫Demo--获取网易云音乐专辑封面
  5. 用python浪漫告白_Python实现浪漫表白
  6. 市盈率指标详解及相关文献概述
  7. redis分布式方案redis cluster的介绍和实践
  8. 计算机组成原理 机器数的浮点表示法
  9. Verilog 流水线设计
  10. 阿里云上创建Oracle RAC-静默模式