PyTorch中nn.xx与nn.functional.xx的区别

  • 1 总体
  • 2 两者的相同之处:
    • 2.1 功能相同:
    • 2.2 运行效率也是近乎相同。
  • 3 两者的差别之处:
    • 3.1 调用方式不一样
    • 3.2 与nn.Sequential()结合性不一样
    • 3.3 管理参数不一样
    • 3.4 使用Dropout时不一样

1 总体

  • nn.functional.xx是底层的函数接口
  • nn.xx是nn.functional.xxx的类封装,并且nn.Xxx都继承于一个共同祖先nn.Module。这一点导致nn.Xxx除了具有nn.functional.xxx功能之外,内部附带了nn.Module相关的属性和方法,例如train(), eval(),load_state_dict, state_dict 等。

换言之:

  • nn.Module 实现的 layer 是由 class Layer(nn.Module) 定义的特殊类
  • nn.functional 中的函数更像是纯函数,由 def function(input) 定义

2 两者的相同之处:

2.1 功能相同:

即nn.Conv2d和nn.functional.conv2d 都是进行卷积,nn.Dropout 和nn.functional.dropout都是进行dropout,。。。。。;

2.2 运行效率也是近乎相同。

3 两者的差别之处:

3.1 调用方式不一样

nn.Xxx 需要先实例化并传入参数,然后以函数调用的方式调用实例化的对象并传入输入数据。

nn.functional.xxx同时传入输入数据和weight, bias等其他参数 。

# torch.nn
inputs =  torch.randn(64, 3, 244, 244)
self.conv = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1)
outputs = self.conv(inputs)# torch.nn.functional   需要同时传入数据和 weight,bias等参数
inputs =  torch.randn(64, 3, 244, 244)
weight = torch.randn(64, 3, 3, 3)
bias = torch.randn(64)
outputs = nn.functinoal.conv2d(inputs, weight, bias, padding=1)

3.2 与nn.Sequential()结合性不一样

nn.xxx 能够放在 nn.Sequential里,而 nn.functional.xxx 就不行

3.3 管理参数不一样

nn.Xxx不需要你自己定义和管理weight;而nn.functional.xxx需要你自己定义weight,每次调用的时候都需要手动传入weight, 不利于代码复用。

import torch
import torch.nn as nn
import torch.nn.functional as F# torch.nn 定义的CNN
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv_1 = nn.Conv2d(1, 16, krenel_size=5, padding=0)self.relu_1 = nn.ReLU(inplace=True)self.maxpool_1 = nn.MaxPool2d(kernel_size=2)self.conv_2 = nn.Conv2d(16, 32, krenel_size=5, padding=0)self.relu_2 = nn.ReLU(inplace=True)self.maxpool_2 = nn.MaxPool2d(kernel_size=2)   self.linear = nn.Linear(4*4*32, 10)def forward(self, x):x = x.view(x.size(0), -1)out = self.maxpool_1(self.relu_1(self.conv_1(x)))out = self.maxpool_2(self.relu_2(self.conv_2(out)))out = self.linear(out.view(x.size(0), -1))return out# torch.nn.functional 定义一个相同的CNN
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv_1_weight = nn.Parameter(torch.randn(16, 1, 5, 5))self.bias_1_weight = nn.Parameter(torch.randn(16))self.conv_2_weight = nn.Parameter(torch.randn(32, 16, 5, 5))self.bias_2_weight = nn.Parameter(torch.randn(32))self.linear_weight = nn.Parameter(torch.randn(4 * 4 * 32, 10))self.bias_weight = nn.Parameter(torch.randn(10))def forward(self, x):x = x.view(x.size(0), -1)out = F.conv2d(x, self.conv_1_weight, self.bias_1_weight)out = F.conv2d(out, self.conv_2_weight, self.bias_2_weight)out = F.linear(out.view(x.size(0), -1), self.linear_weight, self.bias_weight)

3.4 使用Dropout时不一样

在使用Dropout时,推荐使用 nn.xxx。因为一般只有训练时才使用 Dropout,在验证或测试时不需要使用 Dropout。使用 nn.Dropout时,如果调用 model.eval() ,模型的 Dropout 层都会关闭;但如果使用 nn.functional.dropout,在调用 model.eval() 时,不会关闭 Dropout。

参考资料:
http://www.manongjc.com/detail/11-nnmwgaxcsvsaxiy.html
https://www.zhihu.com/question/66782101

