这里主要讲不同常见优化器代码的实现,以及在一个小数据集上做一个简单的比较。

备注:pytorch需要升级到最新版本

其中,SGD和SGDM,还有Adam是pytorch自带的优化器,而RAdam是最近提出的一个说是Adam更强的优化器,但是一般情况下真正的大佬还在用SGDM来做优化器

导入必要库:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.utils.data as Data
from torch.optim.optimizer import Optimizer
import math

主程序部分:

LR = 0.01
BATCH_SIZE = 32
EPOCH = 12# fake dataset
x = torch.unsqueeze(torch.linspace(-1, 1, 300), dim=1)
y = x.pow(2) + 0.1 * torch.normal(torch.zeros(*x.size()))torch_dataset = Data.TensorDataset(x, y)
loader = Data.DataLoader(dataset=torch_dataset,batch_size=BATCH_SIZE,shuffle=True,num_workers=2
)class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.hidden = nn.Linear(1, 20)self.prediction = nn.Linear(20, 1)def forward(self, x):x = F.relu(self.hidden(x))x = self.prediction(x)return xdef main():net_SGD = Net()net_Momentum = Net()net_Adam = Net()net_RAdam = Net()nets = [net_SGD, net_Momentum, net_Adam, net_RAdam]opt_SGD = optim.SGD(net_SGD.parameters(), lr=LR)opt_Momentum = optim.SGD(net_Momentum.parameters(), lr=LR, momentum=0.9)opt_Adam = optim.Adam(net_Adam.parameters(), lr=LR, betas=(0.9, 0.99))opt_RAdam = RAdam(net_RAdam.parameters(),lr=LR,weight_decay=0)optimizers = [opt_SGD, opt_Momentum, opt_Adam, opt_RAdam]loss_func = nn.MSELoss()losses_his = [[], [], [], []]# trainingfor epoch in range(EPOCH):print('EPOCH:', epoch)for step, (batch_x, batch_y) in enumerate(loader):b_x = batch_xb_y = batch_yfor net, opt, l_his in zip(nets, optimizers, losses_his):out = net(b_x)loss = loss_func(out, b_y)opt.zero_grad()loss.backward()opt.step()l_his.append(loss.item())labels = ['SGD', 'Momentum', 'Adam','RAdam']for i, l_his in enumerate(losses_his):plt.plot(l_his, label=labels[i])plt.legend(loc='best')plt.xlabel('Steps')plt.ylabel('Loss')plt.ylim((0, 0.2))plt.show()if __name__ == '__main__':main()

下图是优化器的对比:

可以看出来,Adam的效果可以说是非常好的。然后SGDM其次,SGDM是大佬们经常会使用的,所以在这里虽然看起来SGDM效果不如Adam,但是依然推荐在项目中,尝试一下SGDM的效果。

往期精彩回顾适合初学者入门人工智能的路线及资料下载机器学习及深度学习笔记等资料打印机器学习在线手册深度学习笔记专辑《统计学习方法》的代码复现专辑
AI基础下载机器学习的数学基础专辑获取一折本站知识星球优惠券,复制链接直接打开:https://t.zsxq.com/yFQV7am本站qq群1003271085。加入微信群请扫码进群:

