要点

这节内容主要是用 Torch 实践 这个 优化器 动画简介 中起到的几种优化器, 这几种优化器具体的优势不会在这个节内容中说了, 所以想快速了解的话, 上面的那个动画链接是很好的去处.

下图就是这节内容对比各种优化器的效果:

伪数据

为了对比各种优化器的效果, 我们需要有一些数据, 今天我们还是自己编一些伪数据, 这批数据是这样的:

import torch
import torch.utils.data as Data
import torch.nn.functional as F
import matplotlib.pyplot as plttorch.manual_seed(1)    # reproducibleLR = 0.01
BATCH_SIZE = 32
EPOCH = 12# fake dataset
x = torch.unsqueeze(torch.linspace(-1, 1, 1000), dim=1)
y = x.pow(2) + 0.1*torch.normal(torch.zeros(*x.size()))# plot dataset
plt.scatter(x.numpy(), y.numpy())
plt.show()# 使用上节内容提到的 data loader
torch_dataset = Data.TensorDataset(x, y)
loader = Data.DataLoader(dataset=torch_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2,)

每个优化器优化一个神经网络

为了对比每一种优化器, 我们给他们各自创建一个神经网络, 但这个神经网络都来自同一个 Net 形式.

# 默认的 network 形式
class Net(torch.nn.Module):def __init__(self):super(Net, self).__init__()self.hidden = torch.nn.Linear(1, 20)   # hidden layerself.predict = torch.nn.Linear(20, 1)   # output layerdef forward(self, x):x = F.relu(self.hidden(x))      # activation function for hidden layerx = self.predict(x)             # linear outputreturn x# 为每个优化器创建一个 net
net_SGD         = Net()
net_Momentum    = Net()
net_RMSprop     = Net()
net_Adam        = Net()
nets = [net_SGD, net_Momentum, net_RMSprop, net_Adam]

优化器 Optimizer

接下来在创建不同的优化器, 用来训练不同的网络. 并创建一个 loss_func 用来计算误差. 我们用几种常见的优化器, SGDMomentumRMSpropAdam.

# different optimizers
opt_SGD         = torch.optim.SGD(net_SGD.parameters(), lr=LR)
opt_Momentum    = torch.optim.SGD(net_Momentum.parameters(), lr=LR, momentum=0.8)
opt_RMSprop     = torch.optim.RMSprop(net_RMSprop.parameters(), lr=LR, alpha=0.9)
opt_Adam        = torch.optim.Adam(net_Adam.parameters(), lr=LR, betas=(0.9, 0.99))
optimizers = [opt_SGD, opt_Momentum, opt_RMSprop, opt_Adam]loss_func = torch.nn.MSELoss()
losses_his = [[], [], [], []]   # 记录 training 时不同神经网络的 loss

训练/出图

接下来训练和 loss 画图.

for epoch in range(EPOCH):print('Epoch: ', epoch)for step, (b_x, b_y) in enumerate(loader):# 对每个优化器, 优化属于他的神经网络for net, opt, l_his in zip(nets, optimizers, losses_his):output = net(b_x)              # get output for every netloss = loss_func(output, b_y)  # compute loss for every netopt.zero_grad()                # clear gradients for next trainloss.backward()                # backpropagation, compute gradientsopt.step()                     # apply gradientsl_his.append(loss.data.numpy())     # loss recoder

SGD 是最普通的优化器, 也可以说没有加速效果, 而 Momentum 是 SGD 的改良版, 它加入了动量原则. 后面的 RMSprop 又是 Momentum 的升级版. 而 Adam 又是 RMSprop 的升级版. 不过从这个结果中我们看到, Adam 的效果似乎比 RMSprop 要差一点. 所以说并不是越先进的优化器, 结果越佳. 我们在自己的试验中可以尝试不同的优化器, 找到那个最适合你数据/网络的优化器.

所以这也就是在我 github 代码 中的每一步的意义啦.

