文章目录

  • 引言
  • 一、模型的创建
    • 1.nn.Module

引言

  这一节,我们开始讲解模型模块。

一、模型的创建

  模型的构建有两个要素:

下面我们以LeNet模型为例,展示其模型创建过程

class LeNet(nn.Module):# 初始化构建子模块def __init__(self, classes):super(LeNet, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16*5*5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, classes)# 拼接子模块def forward(self, x):out = F.relu(self.conv1(x))out = F.max_pool2d(out, 2)out = F.relu(self.conv2(out))out = F.max_pool2d(out, 2)out = out.view(out.size(0), -1)out = F.relu(self.fc1(out))out = F.relu(self.fc2(out))out = self.fc3(out)return out# 权值的初始化def initialize_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.xavier_normal_(m.weight.data)if m.bias is not None:m.bias.data.zero_()elif isinstance(m, nn.BatchNorm2d):m.weight.data.fill_(1)m.bias.data.zero_()elif isinstance(m, nn.Linear):nn.init.normal_(m.weight.data, 0, 0.1)m.bias.data.zero_()

但是我们什么时候实现模型的拼接与前向传播呢?
LeNet模型继承于Module,Module类中有__call__函数,__call__函数表明这一实例是可以像函数一样被调用的,__call__函数中会调用上面定义好的forword前向传播函数。

    def __call__(self, *input, **kwargs):for hook in self._forward_pre_hooks.values():result = hook(self, input)if result is not None:if not isinstance(result, tuple):result = (result,)input = resultif torch._C._get_tracing_state():result = self._slow_forward(*input, **kwargs)else:# 前向传播result = self.forward(*input, **kwargs)for hook in self._forward_hooks.values():hook_result = hook(self, input, result)if hook_result is not None:result = hook_resultif len(self._backward_hooks) > 0:var = resultwhile not isinstance(var, torch.Tensor):if isinstance(var, dict):var = next((v for v in var.values() if isinstance(v, torch.Tensor)))else:var = var[0]grad_fn = var.grad_fnif grad_fn is not None:for hook in self._backward_hooks.values():wrapper = functools.partial(hook, self)functools.update_wrapper(wrapper, hook)grad_fn.register_hook(wrapper)return result

1.nn.Module

  在模型模块有一个非常重要的概念是nn.Module,所有的模型、所有的网络层都是继承于nn.Module类的。下面我们先介绍一下torch.nn

这一节我们的重点是nn.Modulenn.Module的属性如下:

  • parameters:存储管理nn.Parameter类
    比如:权值、偏置等这些参数
  • modules :存储管理nn.Module类
    比如:在LeNet模型中的卷积层、池化层
  • buffers:存储管理缓冲属性,如BN层中的running_mean
  • ***_hooks:存储管理钩子函数
    self._parameters = OrderedDict()
    self._buffers = OrderedDict()
    self._backward_hooks = OrderedDict()
    self._forward_hooks = OrderedDict()
    self._forward_pre_hooks = OrderedDict()
    self._state_dict_hooks = OrderedDict()
    self._load_state_dict_pre_hooks = OrderedDict()
    self._modules = OrderedDict()
    

注:在Module模块中有一个机制:拦截所有类属性赋值语句,会跳转到Module中的__setattr__函数

    def __setattr__(self, name, value):def remove_from(*dicts):for d in dicts:if name in d:del d[name]params = self.__dict__.get('_parameters')if isinstance(value, Parameter):if params is None:raise AttributeError("cannot assign parameters before Module.__init__() call")remove_from(self.__dict__, self._buffers, self._modules)self.register_parameter(name, value)elif params is not None and name in params:if value is not None:raise TypeError("cannot assign '{}' as parameter '{}' ""(torch.nn.Parameter or None expected)".format(torch.typename(value), name))self.register_parameter(name, value)else:modules = self.__dict__.get('_modules')if isinstance(value, Module):if modules is None:raise AttributeError("cannot assign module before Module.__init__() call")remove_from(self.__dict__, self._parameters, self._buffers)modules[name] = valueelif modules is not None and name in modules:if value is not None:raise TypeError("cannot assign '{}' as child module '{}' ""(torch.nn.Module or None expected)".format(torch.typename(value), name))modules[name] = valueelse:buffers = self.__dict__.get('_buffers')if buffers is not None and name in buffers:if value is not None and not isinstance(value, torch.Tensor):raise TypeError("cannot assign '{}' as buffer '{}' ""(torch.Tensor or None expected)".format(torch.typename(value), name))buffers[name] = valueelse:object.__setattr__(self, name, value)

这个函数的主要作用是:对value的数据类型进行判断,

  • 判断是否为Parameters属性,如果是的话,就存储到register_parameter字典中
  • 判断是否为Module属性,如果是的话,就存储到modules字典中

nn.Module总结:

  • 一个module可以包含多个子module
  • 一个module相当于一个运算,必须实现forward()函数
  • 每个module都有8个字典管理它的属性

如果对您有帮助,麻烦点赞关注,这真的对我很重要!!!如果需要互关,请评论或者私信!


