TensorFlow2网络训练技巧

文章目录

  • TensorFlow2网络训练技巧
    • 简介
    • 过拟合与欠拟合
    • 过拟合问题
    • 动量(Momentum)SGD
    • 学习率衰减(learning rate decay)
    • 补充说明

简介

  • 在神经网络这种端到端模型的训练过程中,主要的关注点实际上并不多,参数初始化、激活函数、损失函数、优化方法等,但是过深的模型不免会带来欠拟合和过拟合的问题,为了解决过拟合带来的问题,采用了诸如数据增强、参数正则化、Dropout、Batch Normalization、划分验证集(交叉验证)等方法,这些被称为训练技巧(trick)。
  • 当然,为了应对训练速度慢的问题,有时候也采用一些特殊的训练技巧于优化器上,如加入动量的SGD等。
  • 上述的训练trick在TensorFlow2中都提供了简介高效的API接口,使用时直接调用这些接口即可很方便的控制训练、可视化训练、数据增广等。

过拟合与欠拟合

  • 模型的表达能力(Model Capacity)是多变的。一元线性回归模型的表达能力很弱,它只能拟合线性分布的数据;神经网络的表达能力很强,参数量庞大,可以拟合非常复杂的分布。
  • 深度学习中的模型都是层次非常深的,参数极其复杂。要训练这样的网络是比较困难的,需要大量的数据用于参数的学习调整。数据量过少,网络难以被充分训练,无法达到拟合训练集分布的效果,这种问题是机器学习中常见的欠拟合问题(underfitting),该情况下模型的表达能力不够,模型复杂度(estimated)小于数据真实复杂度(ground-truth)。还有另一种情况,模型复杂度(estimated)大于数据真实复杂度(ground-truth),这是因为训练后期模型为了降低loss过分拟合训练数据,从而导致拟合程度过高,模型失去泛化能力。这种问题在机器学习中称为过拟合问题(overfitting)。过拟合现象在训练可视化过程中的表现为随着训练轮次增加,训练集损失不断减少,验证集损失先减少后增加。
  • 现代机器学习中,神经网络这样的模型深度很深,模型的表达能力很强,常出现的问题是过拟合问题,欠拟合问题已经较少出现。

