Pytorch使用--学习记录
目录
1,model.train(),model.eval()
2,tensor.unfold()
3,torch.nn.Unfold
4, pytorch张量约减操作
5,nn._init_()
1,model.train(),model.eval()
model.train():启用 BatchNormalization 和 Dropout
model.eval():不启用 BatchNormalization 和 Dropout
训练完train_datasets之后,model要来测试样本了。在模型输入测试数据之前model(test_datasets),需要加上model.eval(). 否则的话,有输入数据,即使不训练,它也会改变权值。这是model中含有batch normalization层所带来的影响
2,tensor.unfold()
unfold(dim, size, step) → Tensor
在
dim
维填充上所有大小为size
的分片。两个分片之间的步长为step
。 如果_sizedim_是dim维度的原始大小,则在返回tensor中的维度dim大小是_(sizedim-size)/step+1_ 维度大小的附加维度将附加在返回的tensor中。
>>> x = torch.arange(1, 8)
>>> x1234567
[torch.FloatTensor of size 7]>>> x.unfold(0, 2, 1)1 22 33 44 55 66 7
[torch.FloatTensor of size 6x2]>>> x.unfold(0, 2, 2)1 23 45 6
[torch.FloatTensor of size 3x2]
3,torch.nn.Unfold
torch.nn.
Unfold
(kernel_size: Union[T, Tuple[T, ...]], dilation: Union[T, Tuple[T, ...]] = 1, padding: Union[T, Tuple[T, ...]] = 0, stride: Union[T, Tuple[T, ...]] = 1)
从输入张量的一个Batch数据中提取滑动的局部块,
参数:卷积核的尺寸,空洞大小,填充大小和步长。
,
是batch 维度,
是通道维度,
是任意的空间维度。该运算在输入的空间维度内每一个kernel_size大小的slide,形成第三个维度中的列。
nfold的输出为,其中
为kernel_size长和宽的乘积, L是channel的长宽根据kernel_size的长宽滑动裁剪后,得到的区块的数量。
d是所有空间维度,空间尺寸指输入的空间维度()
每个区块的大小为
下图中公式有错,一上面计算区块数量的公式为准
nn.Unfold对输入channel的每一个 的滑动窗口区块做了展平操作。
torch.Size([1, 2, 4, 4])
tensor([[[[ 1.4818, -0.1026, -1.7688, 0.5384],[-0.4693, -0.0775, -0.7504, 0.2283],[-0.1414, 1.0006, -0.0942, 2.2981],[-0.9429, 1.1908, 0.9374, -1.3168]],[[-1.8184, -0.3926, 0.1875, 1.3847],[-0.4124, 0.9766, -1.3303, -0.0970],[ 1.7679, 0.6961, -1.6445, 0.7482],[ 0.1729, -0.3196, -0.1528, 0.2180]]]])
torch.Size([1, 8, 4])
tensor([[[ 1.4818, -1.7688, -0.1414, -0.0942],[-0.1026, 0.5384, 1.0006, 2.2981],[-0.4693, -0.7504, -0.9429, 0.9374],[-0.0775, 0.2283, 1.1908, -1.3168],[-1.8184, 0.1875, 1.7679, -1.6445],[-0.3926, 1.3847, 0.6961, 0.7482],[-0.4124, -1.3303, 0.1729, -0.1528],[ 0.9766, -0.0970, -0.3196, 0.2180]]])
def unfold_x():#N*M,C,T,V,Input#window_size=3,dilation=1,stride=1x=torch.randn(2,3,5,4)print(f'----------------x \n{x}')unfold=torch.nn.Unfold(kernel_size=(3, 1),dilation=(1, 1),stride=(1, 1),padding=(1, 0))un_x=unfold(x)print('un_x shape {}'.format(un_x.size()))print(f'............un_x\n{un_x}')return
x=torch.randn(2,3,5,4) print(f'----------------x \n{x}')
tensor([[[[-0.1007, -0.1986, 0.2615, 0.2375],[ 0.3395, -0.2650, -2.3015, -0.5818],[-2.4892, 1.3659, 0.9418, 0.5290],[-0.1242, -1.1327, -0.7105, -0.3952],[-0.3351, -0.3885, -1.0516, 0.0144]],[[ 1.1538, -0.2460, -0.6409, 2.3420],[-1.6041, -0.0226, -1.1131, -1.2851],[ 1.5435, 2.1038, 0.1150, 0.7285],[-0.8543, 0.5684, -0.0907, -1.5588],[-0.1338, 1.2914, 0.5947, -0.1871]],[[-0.5479, 0.0572, -1.3323, 0.2371],[-0.3639, 0.8004, -2.4990, -2.6908],[-0.3635, 0.5411, 0.6723, -1.1272],[ 1.7912, 1.1216, 0.2887, 0.8244],[ 0.2222, 1.1524, 1.2438, 0.4919]]],[[[ 0.5019, 1.0633, 0.3409, -0.4121],[ 1.1162, 0.0055, 1.2277, -1.4919],[ 0.0533, -1.6769, -0.9581, 1.7418],[ 1.9506, -0.7145, -0.3485, 0.0497],[ 1.7571, -1.0860, 0.1596, 0.4369]],[[-0.9666, -0.7096, 0.3977, 0.9115],[-0.0983, 0.3316, 0.1486, 0.2869],[ 0.7518, -0.7357, 0.2328, -1.5851],[ 0.2918, 0.4178, 0.0045, -1.1917],[-1.2200, -1.2876, 1.9524, -2.4134]],[[ 2.6374, 1.4099, 0.8991, -0.7087],[ 0.2047, -0.6513, -0.8530, 0.7599],[ 0.2445, 0.5106, -2.3711, 0.5012],[ 1.2275, -0.0866, 0.6022, -0.0259],[ 0.0051, -0.0894, -0.1819, -0.7296]]]])
tensor([[[ 0.0000, 0.0000, 0.0000, 0.0000, -0.1007, -0.1986, 0.2615,0.2375, 0.3395, -0.2650, -2.3015, -0.5818, -2.4892, 1.3659,0.9418, 0.5290, -0.1242, -1.1327, -0.7105, -0.3952],[-0.1007, -0.1986, 0.2615, 0.2375, 0.3395, -0.2650, -2.3015,-0.5818, -2.4892, 1.3659, 0.9418, 0.5290, -0.1242, -1.1327,-0.7105, -0.3952, -0.3351, -0.3885, -1.0516, 0.0144],[ 0.3395, -0.2650, -2.3015, -0.5818, -2.4892, 1.3659, 0.9418,0.5290, -0.1242, -1.1327, -0.7105, -0.3952, -0.3351, -0.3885,-1.0516, 0.0144, 0.0000, 0.0000, 0.0000, 0.0000],[ 0.0000, 0.0000, 0.0000, 0.0000, 1.1538, -0.2460, -0.6409,2.3420, -1.6041, -0.0226, -1.1131, -1.2851, 1.5435, 2.1038,0.1150, 0.7285, -0.8543, 0.5684, -0.0907, -1.5588],[ 1.1538, -0.2460, -0.6409, 2.3420, -1.6041, -0.0226, -1.1131,-1.2851, 1.5435, 2.1038, 0.1150, 0.7285, -0.8543, 0.5684,-0.0907, -1.5588, -0.1338, 1.2914, 0.5947, -0.1871],[-1.6041, -0.0226, -1.1131, -1.2851, 1.5435, 2.1038, 0.1150,0.7285, -0.8543, 0.5684, -0.0907, -1.5588, -0.1338, 1.2914,0.5947, -0.1871, 0.0000, 0.0000, 0.0000, 0.0000],[ 0.0000, 0.0000, 0.0000, 0.0000, -0.5479, 0.0572, -1.3323,0.2371, -0.3639, 0.8004, -2.4990, -2.6908, -0.3635, 0.5411,0.6723, -1.1272, 1.7912, 1.1216, 0.2887, 0.8244],[-0.5479, 0.0572, -1.3323, 0.2371, -0.3639, 0.8004, -2.4990,-2.6908, -0.3635, 0.5411, 0.6723, -1.1272, 1.7912, 1.1216,0.2887, 0.8244, 0.2222, 1.1524, 1.2438, 0.4919],[-0.3639, 0.8004, -2.4990, -2.6908, -0.3635, 0.5411, 0.6723,-1.1272, 1.7912, 1.1216, 0.2887, 0.8244, 0.2222, 1.1524,1.2438, 0.4919, 0.0000, 0.0000, 0.0000, 0.0000]],[[ 0.0000, 0.0000, 0.0000, 0.0000, 0.5019, 1.0633, 0.3409,-0.4121, 1.1162, 0.0055, 1.2277, -1.4919, 0.0533, -1.6769,-0.9581, 1.7418, 1.9506, -0.7145, -0.3485, 0.0497],[ 0.5019, 1.0633, 0.3409, -0.4121, 1.1162, 0.0055, 1.2277,-1.4919, 0.0533, -1.6769, -0.9581, 1.7418, 1.9506, -0.7145,-0.3485, 0.0497, 1.7571, -1.0860, 0.1596, 0.4369],[ 1.1162, 0.0055, 1.2277, -1.4919, 0.0533, -1.6769, -0.9581,1.7418, 1.9506, -0.7145, -0.3485, 0.0497, 1.7571, -1.0860,0.1596, 0.4369, 0.0000, 0.0000, 0.0000, 0.0000],[ 0.0000, 0.0000, 0.0000, 0.0000, -0.9666, -0.7096, 0.3977,0.9115, -0.0983, 0.3316, 0.1486, 0.2869, 0.7518, -0.7357,0.2328, -1.5851, 0.2918, 0.4178, 0.0045, -1.1917],[-0.9666, -0.7096, 0.3977, 0.9115, -0.0983, 0.3316, 0.1486,0.2869, 0.7518, -0.7357, 0.2328, -1.5851, 0.2918, 0.4178,0.0045, -1.1917, -1.2200, -1.2876, 1.9524, -2.4134],[-0.0983, 0.3316, 0.1486, 0.2869, 0.7518, -0.7357, 0.2328,-1.5851, 0.2918, 0.4178, 0.0045, -1.1917, -1.2200, -1.2876,1.9524, -2.4134, 0.0000, 0.0000, 0.0000, 0.0000],[ 0.0000, 0.0000, 0.0000, 0.0000, 2.6374, 1.4099, 0.8991,-0.7087, 0.2047, -0.6513, -0.8530, 0.7599, 0.2445, 0.5106,-2.3711, 0.5012, 1.2275, -0.0866, 0.6022, -0.0259],[ 2.6374, 1.4099, 0.8991, -0.7087, 0.2047, -0.6513, -0.8530,0.7599, 0.2445, 0.5106, -2.3711, 0.5012, 1.2275, -0.0866,0.6022, -0.0259, 0.0051, -0.0894, -0.1819, -0.7296],[ 0.2047, -0.6513, -0.8530, 0.7599, 0.2445, 0.5106, -2.3711,0.5012, 1.2275, -0.0866, 0.6022, -0.0259, 0.0051, -0.0894,-0.1819, -0.7296, 0.0000, 0.0000, 0.0000, 0.0000]]])
4, pytorch张量约减操作
agg = torch.einsum('vu,nctu->nctv', a_n, b_n)
a_n=torch.from_numpy(A)
b_n=torch.from_numpy(b)
print(a_n)
print(b_n)
agg = torch.einsum('vu,nctu->nctv', a_n, b_n)
#每个张量的最后一个维度u对应行,相乘相加后形成v的每个元素
#a_n[v,1]点乘b_n[n,c,t,1]得到 agg[n,c,t,v11]
print(agg)
print(agg.size())
tensor([[ 0., 1., 2., 3., 4.],[ 5., 6., 7., 8., 9.],[10., 11., 12., 13., 14.],[15., 16., 17., 18., 19.],[20., 21., 22., 23., 24.]], dtype=torch.float64)
tensor([[[[ 0., 1., 2., 3., 4.]],[[ 5., 6., 7., 8., 9.]],[[10., 11., 12., 13., 14.]]],[[[15., 16., 17., 18., 19.]],[[20., 21., 22., 23., 24.]],[[25., 26., 27., 28., 29.]]]], dtype=torch.float64)
tensor([[[[ 30., 80., 130., 180., 230.]],[[ 80., 255., 430., 605., 780.]],[[ 130., 430., 730., 1030., 1330.]]],[[[ 180., 605., 1030., 1455., 1880.]],[[ 230., 780., 1330., 1880., 2430.]],[[ 280., 955., 1630., 2305., 2980.]]]], dtype=torch.float64)
torch.Size([2, 3, 1, 5])
5,nn._init_()
torch.nn.init.uniform_(tensor, a=0, b=1)
服从~U ( a , b ) 均匀分布
Pytorch使用--学习记录相关推荐
- Pytorch深度学习记录:对CIFAR-10的深度学习模型搭建与测试
前言 CIFAR-10介绍 下载.分类与读入数据集 数据集下载 解压与分类数据集 读入数据 搭建神经网络 卷积层(Convolutional layer) 池化层(Pooling lay) 残差网络( ...
- Pytorch学习记录-torchtext和Pytorch的实例( 使用神经网络训练Seq2Seq代码)
Pytorch学习记录-torchtext和Pytorch的实例1 0. PyTorch Seq2Seq项目介绍 1. 使用神经网络训练Seq2Seq 1.1 简介,对论文中公式的解读 1.2 数据预 ...
- add函数 pytorch_Pytorch学习记录-Pytorch可视化使用tensorboardX
Pytorch学习记录-Pytorch可视化使用tensorboardX 在很早很早以前(至少一个半月),我做过几节关于tensorboard的学习记录. https://www.jianshu.co ...
- 【项目实战】vue-springboot-pytorch前后端结合pytorch深度学习 html打开本地摄像头 监控人脸和记录时间
是一个项目的一个功能之一,调试了两小时,终于能够 javascript设置开始计和暂停计时 监控人脸 记录时间了 效果图: 离开页面之后回到页面会从0计时(不是关闭页面,而是页面失去焦点) 离开摄像头 ...
- ubuntu20.04+gpu驱动下载+cuda10.2+cudnn+pytorch深度学习搭建记录(一路爬坑的一天...)
ubuntu20.04+gpu驱动下载+cuda10.2+cudnn+pytorch 深度学习环境搭建记录(一路爬坑的一天-) 1.gpu驱动下载 参考:https://blog.csdn.net/f ...
- PyTorch学习记录——PyTorch进阶训练技巧
PyTorch学习记录--PyTorch进阶训练技巧 1.自定义损失函数 1.1 以函数的方式定义损失函数 1.2 以类的方式定义损失函数 1.3 比较与思考 2.动态调整学习率 2.1 官方提供的s ...
- PyTorch学习记录-1PyTorch安装
学习建议里有PyTorch,所以我就开始了PyTorch的学习. 首先就是安装啦,去官网很清楚,可以选择自己的版本和平台,然后下面就会出现 Run this command: 后面跟着的命令复制运行 ...
- seq2seq模型_Pytorch学习记录-Seq2Seq模型对比
Pytorch学习记录-torchtext和Pytorch的实例4 0. PyTorch Seq2Seq项目介绍 在完成基本的torchtext之后,找到了这个教程,<基于Pytorch和tor ...
- PyTorch深度学习实践
根据学习情况随时更新. 2020.08.14更新完成. 参考课程-刘二大人<PyTorch深度学习实践> 文章目录 (一)课程概述 (二)线性模型 (三)梯度下降算法 (四)反向传播 (五 ...
最新文章
- linux问答学知识
- 自动化运维之SaltStack实践
- wxWidgets:wxStaticBox类用法
- 浙江绿盟科技2011.10.14校园招聘会笔试题
- php xmlhttprequest,DOM XMLHttpRequest
- CSDN挑战编程——《数学问题》
- 原生Java高仿抖音短视频APP双端源码
- ApacheCN/iBooker 未来计划 2019.11
- Windows API 逐个逐个学(3)----Windows系统基本服务API GetSystemDirectory
- PHRefreshTriggerView
- ps aux 和ps -aux和 ps -ef的选择
- 在CentOS6.9中搭建HBase
- 视频画面帧的展示控件SurfaceView及TextureView对比
- 如何识别媒体偏见_面部识别,种族偏见和非洲执法
- NRF24L01+模块:一对一双向通信,成功!
- 计算机研究生开题报告ppt模板,硕士开题报告ppt模板
- 全屏网页时钟屏保flipclock-beautify,简约风格,电脑手机均支持访问
- Unity之AB包的创建加载
- Python实现自动通关别踩白块儿
- 完整登录、注册页面(无功能)
热门文章
- HKPM智慧集贸市场管理系统在某农贸市场的应用
- 在桥式结构中的注意事项 — 探头的CMRR
- Google 应用与游戏出海 11 月刊: 领取您的节假季突围攻略
- Matlab pcode p文件 p代码 p文件代转m文件,pcode文件解密工具
- 标准小知识3一一直流电子负载
- RDBMS之SQL:SQL语言的各种方言的简介(MySQL/Hive SQL/PQL/OracleSQL/SQLite影响力排序)、主流语言的对比之详细攻略
- 7-2 复数计算 (10分)
- 淘宝一月打假结果:假货投诉占三成
- vi操作笔记及资料下载
- 测试驱动开发系列之五--测试的模式与反模式