使用权值衰减算法解决神经网络过拟合问题、python实现

  • 一、what is 过拟合
  • 二、过拟合原因
  • 三、权值衰减
  • 四、实验验证
    • 4.1制造过拟合现象
    • 4.2使用权值衰减抑制过拟合

一、what is 过拟合

过拟合指只能拟合训练数据,但不能很好拟合不包含在训练数据中的其他数据的状态。

二、过拟合原因

模型参数过多、表现力强

训练数据少

三、权值衰减

这玩意在之前提到过,就是减小权值,通过在学习过程中对大的权重进行惩罚,来抑制过拟合。

深度学习目的是减小损失函数的值,那么为损失函数加上权重平方范数,就可以抑制权重变大。

L2范数是什么,就是相当于各个元素的平方和。如下面数学式子表示。

正则化是什么,regularizer,也就是规则化,也就是,向你的模型加入某些规则。为损失函数加上权重平方范数其实就是加上了正则化项,这个正则化项就是L2范数的权值衰减。

L2范数的权值衰减数学表达:

W是权重。λ是控制正则化强度的超参数,它越大,对权重施加的惩罚越多。二分之一是调整常量,这样的话,求导后是λW。

求权重的梯度的计算中,要为之前的误差反向传播法的结果加上正则化项的导数λW。

四、实验验证

4.1制造过拟合现象

我们制造过拟合,就需要增加网络参数,减少训练数据,那么就从MNIST数据集里只选出来300个数据,然后增加网络复杂幅度用7层网络,每层100个神经元,激活函数ReLU。

代码:

import os
import syssys.path.append(os.pardir)  # 为了导入父目录的文件而进行的设定
import numpy as np
import matplotlib.pyplot as plt
from dataset.mnist import load_mnist
from common.multi_layer_net import MultiLayerNet
from common.optimizer import SGD(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True)# 为了再现过拟合,减少学习数据
x_train = x_train[:300]
t_train = t_train[:300]# weight decay(权值衰减)的设定 =======================
weight_decay_lambda = 0 # 不使用权值衰减的情况
#weight_decay_lambda = 0.1
# ====================================================network = MultiLayerNet(input_size=784, hidden_size_list=[100, 100, 100, 100, 100, 100], output_size=10,weight_decay_lambda=weight_decay_lambda)
optimizer = SGD(lr=0.01)max_epochs = 201
train_size = x_train.shape[0]
batch_size = 100train_loss_list = []
train_acc_list = []
test_acc_list = []iter_per_epoch = max(train_size / batch_size, 1)
epoch_cnt = 0for i in range(1000000000):batch_mask = np.random.choice(train_size, batch_size)x_batch = x_train[batch_mask]t_batch = t_train[batch_mask]grads = network.gradient(x_batch, t_batch)optimizer.update(network.params, grads)if i % iter_per_epoch == 0:train_acc = network.accuracy(x_train, t_train)test_acc = network.accuracy(x_test, t_test)train_acc_list.append(train_acc)test_acc_list.append(test_acc)print("epoch:" + str(epoch_cnt) + ", train acc:" + str(train_acc) + ", test acc:" + str(test_acc))epoch_cnt += 1if epoch_cnt >= max_epochs:break# 3.绘制图形==========
markers = {'train': 'o', 'test': 's'}
x = np.arange(max_epochs)
plt.plot(x, train_acc_list, marker='o', label='train', markevery=10)
plt.plot(x, test_acc_list, marker='s', label='test', markevery=10)
plt.xlabel("epochs")
plt.ylabel("accuracy")
plt.ylim(0, 1.0)
plt.legend(loc='lower right')
plt.show()

由图可看出,过了100个epoch,用训练数据测得识别精度几乎为100%,但是测试数据与100%有较大差距,说明,模型对训练时没有使用的测试数据拟合的不好。

4.2使用权值衰减抑制过拟合

在上面代码中修改weight_decay_lambda = 0.1,可以看到如下结果。

测试数据和训练数据识别精度有差距,但是差距比之前的减小了。

有人说,那他的测试数据识别精度也没提升啊。

其实这是因为训练数据的精度也没到100%哦,如果再有多个训练数据进入网络训练,它精度到100%的时候,测试数据精度就会比之前过拟合时候要高。