过拟合问题

  • 检测

    • 划分数据集(Splitting)

      • 将有标注的训练数据拿出小部分划分为验证集,验证集由于包含标签数据,可以利用训练好的模型进行预测得到相关metrics(如accuracy等),用于检测模型的训练情况(包括是否过拟合)。
      • 可以对tensor进行直接划分(该方法不会随机打乱数据集)。训练时可以直接将验证集作为参数传入,在验证集上评测的指标与training设定的metric相同。验证集用于训练过程控制训练,模型最终应用在测试集上。
    • k折交叉验证(K-fold cross validation)
      • 之前的划分是一次性划分,有较大概率无法利用所有数据,因为只能使用划分的训练集进行训练,而不能使用验证集进行训练,这部分验证集数据信息就被放弃了。k折交叉验证是为了充分利用数据性能,多次进行数据集划分,要求每次划分的那部分验证集(如20%)在后面的4折中会作为训练集,这样5折下来,每一部分数据都被作为训练数据过(5折交叉验证是最常见的,更少不能充分验证,更多训练量过大)。
      • 在TensorFlow2中可以自行实现k折划分。当然,TensorFlow2也提供比较简单的接口,只需要指出划分比例,则会自动划分出验证集。
  • 处理
    • 充分的数据

      • 充分的数据可以有效训练网络,要求网络进行更多学习,减轻过拟合。正是ImageNet这样的大规模标注数据集,深度学习的发展才会如此迅速。
    • 降低模型复杂度
      • 正则化(Regularization)方法通过在loss函数中添加惩罚项,迫使参数的范数趋近于0,从而使得低复杂的参数较大,使得复杂网络退化。在有些地方,该方法也叫作weight decay(参数衰减)。
      • keras模块下的layer中参数正则化非常简单,只要传入正则化方法对象即可。实际使用中,TensorFlow2提供更灵活的方法。
    • Dropout方法
      • 一个简单粗暴的防止过拟合的方法,以一定概率关闭神经元连接,迫使网络学习更多。
      • 在TensorFlow2中,keras的layers模块将Dropout操作封装为一层的操作,通过堆叠使用。但是注意,添加了Dropout的网络前向传播是必须制定training参数,因为测试集上预测是不应该断开连接,这只是训练时的技巧。
        network = Sequential([layers.Dense(256, activation='relu'),layers.Dropout(0.5), # 0.5 rate to droplayers.Dense(128, activation='relu'),layers.Dropout(0.5), # 0.5 rate to droplayers.Dense(64, activation='relu'),layers.Dense(32, activation='relu'),layers.Dense(10)])
        
    • 数据增广
    • Early Stopping
      • 早停是防止过拟合的一种常用手段,当训练时验证集metric已经饱和或者开始变坏达到指定次数时,停止训练。
      • 通过keras的callbacks模块可以很方便实现这个功能。
        es = keras.callbacks.EarlyStopping(monitor='val_acc', patience=5)
        

动量(Momentum)SGD

  • 梯度更新的方向不仅仅依赖于当前梯度的方向,而且依赖于上一次梯度的方向(可以理解为惯性)。
  • 通过添加动量项,可以使得梯度下降算法找到更好的局部最优解或者全局最优解。但是,有时候动量SGD有可能花费更多的时间找到不是很好的解。
  • 在TensorFlow2中,动量项的梯度更新不需要人为完成,只需要指定动量超参数权值,其余交由优化器完成即可。很多优化算法如Adam是默认使用momentum策略的,不需要人为指定。其中,指定动量项权值为0.9是一个常用策略。

学习率衰减(learning rate decay)

  • 训练后期,过大的学习率可能导致不断波动,难以优化。此时采用学习率衰减策略会是一个不错的方法,该策略后期会自动调整学习率。
  • 同样的,keras的callbacks模块提供了回调函数用于减少学习率。这里的衰减是触发执行的,即后期monitor监控的值不再变好的次数达到patience则会降低学习率。
    rl = keras.callbacks.ReduceLROnPlateau(monitor='val_acc', patience=5)
    
  • 也可以在训练过程中,手动确定衰减策略降低学习率。
    optimizer.learning_rate = 0.2 * (100-epoch)/100
    

补充说明

  • 本文主要针对TensorFlow2中训练技巧进行了简单使用上的介绍。
  • 博客同步至我的个人博客网站,欢迎浏览其他文章。
  • 如有错误,欢迎指正。

TensorFlow2-网络训练技巧相关推荐

  1. 深度学习 网络训练技巧

    网络训练技巧: 1.数据增强:缩放.随机位置截取.翻卷.随机旋转.亮度.对比度.颜色变化等方法. 2.学习率衰减:随着训练的进行不断的减小学习率. 例如:一开始学习率0.01,在10000步后降为0. ...

  2. 深度学习网络训练技巧

    (改了点) 转载请注明:炼丹实验室 新开了一个专栏,为什么叫炼丹实验室呢,因为以后会在这个专栏里分享一些关于深度学习相关的实战心得,而深度学习很多人称它为玄学,犹如炼丹一般.不过即使是炼丹也是可以摸索 ...

  3. 深度学习-网络训练技巧

    1.深度学习的一些基本概念,学习率.batch.epoch.optimizer.评价函数(损失函数)等 1.1 学习率(Learning Rate) 学习率:是控制模型学习效率(步长)的权重. 学习率 ...

  4. 网络训练技巧--参数初始化与优化方法

    实验日期 2021.10.07 实验环境 python 3.8.8 64-bit('base':conda) Initialization 深度学习的参数权重是很重要的,设置不当可能会导致梯度消失或梯 ...

  5. Numpy实现BP神经网络(包含Dropout、BN等训练技巧)

    BP神经网络 简介 本文主要通过在MNIST数据集上使用全连接神经网络对比有无归一化.有无Dropout以及有无Batch Normalization进行训练,可视化分析常用的深度网络训练技巧的原因及 ...

  6. 干货丨一文看懂生成对抗网络:从架构到训练技巧

    文章来源:机器之心 论文地址:https://arxiv.org/pdf/1710.07035.pdf 生成对抗网络(GAN)提供了一种不需要大量标注训练数据就能学习深度表征的方式.它们通过反向传播算 ...

  7. 深入理解生成对抗网络(GAN 基本原理,训练崩溃,训练技巧,DCGAN,CGAN,pix2pix,CycleGAN)

    文章目录 GAN 基本模型 模型 GAN 的训练 模式崩溃 训练崩溃 图像生成中的应用 DCGAN:CNN 与 GAN 的结合 转置卷积 DCGAN CGAN:生成指定类型的图像 图像翻译中的应用 p ...

  8. VGG16、VGG19网络架构及模型训练 tricks :训练技巧、测试技巧

    在上一篇文章的基础之上,总结一下论文中提出的训练技巧和测试技巧.上一篇文章参考:VGG论文笔记--VGGNet网络架构演变[VGG16,VGG19] 一.训练技巧 技巧1:Scale jitterin ...

  9. Caffe技巧之使用snapshot来继续网络训练

    Caffe技巧之使用snapshot来继续网络训练 Caffe技巧之使用snapshot来继续网络训练 Step 1设置solverprototxt Step 2设置运行脚本sh 有时候想在已经训练好 ...

最新文章

  1. 阿里NASA计划:城市大脑成智能研究第一平台
  2. html标签ref,HTML: param 标签
  3. 中国反渗透膜产业竞争现状与投资战略决策报告2021-2027年版
  4. SQL基础【六、and与or】
  5. errno的基本用法
  6. MOBIUS:百度凤巢新一代广告召回系统
  7. 第二季2:/package/mpp/sample的总体分析
  8. python实用黑客脚本_Python黑客攻防(十六)编写Dos脚本,进行容易攻击演示
  9. php打造自己的喜马拉雅,打造自己的私人知识宝库利器——mybase 7.3.5
  10. python字典实现原理-哈希函数-解决哈希冲突方法
  11. PLC哪些编程软件可以通用?
  12. element ui中el-image不显示图片
  13. 分析一块某宝上的WiFi摄像头模块
  14. ARM:ARM体系结构与编程、ARM指令流水线、ARM编程模型基础
  15. 学号密码错误的计算机流程图,自学考试管理信息系统练习题及答案
  16. 软考-嵌入式系统设计师-笔记:历年专业英语题
  17. 互联网思考悲伤之后 如何重新定位移动互联网方向
  18. 爬取百思不得姐上面的视频
  19. spring interation学习-01发送jms消息
  20. yolo-v3代码学习

热门文章

  1. Apollo配置中心原理简介
  2. ConcurrentHashMap的初步使用及场景
  3. Spring 事务core 模块-RowMapper
  4. 高仿真的类-ApplicationContext
  5. RequestToViewNameTranslator
  6. Spring IOC 容器根据Bean 名称或者类型进行autowiring 自动依赖注入
  7. CGLib 和JDK 动态代理对比
  8. 通过JSR250规范 提供的注解@PostConstruct 和@ProDestory标注的方法
  9. MybatisPlus实现分页
  10. 获取class文件对象的三种方式