前面几篇博文分析了每一种参数优化方案,现在做一个对比,代码参考斋藤的红鱼书第六章。

实验对mnist数据集的6万张图片训练,使用5层全连接神经网络(4个隐藏层,每个隐藏层有100个神经元),共迭代2000次,下图是损失函数随着训练迭代次数的变化:

可以看到SGD是最慢的,而AdaGrad最快, 且最终的识别精度也更高,这并不是一定的,跟数据也有关

贴出部分迭代过程变化:

===========iteration:1200===========
SGD:0.2986528195291609
Momentum:0.1037981040196782
AdaGrad:0.0668137679448615
Adam:0.05010293181776089
===========iteration:1300===========
SGD:0.17833478097202
Momentum:0.06128433751079029
AdaGrad:0.01779291355463178
Adam:0.036788168826807605
===========iteration:1400===========
SGD:0.30288604165486865
Momentum:0.07708723420976107
AdaGrad:0.036239187352732696
Adam:0.03584596636673899
===========iteration:1500===========
SGD:0.21648932214740826
Momentum:0.11593046640138721
AdaGrad:0.033343153287890816
Adam:0.039999528396092415
===========iteration:1600===========
SGD:0.23519516569365168
Momentum:0.06509188355944322
AdaGrad:0.0377409654184555
Adam:0.05803067028715449
===========iteration:1700===========
SGD:0.28851197390150085
Momentum:0.14561108131745754
AdaGrad:0.07160438141432544
Adam:0.07280250583341145
===========iteration:1800===========
SGD:0.14382629146685216
Momentum:0.03977221072571262
AdaGrad:0.015159891599626725
Adam:0.019623602905335474
===========iteration:1900===========
SGD:0.19067465612724083
Momentum:0.053986168113818435
AdaGrad:0.03665586658910679
Adam:0.038508895473566646

主要代码:(完整代码可去图灵社区找红鱼书,随书下载)

# coding: utf-8
# OptimizerCompare.pyimport numpy as np
import matplotlib.pyplot as plt
from dataset.mnist import load_mnist
from MultiLayerNet import MultiLayerNet
from util import smooth_curve
from optimizer import *# 0:读入MNIST数据==========
(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True)train_size = x_train.shape[0]
batch_size = 128
max_iterations = 2000# 1:进行实验的设置==========
optimizers = {}
optimizers['SGD'] = SGD()
optimizers['Momentum'] = Momentum()
optimizers['AdaGrad'] = AdaGrad()
optimizers['Adam'] = Adam()
# optimizers['RMSprop'] = RMSprop()networks = {}
train_loss = {}
for key in optimizers.keys():networks[key] = MultiLayerNet(input_size=784, hidden_size_list=[100, 100, 100, 100],output_size=10)train_loss[key] = []# 2:开始训练==========
for i in range(max_iterations):batch_mask = np.random.choice(train_size, batch_size)x_batch = x_train[batch_mask]t_batch = t_train[batch_mask]for key in optimizers.keys():grads = networks[key].gradient(x_batch, t_batch)optimizers[key].update(networks[key].params, grads)loss = networks[key].loss(x_batch, t_batch)train_loss[key].append(loss)if i % 100 == 0:print("===========" + "iteration:" + str(i) + "===========")for key in optimizers.keys():loss = networks[key].loss(x_batch, t_batch)print(key + ":" + str(loss))# 3.绘制图形==========
markers = {"SGD": "o", "Momentum": "x", "AdaGrad": "s", "Adam": "D"}
x = np.arange(max_iterations)
for key in optimizers.keys():plt.plot(x, smooth_curve(train_loss[key]), marker=markers[key], \markevery=100, label=key)
plt.xlabel("iterations")
plt.ylabel("loss")
plt.ylim(0, 1)
plt.legend()
plt.show()

