TensorFlow2 大幅提高模型准确率的神奇操作

  • 过拟合
  • Regulation
    • 公式
    • 例子
  • 动量
    • 公式
    • 例子
  • 学习率递减
    • 过程
    • 例子
  • Early Stopping
  • Dropout

过拟合

当训练集的的准确率很高, 但是测试集的准确率很差的时候就, 我们就遇到了过拟合 (Overfitting) 的问题. 如图:


过拟合产生的一大原因是因为模型过于复杂. 下面我们将通过讲述 5 种不同的方法来解决过拟合的问题, 从而提高模型准确度.

Regulation

Regulation 可以帮助我们通过约束要优化的参数来防止过拟合.

公式

未加入 regulation 的损失:

加入 regulation 的损失:

λ 和 lr (learning rate) 类似. 如果 λ 的值越大, regularion 的力度也就越强, 权重的值也就越小.

例子

添加了 l2 regulation 的网络:

network = tf.keras.Sequential([tf.keras.layers.Dense(256, kernel_regularizer=tf.keras.regularizers.l2(0.001), activation=tf.nn.relu),tf.keras.layers.Dense(128, kernel_regularizer=tf.keras.regularizers.l2(0.001), activation=tf.nn.relu),tf.keras.layers.Dense(64, kernel_regularizer=tf.keras.regularizers.l2(0.001), activation=tf.nn.relu),tf.keras.layers.Dense(32, kernel_regularizer=tf.keras.regularizers.l2(0.001), activation=tf.nn.relu),tf.keras.layers.Dense(10)
])

动量

动量 (Momentum) 是指运动物体的租用效果. 在梯度下降的过程中, 通过在优化器中加入动量, 我们可以减少摆动从而达到更优的效果.

未添加动量:


添加动量:

公式

未加动量的权重更新:

  • w: 权重 (weight)
  • k: 迭代的次数
  • α: 学习率 (learning rate)
  • ∇f(): 微分

添加动量的权重更新:

  • β: 动量权重
  • z: 历史微分

例子

添加了动量的优化器:

optimizer = tf.keras.optimizers.SGD(learning_rate=0.02, momentum=0.9)
optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.02, momentum=0.9)

注: Adam 优化器默认已经添加动量, 所以无需自行添加.

学习率递减

简单的来说, 如果学习率越大, 我们训练的速度就越大, 但找到最优解的概率也就越小. 反之, 学习率越小, 训练的速度就越慢, 但找到最优解的概率就越大.

过程

我们可以在训练初期把学习率调的稍大一些, 使得网络迅速收敛. 在训练后期学习率小一些, 使得我们能得到更好的收敛以获得最优解. 如图:

例子

learning_rate = 0.2  # 学习率
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate, momentum=0.9)  # 优化器# 迭代
for epoch in range(iteration_num):optimizer.learninig_rate = learning_rate * (100 - epoch) / 100  # 学习率递减

Early Stopping

之前我们提到过, 当训练集的准确率仍在提升, 但是测试集的准确率反而下降的时候, 我们就遇到了过拟合 (overfitting) 的问题.

Early Stopping 可以帮助我们在测试集的准确率下降的时候停止训练, 从而避免继续训练导致的过拟合问题.

Dropout

Learning less to learn better

Dropout 会在每个训练批次中忽略掉一部分的特征, 从而减少过拟合的现象.

dropout, 通过强迫神经元, 和随机跳出来的其他神经元共同工作, 达到好的效果. 消除减弱神经元节点间的联合适应性, 增强了泛化能力.

例子:

network = tf.keras.Sequential([tf.keras.layers.Dense(256, activation=tf.nn.relu),tf.keras.layers.Dropout(0.5),  # 忽略一半tf.keras.layers.Dense(128, activation=tf.nn.relu),tf.keras.layers.Dropout(0.5),  # 忽略一半tf.keras.layers.Dense(64, activation=tf.nn.relu),tf.keras.layers.Dropout(0.5),  # 忽略一半tf.keras.layers.Dense(32, activation=tf.nn.relu),tf.keras.layers.Dense(10)
])

