上一篇文章(绘制隐藏层的激活值的分布[直方图])我们知道权重的初始值对激活层的值影响很大,也直接关系到神经网络学习是否顺利的至关重要的一环。

现在通过MNIST数据集的实例来比较下,直观感受不同的初始值对神经网络学习的影响程度。基于标准差分别为0.01,“Xavier初始值”,"He初始值"三个画图比较,其中Xavier和He的初始值分别对应的是Sigmoid和ReLU激活函数。

import numpy as np
import matplotlib.pyplot as plt
from dataset.mnist import load_mnist
from common.optimizer import *
from common.multi_layer_net import MultiLayerNet#读取MNIST数据
(x_train,t_train),(x_test,t_test)=load_mnist(normalize=True)
train_num=x_train.shape[0]#60000张训练数据
batch_num=200#每次随机抽取的数量
max_iter=500#迭代次数#权重初始值
weight_inits={'std':0.01,'Xavier':'sigmoid','He':'relu'}
mySGD=SGD(lr=0.01)networks={}
train_loss={}#5层实验
for k,weight_init in weight_inits.items():networks[k]=MultiLayerNet(inputSize=784,hiddenSizeList=[100,100,100,100],outputSize=10,weight_init_std=weight_init)train_loss[k]=[]#分别保存它们的损失函数值for i in range(max_iter):batch_mask=np.random.choice(train_num,batch_num)x_batch=x_train[batch_mask]t_batch=t_train[batch_mask]for k in weight_inits.keys():grads=networks[k].gradient(x_batch,t_batch)mySGD.update(networks[k].params,grads)loss=networks[k].loss(x_batch,t_batch)train_loss[k].append(loss)if i%100==0:print('迭代次数:'+str(i+100))for k in weight_inits.keys():loss=networks[k].loss(x_batch,t_batch)print(k+":"+str(loss))def smooth_curve(x):'''使得图形变得更光滑'''window_len=11s=np.r_[x[window_len-1:0:-1],x,x[-1:-window_len:-1]]w=np.kaiser(window_len,2)y=np.convolve(w/w.sum(),s,mode='valid')return y[5:len(y)-5]markers={'std':'o','Xavier':'s','He':'D'}
x=np.arange(max_iter)
for k in weight_inits.keys():plt.plot(x,smooth_curve(train_loss[k]),marker=markers[k],markevery=100,label=k)
plt.xlabel('iterations')
plt.ylabel('loss value')
plt.legend()
plt.show()

迭代次数:100
std:2.3025414131108293
Xavier:2.3072241661349584
He:2.446616670013899
迭代次数:200
std:2.3018910751129114
Xavier:2.243486197591573
He:1.6773269832035373
迭代次数:300
std:2.302147254233616
Xavier:2.1178870700671233
He:0.8535613642618993
迭代次数:400
std:2.3034931951252866
Xavier:1.7651264489359435
He:0.5502846316949196
迭代次数:500
std:2.300246888304067
Xavier:1.2227568236767556
He:0.48176101990404274

 可以看出标准差为0.01的时候,无法顺利进行学习。

导入的类等其他代码可以参阅:

全连接的多层神经网络结构(MultiLayerNet)https://blog.csdn.net/weixin_41896770/article/details/121451390

