前言

我们在使用Pytorch的时候,模型训练时,不需要调用forward这个函数,只需要在实例化一个对象中传入对应的参数就可以自动调用 forward 函数。

class Module(nn.Module):def __init__(self):super().__init__()# ......def forward(self, x):# ......return xdata = ......  # 输入数据# 实例化一个对象
model = Module()# 前向传播
model(data)# 而不是使用下面的
# model.forward(data)

但是实际上model(data)是等价于model.forward(data),这是为什么呢???下面我们来分析一下原因。

forward函数

model(data)之所以等价于model.forward(data),就是因为在类(class)中使用了__call__函数,对__call__函数不懂得可以点击链接:可调用:__call__函数

class Student:def __call__(self):print('I can be called like a function')a = Student()
a()

输出结果:

I can be called like a function

由上面的__call__函数可知,我们可以将forward函数放到__call__函数中进行调用:

class Student:def __call__(self, param):print('I can called like a function')print('传入参数的类型是:{}   值为: {}'.format(type(param), param))res = self.forward(param)return resdef forward(self, input_):print('forward 函数被调用了')print('in  forward, 传入参数类型是:{}  值为: {}'.format(type(input_), input_))return input_a = Student()input_param = a('data')
print("对象a传入的参数是:", input_param)

输出结果:

I can called like a function
传入参数的类型是:<class 'str'>   值为: data
forward 函数被调用了
in  forward, 传入参数类型是:<class 'str'>  值为: data
对象a传入的参数是: data

到这里我们就可以明白了为什么model(data)等价于model.forward(data),是因为__call__函数中调用了forward函数。

Pytorch 中的 forward理解相关推荐

  1. pytorch中repeat()函数理解

    pytorch中repeat()函数理解 最近在学习过程中遇到了repeat()函数的使用,这里记录一下自己对这个函数的理解. 情况1:repeat参数个数与tensor维数一致时 a = torch ...

  2. pytorch 中 contiguous() 函数理解

    pytorch 中 contiguous() 函数理解 文章目录 pytorch 中 contiguous() 函数理解 引言 使用 contiguous() 后记 文章抄自 Pytorch中cont ...

  3. Pytorch中的contiguous理解

    最近遇到这个函数,但查的中文博客里的解释貌似不是很到位,这里翻译一下stackoverflow上的回答并加上自己的理解. 在pytorch中,只有很少几个操作是不改变tensor的内容本身,而只是重新 ...

  4. Pytorch中contiguous()函数理解

    引言 在pytorch中,只有很少几个操作是不改变tensor的内容本身,而只是重新定义下标与元素的对应关系的.换句话说,这种操作不进行数据拷贝和数据的改变,变的是元数据. 会改变元数据的操作是: n ...

  5. pytorch中的forward函数详细理解

    文章目录 前言 forward 的使用 forward 使用的解释 前言 最近在使用pytorch的时候,模型训练时,不需要使用forward,只要在实例化一个对象中传入对应的参数就可以自动调用 fo ...

  6. Pytorch中dim的理解

    dim的定义 dim 表示维度 x = torch.randn(2, 3, 3)print(x) print(x.size()) print(x.dim()) 输出: tensor([[[-1.694 ...

  7. pytorch中unsqueeze()函数理解

    unsqueeze()函数起升维的作用,参数表示在哪个地方加一个维度. 在第一个维度(中括号)的每个元素加中括号 0表示在张量最外层加一个中括号变成第一维. 直接看例子: import torch i ...

  8. pytorch中bilinear的理解

    直接参考官方文档,x1和x2是两个输入,A是参数矩阵,如下表达式 但仔细看实现发现这个表达式并不是简单连乘的关系.假设x1(shape是b,n)和x2(shape是b,m)是二维,那么A是个三维ten ...

  9. pytorch中数组维度的理解

    pytorch中数组维度理解与numpy中类似,pytorch中维度用dim表示,numpy中用axis表示 这里主要想说下维度的变化. dim = x ,表示在第x为上进行操作,那个维度会发生变化. ...

  10. pytorch中网络loss传播和参数更新理解

    相比于2018年,在ICLR2019提交论文中,提及不同框架的论文数量发生了极大变化,网友发现,提及tensorflow的论文数量从2018年的228篇略微提升到了266篇,keras从42提升到56 ...

最新文章

  1. playframework学习笔记1 -- 开发环境和第一个工程
  2. Xamarin Essentials教程获取路径文件系统FileSystem
  3. 简述html 布局的原理,css布局原理与实现-2019年9月4日20时
  4. Android 插件化总结
  5. 附加到SQL2012的数据库就不能再附加到低于SQL2012的数据库版本
  6. mysql最大述_mysql最大字段数量及 varchar类型总结
  7. 华为rh5885服务器oid_华为RH5885H v3机架服务器RAID配置实例
  8. storm hook的使用
  9. python 字体_python docx字体设置
  10. 出大问题!webpack 多入口html模板在后端
  11. 【kafka】kafka 执行 多个脚本 kafka-run-class.sh 导致 server 节点 时不时挂掉
  12. maven项目jsp无法识别jstl的解决办法
  13. Excel做文件归档
  14. 内存带宽测试程序——stream2-C语言版
  15. 使用Rsync+cwRsync实现数据异机备份+异地备份
  16. Codeforces Round #700 (Div. 2)-B. The Great Hero-题解-一行实现向上取整
  17. 移动互联网创业机会只剩3年
  18. 数学建模——熵权法步骤及程序详解
  19. 判断图有无环_数读湾区经济潜能:基于大数据分析的环杭州湾大湾区“一体化”发展潜能!...
  20. 李佳琦转行成直播一哥,他做对了哪些事?

热门文章

  1. 听王自如聊蜕变历程:云计算时代如何输出价值
  2. python简单语法题_Python练习+简单语法摘要,习题,总结
  3. HTML5实现音频和视频嵌入,如何利用HTML5实现音频和视频嵌入的方法
  4. 合并两张图片php,php多张图片合并方法分享
  5. 【计算机网络】-- 第一章--概述(概念、组成、功能、分类、性能指标、体系结构)
  6. ***【九度oj-1343】城际公路网
  7. PowerVR SDK记录
  8. CoffeeScript
  9. 【软件构造】黑盒测试与白盒测试
  10. 基于centos7和windows 搭建局域网wiki.js知识管理库的两种解决方案