使用权值衰减算法解决神经网络过拟合问题、python实现相关推荐

  1. 解决神经网络过拟合问题—Dropout方法、python实现

    解决神经网络过拟合问题-Dropout方法 一.what is Dropout?如何实现? 二.使用和不使用Dropout的训练结果对比 一.what is Dropout?如何实现? 如果网络模型复 ...

  2. 神经网络怎么解决过拟合,解决神经网络过拟合

    如何防止神经网络过拟合,用什么方法可以防止? 你这个问题本来就问的很模糊,你是想问神经网络的过拟合变现什么样还是为什么出现过拟合呢. 为此针对于第一个问题,神经网络的过拟合与支持向量机.高斯混合模型等 ...

  3. 神经网络过拟合怎么解决,神经网络过拟合怎么办

    神经网络如何防止过拟合? 你这个问题本来就问的很模糊,你是想问神经网络的过拟合变现什么样还是为什么出现过拟合呢. 为此针对于第一个问题,神经网络的过拟合与支持向量机.高斯混合模型等建模方法的过拟合类似 ...

  4. 人工蜂群算法python_改进的人工蜂群算法解决聚类问题(在Python中的分步实现)...

    在 之前的文章 中,我介绍了如何通过实施名为Artificial Bee Colony(ABC)的群集智能(SI)算法来解决现实世界中的优化问题. 现在是时候让我们掌握一些真实的数据并解释我们如何使用 ...

  5. 使用early stopping解决神经网络过拟合问题

    神经网络训练多少轮是一个很关键的问题,训练轮数少了欠拟合(underfit),训练轮数多了过拟合(overfit),那如何选择训练轮数呢? Early stopping可以帮助我们解决这个问题,它的作 ...

  6. 抑制过拟合的方法之权值衰减

    机器学习很常见的一个需要解决的问题就是过度拟合(overift),过拟合的意思是它能够很好的拟合训练数据,但是对于训练数据之外的数据可能就显得差强人意了,也就是常说的泛化能力比较差,所以抑制过拟合就显 ...

  7. 权值衰减和L2正则化傻傻分不清楚?

    点击上方"AI公园",关注公众号,选择加"星标"或"置顶" 作者:Divyanshu Mishra 编译:ronghuaiyang 导读 权 ...

  8. 权值衰减和 L2 正则化傻傻分不清楚?

    作者 | Divyanshu Mishra 编译 | ronghuaiyang 转自 | AI公园 导读 权值衰减和L2正则化,到底是不是同一个东西,这篇文章给你答案. 神经网络是伟大的函数逼近器和特 ...

  9. 神经网络过拟合什么意思,神经网络中解决过拟合

    神经网络,什么过拟合?,什么是欠拟合? 欠拟合是指模型不能在训练集上获得足够低的误差.而过拟合是指训练误差和测试误差之间的差距太大.相关介绍:人工神经网络(ANN)或联结主义系统是受构成动物大脑的生物 ...

最新文章

  1. mysql8 php7_windows10-nginx-mysql8.0-php7.0环境搭建
  2. 计算机术语new一个,微机原理第一章计算机基础知识(new)
  3. javascript下載csv檔案
  4. CSS3常用动画总结
  5. 箭头函数特殊性与普通函数的区别
  6. MYSQL出错代码列表大全(中文)
  7. PTA-7-1 将数组中的数逆序存放 (20分)(C语言)
  8. SQL Server 2014 导入Excel
  9. 解决远程桌面关闭后teamviewer不能连接的问题
  10. python要学多久可以找到工作-学习Python多久能找到工作?老男孩Python开发培训
  11. python常用代码总结-常见的排序算法的总结及python代码实现
  12. Python实现对给定的列表中连续数字的寻找
  13. 内核的解压缩过程详解
  14. 如何彻底删除adobe?adobe官方清理工具怎么用?
  15. Geant4学习一:写一个简单程序
  16. 微信公众号程序开发接入流程
  17. window.open() 打开IE缓慢的原因
  18. 【龙讯module小课堂】浅谈对gap的认识:PWmat中修正gap的module
  19. NOIP模拟系列 [BZOJ4668]冷战
  20. ISP(图像信号处理)白平衡White Balance

热门文章

  1. Oracle 创建表 练习题
  2. maven仓库理解、下载及设置
  3. js中toFixed方法的两个坑
  4. 非线性最优化(二)——高斯牛顿法和Levengerg-Marquardt迭代
  5. sentry + vue实现错误日志监控
  6. 谷歌发布最新版安卓Android,谷歌发布安卓 9 正式版,代号 Android Pie
  7. mysql永远不用utf8_永远不要在 MySQL 中使用「utf8」
  8. 腾讯视频如何多倍速播放视频
  9. win10使用网络共享功能的方法
  10. Mybatis判断int类型是否为空