目录

1,model.train(),model.eval()

2,tensor.unfold()

3,torch.nn.Unfold

4, pytorch张量约减操作

5,nn._init_()


1,model.train(),model.eval()

model.train():启用 BatchNormalization 和 Dropout

model.eval():不启用 BatchNormalization 和 Dropout

训练完train_datasets之后,model要来测试样本了。在模型输入测试数据之前model(test_datasets),需要加上model.eval(). 否则的话,有输入数据,即使不训练,它也会改变权值。这是model中含有batch normalization层所带来的影响

2,tensor.unfold()

unfold(dim, size, step) → Tensor

dim维填充上所有大小为size的分片。两个分片之间的步长为step。 如果_sizedim_是dim维度的原始大小,则在返回tensor中的维度dim大小是_(sizedim-size)/step+1_ 维度大小的附加维度将附加在返回的tensor中。

>>> x = torch.arange(1, 8)
>>> x1234567
[torch.FloatTensor of size 7]>>> x.unfold(0, 2, 1)1  22  33  44  55  66  7
[torch.FloatTensor of size 6x2]>>> x.unfold(0, 2, 2)1  23  45  6
[torch.FloatTensor of size 3x2]

3,torch.nn.Unfold

torch.nn.Unfold(kernel_size: Union[T, Tuple[T, ...]], dilation: Union[T, Tuple[T, ...]] = 1, padding: Union[T, Tuple[T, ...]] = 0, stride: Union[T, Tuple[T, ...]] = 1)

从输入张量的一个Batch数据中提取滑动的局部块,

参数:卷积核的尺寸,空洞大小,填充大小和步长。

是batch 维度,是通道维度,是任意的空间维度。该运算在输入的空间维度内每一个kernel_size大小的slide,形成第三个维度中的列。

nfold的输出为,其中为kernel_size长和宽的乘积, L是channel的长宽根据kernel_size的长宽滑动裁剪后,得到的区块的数量。

