前面章节我们知道神经网络的目的是寻找最优参数,介绍了四种以及两种改进的方法来寻找最优参数,并画图进行了比较,具体可参阅
神经网络技巧篇之寻找最优参数的方法https://blog.csdn.net/weixin_41896770/article/details/121375510
神经网络技巧篇之寻找最优参数的方法【续】https://blog.csdn.net/weixin_41896770/article/details/121419590

现在我们通过以前的一个经典实例(MNIST手写数字识别)来测试对比四种方法,有不熟悉MNIST数据集的伙伴们可以先参阅以前的两篇文章

MNIST数据集手写数字识别(一)https://blog.csdn.net/weixin_41896770/article/details/119576575

MNIST数据集手写数字识别(二)https://blog.csdn.net/weixin_41896770/article/details/119710429我们只需在上一篇文章的基础上进行修改,把四种方法加入进来,构造一个5层的神经网络,隐藏层每层100个神经元,计算损失函数值来确认四种方法在学习进展上有多大程度的差异。

import numpy as np
import matplotlib.pyplot as plt
from dataset.mnist import load_mnist
from common.optimizer import *
from common.multi_layer_net import MultiLayerNet#读取MNIST数据
(x_train,t_train),(x_test,t_test)=load_mnist(normalize=True)
train_num=x_train.shape[0]#60000张训练数据
batch_num=200#每次随机抽取的数量
max_iter=500#迭代次数#四种方法,构造5层(4个隐藏层+1个输出层)神经网络来评估
mySGDDict={}
mySGDDict['SGD']=SGD()
mySGDDict['Momentum']=Momentum()
mySGDDict['AdaGrad']=AdaGrad()
mySGDDict['Adam']=Adam()networks={}
train_loss={}
for k in mySGDDict.keys():networks[k]=MultiLayerNet(inputSize=784,hiddenSizeList=[100,100,100,100],outputSize=10)train_loss[k]=[]
for i in range(max_iter):batch_mask=np.random.choice(train_num,batch_num)x_batch=x_train[batch_mask]t_batch=t_train[batch_mask]#分别使用四种方法进行训练for k in mySGDDict.keys():grads=networks[k].gradient(x_batch,t_batch)mySGDDict[k].update(networks[k].params,grads)loss=networks[k].loss(x_batch,t_batch)train_loss[k].append(loss)#保存损失函数的值if i%100==0:print('迭代次数:'+str(i+100))for k in mySGDDict.keys():loss=networks[k].loss(x_batch,t_batch)print(k+":"+str(loss))def smooth_curve(x):'''使得图形变得更光滑'''window_len=11s=np.r_[x[window_len-1:0:-1],x,x[-1:-window_len:-1]]w=np.kaiser(window_len,2)y=np.convolve(w/w.sum(),s,mode='valid')return y[5:len(y)-5]#画图
markers={'SGD':'o','Momentum':'x','AdaGrad':'s','Adam':'D'}#样式
x=np.arange(max_iter)
for k in mySGDDict.keys():plt.plot(x,smooth_curve(train_loss[k]),marker=markers[k],markevery=100,label=k)
plt.xlabel('iterations')
plt.ylabel('loss value')
plt.legend()
plt.show()

迭代次数:100
SGD:2.438386391808663
Momentum:2.3930390242209603
AdaGrad:2.2428315543940314
Adam:2.445649073594391
迭代次数:200
SGD:1.8504058784313966
Momentum:0.3388485347141527
AdaGrad:0.14634168979474382
Adam:0.14149500888618036
迭代次数:300
SGD:0.9987654073163665
Momentum:0.23492597754880898
AdaGrad:0.11494675925508649
Adam:0.13704056059009395
迭代次数:400
SGD:0.6673420358314794
Momentum:0.2281848809517749
AdaGrad:0.09351207134574782
Adam:0.12776555371113074
迭代次数:500
SGD:0.5480751297953668
Momentum:0.19940082236878404
AdaGrad:0.09958623479585509
Adam:0.10639644884259741

从图中可知,横轴是学习的迭代次数,纵轴是损失函数的值,与SGD相比较,其他三种方法学习得更快,而且速度基本相同,仔细看的话,AdaGrad的学习进行得稍微快一点,实际上的结果还跟学习率等超参数以及多少层神经网络有关系。

