pytorch的paramter
register_parameter
nn.Parameters 与 register_parameter 都会向 _parameters
写入参数,但是后者可以支持字符串命名。
从源码中可以看到,nn.Parameters为Module添加属性的方式也是通过register_parameter向 _parameters
写入参数。
def __setattr__(self, name, value):def remove_from(*dicts):for d in dicts:if name in d:del d[name]params = self.__dict__.get('_parameters')if isinstance(value, Parameter):if params is None:raise AttributeError("cannot assign parameters before Module.__init__() call")remove_from(self.__dict__, self._buffers, self._modules)self.register_parameter(name, value)elif params is not None and name in params:if value is not None:raise TypeError("cannot assign '{}' as parameter '{}' ""(torch.nn.Parameter or None expected)".format(torch.typename(value), name))self.register_parameter(name, value)else:modules = self.__dict__.get('_modules')if isinstance(value, Module):if modules is None:raise AttributeError("cannot assign module before Module.__init__() call")remove_from(self.__dict__, self._parameters, self._buffers)modules[name] = valueelif modules is not None and name in modules:if value is not None:raise TypeError("cannot assign '{}' as child module '{}' ""(torch.nn.Module or None expected)".format(torch.typename(value), name))modules[name] = valueelse:buffers = self.__dict__.get('_buffers')if buffers is not None and name in buffers:if value is not None and not isinstance(value, torch.Tensor):raise TypeError("cannot assign '{}' as buffer '{}' ""(torch.Tensor or None expected)".format(torch.typename(value), name))buffers[name] = valueelse:object.__setattr__(self, name, value)
import torch
from torch import nnclass MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()print('before register:\n', self._parameters, end='\n\n')self.register_parameter('my_param1', nn.Parameter(torch.randn(3, 3)))print('after register and before nn.Parameter:\n', self._parameters, end='\n\n')self.my_param2 = nn.Parameter(torch.randn(2, 2))print('after register and nn.Parameter:\n', self._parameters, end='\n\n')def forward(self, x):return xmymodel = MyModel()for k, v in mymodel.named_parameters():print(k, v)
程序返回为:
before register:OrderedDict()after register and before nn.Parameter:OrderedDict([('my_param1', Parameter containing:
tensor([[-1.3542, -0.4591, -2.0968],[-0.4345, -0.9904, -0.9329],[ 1.4990, -1.7540, -0.4479]], requires_grad=True))])after register and nn.Parameter:OrderedDict([('my_param1', Parameter containing:
tensor([[-1.3542, -0.4591, -2.0968],[-0.4345, -0.9904, -0.9329],[ 1.4990, -1.7540, -0.4479]], requires_grad=True)), ('my_param2', Parameter containing:
tensor([[ 1.0205, -1.3145],[-1.1108, 0.4288]], requires_grad=True))])my_param1 Parameter containing:
tensor([[-1.3542, -0.4591, -2.0968],[-0.4345, -0.9904, -0.9329],[ 1.4990, -1.7540, -0.4479]], requires_grad=True)
my_param2 Parameter containing:
tensor([[ 1.0205, -1.3145],[-1.1108, 0.4288]], requires_grad=True)
pytorch的paramter相关推荐
- pytorch学习笔记(十二):详解 Module 类
Module 是 pytorch 提供的一个基类,每次我们要 搭建 自己的神经网络的时候都要继承这个类,继承这个类会使得我们 搭建网络的过程变得异常简单. 本文主要关注 Module 类的内部是怎么样 ...
- PyTorch学习笔记(9)——nn.Conv2d和其中的padding策略
一. Caffe.Tensorflow的padding策略 在之前的转载过的一篇文章--<tensorflow ckpt文件转caffemodel时遇到的坑>提到过,caffe的paddi ...
- 通过anaconda2安装python2.7和安装pytorch
①由于官网下载anaconda2太慢,最好去byrbt下载,然后安装就行 ②安装完anaconda2会自动安装了python2.7(如终端输入python即进入python模式) 但是可能没有设置环境 ...
- 记录一次简单、高效、无错误的linux上安装pytorch的过程
1 准备miniconda Miniconda Miniconda 可以理解成Anaconda的免费.浓缩版.它非常小,只包含了conda.python以及它们依赖的一些包.我们可以根据我们的需要再安 ...
- 各种注意力机制PyTorch实现
给出了整个系列的PyTorch的代码实现,以及使用方法. 各种注意力机制 Pytorch implementation of "Beyond Self-attention: External ...
- PyTorch代码调试利器_TorchSnooper
GitHub 项目地址: https://github.com/zasdfgbnm/TorchSnooper 大家可能遇到这样子的困扰:比如说运行自己编写的 PyTorch 代码的时候,PyTorch ...
- pytorch常用代码
20211228 https://mp.weixin.qq.com/s/4breleAhCh6_9tvMK3WDaw 常用代码段 本文代码基于 PyTorch 1.x 版本,需要用到以下包: impo ...
- API pytorch tensorflow
pytorch与tensorflow API速查表 方法名称 pytroch tensorflow numpy 裁剪 torch.clamp(x, min, max) tf.clip_by_value ...
- tensor转换 pytorch tensorflow
一.tensorflow的numpy与tensor互转 1.数组(numpy)转tensor 利用tf.convert_to_tensor(numpy),将numpy转成tensor >> ...
最新文章
- face detection[PyramidBox]
- RHEL5系列之三:GNOME桌面的简单管理应用(1)
- [转载]漫谈游戏中的阴影技术
- UIKit封装的系统动画
- jdk1.8 base64注意事项
- Mybatis构建sql语法
- AS 中强制类型转换
- UI设计教程学习分享:APP布局
- How to install and configure vsftpd
- RHCE课程-RH253Linux服务器架设笔记五-APACHE服务器配置(2)
- 电脑突然显示只有C盘,其他磁盘不显示了----解决方法(很简单)
- C语言中将字符串转换为数字
- 影视APP下载页面html源码
- 计算机word虚线在哪里,在word中画虚线的五种方法
- 关于机械硬盘坏道(超时无响应、低速区域、掉盘)的修复尝试
- 七年级计算机上册知识树,七年级上知识树.doc
- 硬盘损坏,怪我咯?3分钟拯救硬盘里的小姐姐!
- sklearn机器学习(五)线性回归算法测算房价
- C/C++ 中 exit() 函数
- 消防工程师答题做试题模拟真题微信小程序,margin:25px 50px 75px 100px;
热门文章
- awk以空格为分隔符的问题
- 又一个 Golang 编写的僵尸网络:KmsdBot
- SQL分组查询,结果只取最新记录
- 基于springboot实现学校线上教学平台管理系统【源码+论文】分享
- JAVA计算机毕业设计学术会议信息网站Mybatis+源码+数据库+lw文档+系统+调试部署
- 如何在Visio里面添加“左”箭头
- 基于遗传算法优化的BP神经网络
- 2022年语音合成(TTS)和语音识别(ASR)年度总结
- 四步轻松实现用Visio画UML类图
- DNS是什么?有哪些公共 DNS ?