在使用pytorch定义神经网络结构时,经常会看到类似如下的.view() / flatten()用法,这里对其用法做出讲解与演示。

torch.reshape用法

reshape()可以由torch.reshape(),也可由torch.Tensor.reshape()调用,
其作用是在不改变tensor元素数目的情况下改变tensor的shape。

torch.reshape() 需要两个参数,一个是待被改变的张量tensor,一个是想要改变的形状。

torch.reshape(input, shape) → Tensor
input(Tensor)-要重塑的张量
shape(python的元组:ints)-新形状`

案例1
输入:

import torcha = torch.tensor([[0,1],[2,3]])
x = torch.reshape(a,(-1,))
print (x)b = torch.arange(4.)
Y = torch.reshape(a,(2,2))
print(Y)

结果:

tensor([0, 1, 2, 3])
tensor([[0, 1],
[2, 3]])

torch.view用法

view()的原理很简单,其实就是把原先tensor中的数据进行排列,排成一行,然后根据所给的view()中的参数从一行中按顺序选择组成最终的tensor。
view()可以有多个参数,这取决于你想要得到的是几维的tensor,一般设置两个参数,也是神经网络中常用的(一般在全连接之前),代表二维。
view(h,w),h代表行(想要变为几行),当不知道要变为几行,但知道要变为几列时可取-1;w代表的是列(想要变为几列),当不知道要变为几列,但知道要变为几行时可取-1。

一、普通用法(手动调整)

view()相当于reshape、resize,重新调整Tensor的形状。
案例2.
输入

import torch
a1 = torch.arange(0,16)
print(a1)

输出

tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])

输入

a2 = a1.view(8, 2)
a3 = a1.view(2, 8)
a4 = a1.view(4, 4)print(a2)
print(a3)
print(a4)

输出

tensor([[ 0, 1],
[ 2, 3],
[ 4, 5],
[ 6, 7],
[ 8, 9],
[10, 11],
[12, 13],
[14, 15]])
tensor([[ 0, 1, 2, 3, 4, 5, 6, 7],
[ 8, 9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])

二、特殊用法:参数-1(自动调整size)

view中一个参数定为-1,代表自动调整这个维度上的元素个数,以保证元素的总数不变。
输入

import torch
a1 = torch.arange(0,16)
print(a1)

输出

tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])

输入

a2 = a1.view(-1, 16)
a3 = a1.view(-1, 8)
a4 = a1.view(-1, 4)
a5 = a1.view(-1, 2)
a6 = a1.view(4*4, -1)
a7 = a1.view(1*4, -1)
a8 = a1.view(2*4, -1)print(a2)
print(a3)
print(a4)
print(a5)
print(a6)
print(a7)
print(a8)

输出

tensor([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0, 1, 2, 3, 4, 5, 6, 7],
[ 8, 9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
tensor([[ 0, 1],
[ 2, 3],
[ 4, 5],
[ 6, 7],
[ 8, 9],
[10, 11],
[12, 13],
[14, 15]])
tensor([[ 0],
[ 1],
[ 2],
[ 3],
[ 4],
[ 5],
[ 6],
[ 7],
[ 8],
[ 9],
[10],
[11],
[12],
[13],
[14],
[15]])
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
tensor([[ 0, 1],
[ 2, 3],
[ 4, 5],
[ 6, 7],
[ 8, 9],
[10, 11],
[12, 13],
[14, 15]])

torch.nn.Flatten(start_dim=1,end_dim=-1)

start_dim与end_dim分别表示开始的维度和终止的维度,默认值为1和-1,其中1表示第一维度,-1表示最后的维度。结合起来看意思就是从第一维度到最后一个维度全部给展平为张量。(注意:数据的维度是从0开始的,也就是存在第0维度,第一维度并不是真正意义上的第一个)。
因为其被用在神经网络中,输入为一批数据,第 0 维为batch(输入数据的个数),通常要把一个数据拉成一维,而不是将一批数据拉为一维。所以torch.nn.Flatten()默认从第一维开始平坦化。

使用nn.Flatten(),使用默认参数

官方给出的示例:

input = torch.randn(32, 1, 5, 5)
# With default parameters
m = nn.Flatten()
output = m(input)
output.size()
#torch.Size([32, 25])
# With non-default parameters
m = nn.Flatten(0, 2)
output = m(input)
output.size()
#torch.Size([160, 5])

#开头的代码是注释
整段代码的意思是:给定一个维度为(32,1,5,5)的随机数据。
1.先使用一次nn.Flatten(),使用默认参数:

m = nn.Flatten()

也就是说从第一维度展平到最后一个维度,数据的维度是从0开始的,第一维度实际上是数据的第二位置代表的维度,也就是样例中的1。
因此进行展平后的结果也就是[32,155]→[32,25]

2.接着再使用一次指定参数的nn.Flatten(),即

m = nn.Flatten(0,2)

也就是说从第0维度展平到第2维度,0~2,对应的也就是前三个维度。
因此结果就是[3215,5]→[160,25]

torch.flatten

torch.flatten()函数经常用于写分类神经网络的时候,经过最后一个卷积层之后,一般会再接一个自适应的池化层,输出一个BCHW的向量。这时候就需要用到torch.flatten()函数将这个向量拉平成一个Bx的向量(其中,x = CHW),然后送入到FC层中。

语句结构

 torch.flatten(input, start_dim=0, end_dim=-1)

input: 一个 tensor,即要被“摊平”的 tensor。
start_dim: “摊平”的起始维度。
end_dim: “摊平”的结束维度。
作用与 torch.nn.flatten 类似,都是用于展平 tensor 的,只是 torch.flatten 是 function 而不是类,其默认开始维度为第 0 维。例1:

import torchdata_pool = torch.randn(2,2,3,3) # 模拟经过最后一个池化层或自适应池化层之后的输出,Batchsize*c*h*w
print(data_pool)y=torch.flatten(data_pool,1)
print(y)

输出结果:

结果是一个B*x的向量。

本文源于多篇资料的提炼汇总,部分参考资料如下。
参考资料:参考1;参考2;参考3;参考4

pytorch中的reshape()、view()、nn.flatten()和flatten()相关推荐

  1. 神经网路:pytorch中Variable和view参数解析

    在PyTorch中计算图的特点总结如下: autograd根据用户对Variable的操作来构建其计算图. requires_grad variable默认是不需要被求导的,即requires_gra ...

  2. python中size_x的意思,对pytorch中x = x.view(x.size(0), -1) 的理解说明

    在pytorch的CNN代码中经常会看到 x.view(x.size(0), -1) 首先,在pytorch中的view()函数就是用来改变tensor的形状的,例如将2行3列的tensor变为1行6 ...

  3. PyTorch中contiguous、view、Sequential、permute函数的用法

    在pytorch中,tensor的实际数据以一维数组(storage)的形式存储于某个连续的内存中,以"行优先"进行存储. 1. tensor的连续性 tensor连续(conti ...

  4. pytorch 中pad函数toch.nn.functional.pad()的使用

    padding操作是给图像外围加像素点. 为了实际说明操作过程,这里我们使用一张实际的图片来做一下处理. 这张图片是大小是(256,256),使用pad来给它加上一个黑色的边框.具体代码如下: imp ...

  5. python语言中ch用法_pytorch 中pad函数toch.nn.functional.pad()的用法

    padding操作是给图像外围加像素点. 为了实际说明操作过程,这里我们使用一张实际的图片来做一下处理. 这张图片是大小是(256,256),使用pad来给它加上一个黑色的边框.具体代码如下: imp ...

  6. pytorch学习笔记七:nn网络层——池化层、线性层

    一.池化层 池化运算:对信号进行"收集" 并"总结",类似于水池收集水资源,因而得名池化层. 收集:由多变少,图像的尺寸由大变小 总结:最大值/平均值 下面是最 ...

  7. pytorch中交叉熵

    关于pytorch中交叉熵的使用,pytorch的交叉熵是其loss function的一种且包含了softmax的过程. pytorch中交叉熵函数是nn.CrossEntropyLoss().其参 ...

  8. PyTorch中的sampled_softmax_loss

    最近做一篇论文的复现,发现PyTorch中没有类似于tf.nn.sampled_softmax_loss的功能,经过一番寻找,在github上找到了答案:Stonesjtu/Pytorch-NCE.奇 ...

  9. Pytorch中 permute / transpose 和 view / reshape, flatten函数

    1.transpose与permute transpose() 和 permute() 都是返回转置后矩阵,在pytorch中转置用的函数就只有这两个 ,这两个函数都是交换维度的操作 transpos ...

最新文章

  1. eclipse 出现user operation is waiting
  2. 数据结构第5章例题 若矩阵Am×n中存在某个元素aij满足:aij是第i行中的最小值且是第j列中的最大值,则称该元素为矩阵A的一个鞍点。试编写一个算法,找出A中的所有鞍点。
  3. java svn安装地址_SVN的安装和配置
  4. mysql 关联查询_响应时间长?MySQL查询优化教程来了!
  5. ad20中怎么多选操作改层_在操作系统中CPU是怎么调度的
  6. centos 怎样显示metric_centos7系列问题
  7. c++ primer 6.5.1节练习答案
  8. 为解决WINDOWS JRE启动外壳,找了好几个方案
  9. Ulipad快捷键大总结
  10. 卫星地图破坏男女恋人之间的关系
  11. mac用navicat连接mysql_Mac OS下,使用Navicat连接MySQL出现的问题
  12. cntv客户端_cntv网络电视_cntv官方下载-太平洋下载中心
  13. 平面方程(Plane Equation)求解方法
  14. 蚂蚁上市招股书:员工持股40%月薪人均5万,直奔财富自由
  15. 计算机电子琴乐谱数字键,电子琴键盘与乐谱对照表.pdf
  16. 什么是色彩管理,OPPO 全链路色彩管理全在哪?
  17. html发票页面,HTML5 发票模板
  18. 基于51单片机的交通灯(资源链接见末尾)
  19. RT_thread STM32通用Bootloader 做OTA升级
  20. 百度飞桨AI抠图+图片合成

热门文章

  1. android指纹解锁分析,浅析4种手机指纹解锁方式的优劣势
  2. hiredis的各种windows版本
  3. YOJ3509-小豪搬宝藏
  4. 【解决方案】RTSP/Onvif安防视频直播解决方案EasyNVR在某省高速上云项目中的应用分析
  5. chinapay支付接口php,GitHub - tension/chinapay-for-ecshop: 上海银联(chinapay)支付插件 for ECSHOP...
  6. 蜻蜓FM实时推荐系统的发展和演进
  7. 顺丰菜鸟大战 本质是以数据获得企业竞争壁垒
  8. 操作系统如何建立异常处理?
  9. 冒泡,选择,插入排序
  10. 离散选择模型之Gumbel分布