pytorch中网络loss传播和参数更新理解
相比于2018年,在ICLR2019提交论文中,提及不同框架的论文数量发生了极大变化,网友发现,提及tensorflow的论文数量从2018年的228篇略微提升到了266篇,keras从42提升到56,但是pytorch的数量从87篇提升到了252篇。
TensorFlow: 228--->266
Keras: 42--->56
Pytorch: 87--->252
在使用pytorch中,自己有一些思考,如下:
1. loss计算和反向传播
import torch.nn as nn
criterion = nn.MSELoss().cuda()
output = model(input)
loss = criterion(output, target)
loss.backward()
通过定义损失函数:criterion,然后通过计算网络真实输出和真实标签之间的误差,得到网络的损失值:loss;
最后通过loss.backward()完成误差的反向传播,通过pytorch的内在机制完成自动求导得到每个参数的梯度。
需要注意,在机器学习或者深度学习中,我们需要通过修改参数使得损失函数最小化或最大化,一般是通过梯度进行网络模型的参数更新,通过loss的计算和误差反向传播,我们得到网络中,每个参数的梯度值,后面我们再通过优化算法进行网络参数优化更新。
2. 网络参数更新
在更新网络参数时,我们需要选择一种调整模型参数更新的策略,即优化算法。
优化算法中,简单的有一阶优化算法:
其中就是通常说的学习率,是函数的梯度;
自己的理解是,对于复杂的优化算法,基本原理也是这样的,不过计算更加复杂。
在pytorch中,torch.optim是一个实现各种优化算法的包,可以直接通过这个包进行调用。
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
注意:1)在前面部分1中,已经通过loss的反向传播得到了每个参数的梯度,然后再本部分通过定义优化器(优化算法),确定了网络更新的方式,在上述代码中,我们将模型的需要更新的参数传入优化器。
2)注意优化器,即optimizer中,传入的模型更新的参数,对于网络中有多个模型的网络,我们可以选择需要更新的网络参数进行输入即可,上述代码,只会更新model中的模型参数。对于需要更新多个模型的参数的情况,可以参考以下代码:
optimizer = torch.optim.Adam([{'params': model.parameters()}, {'params': gru.parameters()}], lr=0.01)
3) 在优化前需要先将梯度归零,即optimizer.zeros()。
3. loss计算和参数更新
import torch.nn as nn
import torch
criterion = nn.MSELoss().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
output = model(input)
loss = criterion(output, target)
optimizer.zero_grad() # 将所有参数的梯度都置零
loss.backward() # 误差反向传播计算参数梯度
optimizer.step() # 通过梯度做一步参数更新
————————————————
版权声明:本文为CSDN博主「少年木」的原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/yangzhengzheng95/article/details/85268896
pytorch中网络loss传播和参数更新理解相关推荐
- pytorch 中维度(Dimension)概念的理解
pytorch 中维度(Dimension)概念的理解 Dimension为0(即维度为0时) 维度为0时,即tensor(张量)为标量.例如:神经网络中损失函数的值即为标量. 接下来我们创建一个di ...
- Pytorch中dilation(Conv2d)参数详解
目录 一.Conv2d 二.Conv2d中的dilation参数 一.Conv2d 首先我们看一下Pytorch中的Conv2d的对应函数(Tensor通道排列顺序是:[batch, channel, ...
- 对于pytorch中nn.CrossEntropyLoss()与nn.BCELoss()的理解和使用
在pytorch中nn.CrossEntropyLoss()为交叉熵损失函数,用于解决多分类问题,也可用于解决二分类问题. BCELoss是Binary CrossEntropyLoss的缩写,nn. ...
- pytorch 中网络参数 weight bias 初始化方法
权重初始化对于训练神经网络至关重要,好的初始化权重可以有效的避免梯度消失等问题的发生. 在pytorch的使用过程中有几种权重初始化的方法供大家参考. 注意:第一种方法不推荐.尽量使用后两种方法. # ...
- Pytorch中KL loss
1. 概念 KL散度可以用来衡量两个概率分布之间的相似性,两个概率分布越相近,KL散度越小. 上述公式表示P为真实事件的概率分布,Q为理论拟合出来的该事件的概率分布.D(P||Q)(P拟合Q)和D(Q ...
- pytorch中的nn.LSTM模块参数详解
直接去官网查看相关信息挺好的,但是为什么有的时候进不去 官网:https://pytorch.org/docs/stable/nn.html#torch.nn.LSTM 使用示例,在使用中解释参数 单 ...
- 图解 Pytorch 中 nn.Conv2d 的 groups 参数
文章目录 普通卷积复习 Groups是如何改变卷积方式的 实验验证 参考资料 普通卷积复习 首先我们先来简单复习一下普通的卷积行为. 从上图可以看到,输入特征图为3,经过4个filter卷积后生成了4 ...
- PyTorch中网络里面的inplace=True字段的意思
在例如nn.LeakyReLU(inplace=True)中的inplace字段是什么意思呢?有什么用? inplace=True的意思是进行原地操作,例如x=x+5,对x就是一个原地操作,y=x+5 ...
- pytorch中tensor.mean(axis, keepdim)参数理解小实验
虽然没试过其他形式的多维数据,不过想来应该是一样的吧 -- 1.结论 keepdim=True 运算完之后的维度和原来一样,原来是三维数组现在还是三维数组(不过某一维度变成了1): keepdim=F ...
最新文章
- mysql分库一个库和多个库_数据库分库后不同库之间的关联
- [云炬创业基础笔记]第七章创业资源测试3
- PHP的JSON封装
- 为vim编辑器增加行号功能
- Dapr牵手.NET学习笔记:用docker-compose部署服务
- 你需要知道的Linux 系统下外设时钟管理
- php mysql 开发微博_php+mysql基于Android的手机微博应用开发
- android+wear+游戏,技术帝:Android Wear手表运行一代PS游戏
- [软件工程]在线教程
- 一个简单好用的日志框架NLog
- Golang 变量申明方式
- DEDE的安装 和 DEDE文件和目录详解与安全问题
- Oracle odi 数据表导出到文件
- Supervisor 自动管理进程
- react 如何引入打印控件 CLodop
- ADS1110/ADS1271
- 绕过apple id的那些事
- 设计模式之工厂模式(C++)
- 由“三姬分金”到“海盗分金”
- android studio 设备调试及Logcat查看
热门文章
- matlab中-psi_matlab输出论文仿真图
- 详细设计说明书读后感_专利申请详细步骤是怎样的,要多久时间
- 【c语言】蓝桥杯算法提高 c++_ch02_03
- 汇博工业机器人码垛机怎么写_一文带您理解码垛机器人,原来它这么简单!
- 愤怒的小鸟素材包_点映预售开启|愤怒的小鸟2搞笑升级,萌贱无敌!
- 下午花一小时整理的JVM运行时方法区
- MyClouds-V1.0 发布,微服务治理及快速开发平台
- 样条表示---OpenGL的逼近样条函数
- 电信应在短时间内放弃CDMA网络
- OpenStack 系列之File Share Service(Manila)详解