my torch voyage
文章目录
- ImportError: dlopen: cannot load any more object with static TLS
- 一些节约显存的trick
- Unfold 和Fold
- torch 初始化通过函数加载预训练模型是否是原地修改网络参数
- 再次注意torch eval()
随便记录一下torch的使用
ImportError: dlopen: cannot load any more object with static TLS
20200410
在import torch
的时候抛出了这个错误,百度一下,
- 有人更换cv2和torch的导入顺序,我尝试了下,然并卵
- 考虑到我导入的包有些是自己写的包,考虑到重名问题改了,没用
- 在最开始的地方,第一步就导入torch,成功解决
一些节约显存的trick
with torch.no_grad()
再验证的时候加上这个会再计算梯度的时候节约显存计算和使用torch.cuda.empty_cache()
用在每个epoch结束或者开始的地方,释放一些没有用的显存,他的机制是释放没有变量引用的显存python del
主动删除一些有引用的变量显存,在使用完某些变量或者中间计算量的时候手动删除这些变量较少显存占用
实际上我是用了上面三个方法还是不行2333,只是延迟了显存爆表的时间。。。,最后无奈还是减少了batchsize,现在看情况,最尴尬的情况就是我在本地测试的时候没有使用del和empty_cache,显存也没有变。。,但是后者确实可以在nvidia-smi中看到释放的显存,但是总的显存并没有变化,用于训练+测试好像没有效果,回头把他们全都注释掉
Unfold 和Fold
torch.nn.Unfold
torch.nn.functional.unfold
torch.tensor.unfold
从一个批次的输入样本中提取出滑动的局部区域块
不太懂啊
torch.nn.Unfold(kernel_size, dilation=1,padding=0,stride=1)
参考
输入为(N, C, H, W)
输出为(N, Cx∏\prod∏(kernel_size), L)
就是把卷积核滑动过程中经过的区域提取出来,遍历每个通道(channel, C), 每个patch 的大小为∏(kernelsize)\prod(kernel_size)∏(kernelsize), 每个通道有L 个patch
torch.nn.Fold(output_size, kernel_size, dilation=1, padding=0, stride=1)
则相反
- 待实验
torch 初始化通过函数加载预训练模型是否是原地修改网络参数
这句话的意思是,将网络模型作为参数传递给初始化函数,但是初始化函数并没有返回值,这样的情况下原网络的参数是否被修改,即是值传递还是传递引用
实验证明是传递引用
# coding=utf-8
"""
@filename : temp.py
@author : keda_wl
@time : 2021-09-30 14:44:15
@description: 网络通过函数调用加载参数是原地修改值吗?
"""import torch
import torch.nn as nnclass Net(nn.Module):def __init__(self):super().__init__()self.l = nn.Linear(4, 2, bias=False)def forward(self, x):return self.l(x)def func1(net1):net1.apply(func2)print(net1.state_dict())torch.save(net1.state_dict(), 'temp.pth')return net1def func2(m):classname = m.__class__.__name__if classname.find('Linear') != -1:torch.nn.init.zeros_(m.weight.data)def func3(net2):func4(net2)print(net2.state_dict())def func4(net4):net4.load_state_dict(torch.load('temp.pth'))# torch.nn.init
net1 = Net()
print(net1.state_dict())
func1(net1)
# load state dict
net2 = Net()
print(net2.state_dict())
func3(net2)"""
OrderedDict([('l.weight', tensor([[-0.0960, 0.0352, 0.4729, -0.3162],[-0.4901, -0.2909, 0.2769, -0.3116]]))])
OrderedDict([('l.weight', tensor([[0., 0., 0., 0.],[0., 0., 0., 0.]]))])
OrderedDict([('l.weight', tensor([[-0.4603, 0.1903, -0.0606, -0.4371],[ 0.1231, 0.3120, -0.4781, 0.4559]]))])
OrderedDict([('l.weight', tensor([[0., 0., 0., 0.],[0., 0., 0., 0.]]))])
"""
再次注意torch eval()
20220915 今天使用训练好的torch 模型测试,对比另一份代码,超参都一样,但是结果总是对不上。经过debug 后,发现时模型加载参数后,没有调用model.eval()
, 模型结构中包含batchnorm
所以会出错。
记录下再次注意包含batchnorm dropout 之类的模型需要调用model.eval()
my torch voyage相关推荐
- Pyinstaller 打包 torch 后执行失败 OSError: could not get source code
1. 问题现象 系统环境 Python 3.6.9 torch 1.2.0 torchvision 0.4.0 Pyinstaller 4.5.1 Pyinstaller 打包 torch 后执行失败 ...
- torch.nn.functional.cross_entropy.ignore_index
ignore_index表示计算交叉熵时,自动忽略的标签值,example: import torch import torch.nn.functional as F pred = [] pred.a ...
- torch.backends.cudnn.deterministic 使用cuda保证每次结果一样
为什么使用相同的网络结构,跑出来的效果完全不同,用的学习率,迭代次数,batch size 都是一样?固定随机数种子是非常重要的.但是如果你使用的是PyTorch等框架,还要看一下框架的种子是否固定了 ...
- PyTorch的torch.cat
字面理解:torch.cat是将两个张量(tensor)拼接在一起,cat是concatnate的意思,即拼接,联系在一起. 例子理解 import torch A=torch.ones(2,3) # ...
- pytorch学习 中 torch.squeeze() 和torch.unsqueeze()的用法
squeeze的用法主要就是对数据的维度进行压缩或者解压. 先看torch.squeeze() 这个函数主要对数据的维度进行压缩,去掉维数为1的的维度,比如是一行或者一列这种,一个一行三列(1,3)的 ...
- PyTorch里面的torch.nn.Parameter()
在刷官方Tutorial的时候发现了一个用法self.v = torch.nn.Parameter(torch.FloatTensor(hidden_size)),看了官方教程里面的解释也是云里雾里, ...
- PyTorch官方中文文档:torch.optim 优化器参数
内容预览: step(closure) 进行单次优化 (参数更新). 参数: closure (callable) –...~ 参数: params (iterable) – 待优化参数的iterab ...
- torch中的copy()和clone()
torch中的copy()和clone() 1.torch中的copy()和clone() y = torch.Tensor(2,2):copy(x) --- 1 修改y并不改变原来的x y = x: ...
- torch.nn.Embedding理解
Pytorch官网的解释是:一个保存了固定字典和大小的简单查找表.这个模块常用来保存词嵌入和用下标检索它们.模块的输入是一个下标的列表,输出是对应的词嵌入. torch.nn.Embedding(nu ...
- torch.nn.Linear()函数的理解
import torch x = torch.randn(128, 20) # 输入的维度是(128,20) m = torch.nn.Linear(20, 30) # 20,30是指维度 outpu ...
最新文章
- 浅谈搜索引擎百度分词技术
- 【面试题】Spring框架中Bean的生命周期
- 加工生产调度(信息学奥赛一本通-T1425)
- [转载] numpy.inf
- matlab min函数_数学建模与MATLAB非线性规划
- Web服务器点击劫持(ClickJacking)的安全防范
- 【VR】Leap Motion 官网文档 FingerModel (手指模型)
- 使用5502自带的UART口发送数据乱码的问题
- JDK 安装 Java环境变量配置
- linux操作系统第三版课后题答案,linux操作系统( 课后习题答案).doc
- android 解压zip工具,ZArchiver解压缩工具
- 并行网络测试软件,Manul:一款基于覆盖率引导的并行模糊测试工具
- win10触摸板手势教程
- MyCobot六轴机械臂开箱及开发前的准备工作(一)
- 按照分类方法判断图片里是否有鹅蛋
- 鸡和兔子若干只,头有35个,脚有94个,求兔子和鸭个多少只
- [C#]Unicode与汉字互转
- guid分区怎么装win7_如何在GUID分区装win7系统并以UEFI启动?
- GB2312、GB18030、GBK、UNICODE、BIG5之间兼容关系如何?
- rabbitmq direct reply-to 在springAMQP和python之间的使用
热门文章
- Windows server 2008 安装Hyper-V
- nanopi 2 fire s5p4418 初次体验 (1)uboot,linux kernel编译
- Java基础每日一练—第5天:预测身高案列
- QT设置窗体标题及背景颜色
- 编译原理之Frist集与Follow集
- unity 打包一直停留在 detecting current sdk tools version
- Pangu Separates Heaven and Earth(签到题)
- idea 设置eplice 前进后退快捷键
- iPhone/iPad怎么进入恢复模式?
- Hex Fiend——mac 下 WinHex的完美替代