einops包中的rearrange,reduce, repeat及einops.layers.torch中的Rearrange,Reduce。对高维数据的处理方式
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
一.rearrange和Rearrange,作用:从函数名称也可以看出是对张量尺度进行重排,
区别:
1.einops.layers.torch中的Rearrange,用于搭建网络结构时对张量进行“隐式”的处理
例如:
class PatchEmbedding(nn.Module):def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768, img_size: int = 224):self.patch_size = patch_sizesuper().__init__()self.projection = nn.Sequential(# using a conv layer instead of a linear one -> performance gainsnn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size),Rearrange('b e (h) (w) -> b (h w) e'),)
这里的Rearrange('b e (h) (w) -> b (h w) e'),表示将4维张量转换为3维,且原来的最后两维合并为一维:(16,512,4,16)->(16,64,512)
这样只要我们知道初始的张量维度就可以操作注释来对其进行维度重排。
2.eniops中的rearrange,用于对张量‘显示’的处理,是一个函数
例如:
rearrange(images, 'b h w c -> b (h w) c')
将4维张量转换为3维,同样的,只要我们知道初始维度,就可以操作注释对其进行重排
值得注意的是:这里的注释给定以后就代表当前维度,不能更改,例如:
image = torch.randn(1,2,3,2) # torch.Size([1,2,3,2]) out = rearrange(image, 'b c h w -> b (c h w)', c=2,h=3,w=2) # torch.Size([1,12])
# h,w的值更改
err1 = rearrange(image, 'b c h w -> b (c h w)', c=2,h=2,w=3) # 报错
二.repeat:即将tensor中的某一维度进行重复,以扩充该维度数量
B = 16
cls_token = torch.randn(1, 1, emb_size)
cls_tokens = repeat(cls_token, '() n e -> b n e', b=B)#维度为1的时候可用()代替
将(1,1,emb_size)的张量处理为(B,1,emb_size)
R = 16
a = torch.randn(2,3,4)
b = repeat(a, 'b n e -> (r b) n e', r = R)
#(2R, 3, 4)
c = repeat(a, 'b n e -> b (r n) e', r = R)
#(2, 3R, 4)#错误用法:
d = repeat(a, 'b n e -> c n e', c = 2R)
#将(2,3,4)维张量处理为(2R, 3, 4)......
上面都是同纬度的扩充,我们看一个升维的扩充:
R = 5
a = torch.randn(2, 3, 4)
d = repeat(a,'b n e-> b n c e ', c = R)
#将(2,3,4)维张量处理为(2, 3, 5, 4)......
这里我们同样只须操作维度注释即可完成相应的张量操作。
三.Reduce 和 reduce:
x = torch.randn(100, 32, 64)
# perform max-reduction on the first axis:
y0 = reduce(x, 't b c -> b c', 'max') #(32, 64)#指定h2,w2,相当于指定池化核的大小
x = torch.randn(10, 512, 30, 40)
# 2d max-pooling with kernel size = 2 * 2
y1 = reduce(x, 'b c (h1 h2) (w1 w2) -> b c h1 w1', 'max', h2=2, w2=2)
#(10, 512, 15, 20)# go back to the original height and width
y2 = rearrange(y1, 'b (c h2 w2) h1 w1 -> b c (h1 h2) (w1 w2)', h2=2, w2=2)
#(10, 128, 30, 40)
#指定h1,w1,相当于指定池化后张量的大小
# 2d max-pooling to 12 * 16 grid:
y3 = reduce(x, 'b c (h1 h2) (w1 w2) -> b c h1 w1', 'max', h1=12, w1=16)
#(10, 512, 12, 16)# 2d average-pooling to 12 * 16 grid:
y4 = (reduce(x, 'b c (h1 h2) (w1 w2) -> b c h1 w1', 'mean', h1=12, w1=16)
#(10, 512, 12, 16)# Global average pooling
y5 = reduce(x, 'b c h w -> b c', 'mean')
#(10, 512)
Redece同理。
注意:这里我们以张量为例,einops同样可以处理numpy下的数据
einops包中的rearrange,reduce, repeat及einops.layers.torch中的Rearrange,Reduce。对高维数据的处理方式相关推荐
- Torch 中添加自己的 nn Modules:以添加 Dropout、 Triplet Loss 为例
Preface 因为要复现前面阅读的一篇论文:<论文笔记:Deep Relative Distance Learning: Tell the Difference Between Similar ...
- einops.rearrange、repeat、reduce==>对维度进行操作
1.einops.rearrange==>重新指定维度 def rearrange(tensor, pattern, **axes_lengths): einops.rearrange is a ...
- R语言使用scatterplot3d包的scatterplot3d函数可视化3D散点图(3D scatter plots)、在3D散点图中添加垂直线和数据点描影、3D图中添加回归平面
R语言使用scatterplot3d包的scatterplot3d函数可视化3D散点图(3D scatter plots).在3D散点图中添加垂直线和数据点描影.3D图中添加回归平面(overlaid ...
- 【FPGA】SRIO中的关键问题总结(一)SRIO中的关键数据包格式总结
目录 1 SRIO事务及其类型 2 常用的I/O逻辑操作事务 3 HELLO包格式(重点) 4 SRIO数据包包格式 5 控制符号数据包格式 1 SRIO事务及其类型 SRIO(Serial Rapi ...
- openresty开发系列16--lua中的控制结构if-else/repeat/for/while
openresty开发系列16--lua中的控制结构if-else/repeat/for/while一)条件 - 控制结构 if-else if-else 是我们熟知的一种控制结构.Lua 跟其他语言 ...
- Maven 手动安装JAR包到本地maven仓库后,但在项目中依旧报错找不到JAR包解决方法
Maven 手动安装JAR包到本地maven仓库后,但在项目中依旧报错找不到JAR包解决方法 参考文章: (1)Maven 手动安装JAR包到本地maven仓库后,但在项目中依旧报错找不到JAR包解决 ...
- torch中的expand和repeat
在torch中,如果要改变某一个tensor的维度,可以利用view.expand.repeat.transpose和permute等方法,这里对这些方法的一些容易混淆的地方做个总结. expand和 ...
- maven打jar包,并将依赖jar打入外部lib文件中
在pom.xml中加入如下配置,在mainClass里写程序的入口方法 <!-- maven打jar包,并将依赖jar打入外部lib文件中 --> <plugins> < ...
- java: 程序包 javax.smartcardio 不可见 (程序包 javax.smartcardio 已在模块 java.smartcardio 中声明, 但该模块不在模块图中)
java: 程序包 javax.smartcardio 不可见(程序包 javax.smartcardio 已在模块 java.smartcardio 中声明, 但该模块不在模块图中) 这是JDK9 ...
最新文章
- php把excel变成数组,PHP excel读取excel文件转换为数组
- Package xxx is not available, but is referred to by another package
- LCIS code force 10D
- linux里的挂载错误无法开机怎么办,Linux基础知识 - 开机挂载错误
- C++教程[又能学英文,又能学编程]
- apache启用gzip压缩方法
- Silverlight 3 OOB 原理
- 各种简单的困难的模板,持续更新
- 疯狂的图形(利用C# + GDI plus模拟杂乱无章的现实场景)
- HUPlayer 使用向导和常见问题
- linux usb有线网卡驱动_基于USB设备的Linux网络驱动程序开发
- IO流文件指针(移动和获取文件读指针)
- MongoDB中updateOne的正常使用
- 帮Customer Architecture写的小脚本
- 【阿里云·云原生架构·白皮书】保姆级解读 一、 云原生架构定义
- 进阶版拉依达准则(3sigm准则)的提出与应用
- GPU加速(一)CUDA C编程及GPU基本知识
- 2021年软考考试时间确定
- 树莓派与PCF8591模数转换器的那些事儿
- 1603: 海岛争霸
热门文章
- python:collections模块
- JavaScript高级程序设计-读书笔记(6)
- C Primer Plus(第五版)7
- 选择“关机”还是“睡眠”?
- 黑顶帽—lhMorpBlackTopHat
- I've got so many hongbaos(should it be translated as red bags?)
- 【数据结构与算法】之深入解析“根据身高重建队列”的求解思路与算法示例
- iOS之深入解析静态库和动态库
- Metal之加载TGA与PNG/JPEG纹理图片
- 2019年第十届蓝桥杯 - 省赛 - C/C++大学C组 - D. 质数