Pytorch源码与运行原理浅析--网络篇(一)
前言
申请的专栏开通了,刚好最近闲下来了,就打算开这个坑了hhhhh
第一篇就先讲一讲pytorch的运行机制好了。。。
记得当时刚刚接触的时候一直搞不明白,为什么自己只是定义了几个网络,就可以完整的训练整个模型,它背后的机制又是如何,搞明白了这个,才有可能去做更多的定制的更改,比如更改loss,反传方式,梯度下降机制,甚至自定义参数更新速率(比如学习率随着迭代轮数下降),文章比较浅显,希望各位大神不吝赐教。
知识储备
看此文章的前提,大概需要你写过一个利用pytorch的训练程序,哪怕官网上的MNIST。
因为本文目的是告诉你为什么这么写
为什么不用TensorFlow
其实我之前是有用TF的,但是,emmmmmmmm.......
之后接触了Pytorch,那一整天都在感叹"还有这种操作?"
个人感觉TF不是一个易于理解和易于扩展的框架。
比如说,我想实现学习率随迭代轮数降低,需要修改哪些?
那么,让我们开始吧
从MNIST说起
网络定义篇
import torch.nn as nn
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 10, kernel_size=5)self.conv2 = nn.Conv2d(10, 20, kernel_size=5)self.conv2_drop = nn.Dropout2d()self.fc1 = nn.Linear(320, 50)self.fc2 = nn.Linear(50, 10)def forward(self, x):x = F.relu(F.max_pool2d(self.conv1(x), 2))x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))x = x.view(-1, 320)x = F.relu(self.fc1(x))x = F.dropout(x, training=self.training)x = self.fc2(x)return F.log_softmax(x)
这一段是MNIST给的定义Net的代码,那么,让我们看一看,这一段代码说明了什么,首先,__init__方法直接定义了你的网络,这就是你的模型中含有的全部的东西,你的模型本身也只有__init__ 中的属性会被每一次训练的时候更改,可以说这个思路是十分的清晰了。
之后,是forward方法,这里定义了如何处理传入的数据(就是那个x),返回这个神经网络的output
这里,我把它比作名词和动词的关系,__init__()方法定义了网络本身,或者说定义了一个个的名词,而我们也需要一系列的"猜测"过程,猜出这些名词是什么。而forward()方法,则是一个个的动词,它提供了如何处理这些名词的方式。
而之后,我们来看看,运行的时候,发生了什么
首先,我们看看torch.nn.Module,看看它是如何定义的。
torch.nn.Module
源代码在此处
class Module(object):dump_patches = Falsedef __init__(self):self._backend = thnn_backendself._parameters = OrderedDict()self._buffers = OrderedDict()self._backward_hooks = OrderedDict()self._forward_hooks = OrderedDict()self._forward_pre_hooks = OrderedDict()self._modules = OrderedDict()self.training = Truedef forward(self, *input):raise NotImplementedError
(代码不完整,只截取了一段)
可以看到,Module类定义了一系列训练时使用的变量比如参数(感觉这是是缓存的参数,用来之后做参数更新用的),buffers,几个hooks(个人感觉这些hooks是之后与loss,反传之类的步骤通讯数据用的)
反传里面是有一个判断的逻辑,判断你的子类有没有定义网络,没有就报错(讲真,这个想法很棒啊QwQ,子类重写父类方法,没有重写就是个报错hhhhhh)
def register_buffer(self, name, tensor):self._buffers[name] = tensordef register_parameter(self, name, param):if '_parameters' not in self.__dict__:raise AttributeError("cannot assign parameter before Module.__init__() call")if param is None:self._parameters[name] = Noneelif not isinstance(param, Parameter):raise TypeError("cannot assign '{}' object to parameter '{}' ""(torch.nn.Parameter or None required)".format(torch.typename(param), name))elif param.grad_fn:raise ValueError("Cannot assign non-leaf Variable to parameter '{0}'. Model ""parameters must be created explicitly. To express '{0}' ""as a function of another variable, compute the value in ""the forward() method.".format(name))else:self._parameters[name] = param
buffer和parameter的注册,这里有一点需要提醒,在你自定义的网络中,如果你用了类似
self.some_dict['keys'] = nn.Conv2d(10, 20, kernel_size=5)
这种语句的话,pytorch是没有办法这个变量的,也不会参与之后的传参之类的
在定义了上面那句话之后你必须用类似
# method 1
setattr(self, 'some_name', self.some_dict['keys'])
# method 2
self.register_parameter('some_name', self.some_dict['keys'])
比如笔者自己的代码
self.LocalConv1 = {i + 1: nn.Conv2d(32, 32, 3, stride=1, padding=0) for i in range(4)}for i in self.LocalConv1:setattr(self, 'LocalConvPart%d' % i, self.LocalConv1[i])self.GlobalFullConnect = nn.Linear(7 * 2 * 32, 400)self.LocalFullConnect = {i + 1: nn.Linear(32 * 23 * 16, 100) for i in range(4)}for i in self.LocalFullConnect:setattr(self, 'LocalFullConnectPart%d' % i, self.LocalFullConnect[i])
建议使用方法1,因为Module类重载了__setattr__()方法,如下
def __setattr__(self, name, value):def remove_from(*dicts):for d in dicts:if name in d:del d[name]params = self.__dict__.get('_parameters')if isinstance(value, Parameter):if params is None:raise AttributeError("cannot assign parameters before Module.__init__() call")remove_from(self.__dict__, self._buffers, self._modules)self.register_parameter(name, value)elif params is not None and name in params:if value is not None:raise TypeError("cannot assign '{}' as parameter '{}' (torch.nn.Parameter or None expected)".format(torch.typename(value), name))self.register_parameter(name, value)else:modules = self.__dict__.get('_modules')if isinstance(value, Module):if modules is None:raise AttributeError("cannot assign module before Module.__init__() call")remove_from(self.__dict__, self._parameters, self._buffers)modules[name] = valueelif modules is not None and name in modules:if value is not None:raise TypeError("cannot assign '{}' as child module '{}' ""(torch.nn.Module or None expected)".format(torch.typename(value), name))modules[name] = valueelse:buffers = self.__dict__.get('_buffers')if buffers is not None and name in buffers:if value is not None and not torch
Pytorch源码与运行原理浅析--网络篇(一)相关推荐
- 【源码阅读计划】浅析 Java 线程池工作原理及核心源码
[源码阅读计划]浅析 Java 线程池工作原理及核心源码 为什么要用线程池? 线程池的设计 线程池如何维护自身状态? 线程池如何管理任务? execute函数执行过程(分配) getTask 函数(获 ...
- PyTorch源码浅析(1):THTensor
PyTorch源码浅析(1):THTensor PyTorch中Tensor的存储和表示分开,多个THTensor可能共享一个THStorage,每个THTensor可能拥有不同的view(e.g. ...
- pytorch 测试每一类_DeepFM全方面解析(附pytorch源码)
写在前面 最近看了DeepFM这个模型.把我学习的思路和总结放上来给大家和未来的自己做个参考和借鉴.文章主要希望能串起学习DeepFM的各个环节,梳理整个学习思路.以"我"的角度浅 ...
- ELMo解读(论文 + PyTorch源码)
ELMo的概念也是很早就出了,应该是18年初的事情了.但我仍然是后知后觉,居然还是等BERT出来很久之后,才知道有这么个东西.这两天才仔细看了下论文和源码,在这里做一些记录,如果有不详实的地方,欢迎指 ...
- Transformer-XL解读(论文 + PyTorch源码)
前言 目前在NLP领域中,处理语言建模问题有两种最先进的架构:RNN和Transformer.RNN按照序列顺序逐个学习输入的单词或字符之间的关系,而Transformer则接收一整段序列,然后使用s ...
- PyTorch源码解读之torchvision.models
PyTorch框架中有一个非常重要且好用的包:torchvision,该包主要由3个子包组成,分别是:torchvision.datasets.torchvision.models.torchvisi ...
- Tomcat7.0源码分析——请求原理分析(上)
前言 谈起Tomcat的诞生,最早可以追溯到1995年.近20年来,Tomcat始终是使用最广泛的Web服务器,由于其使用Java语言开发,所以广为Java程序员所熟悉.很多早期的J2EE项目,由程序 ...
- 【赠书福利】掘金爆火小册同名《Spring Boot源码解读与原理剖析》正式出书了!...
关注我们丨文末赠书 承载着作者的厚望,掘金爆火小册同名读物<Spring Boot源码解读与原理剖析>正式出书! 本书前身是掘金社区销量TOP的小册--<Spring Boot源码解 ...
- PyTorch源码学习系列 - 1.初识
本系列文章会优先发布于微信公众号和知乎,欢迎大家关注 微信公众号:小飞怪兽屋 知乎: PyTorch源码学习系列 - 1.初识 - 知乎 (zhihu.com) 目录 本系列的目的 PyTorch是什 ...
最新文章
- ios 圆形旋转菜单_iOS 圆环菜单
- Spring Cloud实战小贴士:Zuul的饥饿加载(eager-load)使用
- leetcode算法题--除数博弈★
- linux如何使用物理内存_10 张图解再谈 Linux 物理内存和虚拟内存
- 关键词提取算法—TF/IDF算法
- 基于JAVA+SpringMVC+Mybatis+MYSQL的网上商城
- C#将数据库图片显示在pictureBox
- Qt一个进程运行另一个进程
- cornerstone 使用
- 如何将webp格式转换成jpg?
- 利用小程序快速生成App,只需七步
- uni-app自动定位当前位置
- [洛谷P4118][Ynoi2016]炸脖龙I([洛谷P3934]Nephren Ruq Insania)
- uniapp swiper内嵌video组件的坑
- Delphi 仿QQ皮肤控件设计与运行效果图
- 困扰所有SAP顾问多年的问题终于解决了
- 计算机中常用t来表示,2012年计算机等级考试一级B考点详解(4)
- 图片太大怎么压缩变小,如何压缩图片?
- 三国志10在win7下的安装
- 数据库对日期进行比较
热门文章
- 2022-2028年中国钢材市场投资分析及前景预测报告(全卷)
- starrocks问题小结
- Linux下环境变量配置方法梳理(.bash_profile和.bashrc的区别)
- 从底层吃透java内存模型(JMM)、volatile、CAS
- 卷积神经网络(CNN,ConvNet)
- 图像零交叉点,视频生成,视频识别,视频摘要,视频浓缩
- arm,asic,dsp,fpga,mcu,soc各自的特点
- 2021年大数据Spark(六):环境搭建集群模式 Standalone
- Android Environment 的作用以及常用的方法
- 微信小程序picker 轮滑1-100的实现