在pytorch的CNN代码中经常会看到

x.view(x.size(0), -1)

首先,在pytorch中的view()函数就是用来改变tensor的形状的,例如将2行3列的tensor变为1行6列,其中-1表示会自适应的调整剩余的维度

a = torch.Tensor(2,3)

print(a)

# tensor([[0.0000, 0.0000, 0.0000],

# [0.0000, 0.0000, 0.0000]])

print(a.view(1,-1))

# tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])

在CNN中卷积或者池化之后需要连接全连接层,所以需要把多维度的tensor展平成一维,x.view(x.size(0), -1)就实现的这个功能

def forward(self,x):

x=self.pre(x)

x=self.layer1(x)

x=self.layer2(x)

x=self.layer3(x)

x=self.layer4(x)

x=F.avg_pool2d(x,7)

x=x.view(x.size(0),-1)

return self.fc(x)

卷积或者池化之后的tensor的维度为(batchsize,channels,x,y),其中x.size(0)指batchsize的值,最后通过x.view(x.size(0), -1)将tensor的结构转换为了(batchsize, channels*x*y),即将(channels,x,y)拉直,然后就可以和fc层连接了

补充:pytorch中view的用法(重构张量)

view在pytorch中是用来改变张量的shape的,简单又好用。

pytorch中view的用法通常是直接在张量名后用.view调用,然后放入自己想要的shape。如

tensor_name.view(shape)

Example:

1. 直接用法:

>>> x = torch.randn(4, 4)

>>> x.size()

torch.Size([4, 4])

>>> y = x.view(16)

>>> y.size()

torch.Size([16])

2. 强调某一维度的尺寸:

>>> z = x.view(-1, 8) # the size -1 is inferred from other dimensions

>>> z.size()

torch.Size([2, 8])

3. 拉直张量:

(直接填-1表示拉直, 等价于tensor_name.flatten())

>>> y = x.view(-1)

>>> y.size()

torch.Size([16])

4. 做维度变换时不改变内存排列

>>> a = torch.randn(1, 2, 3, 4)

>>> a.size()

torch.Size([1, 2, 3, 4])

>>> b = a.transpose(1, 2) # Swaps 2nd and 3rd dimension

>>> b.size()

torch.Size([1, 3, 2, 4])

>>> c = a.view(1, 3, 2, 4) # Does not change tensor layout in memory

>>> c.size()

torch.Size([1, 3, 2, 4])

>>> torch.equal(b, c)

False

注意最后的False,在张量b和c是不等价的。从这里我们可以看得出来,view函数如其名,只改变“看起来”的样子,不会改变张量在内存中的排列。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持得牛网。如有错误或未考虑完全的地方,望不吝赐教。

python中size_x的意思,对pytorch中x = x.view(x.size(0), -1) 的理解说明相关推荐

  1. pyrorch中 out.view(out.size(0), -1)  out.view(-1, 1, 28, 28)  clamp(min,max)作用

    1.     view(out.size(0), -1) 目的是将多维的的数据如(none,36,2,2)平铺为一维如(none,144).作用类似于keras中的Flatten函数.只不过keras ...

  2. python size(0)_对x.view(x.size(0), -1)的一些理解

    一般地,在CNN等网络中,都是通过卷积过滤器对目标进行计算,然而这些计算都是建立在高维数据. 最后,项目需要对数据进行分类或者识别,就需要全连接层Linear,这时候就需要将高维数据平铺变为低位数据. ...

  3. python batchnorm2d_BatchNorm2d原理、作用及其pytorch中BatchNorm2d函数的参数讲解

    BN原理.作用: 函数参数讲解: BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) 1. ...

  4. 实践教程 | 浅谈 PyTorch 中的 tensor 及使用

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者 | xiaopl@知乎(已授权) 来源 | https://z ...

  5. 机器学习花朵图像分类_在PyTorch中使用转移学习进行图像分类

    想了解更多好玩的人工智能应用,请关注公众号"机器AI学习 数据AI挖掘","智能应用"菜单中包括:颜值检测.植物花卉识别.文字识别.人脸美妆等有趣的智能应用.. ...

  6. pyTorch中tensor运算

    文章目录 PyTorch的简介 PyTorch中主要的包 PyTorch的安装 使用GPU的原因 使数据在GPU上运行 什么使Tensor(张量) 一些术语介绍 Tensor的属性介绍(Rank,ax ...

  7. 详解Pytorch中的requires_grad、叶子节点与非叶子节点、with torch.no_grad()、model.eval()、model.train()、BatchNorm层

    requires_grad requires_grad意为是否需要计算梯度 使用backward()函数反向传播计算梯度时,并不是计算所有tensor的梯度,只有满足下面条件的tensor的梯度才会被 ...

  8. 在PyTorch中使用卷积神经网络建立图像分类模型

    概述 在PyTorch中构建自己的卷积神经网络(CNN)的实践教程 我们将研究一个图像分类问题--CNN的一个经典和广泛使用的应用 我们将以实用的格式介绍深度学习概念 介绍 我被神经网络的力量和能力所 ...

  9. Tensorflow 1.x 和 Pytorch 中 Conv2d Padding的区别

    Tensorflow 和 Pytorch 中 Conv2d Padding的区别 Pytorch中Conv2d的Padding 可以是整数,二元组,字符串三种形式. 整数(int).如果输入的padd ...

最新文章

  1. 作业——08 爬虫综合大作业
  2. UBOOT 2011-3版本分析(初步感受)
  3. docker安装zookeeper(单节点安装)
  4. lwip网络通信socket_lwIP在Socket模式下接口:BSD Socket API
  5. c语言实现软件锁相环,锁相环系统及锁相环系统的实现方法技术方案
  6. ADO.NET常用对象详解之:Command对象
  7. Connected to the target VM, address: '127.0.0.1:60885', transport: 'socket'
  8. 交换机配置常用的命令
  9. 数字PCR和实时PCR的全球与中国市场2022-2028年:技术、参与者、趋势、市场规模及占有率研究报告
  10. 亮道剧学铭:激光雷达系统量产上车没那么容易
  11. 【2D多目标跟踪】Quasi-Dense Similarity Learning for Multiple Object Tracking阅读笔记
  12. EGo1下板_数码管动态显示
  13. 【Three.js入门】标准网格材质、置换贴图、粗糙度贴图、金属贴图、法线贴图
  14. Hub能新建但不能打开项目 Failed to connect to pipe_20220313
  15. 关于onKeyDown方法
  16. 提取excel文件的链接
  17. 香港理工大学酒店管理html,香港理工大学大酒店管理硕士要求
  18. linux网卡下有两个system,systemd-networkd 作为网络管理服务,导致dhcp给所有网卡分配同样的IP...
  19. 充电线---E-Marker芯片
  20. 锁定乌镇2019世界互联网大会,5G无人驾驶汽车智慧开跑

热门文章

  1. IE浏览器不能上网的处理办法
  2. 实现删除商品信息功能
  3. 卷积码主要是对抗_【零基础学会LTE】【3】LTE 36.212 咬尾卷积码详解
  4. bucket sort sample sort 并行_Java 中 Arrays.sort 和 Arrays.parallelSort 哪个更快?
  5. 倒N字形排列java_Java实现n位数字的全排列
  6. linux mp4v2编译,Android 编译mp4 v2 2.0.0生成动态库
  7. 惠普服务器eth0的位置,HPUX下定位网卡位置
  8. 万圣节海报素材PSD分层模板
  9. 初学者UI设计临摹素材模板,请先搞清楚这4个分类!
  10. python中if控制语句_Python 极简教程(十二)逻辑控制语句 if else