NN学习技巧之参数最优化的四种方法对比(SGD, Momentum, AdaGrad, Adam),基于MNIST数据集相关推荐

  1. DL之DNN优化技术:DNN优化器的参数优化—更新参数的四种最优化方法(SGD/Momentum/AdaGrad/Adam)的案例理解、图表可视化比较

    DL之DNN优化技术:DNN优化器的参数优化-更新参数的四种最优化方法(SGD/Momentum/AdaGrad/Adam)的案例理解.图表可视化比较 目录 四种最优化方法简介 优化器案例理解 输出结 ...

  2. DL之DNN:自定义MultiLayerNet(5*100+ReLU+SGD/Momentum/AdaGrad/Adam四种最优化)对MNIST数据集训练进而比较不同方法的性能

    DL之DNN:自定义MultiLayerNet(5*100+ReLU+SGD/Momentum/AdaGrad/Adam四种最优化)对MNIST数据集训练进而比较不同方法的性能 目录 输出结果 设计思 ...

  3. 快速排序——寻找数组第K大数(由浅入深,四种方法对比讲解!)

    寻找数组第K大数是大厂面试中经常考到的一题,有的小机灵鬼直接用sort()进行排序,两行代码解决,这样看似可行,实则掉入了出题人的陷阱.面试官希望看到的是你对算法的理解,而不是函数的调用.下面,我就以 ...

  4. Java求两集合中元素交集的四种方法对比总结

    hello,你好呀,我是灰小猿,一个超会写bug的程序猿! 最近在做项目的时候有用到对两个集合中的元素进行对比求其交集的情况,因为涉及到的数据量比较大,所以在进行求两个集合中元素交集的时候,就应该考虑 ...

  5. Python遍历字典的四种方法对比

    #!/usr/bin/python from time import clockl = [(x,x) for x in xrange (10000000)] d = dict(l) t0 = cloc ...

  6. 有时间窗车辆路径问题(VRPTW)解决方案合集,[CW节约算法,TS(硬约束版),TS(惩罚函数版),LNS四种方法对比(附MATLAB代码)]

    前言 本文中构造初始解的方式不采用CW法,而是采用论文里说的构造初始解的方法. 然后在调试的过程中发现,自适应调整惩罚权重的策略效果不好,于是稍微更改了一下自适应调整权重的策略:只在解违反约束时,使权 ...

  7. java中遍历HashMap的四种方法及效率比较

    1. 推荐方法: 使用entrySet 遍历Map 类集合KV,而不是keySet 方式进行遍历. 代码示例如下: // 循环第二种HashMap<Integer, String> map ...

  8. 从零学习Fluter(八):Flutter的四种运行模式--Debug、Release、Profile和test以及命名规范...

    从零学习Fluter(八):Flutter的四种运行模式--Debug.Release.Profile和test以及命名规范 好几天没有跟新我的这个系列文章,一是因为这两天我又在之前的基础上,重新认识 ...

  9. excel多列多行堆叠成多列一行_「Excel技巧」Excel快速实现将一行转为多行多列的四种方法...

    今天来说说在Excel中,将表格里的一列转换为多行多列的几种方法. 例如,以下表格,是一个行业分类表,都放在同一列中.现我们准备把它转为多列. 表格里数据除掉标题行行,总共有60列数据,干脆我们就给它 ...

最新文章

  1. oracle函数 case,oracle的case函数和case控制结构 (摘)
  2. CVPR 2022|重新思考对齐Prototype的域自适应:基于Graph Matching的新范式
  3. AI + 3D!英伟达开源3D深度学习框架Kaolin
  4. BZOJ4066 简单题(KD-Tree)
  5. 阿里云播放器,判断直播时的状态
  6. java程序员面试真题及详解2017(纯手动)
  7. 图像列表控制(CImageList)
  8. 【有感】聆听哈佛幸福课 (上)
  9. 维吉尼亚密码原理详解及算法实现
  10. java log 乱码_Java日志文件乱码
  11. 11.绘制统计图形——误差棒图
  12. SwiftUI OCR功能大全之 基于 SwiftUI 构建文档扫描仪
  13. excel斜线表头的制作
  14. Linux kill,killall和killall5
  15. mysql 给用户取消权限_MySQL创建用户并授权及撤销用户权限
  16. 程序员如何兼职接单良心推荐
  17. Prime算法和Krustal算法(转自博客园华山大师兄)
  18. 【2阶】BootStrap制作简易CRM管理系统-crm-1
  19. 《go语言圣经》习题答案-第5章
  20. VMware桥接网络

热门文章

  1. 机器人的洪流—刷库、撞库那些事儿
  2. 让她/他心动的告白,页面制作(9个页面+链接+代码,原生HTML+CSS+JS实现)
  3. 微信小程序 实现阿里云上传
  4. 企业管理内容有哪些了,分别是什么?
  5. 元宇宙退潮,人工智能起飞,大厂 Al 新赛点在哪?
  6. 数据库原理与实践课设(宾馆管理系统),java+jdbc+sqlserver2017
  7. 蒂姆·库克的五项核心领导力
  8. 计算机应用文摘杂志影响因子,计算机应用文摘杂志
  9. FigDraw 12. SCI 文章绘图之相关性矩阵图(Correlation Matrix)
  10. 【亲密关系】001 亲密关系的影响因素