Optimizer 优化器相关推荐

  1. 加速神经网络训练方法及不同Optimizer优化器性能比较

    本篇博客主要介绍几种加速神经网络训练的方法. 我们知道,在训练样本非常多的情况下,如果一次性把所有的样本送入神经网络,每迭代一次更新网络参数,这样的效率是很低的.为什么?因为梯度下降法参数更新的公式一 ...

  2. pytorch 7 optimizer 优化器 加速训练

    pytorch 7 optimizer 优化器 加速训练 import torch import torch.utils.data as Data import torch.nn.functional ...

  3. PyTorch 实现批训练和 Optimizer 优化器

    批训练 import torch import torch.utils.data as DataBATCH_SIZE = 5x = torch.linspace(1, 10, 10) # this i ...

  4. Optimizer优化器

    这节内容主要是对比在 Torch 实践中所会用到的几种优化器 编写伪数据 为了对比各种优化器的效果, 需要有一些数据, 可以自己编一些伪数据, 这批数据是这样的: 具体的数据生成代码如下: impor ...

  5. [Python人工智能] 四.TensorFlow创建回归神经网络及Optimizer优化器

    从本篇文章开始,作者正式开始研究Python深度学习.神经网络及人工智能相关知识.前一篇文章讲解了TensorFlow基础和一元直线预测的案例,以及Session.变量.传入值和激励函数:这篇文章将详 ...

  6. PLSQL_性能优化系列04_Oracle Optimizer优化器

    2014-09-25 Created By BaoXinjian 一.摘要 1. Oracle优化器介绍 本文讲述了Oracle优化器的概念.工作原理和使用方法,兼顾了Oracle8i.9i以及最新的 ...

  7. 深度学习训练之optimizer优化器(BGD、SGD、MBGD、SGDM、NAG、AdaGrad、AdaDelta、Adam)的最全系统详解

    文章目录 1.BGD(批量梯度下降) 2.SGD(随机梯度下降) 2.1.SGD导致的Zigzag现象 3.MBGD(小批量梯度下降) 3.1 BGD.SGD.MBGD的比较 4.SGDM 5.NAG ...

  8. TensorFlow(四)优化器函数Optimizer

    因为大多数机器学习任务就是最小化损失,在损失定义的情况下,后面的工作就交给了优化器.因为深度学习常见的是对于梯度的优化,也就是说,优化器最后其实就是各种对于梯度下降算法的优化. 常用的optimize ...

  9. 优化器 optimizer

    优化器 optimizer optimizer 优化器,用来根据参数的梯度进行沿梯度下降方向进行调整模型参数,使得模型loss不断降低,达到全局最低,通过不断微调模型参数,使得模型从训练数据中学习进行 ...

最新文章

  1. java常用的7大排序算法汇总
  2. 「机器学习」机器学习算法优缺点对比(汇总篇)
  3. 离ExtJS 4.1 beta发布只剩26个bug了
  4. 从菜鸟到老司机,数据科学的17个必用数据集推荐
  5. GitLab代码回滚到特定版本
  6. oracle易忘函数用法(2)
  7. python的xpath用法介绍_python爬虫之xpath的基本使用详解
  8. 计算机网络——链路层之停止等待协议
  9. 自适应浮动表单填充布局脚本
  10. Intel 64/x86_64/IA-32/x86处理器基本执行环境 (1) - 32位执行环境概述
  11. C# break ,continue, return
  12. 安卓手机鸿蒙系统怎么下载,华为鸿蒙系统来了:安卓系统会成为下一个“塞班”吗?...
  13. 100道初级网络工程师测试题
  14. McAfee麦咖啡8.5企业版高级教程
  15. 《统计学》第八版贾俊平第十四章指数知识点总结及课后习题答案
  16. FPGA串口波特率计算方法
  17. CUDA安装教程及调试:本机win10+vs2013+NVIDIA GeForce GTX 1050Ti
  18. 广西行政村数据shp_广西自治区乡镇行政区划数据 精度1:10万
  19. 平面设计需要学习什么,平面设计是什么;夏雨老师
  20. Resources文件夹

热门文章

  1. Openssl 生成自签名证书
  2. 八年级上册计算机第三课教案,人教版八年级信息技术上册教案
  3. 再提几个三维地图网站
  4. 云时代架构之游戏服务器的架构演进
  5. 涉密计算机病毒库升级管理,涉密计算机管理
  6. 缓存穿透,击穿,雪崩
  7. Java后台实现pdf文件在浏览器中预览
  8. TI RTOS User Guide
  9. spice仿真1.1
  10. 关于使用LoadImage时的一个小错误(转)