Pytorch之深入理解torch.nn.Parameter()
先看一段代码:
import torch
import torch.nn as nn
a=torch.tensor([1,2],dtype=torch.float32)
print(a)
print(nn.Parameter(a))
print(nn.parameter.Parameter(a))
结论:
- nn.Parameter=nn.parameter.Parameter
- parameter本质仍然是一个tensor。
- nn.Parameter的作用是:将一个不可训练的类型Tensor转换成可以训练的类型parameter,并且会向宿主模型注册该参数,成为一部分。即
model.parameters()
会包含这个parameter。从而,在参数优化的时候可以自动一起优化,这就不需要我们单独对这个参数进行优化啦。
其中2的证明如下:
补充
不少童鞋会想这个nn.Parameter(tensor)和对一个tensor直接进行`requires_grad=True`有什么区别?这两者不是一样的! 即下面的w1和w2功能是不一样的:
#对一个tensor直接进行`requires_grad=True`
w1=torch.tensor([1,2],dtype=torch.float32,requires_grad=True)#nn.Parameter(tensor)
a=torch.tensor([3,4],dtype=torch.float32)
w2=nn.Parameter(a)
功能哪里不一样?其实就是上面说的结论中的第3条,对一个tensor直接进行requires_grad=True
确实也变成了可训练的tensor,但这个tensor无法像parameter那样自动包括在 model.parameters()
中。
最后,我们做一个验证:
class mod(nn.Module):def __init__(self):super(mod,self).__init__()self.w1=torch.tensor([1,2],dtype=torch.float32,requires_grad=True)a=torch.tensor([3,4],dtype=torch.float32)self.w2=nn.Parameter(a)def forward(self,inputs):o1=torch.dot(self.w1,inputs)#使用了带梯度的普通tensoro2=torch.dot(self.w2,inputs)#使用了parameterreturn o1+o2
model=mod()
for p in model.parameters():print(p)
我们发现,只有parameter会在model.parameters()
中,这意味这,w1参数需要手动单独优化。
补充:
上述好像只打印了参数,没有打印参数名称,有点low。高级的如下:
model.state_dict()
#或者
for para in model.named_parameters():print(para)
完
Pytorch之深入理解torch.nn.Parameter()相关推荐
- PyTorch中的torch.nn.Parameter() 详解
PyTorch中的torch.nn.Parameter() 详解 今天来聊一下PyTorch中的torch.nn.Parameter()这个函数,笔者第一次见的时候也是大概能理解函数的用途,但是具体实 ...
- PyTorch里面的torch.nn.Parameter()
在刷官方Tutorial的时候发现了一个用法self.v = torch.nn.Parameter(torch.FloatTensor(hidden_size)),看了官方教程里面的解释也是云里雾里, ...
- torch.nn.Parameter()
中心: 与torch.Tensor相比,torch.Tensor()只是生成一个张量, 而torch.nn.Parameter()可以将张量变为可以训练的参数,而不是一个不可变的张量, 用法: se ...
- torch.nn.parameter详解
:-- 目录: 参考: 1.parameter基本解释: 2.参数requires_grad的深入理解: 2.1 Parameter级别的requires_grad 2.2Module级别的requi ...
- nn.Module、nn.Sequential和torch.nn.parameter学习笔记
nn.Module.nn.Sequential和torch.nn.parameter是利用pytorch构建神经网络最重要的三个函数.搞清他们的具体用法是学习pytorch的必经之路. 目录 nn.M ...
- PySOT代码之SiamRPN++分析——基础知识:hanning、outer、tile、contiguous、flatten、meshgrid、torch.nn.Parameter
基础知识扩充 感谢大佬们的工作,许多内容都是直接拿来用的,原地址附在参考文献板块 np.hanning(M) 汉宁窗是通过使用加权余弦形成的锥形 M:整数,输出窗口中的点数.如果为零或更小,则返回一个 ...
- PyTorch:tensor、torch.nn、autograd、loss等神经网络学习手册(持续更新)
PyTorch1:tensor2.torch.nn.autograd.loss等神经网络学习手册(持续更新) 链接:画图.读写图片 文章目录 一.tensor 二.完整训练过程:数据.模型.可学习参数 ...
- torch.nn.parameter.Parameter分析
torch.nn.parameter.Parameter 作用 a kind of Tensor that is to be considered a module parameter. Parame ...
- pytorch深度学习框架—torch.nn模块(一)
pytorch深度学习框架-torch.nn模块 torch.nn模块中包括了pytorch中已经准备好的层,方便使用者调用构建的网络.包括了卷积层,池化层,激活函数层,循环层,全连接层. 卷积层 p ...
最新文章
- 第二课 , 启动 ./start-all.sh
- make: Nothing to be done for `first'
- java代码初体验_第一次Java 8体验
- 51nod1040 最大公约数之和,欧拉函数或积性函数
- IOS开发基础篇 -- 分类、类别
- mysql 多条记录选择一套_2020-11-09-Mysql(练习题第一套)
- Windows10+Ubuntu 18.04.2+ROS 安装笔记(SSD单硬盘)上
- 我就拜你为师的飞秋爱好者
- 谁动了你的主机-Windows“唤醒”和“开机”时自动拍照-狩猎者项目
- 蓝桥杯单片机篇:NE555 频率测量
- 严重: Catalina.stop: java.net.ConnectException: Connection refused: connect
- SpringBoot 2.0 系列005 --启动实战之SpringApplication应用
- 博客园在我的博客添加点击小心心特效
- 基于安卓手机的WAPI证书安装使用详解
- Mac制作映像(dmg)文件详细步骤
- 图片怎么转换成png格式?
- 妥妥的世界第一:为什么MT4软件的地位无法撼动?
- 生动形象解释虚数的意义
- D. Nearest Excluded Points(cf)坐标反向BFS
- 如果我有100块钱……