PyTorch学习—8.模型创建步骤与nn.Module属性相关推荐

  1. PyTorch 入坑七:模块与nn.Module学习

    PyTorch 入坑七 模型创建概述 PyTorch中的模块 torch模块 torch.Tensor模块 torch.sparse模块 torch.cuda模块 torch.nn模块 torch.n ...

  2. Pytorch学习 - 保存模型和重新加载

    Pytorch学习 - 保存和加载模型 1. 3个函数 2. 模型不同后缀名的区别 3. 保存和重载模型 参考资料: Pytorch官方文档链接 某博客 1. 3个函数 torch.save() : ...

  3. pytorch学习--UNet模型

    详细Unet网络结构可以查看Unet算法原理详解 深度网络训练之中需要大量的有标样本,Unet作者提供了一种新的训练方法,可以更有效的运用相应的有标样本,使网络即使通过少量的训练图片也可以进行更精确的 ...

  4. pytorch中的神经网络模块基础类——torch.nn.Module

    1.torch.nn.Module概要 pytorch官网对torch.nn.Module的描述如下. torch.nn.Module是所有的神经网络模块的基类,且所有的神经网络模块都可以包含其他的子 ...

  5. 『PyTorch』第十五弹_torch.nn.Module的属性设置查询

    一.背景知识 python中两个属相相关方法 result = obj.name 会调用builtin函数getattr(obj,'name')查找对应属性,如果没有name属性则调用obj.__ge ...

  6. Pytorch训练一个模型的步骤总结

    大概停在这部分很久了,总结并提醒自己一下! 目前遇到的步骤大概如下: 1.指定设备 [2.设置随机种子] 3.创建数据集(数据导入,预处理和打包) 4.创建模型 5.创建优化器 [6.学习率调整策略] ...

  7. PyTorch学习笔记(1)nn.Sequential、nn.Conv2d、nn.BatchNorm2d、nn.ReLU和nn.MaxPool2d

    文章目录 一.nn.Sequential 二.nn.Conv2d 三.nn.BatchNorm2d 四.nn.ReLU 五.nn.MaxPool2d 一.nn.Sequential torch.nn. ...

  8. PyTorch nn.Module 一些疑问

    在阅读书籍时,遇到了一些不太理解,或者介绍的不太详细的点. 从代码角度学习理解Pytorch学习框架03: 神经网络模块nn.Module的了解. Pytorch 03: nn.Module模块了解 ...

  9. python与机器学习(七)上——PyTorch搭建LeNet模型进行MNIST分类

    任务要求:利用PyTorch框架搭建一个LeNet模型,并针对MNIST数据集进行训练和测试. 数据集:MNIST 导入: import torch from torch import nn, opt ...

  10. pytorch ——模型创建与nn.Module

    1.网络模型创建步骤 模型模块中分为两个部分,模型创建和权值初始化: 模型创建又分为两部分,构建网络层和拼接网络层:网络层有卷积层,池化层,激活函数等:构建网络层后,需要进行网络层的拼接,拼接成LeN ...

最新文章

  1. 看2021年2月苏州各区新房均价,谈一点个人的思考
  2. Eclipse无法识别(手机)设备的解决方案
  3. Golang 计算MD5值示例
  4. android元素离边框间距,RecyclerView Item 的分割线 距边框距离问题总结
  5. SharePoint 2013 Nintex Workflow 工作流帮助(六)
  6. 科学计数怎么转换成数字_勒夫迈 | 激光尘埃粒子计数器传感器工作原理
  7. nyoj 买水果(组合数求法与分析)
  8. 接入阿里云云呼叫中心
  9. 一元三次方程c语言程序,一元三次方程求解
  10. U盘插入后只显示安全删除硬件问题
  11. typora定制主题分享--绿豆沙背景主题+新night背景主题
  12. 车性能测试软件是什么,3DMark制造商推首款汽车性能测试软件
  13. 有一个已经排好序的数组,要求输入一个数后,按原来排序的规律将它插入数组中
  14. Unity3d C#通过使用大华SDK控制大华摄像头旋转、变焦等云台操作和预置点等控制操作(含源码)
  15. 如何查SCI、EI、SSCI检索?怎么开检索报告?
  16. 快递驿站取件管理系统|基于SpringBoot的快递栈系统设计与实现
  17. 商旅服务平台的会员制规划与运营
  18. 啊屋童装商城android,我们采访了100位漂亮妈妈 她们手机里居然都有一款叫啊屋童装商城的app...
  19. FCOS网络总体流程
  20. python代码turtle是什么意思_python中turtle库中的Turtle()是什么,有什么用?

热门文章

  1. iOS Nib文件一览
  2. c:翻转一个长句中的每个单词
  3. property 、classmethod 、 staticmethod 的用法
  4. delphi控件属性大全-详解-简介
  5. 题目1439:Least Common Multiple
  6. 【CLR Via C#笔记】 值类型与拆装箱、参数传递
  7. LaTeX (1)——LaTex环境的下载与安装(Tex live 2020+ Tex studio编辑器、 proTeXt(MiKTeX+TeXstudio编辑器))
  8. 如何快速上手使用STM32库函数
  9. django之多表查询-2
  10. 导出数据报ORA-39002: 操作无效 ORA-39070: 无法打开日志文件。 ORA-39087: 目录名 DUMP_DIR 无效...