d是所有空间维度,空间尺寸指输入的空间维度(

每个区块的大小为

下图中公式有错,一上面计算区块数量的公式为准

nn.Unfold对输入channel的每一个 的滑动窗口区块做了展平操作。

torch.Size([1, 2, 4, 4])
tensor([[[[ 1.4818, -0.1026, -1.7688,  0.5384],[-0.4693, -0.0775, -0.7504,  0.2283],[-0.1414,  1.0006, -0.0942,  2.2981],[-0.9429,  1.1908,  0.9374, -1.3168]],[[-1.8184, -0.3926,  0.1875,  1.3847],[-0.4124,  0.9766, -1.3303, -0.0970],[ 1.7679,  0.6961, -1.6445,  0.7482],[ 0.1729, -0.3196, -0.1528,  0.2180]]]])
torch.Size([1, 8, 4])
tensor([[[ 1.4818, -1.7688, -0.1414, -0.0942],[-0.1026,  0.5384,  1.0006,  2.2981],[-0.4693, -0.7504, -0.9429,  0.9374],[-0.0775,  0.2283,  1.1908, -1.3168],[-1.8184,  0.1875,  1.7679, -1.6445],[-0.3926,  1.3847,  0.6961,  0.7482],[-0.4124, -1.3303,  0.1729, -0.1528],[ 0.9766, -0.0970, -0.3196,  0.2180]]])

def unfold_x():#N*M,C,T,V,Input#window_size=3,dilation=1,stride=1x=torch.randn(2,3,5,4)print(f'----------------x \n{x}')unfold=torch.nn.Unfold(kernel_size=(3, 1),dilation=(1, 1),stride=(1, 1),padding=(1, 0))un_x=unfold(x)print('un_x shape {}'.format(un_x.size()))print(f'............un_x\n{un_x}')return

x=torch.randn(2,3,5,4)  print(f'----------------x \n{x}')

tensor([[[[-0.1007, -0.1986,  0.2615,  0.2375],[ 0.3395, -0.2650, -2.3015, -0.5818],[-2.4892,  1.3659,  0.9418,  0.5290],[-0.1242, -1.1327, -0.7105, -0.3952],[-0.3351, -0.3885, -1.0516,  0.0144]],[[ 1.1538, -0.2460, -0.6409,  2.3420],[-1.6041, -0.0226, -1.1131, -1.2851],[ 1.5435,  2.1038,  0.1150,  0.7285],[-0.8543,  0.5684, -0.0907, -1.5588],[-0.1338,  1.2914,  0.5947, -0.1871]],[[-0.5479,  0.0572, -1.3323,  0.2371],[-0.3639,  0.8004, -2.4990, -2.6908],[-0.3635,  0.5411,  0.6723, -1.1272],[ 1.7912,  1.1216,  0.2887,  0.8244],[ 0.2222,  1.1524,  1.2438,  0.4919]]],[[[ 0.5019,  1.0633,  0.3409, -0.4121],[ 1.1162,  0.0055,  1.2277, -1.4919],[ 0.0533, -1.6769, -0.9581,  1.7418],[ 1.9506, -0.7145, -0.3485,  0.0497],[ 1.7571, -1.0860,  0.1596,  0.4369]],[[-0.9666, -0.7096,  0.3977,  0.9115],[-0.0983,  0.3316,  0.1486,  0.2869],[ 0.7518, -0.7357,  0.2328, -1.5851],[ 0.2918,  0.4178,  0.0045, -1.1917],[-1.2200, -1.2876,  1.9524, -2.4134]],[[ 2.6374,  1.4099,  0.8991, -0.7087],[ 0.2047, -0.6513, -0.8530,  0.7599],[ 0.2445,  0.5106, -2.3711,  0.5012],[ 1.2275, -0.0866,  0.6022, -0.0259],[ 0.0051, -0.0894, -0.1819, -0.7296]]]])
tensor([[[ 0.0000,  0.0000,  0.0000,  0.0000, -0.1007, -0.1986,  0.2615,0.2375,  0.3395, -0.2650, -2.3015, -0.5818, -2.4892,  1.3659,0.9418,  0.5290, -0.1242, -1.1327, -0.7105, -0.3952],[-0.1007, -0.1986,  0.2615,  0.2375,  0.3395, -0.2650, -2.3015,-0.5818, -2.4892,  1.3659,  0.9418,  0.5290, -0.1242, -1.1327,-0.7105, -0.3952, -0.3351, -0.3885, -1.0516,  0.0144],[ 0.3395, -0.2650, -2.3015, -0.5818, -2.4892,  1.3659,  0.9418,0.5290, -0.1242, -1.1327, -0.7105, -0.3952, -0.3351, -0.3885,-1.0516,  0.0144,  0.0000,  0.0000,  0.0000,  0.0000],[ 0.0000,  0.0000,  0.0000,  0.0000,  1.1538, -0.2460, -0.6409,2.3420, -1.6041, -0.0226, -1.1131, -1.2851,  1.5435,  2.1038,0.1150,  0.7285, -0.8543,  0.5684, -0.0907, -1.5588],[ 1.1538, -0.2460, -0.6409,  2.3420, -1.6041, -0.0226, -1.1131,-1.2851,  1.5435,  2.1038,  0.1150,  0.7285, -0.8543,  0.5684,-0.0907, -1.5588, -0.1338,  1.2914,  0.5947, -0.1871],[-1.6041, -0.0226, -1.1131, -1.2851,  1.5435,  2.1038,  0.1150,0.7285, -0.8543,  0.5684, -0.0907, -1.5588, -0.1338,  1.2914,0.5947, -0.1871,  0.0000,  0.0000,  0.0000,  0.0000],[ 0.0000,  0.0000,  0.0000,  0.0000, -0.5479,  0.0572, -1.3323,0.2371, -0.3639,  0.8004, -2.4990, -2.6908, -0.3635,  0.5411,0.6723, -1.1272,  1.7912,  1.1216,  0.2887,  0.8244],[-0.5479,  0.0572, -1.3323,  0.2371, -0.3639,  0.8004, -2.4990,-2.6908, -0.3635,  0.5411,  0.6723, -1.1272,  1.7912,  1.1216,0.2887,  0.8244,  0.2222,  1.1524,  1.2438,  0.4919],[-0.3639,  0.8004, -2.4990, -2.6908, -0.3635,  0.5411,  0.6723,-1.1272,  1.7912,  1.1216,  0.2887,  0.8244,  0.2222,  1.1524,1.2438,  0.4919,  0.0000,  0.0000,  0.0000,  0.0000]],[[ 0.0000,  0.0000,  0.0000,  0.0000,  0.5019,  1.0633,  0.3409,-0.4121,  1.1162,  0.0055,  1.2277, -1.4919,  0.0533, -1.6769,-0.9581,  1.7418,  1.9506, -0.7145, -0.3485,  0.0497],[ 0.5019,  1.0633,  0.3409, -0.4121,  1.1162,  0.0055,  1.2277,-1.4919,  0.0533, -1.6769, -0.9581,  1.7418,  1.9506, -0.7145,-0.3485,  0.0497,  1.7571, -1.0860,  0.1596,  0.4369],[ 1.1162,  0.0055,  1.2277, -1.4919,  0.0533, -1.6769, -0.9581,1.7418,  1.9506, -0.7145, -0.3485,  0.0497,  1.7571, -1.0860,0.1596,  0.4369,  0.0000,  0.0000,  0.0000,  0.0000],[ 0.0000,  0.0000,  0.0000,  0.0000, -0.9666, -0.7096,  0.3977,0.9115, -0.0983,  0.3316,  0.1486,  0.2869,  0.7518, -0.7357,0.2328, -1.5851,  0.2918,  0.4178,  0.0045, -1.1917],[-0.9666, -0.7096,  0.3977,  0.9115, -0.0983,  0.3316,  0.1486,0.2869,  0.7518, -0.7357,  0.2328, -1.5851,  0.2918,  0.4178,0.0045, -1.1917, -1.2200, -1.2876,  1.9524, -2.4134],[-0.0983,  0.3316,  0.1486,  0.2869,  0.7518, -0.7357,  0.2328,-1.5851,  0.2918,  0.4178,  0.0045, -1.1917, -1.2200, -1.2876,1.9524, -2.4134,  0.0000,  0.0000,  0.0000,  0.0000],[ 0.0000,  0.0000,  0.0000,  0.0000,  2.6374,  1.4099,  0.8991,-0.7087,  0.2047, -0.6513, -0.8530,  0.7599,  0.2445,  0.5106,-2.3711,  0.5012,  1.2275, -0.0866,  0.6022, -0.0259],[ 2.6374,  1.4099,  0.8991, -0.7087,  0.2047, -0.6513, -0.8530,0.7599,  0.2445,  0.5106, -2.3711,  0.5012,  1.2275, -0.0866,0.6022, -0.0259,  0.0051, -0.0894, -0.1819, -0.7296],[ 0.2047, -0.6513, -0.8530,  0.7599,  0.2445,  0.5106, -2.3711,0.5012,  1.2275, -0.0866,  0.6022, -0.0259,  0.0051, -0.0894,-0.1819, -0.7296,  0.0000,  0.0000,  0.0000,  0.0000]]])

4, pytorch张量约减操作

agg = torch.einsum('vu,nctu->nctv', a_n, b_n)

a_n=torch.from_numpy(A)
b_n=torch.from_numpy(b)
print(a_n)
print(b_n)
agg = torch.einsum('vu,nctu->nctv', a_n, b_n)
#每个张量的最后一个维度u对应行,相乘相加后形成v的每个元素
#a_n[v,1]点乘b_n[n,c,t,1]得到 agg[n,c,t,v11]
print(agg)
print(agg.size())
tensor([[ 0.,  1.,  2.,  3.,  4.],[ 5.,  6.,  7.,  8.,  9.],[10., 11., 12., 13., 14.],[15., 16., 17., 18., 19.],[20., 21., 22., 23., 24.]], dtype=torch.float64)
tensor([[[[ 0.,  1.,  2.,  3.,  4.]],[[ 5.,  6.,  7.,  8.,  9.]],[[10., 11., 12., 13., 14.]]],[[[15., 16., 17., 18., 19.]],[[20., 21., 22., 23., 24.]],[[25., 26., 27., 28., 29.]]]], dtype=torch.float64)
tensor([[[[  30.,   80.,  130.,  180.,  230.]],[[  80.,  255.,  430.,  605.,  780.]],[[ 130.,  430.,  730., 1030., 1330.]]],[[[ 180.,  605., 1030., 1455., 1880.]],[[ 230.,  780., 1330., 1880., 2430.]],[[ 280.,  955., 1630., 2305., 2980.]]]], dtype=torch.float64)
torch.Size([2, 3, 1, 5])

5,nn._init_()

torch.nn.init.uniform_(tensor, a=0, b=1) 服从~U ( a , b ) 均匀分布

Pytorch使用--学习记录相关推荐

  1. Pytorch深度学习记录:对CIFAR-10的深度学习模型搭建与测试

    前言 CIFAR-10介绍 下载.分类与读入数据集 数据集下载 解压与分类数据集 读入数据 搭建神经网络 卷积层(Convolutional layer) 池化层(Pooling lay) 残差网络( ...

  2. Pytorch学习记录-torchtext和Pytorch的实例( 使用神经网络训练Seq2Seq代码)

    Pytorch学习记录-torchtext和Pytorch的实例1 0. PyTorch Seq2Seq项目介绍 1. 使用神经网络训练Seq2Seq 1.1 简介,对论文中公式的解读 1.2 数据预 ...

  3. add函数 pytorch_Pytorch学习记录-Pytorch可视化使用tensorboardX

    Pytorch学习记录-Pytorch可视化使用tensorboardX 在很早很早以前(至少一个半月),我做过几节关于tensorboard的学习记录. https://www.jianshu.co ...

  4. 【项目实战】vue-springboot-pytorch前后端结合pytorch深度学习 html打开本地摄像头 监控人脸和记录时间

    是一个项目的一个功能之一,调试了两小时,终于能够 javascript设置开始计和暂停计时 监控人脸 记录时间了 效果图: 离开页面之后回到页面会从0计时(不是关闭页面,而是页面失去焦点) 离开摄像头 ...

  5. ubuntu20.04+gpu驱动下载+cuda10.2+cudnn+pytorch深度学习搭建记录(一路爬坑的一天...)

    ubuntu20.04+gpu驱动下载+cuda10.2+cudnn+pytorch 深度学习环境搭建记录(一路爬坑的一天-) 1.gpu驱动下载 参考:https://blog.csdn.net/f ...

  6. PyTorch学习记录——PyTorch进阶训练技巧

    PyTorch学习记录--PyTorch进阶训练技巧 1.自定义损失函数 1.1 以函数的方式定义损失函数 1.2 以类的方式定义损失函数 1.3 比较与思考 2.动态调整学习率 2.1 官方提供的s ...

  7. PyTorch学习记录-1PyTorch安装

    学习建议里有PyTorch,所以我就开始了PyTorch的学习. 首先就是安装啦,去官网很清楚,可以选择自己的版本和平台,然后下面就会出现 Run this command:  后面跟着的命令复制运行 ...

  8. seq2seq模型_Pytorch学习记录-Seq2Seq模型对比

    Pytorch学习记录-torchtext和Pytorch的实例4 0. PyTorch Seq2Seq项目介绍 在完成基本的torchtext之后,找到了这个教程,<基于Pytorch和tor ...

  9. PyTorch深度学习实践

    根据学习情况随时更新. 2020.08.14更新完成. 参考课程-刘二大人<PyTorch深度学习实践> 文章目录 (一)课程概述 (二)线性模型 (三)梯度下降算法 (四)反向传播 (五 ...

最新文章

  1. linux问答学知识
  2. 自动化运维之SaltStack实践
  3. wxWidgets:wxStaticBox类用法
  4. 浙江绿盟科技2011.10.14校园招聘会笔试题
  5. php xmlhttprequest,DOM XMLHttpRequest
  6. CSDN挑战编程——《数学问题》
  7. 原生Java高仿抖音短视频APP双端源码
  8. ApacheCN/iBooker 未来计划 2019.11
  9. Windows API 逐个逐个学(3)----Windows系统基本服务API GetSystemDirectory
  10. PHRefreshTriggerView
  11. ps aux 和ps -aux和 ps -ef的选择
  12. 在CentOS6.9中搭建HBase
  13. 视频画面帧的展示控件SurfaceView及TextureView对比
  14. 如何识别媒体偏见_面部识别,种族偏见和非洲执法
  15. NRF24L01+模块:一对一双向通信,成功!
  16. 计算机研究生开题报告ppt模板,硕士开题报告ppt模板
  17. 全屏网页时钟屏保flipclock-beautify,简约风格,电脑手机均支持访问
  18. Unity之AB包的创建加载
  19. Python实现自动通关别踩白块儿
  20. 完整登录、注册页面(无功能)

热门文章

  1. HKPM智慧集贸市场管理系统在某农贸市场的应用
  2. 在桥式结构中的注意事项 — 探头的CMRR
  3. Google 应用与游戏出海 11 月刊: 领取您的节假季突围攻略
  4. Matlab pcode p文件 p代码 p文件代转m文件,pcode文件解密工具
  5. 标准小知识3一一直流电子负载
  6. RDBMS之SQL:SQL语言的各种方言的简介(MySQL/Hive SQL/PQL/OracleSQL/SQLite影响力排序)、主流语言的对比之详细攻略
  7. 7-2 复数计算 (10分)
  8. 淘宝一月打假结果:假货投诉占三成
  9. vi操作笔记及资料下载
  10. 测试驱动开发系列之五--测试的模式与反模式