基于MNIST数据集的不同权重初始值的比较相关推荐

  1. DL之DNN优化技术:自定义MultiLayerNetExtend算法(BN层使用/不使用+权重初始值不同)对Mnist数据集训练评估学习过程

    DL之DNN优化技术:自定义MultiLayerNetExtend算法(BN层使用/不使用+权重初始值不同)对Mnist数据集训练评估学习过程 目录 输出结果 设计思路 核心代码 更多输出 相关文章: ...

  2. DL之DNN优化技术:自定义MultiLayerNet【5*100+ReLU】对MNIST数据集训练进而比较三种权重初始值(Xavier参数初始化、He参数初始化)性能差异

    DL之DNN优化技术:自定义MultiLayerNet[5*100+ReLU]对MNIST数据集训练进而比较三种权重初始值(Xavier参数初始化.He参数初始化)性能差异 导读 #思路:观察不同的权 ...

  3. 关于神经网络权重初始值的设置的研究

    关于神经网络权重初始值的设置的研究 一.权重初始值 二.权重初始值会影响隐藏层的激活值分布 三.Xavier初始值 四.He初始值 五.基于MNIST数据集的权重初始值的比较 一.权重初始值 权值衰减 ...

  4. GAN生成对抗网络基本概念及基于mnist数据集的代码实现

    本文主要总结了GAN(Generative Adversarial Networks) 生成对抗网络的基本原理并通过mnist数据集展示GAN网络的应用. GAN网络是由两个目标相对立的网络构成的,在 ...

  5. CNN应用Relu激活函数时设计权重初始值设置方法

    CNN应用Relu激活函数时,根据√(2/n)设计权重初始值 学习<深度学习入门(基于Python的理论与实现)>时,设计DeepConvNet,需要应用Relu激活函数,使用了ReLU的 ...

  6. DL之DNN优化技术:采用三种激活函数(sigmoid、relu、tanh)构建5层神经网络,权重初始值(He参数初始化和Xavier参数初始化)影响隐藏层的激活值分布的直方图可视化

    DL之DNN优化技术:采用三种激活函数(sigmoid.relu.tanh)构建5层神经网络,权重初始值(He参数初始化和Xavier参数初始化)影响隐藏层的激活值分布的直方图可视化 目录

  7. DL之CNN可视化:利用SimpleConvNet算法【3层,im2col优化】基于mnist数据集训练并对卷积层输出进行可视化

    DL之CNN可视化:利用SimpleConvNet算法[3层,im2col优化]基于mnist数据集训练并对卷积层输出进行可视化 导读 利用SimpleConvNet算法基于mnist数据集训练并对卷 ...

  8. 机器学习Tensorflow基于MNIST数据集识别自己的手写数字(读取和测试自己的模型)

    机器学习Tensorflow基于MNIST数据集识别自己的手写数字(读取和测试自己的模型)

  9. pytorch训练GAN的代码(基于MNIST数据集)

    论文:Generative Adversarial Networks 作者:Ian J. Goodfellow 年份:2014年 从2020年3月多开始看网络,这是我第一篇看并且可以跑通代码的论文,简 ...

最新文章

  1. logicaldoc 6.5 结合postgresql 9.x安装部署—基于windows平台
  2. 使用PIL库将一张小图贴到大图的指定位置
  3. 【学术相关】博士毕业也会看第一学历吗?
  4. 使用RSS订阅喜欢的微博博主
  5. 242. 有效的字母异位词 golang
  6. 局域网通讯工具_五大核心开启工业通讯创新之门——西门子工业网络专家计划打造最强行业生态...
  7. 一定要陪一个男人创业,你会和他一样快速成长,并内心变得强大
  8. SmallMQ实现发布
  9. 算法学习(四)冒泡排序
  10. Android简明开发教程十六:Button 画刷示例
  11. 【转】斐讯K2刷华硕固件教程
  12. python求最值_python求极值点(波峰波谷)
  13. 网吧服务器发消息,网吧盗号常见途径总结以及解决办法
  14. Axure制作音乐App原型图
  15. Python爬虫理论 | (2) 网络请求与响应
  16. 近期你已经授权登录过_原来我的微信、QQ 授权登录过这么多应用!(附查找及解绑方法)...
  17. 【性能测试基础】性能专有名词解析及性能瓶颈分析技巧
  18. GRU和LSTM的单元结构
  19. MySQL8.0安装失败
  20. 2022年山东省安全员C证考试题及在线模拟考试

热门文章

  1. TechDay实录|摘取皇冠上的明珠,中文NLP的不二选择——PaddlePaddle
  2. volatile关键字与synchronization关键字的区别?
  3. sed中支持变量的处理方法
  4. Java集合(二、LinkHashMap)
  5. 简单的php文件上传实例
  6. C++提高部分_C++函数模板的概念---C++语言工作笔记080
  7. C++_二维数组_定义方式_数组名称的作用_案例考试成绩统计---C++语言工作笔记021
  8. Elasticsearch--高级-映射mapping_添加行的字段映射---全文检索引擎ElasticSearch工作笔记018
  9. Netty工作笔记0035---Reactor模式图剖析
  10. 人工智能TensorFlow工作笔记007---认识张量