文章目录

  • 前言
  • 一、问题描述
  • 二、官方文档代码
  • 三、optimizer的工作原理
  • 总结

前言

  本系列主要是对pytorch基础知识学习的一个记录,尽量保持博客的更新进度和自己的学习进度。本人也处于学习阶段,博客中涉及到的知识可能存在某些问题,希望大家批评指正。另外,本博客中的有些内容基于吴恩达老师深度学习课程,我会尽量说明一下,但不敢保证全面。


一、问题描述

  此次需要构建的神经网络其实和前几次相同,为了能更直观的理解问题,绘制了一张精美的神经网络结构图:

  到目前为止,我们已经使用了numpy,tensor,Pytorch自动求导以及Pytorch的nn模块来实现同一个神经网络。
  我们在nn的基础上使用优化算法来对神经网络进行优化,看过吴恩达老师深度学习课程的应该对优化算法有大致了解,一般来说有三种:动量梯度下降法(Momentum)、RMSprop算法和Adam优化算法。
  每个优化算法有对应的数学公式,在这里就不细说了。需要明白的是,这些优化算法主要改变反向传播后的参数更新环节,目的是在于加快神经网络的训练过程。

二、官方文档代码

  Pytorch已经将优化算法封装成optim包,我们要做的是把需要优化的参数以及使用到的学习率传入函数中即可。

import torchN, D_in, H, D_out = 64, 1000, 100, 10x = torch.randn(N, D_in)
y = torch.randn(N, D_out)# 定义神经网络需要计算的层
model = torch.nn.Sequential(torch.nn.Linear(D_in, H),torch.nn.ReLU(),torch.nn.Linear(H, D_out))# 定义神经网络的损失函数
loss_fn = torch.nn.MSELoss(reduction="sum")learning_rate = 1e-4
# 定义使用的优化算法,这里使用的的是Adam优化算法
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)for t in range(500):# 前向传播y_pred = model(x)# 计算损失loss = loss_fn(y_pred, y)print(t, loss.item())# 梯度清零optimizer.zero_grad()# 反向传播loss.backward()# 参数更新optimizer.step()

  上述代码与之前代码的差别在于使用到了优化器,并且参数的更新去梯度的清零都是在优化器的基础上完成的。接下来我会浅析一下optimizer的工作原理,为什么是"浅析",是因为我不太懂其更底层的代码。

三、optimizer的工作原理

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

  我们通过上述代码初始化了一个优化器,该优化器使用的是Adam优化算法,optim包里面还包含了其他优化算法。初始化时我们将我们定义的神经网络中的参数传入优化器中,并传入我们定义的学习率。
  然后在反向传播完成后,调用optim包中的step()方法完成参数更新:

 # 参数更新optimizer.step()

  这里我产生了一个疑问:为什么调用optim包中的函数,会对model对象中的属性进行更新。
  在前面我们知道,model.parameters()会返回一个迭代器,对这个迭代器遍历可以依次得到神经网络中的参数,也就是w1,b1,w2,b2。我们打印这四个值的id号:

pa = model.parameters()
for param in pa:print(id(param))

结果如下:

2268871549080
2268871549160
2268871549240
2268871549320

  我们查看 torch.optim.Adam() 返回值 optimizer 的属性:

  optimizer有列表类型的属性 param_groups ,其长度为1。查看列表中的元素,发现是一个字典类型的数据,该字典类型数据底下有key值为"params"的项,其value的值为一个列表,让我们打印列表中元素的id值:

for param2 in optimizer.param_groups[0]["params"]:print(id(param2))

结果如下:

2017646698648
2017646698728
2017646698808
2017646698888

  可以看出,打印结果与上面打印的model.parameters()中参数id值完全相同,这就解释了为什么调用optim中的方法会对model中的属性产生改变。至于为什么会这样,个人推测是采用了深复制,所以requires_grad属性的值也为True,感兴趣的可以去看看源码。
  关于使用 step() 更新参数的原理,我暂时还未弄明白,但是查阅相关资料后了解到,optim的所有优化函数均有step()方法。

总结

  使用Pytorch的优化器 optim 的大致步骤为:定义一个需要的优化器,并传入需要优化的参数和优化使用到的学习率;在反向传播前利用优化器对参数梯度进行清零;反向传播结束后利用优化器对参数进行更新。可以看出,使用优化器后,对神经网络参数的操作可以直接在优化器上进行。

