PyTorch中nn.xx与nn.functional.xx的区别
PyTorch中nn.xx与nn.functional.xx的区别
- 1 总体
- 2 两者的相同之处:
- 2.1 功能相同:
- 2.2 运行效率也是近乎相同。
- 3 两者的差别之处:
- 3.1 调用方式不一样
- 3.2 与nn.Sequential()结合性不一样
- 3.3 管理参数不一样
- 3.4 使用Dropout时不一样
1 总体
- nn.functional.xx是底层的函数接口
- nn.xx是nn.functional.xxx的类封装,并且nn.Xxx都继承于一个共同祖先nn.Module。这一点导致nn.Xxx除了具有nn.functional.xxx功能之外,内部附带了nn.Module相关的属性和方法,例如train(), eval(),load_state_dict, state_dict 等。
换言之:
- nn.Module 实现的 layer 是由 class Layer(nn.Module) 定义的特殊类
- nn.functional 中的函数更像是纯函数,由 def function(input) 定义
2 两者的相同之处:
2.1 功能相同:
即nn.Conv2d和nn.functional.conv2d 都是进行卷积,nn.Dropout 和nn.functional.dropout都是进行dropout,。。。。。;
2.2 运行效率也是近乎相同。
3 两者的差别之处:
3.1 调用方式不一样
nn.Xxx 需要先实例化并传入参数,然后以函数调用的方式调用实例化的对象并传入输入数据。
nn.functional.xxx同时传入输入数据和weight, bias等其他参数 。
# torch.nn
inputs = torch.randn(64, 3, 244, 244)
self.conv = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1)
outputs = self.conv(inputs)# torch.nn.functional 需要同时传入数据和 weight,bias等参数
inputs = torch.randn(64, 3, 244, 244)
weight = torch.randn(64, 3, 3, 3)
bias = torch.randn(64)
outputs = nn.functinoal.conv2d(inputs, weight, bias, padding=1)
3.2 与nn.Sequential()结合性不一样
nn.xxx 能够放在 nn.Sequential里,而 nn.functional.xxx 就不行
3.3 管理参数不一样
nn.Xxx不需要你自己定义和管理weight;而nn.functional.xxx需要你自己定义weight,每次调用的时候都需要手动传入weight, 不利于代码复用。
import torch
import torch.nn as nn
import torch.nn.functional as F# torch.nn 定义的CNN
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv_1 = nn.Conv2d(1, 16, krenel_size=5, padding=0)self.relu_1 = nn.ReLU(inplace=True)self.maxpool_1 = nn.MaxPool2d(kernel_size=2)self.conv_2 = nn.Conv2d(16, 32, krenel_size=5, padding=0)self.relu_2 = nn.ReLU(inplace=True)self.maxpool_2 = nn.MaxPool2d(kernel_size=2) self.linear = nn.Linear(4*4*32, 10)def forward(self, x):x = x.view(x.size(0), -1)out = self.maxpool_1(self.relu_1(self.conv_1(x)))out = self.maxpool_2(self.relu_2(self.conv_2(out)))out = self.linear(out.view(x.size(0), -1))return out# torch.nn.functional 定义一个相同的CNN
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv_1_weight = nn.Parameter(torch.randn(16, 1, 5, 5))self.bias_1_weight = nn.Parameter(torch.randn(16))self.conv_2_weight = nn.Parameter(torch.randn(32, 16, 5, 5))self.bias_2_weight = nn.Parameter(torch.randn(32))self.linear_weight = nn.Parameter(torch.randn(4 * 4 * 32, 10))self.bias_weight = nn.Parameter(torch.randn(10))def forward(self, x):x = x.view(x.size(0), -1)out = F.conv2d(x, self.conv_1_weight, self.bias_1_weight)out = F.conv2d(out, self.conv_2_weight, self.bias_2_weight)out = F.linear(out.view(x.size(0), -1), self.linear_weight, self.bias_weight)
3.4 使用Dropout时不一样
在使用Dropout时,推荐使用 nn.xxx。因为一般只有训练时才使用 Dropout,在验证或测试时不需要使用 Dropout。使用 nn.Dropout时,如果调用 model.eval() ,模型的 Dropout 层都会关闭;但如果使用 nn.functional.dropout,在调用 model.eval() 时,不会关闭 Dropout。
参考资料:
http://www.manongjc.com/detail/11-nnmwgaxcsvsaxiy.html
https://www.zhihu.com/question/66782101
PyTorch中nn.xx与nn.functional.xx的区别相关推荐
- Pytorch中的model.modules()和model.children()的区别
Pytorch中的model.modules()和model.children()的区别 背景:最近在做网络模型中可视化的过程中,需要将网络结构中的某一层的特征进行输出.所以就遇到了这个问题,小小记录 ...
- Pytorch中rand,randn, random以及normal的区别
Pytorch中rand,randn, random以及normal的区别 torch.rand() torch.randn() torch.normal() torch.randperm() tor ...
- PyTorch中的masked_select、masked_fill_()、 masked_fill()的区别
mask_fill的整体意思是使用 value 填充mask值为True的位置,所以需要注意mask的值,有时候需要~取反 masked_fill_(mask, value) - 函数名后面加下划线. ...
- pytorch 中pad函数toch.nn.functional.pad()的使用
padding操作是给图像外围加像素点. 为了实际说明操作过程,这里我们使用一张实际的图片来做一下处理. 这张图片是大小是(256,256),使用pad来给它加上一个黑色的边框.具体代码如下: imp ...
- Pytorch中的 torch.as_tensor() 和 torch.from_numpy() 的区别
之前我写过一篇文章,比较了 torch.Tensor() 和 torch.tensor() 的区别,而这两者都是深拷贝的方法,返回张量的同时,会在内存中创建一个额外的数据副本,与原数据不共享内存,所以 ...
- Pytorch中的 torch.Tensor() 和 torch.tensor() 的区别
直接在搜索引擎里进行搜索,可以看到官方文档中两者对应的页面: 分别点击进去,第一个链接解释了什么是 torch.Tensor: torch.Tensor 是一个包含单一数据类型元素的多维矩阵(数组). ...
- python语言中ch用法_pytorch 中pad函数toch.nn.functional.pad()的用法
padding操作是给图像外围加像素点. 为了实际说明操作过程,这里我们使用一张实际的图片来做一下处理. 这张图片是大小是(256,256),使用pad来给它加上一个黑色的边框.具体代码如下: imp ...
- pytorch学习笔记七:nn网络层——池化层、线性层
一.池化层 池化运算:对信号进行"收集" 并"总结",类似于水池收集水资源,因而得名池化层. 收集:由多变少,图像的尺寸由大变小 总结:最大值/平均值 下面是最 ...
- pytorch中交叉熵
关于pytorch中交叉熵的使用,pytorch的交叉熵是其loss function的一种且包含了softmax的过程. pytorch中交叉熵函数是nn.CrossEntropyLoss().其参 ...
- PyTorch中的sampled_softmax_loss
最近做一篇论文的复现,发现PyTorch中没有类似于tf.nn.sampled_softmax_loss的功能,经过一番寻找,在github上找到了答案:Stonesjtu/Pytorch-NCE.奇 ...
最新文章
- Java架构技术文档:并发编程+设计模式+常用框架+JVM+精选视频
- Rust编写的新终端多路复用器
- [No000077]打造自己的Eclipse
- 走进异步编程的世界 - 开始接触 async/await
- Win7安装OnlyOffice(不使用Docker)
- 【渝粤教育】 广东开放大学 10548_金融学k1_21秋考试
- 零空间,Markov‘s inequality, Chebyshev Chernoff Bound, Union Bound
- 从sqlite 迁移 mysql_将 Ghost 从 SQLite3 数据库迁移到 MySQL 数据库
- poj 1273 最大流
- linux下载镜像的命令wget,Linux wget命令整站下载做网站镜像
- 转:多线程--六种多线程方法解决UI线程阻塞
- mysql+asp.net开发注意大全:mysql创建数据库的时候,创建新用户,并且付给权限。mysql存储过程的编写,mysql数据库引擎的区别,mysql数据库文件夹备份...
- 公开课视频-《第01章 规划》-大企业云桌面部署实战-在线培训-视频(奉献)
- 数字图像处理实验九维纳滤波
- Python 针对Excel操作
- Hark的数据结构与算法练习之堆排序
- 第二十章:异步和文件I/O.(十九)
- Ubuntu进不去图形化界面的解决方案
- openmp配置指南_/openmp (启用 OpenMP 支持)
- Java项目:ssm医院管理系统
热门文章
- 四旋翼无人机飞行器基本知识(四旋翼无人机结构和原理+四轴飞行diy全套入门教程)
- C语言实现求斐波那契数列中的第n项
- matlab中sr锁存器,大家一致避免使用的锁存器为什么依然存在于FPGA中?我们对锁存器有什么误解?...
- 微信小程序获取二维码:报错47001 data format error
- 饼图-图例标记及文字的设置
- c语言常用函数doc下载,c语言常用函数.doc
- 无线通信与编码_新型OFDM波形集_使用MATLAB仿真实现UFMC并与OFDM作对比_含实现代码
- AD PCBlayout 总结
- 11、ARM嵌入式系统:中断使能
- C语言入门教程(一)