PyTorch中nn.xx与nn.functional.xx的区别相关推荐

  1. Pytorch中的model.modules()和model.children()的区别

    Pytorch中的model.modules()和model.children()的区别 背景:最近在做网络模型中可视化的过程中,需要将网络结构中的某一层的特征进行输出.所以就遇到了这个问题,小小记录 ...

  2. Pytorch中rand,randn, random以及normal的区别

    Pytorch中rand,randn, random以及normal的区别 torch.rand() torch.randn() torch.normal() torch.randperm() tor ...

  3. PyTorch中的masked_select、masked_fill_()、 masked_fill()的区别

    mask_fill的整体意思是使用 value 填充mask值为True的位置,所以需要注意mask的值,有时候需要~取反 masked_fill_(mask, value) - 函数名后面加下划线. ...

  4. pytorch 中pad函数toch.nn.functional.pad()的使用

    padding操作是给图像外围加像素点. 为了实际说明操作过程,这里我们使用一张实际的图片来做一下处理. 这张图片是大小是(256,256),使用pad来给它加上一个黑色的边框.具体代码如下: imp ...

  5. Pytorch中的 torch.as_tensor() 和 torch.from_numpy() 的区别

    之前我写过一篇文章,比较了 torch.Tensor() 和 torch.tensor() 的区别,而这两者都是深拷贝的方法,返回张量的同时,会在内存中创建一个额外的数据副本,与原数据不共享内存,所以 ...

  6. Pytorch中的 torch.Tensor() 和 torch.tensor() 的区别

    直接在搜索引擎里进行搜索,可以看到官方文档中两者对应的页面: 分别点击进去,第一个链接解释了什么是 torch.Tensor: torch.Tensor 是一个包含单一数据类型元素的多维矩阵(数组). ...

  7. python语言中ch用法_pytorch 中pad函数toch.nn.functional.pad()的用法

    padding操作是给图像外围加像素点. 为了实际说明操作过程,这里我们使用一张实际的图片来做一下处理. 这张图片是大小是(256,256),使用pad来给它加上一个黑色的边框.具体代码如下: imp ...

  8. pytorch学习笔记七:nn网络层——池化层、线性层

    一.池化层 池化运算:对信号进行"收集" 并"总结",类似于水池收集水资源,因而得名池化层. 收集:由多变少,图像的尺寸由大变小 总结:最大值/平均值 下面是最 ...

  9. pytorch中交叉熵

    关于pytorch中交叉熵的使用,pytorch的交叉熵是其loss function的一种且包含了softmax的过程. pytorch中交叉熵函数是nn.CrossEntropyLoss().其参 ...

  10. PyTorch中的sampled_softmax_loss

    最近做一篇论文的复现,发现PyTorch中没有类似于tf.nn.sampled_softmax_loss的功能,经过一番寻找,在github上找到了答案:Stonesjtu/Pytorch-NCE.奇 ...

最新文章

  1. Java架构技术文档:并发编程+设计模式+常用框架+JVM+精选视频
  2. Rust编写的新终端多路复用器
  3. [No000077]打造自己的Eclipse
  4. 走进异步编程的世界 - 开始接触 async/await
  5. Win7安装OnlyOffice(不使用Docker)
  6. 【渝粤教育】 广东开放大学 10548_金融学k1_21秋考试
  7. 零空间,Markov‘s inequality, Chebyshev Chernoff Bound, Union Bound
  8. 从sqlite 迁移 mysql_将 Ghost 从 SQLite3 数据库迁移到 MySQL 数据库
  9. poj 1273 最大流
  10. linux下载镜像的命令wget,Linux wget命令整站下载做网站镜像
  11. 转:多线程--六种多线程方法解决UI线程阻塞
  12. mysql+asp.net开发注意大全:mysql创建数据库的时候,创建新用户,并且付给权限。mysql存储过程的编写,mysql数据库引擎的区别,mysql数据库文件夹备份...
  13. 公开课视频-《第01章 规划》-大企业云桌面部署实战-在线培训-视频(奉献)
  14. 数字图像处理实验九维纳滤波
  15. Python 针对Excel操作
  16. Hark的数据结构与算法练习之堆排序
  17. 第二十章:异步和文件I/O.(十九)
  18. Ubuntu进不去图形化界面的解决方案
  19. openmp配置指南_/openmp (启用 OpenMP 支持)
  20. Java项目:ssm医院管理系统

热门文章

  1. 四旋翼无人机飞行器基本知识(四旋翼无人机结构和原理+四轴飞行diy全套入门教程)
  2. C语言实现求斐波那契数列中的第n项
  3. matlab中sr锁存器,大家一致避免使用的锁存器为什么依然存在于FPGA中?我们对锁存器有什么误解?...
  4. 微信小程序获取二维码:报错47001 data format error
  5. 饼图-图例标记及文字的设置
  6. c语言常用函数doc下载,c语言常用函数.doc
  7. 无线通信与编码_新型OFDM波形集_使用MATLAB仿真实现UFMC并与OFDM作对比_含实现代码
  8. AD PCBlayout 总结
  9. 11、ARM嵌入式系统:中断使能
  10. C语言入门教程(一)