pytorch基础(四):使用optim优化函数相关推荐

  1. PyTorch学习笔记(四):PyTorch基础实战

    PyTorch实战:以FashionMNIST时装分类为例: 往期学习资料推荐: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 本 ...

  2. pyTorch——基础学习笔记

    pytorch基础学习笔记博文,在整理的时候借鉴的大量的网上资料,存在和一部分图片定义的直接复制黏贴,在本博文的最后将会表明所有的参考链接.由于参考的内容众多,所以博文的更新是一个长久的过程,如果大佬 ...

  3. 深度学习之Pytorch基础教程!

    ↑↑↑关注后"星标"Datawhale 每日干货 & 每月组队学习,不错过 Datawhale干货 作者:李祖贤,Datawhale高校群成员,深圳大学 随着深度学习的发展 ...

  4. 【深度学习】翻译:60分钟入门PyTorch(四)——训练一个分类器

    前言 原文翻译自:Deep Learning with PyTorch: A 60 Minute Blitz 翻译:林不清(https://www.zhihu.com/people/lu-guo-92 ...

  5. 【深度学习】深度学习之Pytorch基础教程!

    作者:李祖贤,Datawhale高校群成员,深圳大学 随着深度学习的发展,深度学习框架开始大量的出现.尤其是近两年,Google.Facebook.Microsoft等巨头都围绕深度学习重点投资了一系 ...

  6. 《深度学习之pytorch实战计算机视觉》第6章 PyTorch基础(代码可跑通)

    上一篇文章<深度学习之pytorch实战计算机视觉>第5章 Python基础讲了Python基础.接下来看看第6章 PyTorch基础. 目录 6.1 PyTorch中的Tensor 6. ...

  7. PyTorch基础(part5)--交叉熵

    学习笔记,仅供参考,有错必纠 文章目录 原理 代码 初始设置 导包 载入数据 模型 原理 交叉熵(Cross-Entropy) Loss=−(t∗ln⁡y+(1−t)ln⁡(1−y))Loss =-( ...

  8. PyTorch基础(part4)

    学习笔记,仅供参考,有错必纠 文章目录 PyTorch 基础 MNIST数据识别 常用代码 导包 载入数据 定义网络结构 PyTorch 基础 MNIST数据识别 常用代码 # 支持多行输出 from ...

  9. PyTorch基础(part3)

    学习笔记,仅供参考,有错必纠 文章目录 PyTorch 基础 线性回归 常用代码 导包 生成数据 构建神经网络模型 非线性回归 生成数据 构建神经网络模型 PyTorch 基础 线性回归 常用代码 # ...

  10. 深度学习导论(3)PyTorch基础

    深度学习导论(3)PyTorch基础 一. Tensor-Pytorch基础数据结构 二. Tensor索引及操作 1. Tensor索引类型 2. Tensor基础操作 3. 数值类型 4. 数值类 ...

最新文章

  1. 他保送北大、读完博士选择回中学任教,“做科研太枯燥,自己更适合教书”...
  2. Spring使用webjar
  3. CI/CD with drone
  4. ALV中动态内表+行转化为列
  5. rocketmq怎么保证数据不会重复_rocketmq如何保证消息不丢失
  6. isFinite使用说明
  7. 自定义控件之圆形的image
  8. 手把手教你做iOS的soap应用(webservice)
  9. 数控系统市场下行压力逐渐增大
  10. VGG19识别CIFAR10数据集(Pytorch实战)
  11. Django 使用 squashmigrations 合并 migration 文件
  12. mac电脑外接显示器后没有声音
  13. [Pytorch系列-71]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix - 使用预训练模型训练pix2pix模型
  14. 【加拿大留学】蒙特利尔中国公派学者 学生学习生活指南【蒙特利尔留学必看,第一次出国必看】
  15. 苹果手机使用说明书_苹果手机11个使用小技巧
  16. 用robot framework + python实现http接口自动化测试框架
  17. 利用百度人脸识别API,实现人脸登陆JavaWeb
  18. 【译】「食人的大鹫」的运动方法 程序动画技术
  19. 全球十年来含金量最高护照阿联酋列榜首,超过111个国家免签
  20. PHP 实现小偷程序

热门文章

  1. asp.net mvc 图片裁剪上传
  2. mysql 分页 pageindex_根据当前页号(pageIndex)和页大小(pageSize)获取分页数据
  3. HG30A-3多用表校验仪
  4. 【数学】多元函数微分学(宇哥笔记)
  5. 马上谈薪了,五险一金你还不知道?作为毕业生,钱不能白交!!!
  6. 动手学深度学习(tensorflow)---学习笔记整理(五、过拟合和欠拟合相关问题篇)
  7. 两个对象值相同(x.equals(y) == true),但却可有不同的hashCode,这句话对不对?
  8. C语言|temp=a,a=b,b=temp;|同行语句可以用逗号隔开
  9. 点击编辑按钮,前端table表格行内指定td可修改。(table是动态生成的)
  10. 【软路由】旁路由使用配置教程