深度学习进阶课程10---减少Overfitting的方法Regularization
上一篇文章写了增加训练数据集的量是减少overfitting的途径之一,其实减小神经网络的规模,也可以减小overfitting,但是更深层更大的网络潜在有更强的学习能力。
我们可以采用另一种方式来减少overfitting,即使对于固定的神经网络和固定的训练集,我们仍然可以减少,Regularization
最常见的一种Regularization:L2 Regularization(weight decay)
新的损失函数 Regularized cross-entropy:
比之前的cross-entropy增加了一项:权重之和(对于神经网络里面的所有权重w相加)
λ>0;,是regularization的参数
n:训练集包含的实例个数
对于最开始写的二次cost函数,加上一项,变为:Regularized quadratic cost:
对于以上两种情况,可以概括表示为:
Regularization的Cost偏向于让神经网络学习比较小的权重w,除非第一项的Cost明显减少
λ:调整两项的相对重要程度,较小的λ倾向于让第一项Cost最小化,较大的λ倾向于最小化权重之和
对以上公式求偏导数:
以上两个偏导数可以用之前介绍的backpropagation算法求得,添加了一项:
对于偏向b,偏导数不变
根据梯度下降算法,更新法则变为:
对于随机梯度下降(stochastic gradient descent):
求和是对于一个mini-batch里面所有的x
做一个实验,看新的regularization表现如何
import mnist_loader #下载数据模块
import network2 #刚才实现好的network
import matplotlib.pyplot as plttraining_data,validation_data,test_data=mnist_loader.load_data_wrapper()net=network2.Network([784,30,10],cost=network2.CrossEntropyCost)
net.large_weight_initializer()#初始化net.SGD(training_data[:1000],400,10,0.5,evaluation_data=test_data,lmbda=0.1,monitor_evaluation_cost=True,monitor_evaluation_accuracy=True,monitor_training_cost=True,monitor_training_accuracy=True)
来看一下效果:
可以看到cost的表现还是非常不错的,曲线非常平缓,一直在下降
再来看一下accuracy
accuracy在test data上面持续增加
最高的accuracy也增加了,说明regularization减少了overfitting
再看一下用50000张训练集做测试,
同样的参数:30 epochs,学习率:0.5,mini-batch size:10
需要改变λ,因为n从10000变成50000了:
变了,所以需要增大λ,让分数的值相对不会变化太大,增大到5.0>>net.large_weight_initializer()
再来看一下效果:
可以看到效果好了很多
如果隐藏层用100个神经元呢
net=network2.Network([784,100,10],cost=network2.CrossEntropyCost)
最终结果在测试集上accuracy达到97.92,比隐藏层30个神经元提高很多
如果调整优化一下参数,学习率=0.1,λ=5.0,只需要30个epoch,准确率就超过了98%,达到了98.04,加入regularization不仅减小了overfitting,还可以避免陷入局部最小点(local minimum),更容易重现实验结果
那么为什么Regularization可以减少overfitting?
假设一个简单的数据集
构建以上模型来模拟这个点的分布,最后拟合如下:
拟合的非常好,误差为0,但是这个方程最高竟然达到了9次!
我们用一个简单的方程来模拟它
y=2xy=2xy=2x
可以看到简单的y=2x也能拟合的不错,这两个模型哪个更好一些呢?
y=2x更简单,仍然很好的描述了数据,巧合的概率很小,所以我们更偏向于y=2x
在神经网络中:
Regularization网络鼓励更小的权重,小的权重的情况下,x的一些随机的变化不会对神经网络的模型造成太大的影响,所以更小的可能受到数据局部噪音的影响
Un-regularization神经网络,权重更大,容易通过神经网络模型比较大的改变来适应数据,更容易学习到局部数据的噪音
Regularization更倾向于学到更简单的一些模型
简单的模型不一定总是更好,要从大量数据实验中获得,目前添加regularization可以更好的泛化更多的从实验中得来,理论的支持还在研究之中
深度学习进阶课程10---减少Overfitting的方法Regularization相关推荐
- 深度学习进阶课程11---减少overfitting的方法Regularization和Dropout
这篇文章继续写一下Regularization,写一下L1 regularization 公式如下: 跟L2 regularization相似,但不太一样,是对权重w的绝对值求和 求偏导: sgn() ...
- 深度学习进阶课程16---用ReL解决Vanishing Gradient问题
上篇文章发现一个新的问题,越往外越接近输出层,学习效率比较高 这篇文章来讨论一下如何解决这个问题 Exploding gradient problem: 修正以上问题: (1)初始化比较大的权重:比如 ...
- 【资源下载】DeepMindUCL深度学习与强化学习进阶课程
点击我爱计算机视觉标星,更快获取CVML新技术 本文课程介绍部分来自机器之心,因为原视频国内无法观看,所以我爱计算机视觉费了老大劲专门搬到国内分享给大家,下载方法见文末. 11月23日,DeepMin ...
- B站上线!DeepMind加UCL强强联手推出深度学习与强化学习进阶课程(附视频)
新智元报道 编辑:元子 [新智元导读]DeepMind和伦敦大学学院(University College London,UCL)合作,推出了一个系列的深度学习与强化学习精品进阶课程.该课程内 ...
- 干货 | 吴恩达亲自为这份深度学习专项课程精炼图笔记点了赞!(附下载)
来源:机器之心.AI有道 本文约7500字,建议阅读10+分钟. 本文整理了深度学习基础.卷积网络和循环网络的学习笔记,附下载哦~ [ 导读 ]吴恩达在推特上展示了一份由 TessFerrandez ...
- 【人工智能】深度学习专项课程精炼图笔记!必备收藏
本文为人工智能学习笔记记录 ,参考机器之心,AI有道,Google资源 目录 深度学习基础 1. 深度学习基本概念 2. logistic 回归 3. 浅层网络的特点 4. 深度神经网络的特点 5. ...
- 深度学习笔记(10) 优化算法(二)
深度学习笔记(10) 优化算法(二) 1. Adam 优化算法 2. 学习率衰减 3. 局部最优的问题 1. Adam 优化算法 Adam代表的是 Adaptive Moment Estimation ...
- 【零基础深度学习教程第二课:深度学习进阶之神经网络的训练】
深度学习进阶之神经网络的训练 神经网络训练优化 一.数据集 1.1 数据集分类 1.2 数据集的划分 1.3 同源数据集的重要性 1.4 无测试集的情况 二.偏差与方差 2.1 概念定义 2.1.1 ...
- 纽约大学深度学习PyTorch课程笔记(自用)Week3
纽约大学深度学习PyTorch课程笔记Week3 Week 3 3.1 神经网络参数变换可视化及卷积的基本概念 3.1.1 神经网络的可视化 3.1.2 参数变换 一个简单的参数变换:权重共享 超网络 ...
最新文章
- C语言求m中n个数字的组合
- java 03_Java基础03—流程控制
- 手把手教你实现Java发送邮件(1)-发送简单的文本
- 计算机教学论研究生,课程与教学论(计算机)专业硕士学位研究生培养方案
- 拿了股权的员工能不干活吗?
- 重要的Python数据分析库
- mysql处理微信表情
- 人工智能助力复工复产,模版OCR轻松搞定健康码识别
- @Aspect 注解使用详解
- revel MySQL_mysql – 如何在Revel Controller中访问Gorm?
- Windows10的虚拟桌面
- vue3学习笔记一:createApp, ref, reactive, onMounted,computed
- 此ca根目录证书不受信任
- 触动千亿电商市场 BitCherry星耀雅加达
- Jmeter录制手机app脚本
- 数字取证二 熟练掌握鉴证大师 了解NTFS分析、LogFile文件使用和USN日志分析
- 新的开始,fighting
- 苹果cms的php.ini,苹果cms伪静态设置教程
- 英语单词词根词缀和词性互相转换
- JS/JavaScript中两个等号 == 和 三个等号 === 的区别
热门文章
- db_connection.execute(sql_str, *args)执行sql语句
- 操作系统(第四版)期末复习总结(上)
- CDH6 安装 Apache atlas
- 深圳软件测试培训:简述关系型数据库和非关系型数据库
- 【ASP.net】浏览器和服务器的交互
- 如何让一个内向的人锻炼与人交流能力?
- Linux 6.2:华为代码加速核心功能 715 倍!
- python记忆式键入,在Python编程模式下输入命令”print(100+200)“执行的结果是()
- linux raid5模拟数据丢失,Linux服务器右异步RAID-5数据恢复实例分析
- 软装设计配饰色彩搭配教程