PyTorch基础(15)-- torch.flatten()方法
前言
最近在复现论文中一个块的时候需要使用到torch.flatten()这个方法,这个方法其实很简单,但其中有一些细节可能需要注意,且有个关键点很容易忘记,故在此记录以备查阅。
方法解析
flatten的中文含义为“扁平化”,具体怎么理解呢?我们可以尝试这么理解,假设你的数据为1维数据,那么这个数据天然就已经扁平化了,如果是2维数据,那么扁平化就是将2维数据变为1维数据,如果是3维数据,那么就要根据你自己所选择的“扁平化程度”来进行操作,假设需要全部扁平化,那么就直接将3维数据变为1维数据,如果只需要部分扁平化,那么有一维的数据不会进行扁平操作,具体看下面的案例分析。
可以看到,torch.flatten()方法有三个参数,分别:
- input tensor:该方法的输入
- start_dim:开始flatten的维度
- end_dim:结束flatten的维度
案例解析
- 导包
import numpy as np
import torch
- 案例1 – 全部扁平化
x = np.arange(27)
x = np.reshape(x, (3,3,3))
x = torch.from_numpy(x)
print('before flatten', x)
x = torch.flatten(x) # 默认扁平化程度为最高
print('after flatten', x)
- 案例2 – 部分扁平化
x = np.arange(27)
x = np.reshape(x, (3, 3, 3))
x = torch.from_numpy(x)
print('before flatten', x)
x = torch.flatten(x, start_dim=0, end_dim=1)
print('after flatten', x)
- 案例3 – 部分扁平化
x = np.arange(27)
x = np.reshape(x, (3, 3, 3))
x = torch.from_numpy(x)
print('before flatten', x)
print(x.shape)
x = torch.flatten(x, start_dim=1, end_dim=2)
print('after flatten', x)
PyTorch基础(15)-- torch.flatten()方法相关推荐
- 3.Pytorch基础模块torch的API之Indexing,Slicing,Joining,Mutating Ops实例详解
文章目录 0. torch 1. Tensors 2. Creation Ops 3. Indexing,Slicing,Joining,Mutating Ops 3.1 torch.cat() 3. ...
- pytorch基础-使用 TORCH.AUTOGRAD 进行自动微分(5)
在训练神经网络时,最常用的算法是反向传播.PyTorch的反向传播(即tensor.backward())是通过autograd包来实现的,autograd包会根据tensor进行过的数学运算来自动计 ...
- PyTorch基础(十)----- torch.max()方法
一.前言 这个方法跟上一篇文章的torch.max()方法非常类似,只不过一个是求最大值,一个是求平均值.在某些情况下,甚至可以代替下采样中的最大池化和平均池化,所以说,这两个方法的用处还是蛮大的. ...
- PyTorch基础(六)----- torch.eq()方法
一.torch.eq()方法详解 对两个张量Tensor进行逐元素的比较,若相同位置的两个元素相同,则返回True:若不同,返回False. torch.eq(input, other, *, out ...
- Pytorch基础知识(15)基于PyTorch的多标签图像分类
早在 2012 年,神经网络就首次赢得了 ImageNet 大规模视觉识别挑战.Alex Krizhevsky,Ilya Sutskever 和 Geoffrey Hinton 彻底改变了图像分类领域 ...
- 【PyTorch】 torch.flatten()与nn.Flatten()的区别
问题 torch.flatten()与nn.Flatten()都可以实现展开Tensor,那么二者的区别是什么呢? 方法 经过查阅相关资料,发现二者主要区别有: (1) 默认的dim不同,torch. ...
- pytorch基础知识+构建LeNet对Cifar10进行训练+PyTorch-OpCounter统计模型大小和参数量+模型存储与调用
整个环境的配置请参考我另一篇博客.ubuntu安装python3.5+pycharm+anaconda+opencv+docker+nvidia-docker+tensorflow+pytorch+C ...
- 深度学习之Pytorch基础教程!
↑↑↑关注后"星标"Datawhale 每日干货 & 每月组队学习,不错过 Datawhale干货 作者:李祖贤,Datawhale高校群成员,深圳大学 随着深度学习的发展 ...
- 【深度学习】深度学习之Pytorch基础教程!
作者:李祖贤,Datawhale高校群成员,深圳大学 随着深度学习的发展,深度学习框架开始大量的出现.尤其是近两年,Google.Facebook.Microsoft等巨头都围绕深度学习重点投资了一系 ...
最新文章
- 数据科学究竟是什么?
- iOS开发:使用Block在两个界面之间传值(Block高级用法:Block传值)
- java boolean io流_java基础入门-day22-IO流
- Guacamole-RDP没有声音解决办法
- php 上传多文件_php 多文件上传的实现实例
- SpringCloud项目总结
- 虚拟机安装windows服务出现无法打开内核设备“\\.Global\vmx86”
- linux make
- iOS输入框禁止输入emoji表情
- asp html5留言板,ASP.NET MVC 开发实例:简单留言板的开发
- 原型设计Axure RP mac
- 报表控件FastReport.NET使用教程:如何在 Visual Studio 中使用报表组件
- 小区门口的健身房,就是韭菜收割厂
- Transformer入门教程(八)时间维度
- Typora极简教程
- 从硬件到软件,苹果一直坚持的造车梦....
- background-size设置背景图片自适应 在ie8下失效的问题
- mysql 数据连续不走索引6_MySql组合索引最左侧原则失效
- 软考哪个含金量更高?
- java基于ssm空气质量检测系统源码网站空气质量监测源码