TensorFlow2 大幅提高模型准确率的神奇操作相关推荐

  1. 深度学习提高模型准确率方法

    这里写目录标题 深度学习 数据 使用更多数据 更改图像大小 减少颜色通道 算法 模型改进 增加训练轮次 迁移学习 添加更多层 调整超参数 总结 深度学习 我们已经收集好了一个数据集,建立了一个神经网络 ...

  2. 内涝预测过程的噪音_提高人工智能模型准确率的测试过程中需要注意什么?

    黑马程序员视频库 播妞微信号:boniu236 传智播客旗下互联网资讯.学习资源免费分享平台 现在人工智能行业发展迅猛,那么人工智能产品特别是使用分类算法实现的产品中判断其能否上线通常是通过算法自带的 ...

  3. 数据挖掘读书笔记--第八章(下):分类:模型评估与选择、提高分类器准确率技术

    散记知识点 --"评估分类器,提高分类器" 5. 模型评估与选择 5.1 评估分类器性能 (1) 评估分类器性能的度量 评估分类器性能的度量主要有:准确率(识别率).敏感度(召回率 ...

  4. 【深度学习】90.94%准确率!谷歌刷新ImageNet新纪录!Model soups:提高模型的准确性和稳健性...

    丰色 发自 凹非寺 转载自:量子位(QbitAI) 如何最大限度地提升模型精度? 最近,谷歌等机构发现: 性能不好的微调模型先不要扔,求一下平均权重! 就能在不增加推理时间以及内存开销的情况下,提高模 ...

  5. 迁移学习训练集准确率一直上不去_可以提高你的图像识别模型准确率的7个技巧...

    假定,你已经收集了一个数据集,建立了一个神经网络,并训练了您的模型. 但是,尽管你投入了数小时(有时是数天)的工作来创建这个模型,它还是能得到50-70%的准确率.这肯定不是你所期望的. 下面是一些提 ...

  6. 使用迁移学习后使用微调再次提高模型训练的准确率

    使用迁移学习后使用微调再次提高模型训练的准确率 1.微调 所谓微调:冻结模型库的底部的卷积层,共同训练新添加的分类器层和顶部部分卷积层.这允许我们"微调"基础模型中的高阶特征表示, ...

  7. 提升10%!如何将机器学习模型准确率从80%提高到90%以上

    全文共2402字,预计学习时长7分钟 图源:unsplash 说实在的,如果你有过项目实践经历,就会明白80%的精确度并不算糟糕.但在现实世界中,人们期望精确度不会少于80%.事实上,我工作过的大多数 ...

  8. MASTER:全局上下文建模大幅提高文本识别精度

    点击我爱计算机视觉标星,更快获取CVML新技术 今天跟大家分享一篇昨天新出的场景文本识别方法MASTER,其发明了一种Multi-Aspect 全局上下文建模方法,有效改进了文本识别精度,在多个数据集 ...

  9. 如何一步一步提高图像分类准确率?

    一.问题描述 当我们在处理图像识别或者图像分类或者其他机器学习任务的时候,我们总是迷茫于做出哪些改进能够提升模型的性能(识别率.分类准确率)...或者说我们在漫长而苦恼的调参过程中到底调的是哪些参数. ...

最新文章

  1. 三维点云对应关系聚合算法的性能评价
  2. 安卓高手之路 图形系统(3 底层SurfceFlinger系统)
  3. ArcGIS中标注之一上下标、分数等特殊形式标注(转)
  4. 用vector实现一个变长数组
  5. UDP千兆以太网FPGA_verilog实现(四、代码前期准备-UDP和IP协议构建)
  6. 程序员幽默:老板让明天带条鱼来大家观察
  7. win7系统如何访问xp系统的服务器,WIN7系统怎么让XP系统访问呢
  8. React开发(171):处理删除与批量删除操作
  9. 【HDU - 5017】Ellipsoid(爬山算法,模拟退火,三分)
  10. 替换jar包_替换代码的情况下不停机!这操作可能工作6年的Java程序员都不会
  11. Argument list too long 文件数过多
  12. 生成core文件的步骤
  13. mysql大数据高并发处理
  14. s3c2410_gpio_setpin()等系列函数
  15. 批量建立域帐号,摆脱管理员的痛!(原创+实战)
  16. 北航计算机组成原理课程设计-2020秋 PreProject-Logisim-Logisim仿真与调试应用与挑战
  17. STM32跑html协议,STM32移植SBUS协议
  18. php把amr转换成mp3,PHP 将amr音频文件转换为mp3格式
  19. 阿铭Linux_网站维护学习笔记20190408
  20. 如何添加使用微信小程序,教程在这里,微信小程序怎样添加使用

热门文章

  1. 【博学谷学习记录】超强总结,用心分享 | 架构师 Kafka学习总结
  2. 东芝打印机2551C A3试卷打印技巧
  3. 新形象,新征程,新时代,光环云荣获2021GIDC“最佳IDC技术融合趋势奖
  4. superset table 表头汉化 ; JS 动态属性名 key
  5. Py之matplotlib:matplotlib绘图中与颜色相关的参数(color颜色参数、linestyle线型参数、marker标记参数)可选列表集合(建议收藏)
  6. common.js 2017
  7. 将十六进制数的ASCII码转换为十进制数。十六进制数的值域为0~65535,最大转换为五位十进制数。要求将缓冲区的000CH的ASCII码转换为十进制,并将结果显示在屏幕上。
  8. Seasar サイトマップ
  9. ctf攻防世界crypto新手区
  10. 华为手机的操作技巧,快来看看