【深度学习】常见优化器的PyTorch实现相关推荐

  1. 妈耶,讲得好详细,十分钟彻底看懂深度学习常用优化器SGD、RMSProp、Adam详解分析

    深度学习常用优化器学习总结 常用优化器 SGD RMS Prop Adam 常用优化器 SGD 基本思想:通过当前梯度和历史梯度共同调节梯度的方向和大小 我们首先根据pytorch官方文档上的这个流程 ...

  2. 深度学习常见优化算法,图解AdaGrad、RMSProp,Adam

    1. AdaGrad AdaGrad算法是梯度下降法的改进算法,其优点是可以自适应学习率.该优化算法在较为平缓处学习速率大,有比较高的学习效率,在陡峭处学习率小,在一定程度上可以避免越过极小值点.在S ...

  3. 深度学习相关优化器以及在tensorflow的使用(转)

    参考链接:https://arxiv.org/pdf/1609.04747.pdf 优化器对比论文 https://www.leiphone.com/news/201706/e0PuNeEzaXWsM ...

  4. 深度学习各类优化器详解(动量、NAG、adam、Adagrad、adadelta、RMSprop、adaMax、Nadam、AMSGrad)

    深度学习梯度更新各类优化器详细介绍 文章目录 <center>深度学习梯度更新各类优化器详细介绍 一.前言: 二.梯度下降变形形式 1.批量归一化(BGD) 2.随机梯度下降(SGD) 3 ...

  5. 【深度学习】优化器详解

    优化器 深度学习模型通过引入损失函数,用来计算目标预测的错误程度.根据损失函数计算得到的误差结果,需要对模型参数(即权重和偏差)进行很小的更改,以期减少预测错误.但问题是如何知道何时应更改参数,如果要 ...

  6. 深度学习TensorFlow优化器的选择

    原文链接:https://blog.csdn.net/junchengberry/article/details/81102058 在很多机器学习和深度学习的应用中,我们发现用的最多的优化器是 Ada ...

  7. 深度学习:优化器工厂,各种优化器介绍,numpy实现深度学习(一)

    文章目录 简单概括参数更新: 优化器 Vanilla Update: Vanilla 代码实现: Momentum Update: Momentum 代码实现: Nesterov Momentum U ...

  8. 深度学习之优化器(优化算法)

    前言 前面已经讲过几中梯度下降算法了,并且给了一个收尾引出这一章节,想看的小伙伴可以去看看这一篇文章:机器学习之梯度下降算法.前面讲过对SGD来说,最要命的是SGD可能会遇到"峡谷" ...

  9. PyTorch框架学习十三——优化器

    PyTorch框架学习十三--优化器 一.优化器 二.Optimizer类 1.基本属性 2.基本方法 三.学习率与动量 1.学习率learning rate 2.动量.冲量Momentum 四.十种 ...

最新文章

  1. 开源大数据周刊-第34期
  2. 阿里云自定义监控tomcat进程数
  3. uniapp点击图片放大_手机做图片放大镜效果很难?看这里,分分钟就能学会!
  4. how is sap-ui-core.js initialize the reqeust of sap-ui-core-dbg.js
  5. codeforces D.MADMAX 动态规划、记忆化搜索
  6. RabbitMq队列 queue
  7. python 文件相似度分析_使用Python做人群相似度分析
  8. 如何实现微信小程序API的Promise化
  9. 软件工程 可行性分析与需求分析
  10. 上三角矩阵法Matlab,在MATLAB中重塑/变换上三角矩阵
  11. 产品经理——产品方法论
  12. 【深度学习Deep Learning系列】word2vec和doc2vec
  13. 基于 VIVADO 的 AM 调制解调(2)工程实现
  14. TouchBar Dino for mac(TouchBar上的小恐龙跑酷游戏)
  15. 电脑操作系统(Androidx86、Windows、Linux)说明
  16. [数学基础知识] Cramér‘s V 相关系数和Python算法实现
  17. 快速编写HTML代码常用的方法
  18. 华为ac、瘦ap简单上线(旁挂式)
  19. aspose-words基本操作
  20. MySQL SELECT查询语句练习2(中级篇)

热门文章

  1. [转]Android输入法框的梳理
  2. SQL Server 数据库使用备份还原造成的孤立用户和对象名‘xxx’无效的错误的解决办法...
  3. 淘宝服务端高并发分布式架构的十四次演进之路
  4. Elasticsearch入门之从零开始安装ik分词器
  5. EntityFreamWork 项目总结
  6. 从零开始学Xamarin.Forms(四) Android 准备步骤(添加第三方Xamarin.Forms.Labs库)
  7. NSDate与NSDateFormatter的相关用法
  8. doubleClick-v2-as3.0 学习笔记(2)--Video相关
  9. Placing a Method with Eval parameter into a DataList
  10. (转)虚函数和纯虚函数区别