基于MNIST数据集的最优参数的方法的比较相关推荐

  1. GAN生成对抗网络基本概念及基于mnist数据集的代码实现

    本文主要总结了GAN(Generative Adversarial Networks) 生成对抗网络的基本原理并通过mnist数据集展示GAN网络的应用. GAN网络是由两个目标相对立的网络构成的,在 ...

  2. pytorch训练GAN的代码(基于MNIST数据集)

    论文:Generative Adversarial Networks 作者:Ian J. Goodfellow 年份:2014年 从2020年3月多开始看网络,这是我第一篇看并且可以跑通代码的论文,简 ...

  3. DL之CNN可视化:利用SimpleConvNet算法【3层,im2col优化】基于mnist数据集训练并对卷积层输出进行可视化

    DL之CNN可视化:利用SimpleConvNet算法[3层,im2col优化]基于mnist数据集训练并对卷积层输出进行可视化 导读 利用SimpleConvNet算法基于mnist数据集训练并对卷 ...

  4. 机器学习Tensorflow基于MNIST数据集识别自己的手写数字(读取和测试自己的模型)

    机器学习Tensorflow基于MNIST数据集识别自己的手写数字(读取和测试自己的模型)

  5. 神经网络--基于mnist数据集取得最高的识别准确率

    前言: Hello大家好,我是Dream. 今天来学习一下如何基于mnist数据集取得最高的识别准确率,本文是从零开始的,如有需要可自行跳至所需内容~ 本文目录: 1.调用库函数 2.调用数据集 3. ...

  6. 文献记录(part80)--基于平均互信息的最优社区发现方法

    学习笔记,仅供参考,有错必纠 关键词:AMI-COPRA 算法 ;AMI-GN 算法 ;平均互信息 ;AMI 方法 ;社区发现; 基于平均互信息的最优社区发现方法 摘要 本文提出一种基于平均互信息的最 ...

  7. 神经网络技巧篇之寻找最优参数的方法【续】

    上一篇文章介绍了四种寻找最优参数的方法,这次做一个补充,对其中两种方法(Momentum和AdaGrad)做一些改进,让参数的更新收敛更快速 Nesterov 是对Momentum动量SGD的一个改进 ...

  8. 神经网络技巧篇之寻找最优参数的方法

    在神经网络的学习中,其中一个重要目的就是找到使损失函数的值尽可能小的参数,为了找到这个最优参数,我们使用梯度(导数)作为线索,沿着梯度方向来更新参数,并重复这个步骤,从而逐渐靠近最优参数,这个过程叫做 ...

  9. matlab朴素贝叶斯手写数字识别_基于MNIST数据集实现手写数字识别

    介绍 在TensorFlow的官方入门课程中,多次用到mnist数据集.mnist数据集是一个数字手写体图片库,但它的存储格式并非常见的图片格式,所有的图片都集中保存在四个扩展名为idx*-ubyte ...

最新文章

  1. Ubuntu Linux系统下apt-get命令详解
  2. ASSERT(断言)的用法
  3. mysql中两种备份方法的优缺点_Mysql两种存储引擎的优缺点
  4. [转]机器学习和深度学习资料汇总【01】
  5. 从flink-example分析flink组件(3)WordCount 流式实战及源码分析
  6. 从一个故障说说Java的三个BlockingQueue
  7. html 标题树,html树
  8. python编写函数_浅谈Python 函数式编程
  9. @1.0.0 dev: `webpack-dev-server --inline --progress --config
  10. java连本地mysql注意事项_java数据库连接及注意事项
  11. PHP框架剥离的判断是否为手机移动终端的函数
  12. Java 标准开发包_JDK 9系列全套官方下载链接
  13. MPEG4Extractor分析
  14. 交换机分布缓存_缓存比普通交换机也大许多
  15. su、sudo命令和限制root远程登录
  16. 关于线性空间和线性映射
  17. QRJDC搭建实现QQ扫码登录对接青龙对接傻妞
  18. Java反射使用的Field类介绍
  19. centos mysql ssh连接,使用SSH隧道连接MYSQL
  20. lombok @data 忽略属性_Lombok使用指南

热门文章

  1. dubbo ---- 入门
  2. 蒟蒻吃药计划-治疗系列 #round6 数据结构初步-指针|链表|结构体
  3. eclipse中ctrl+h默认打开是JavaSearch,怎么设置成默认打开是FileSearch
  4. 等值首尾和-----------2012年12月27日
  5. java学习日记(9)———socket,网络编程的学习
  6. Source code manager common
  7. AndroidStudio_安卓原生开发_Android开发中界面调试很别扭? 设置应用屏宽屏高_应用大小_design_width_in_dp---Android原生开发工作笔记140
  8. 大数据之-Hadoop3.x_MapReduce_序列化案例FlowBean---大数据之hadoop3.x工作笔记0097
  9. 微服务升级_SpringCloud Alibaba工作笔记0025---Nacos持久化切换配置
  10. Python工作笔记003---正则中的re.I re.M_以及m.group和m.groups的解释