向AI转型的程序员都关注了这个号????????????

机器学习AI算法工程   公众号:datayx

本文主要用于记录DSSM模型学习期间遇到的问题及分析、处理经验。先统领性地提出深度学习模型训练过程的一般思路和要点,再结合具体实例进行说明。全文问题和解决方案尽可能具备通用性,同样适用于一般深度学习模型的训练。

深度学习模型训练要素概图

补充:目标函数一般包含经验风险(损失函数或代价函数)和结构风险(正则化项),此处仅指损失函数。

训练深度学习模型,主要需要考虑四个方面(受限于当前认知水平,仅总结了四个方面),分别是:

  1. 数据处理,包含数据清洗和分布;

  2. 模型结构,包括网络层结构设计和一些细节处理,前者主要有输入层设计和隐层设计(输出层设计划分至目标函数),后者主要有初始化、正则化和激活函数;

  3. 目标函数设计,包含目标的意义和难度,前者决定了模型的学习方向,后者影响对模型能否收敛影响很大;

  4. 模型的复杂性,主要包括模型结构复杂性(量化表现是参数数量)和数据复杂性(数据规模与数据本身的特性)。

问题与处理

  • 负样本采集方式过简

最初为了迅速跑通模型,对DSSM-LSTM做了简单的复现,此时的负样本并未采用随机负采样,而是统一选取了负样本空间的前n个(此部分工作已有人完成,我随后接手)。

实际使用模型时,负样本数量远多于正样本,而模型训练时只使用了固定的几类负样本,间接造成正样本多于负样本,显然是不合理的。为了使模型尽可能多地学到负样本特征,采用随机负采样为正样本配平负样,初期正负样本1:4。

由此引发了学习过程中最大的问题——模型无法收敛。

  • 模型不能收敛

使用随机负采样将负样本变得丰富,本是正常操作,却由此导致模型不能收敛(loss大多只在前三个epoch有明显下降,最终loss与最初相比下降幅度不足1/4),实在是不应该,这只能说明模型设计本身存在问题。

模型无法收敛,排除梯度问题以外,通常是问题或目标的复杂性超过了模型的学习能力,数据杂乱、数据复杂、模型结构复杂、损失函数“太难”等。

最初并没有这些经验,先是调整了batch_size和学习率,这仅仅改变了loss的绝对大小,并未改变loss居高不下的问题。随后更改了网络层神经元数量、梯度优化器,也尝试加入激活函数tanh,几乎没有效果。

在此过程中注意到另一个问题——batch_loss变化幅度大,即便在最初三个能下降的epoch中,batch_loss震荡也很厉害。

  • loss震荡幅度大

正常情况下,每个epoch中batch_loss是逐渐减小的,若loss较大且反复震荡,则会导致模型无法收敛,若loss很小,震荡则是趋于收敛的表现。

batch_loss较大,并且震荡,说明数据分布不均匀,经过检查发现数据是和标准问题对应的,比如前50个问题对应问题A,51-110问题对应问题B,其分布具有特定性而非随机性。

因此,每个batch包含的数据差别较大,以batch论,这些batch已经“不算一个数据集”了。解决方法就是随机打乱数据,使其分布没有“特点”,batch之间越接近,数据分布越好。

调整数据分布后,batch_loss相对稳定,loss有了进一步下降,与最初loss相比,最终loss约下降1/3(这是远远不够的,loss下降90%才可初步体现模型效果,至少下降95%才能有较好表现)。

  • 续模型不能收敛

当数据和模型结构无法影响模型收敛性之后,只好试图修改目标函数。修改前,计算loss之前使用softmax函数对输出做了归一化,模型的学习目标由query与正样本的相似度接近1变成了对应的softmax输出接近1。

为了对softmax的输出有直观的认识,模拟了几组数据:

从softmax(a)和softmax(b)可以看出原本巨大的输入差异,在输出层被缩小了,在b中0.9远大于0.01,对应的输出分别为0.37和0.15,差异没有那么大,在a中,0.6也远大于0.05,对应的输出分别为0.29和0.19,差异也没有那么大。

d与b、c相比可以看出最后一个维在整体数据中占比都是90%,但是随着维度的增加,其输出在逐渐下降。

这反映了softmax的两个特性:

其一,缩小原本数据之间的大小差异;

其二,随着维度的增加优势输入(在整体数据中占比较大)的输出会削弱,即输出逐渐下降。

由数据b、c和d可以看出,最后一维这种占比90%的绝对优势维度,其输出也不会达到0.9,且随着维度的增加其值越来越小。因此以某一维度的softmax输出逼近1为学习目标,几乎不可能实现,即损失函数的学习目标太难。

由此,以0.4作为softmax输出的学习目标,间接达到softmax的输入值大于0.9,即query与正样本的相似度大于0.9。更改损失函数后,模型loss迅速下降,终于可以正常训练。

  • 模型差异较大

模型调试阶段,一直以A语料为训练数据,以Top10的语义召回率R为评价指标,随着参数调优,R从0.6逐渐上升,一度达到0.91,由此确定了模型的最佳参数。使用最佳参数配置训练了B语料的模型,R只达到了0.76,同样的配置使用C语料训练模型,R只有0.61。处理同样的任务,

A、B、C语料来自于同样的场景,在模型结果上差距较大,这基本不是模型的问题,更多的可能是数据的问题。在这种假设下,对三种语料的特点做了对比分析。

注):data_size数据集大小,ques_types多分类总类别,quiz<=3,数据量不超过3的类别比例。

从上表中可以看出一条基本规律:数据规模越小,数据类别越多的语料训练出来的模型效果越差。数据规模小说明数据不充分,这对于深度学习模型训练来说确实不利,数据类别多说明数据特性复杂,会增大模型训练的难度。

此外,在C语料中76%类别的问题对应的样本不超过3条, 在B语料中13%类别的问题对应的样本不超过3条,在A语料中仅有8%类别的问题对应的样本不超过3条  ,这表明C语料不仅在整体数据上不充分,在单个类别上更加缺少数据。B语料类别虽然与C接近,但其数据规模相对充分,因此模型训练效果比C的好;同时,B语料规模与A语料接近,但其类别远多于A,因此其模型训练效果不如A。

总之,对于多文本分类问题,语料规模越大,单个类别样本越充足,其模型训练效果越好。

  • 语料模型的微调

上文已分析了机票模型表现差的原因,即数据不充分、特性复杂,但是这并不意味着完全丧失了进一步优化的可能性。

数据就是这个情况,难以改变,目标函数也已被证实有效,无需大的变动,剩下的唯有调整模型结构了和一些超参数了。考虑到数据规模小,相应的应该减少模型参数(模型结构调整),于是从输入层和隐层两个角度对其神经元数量做了削减。

结果表明,输入层神经元减少不仅无益于模型性能提升,反而下降了。这主要是因为,输入层负责将文本转为语义向量对其进行语义表征,而维度降低也意味着表征能力下降,所以不利于模型学习。

而对隐层神经元数量的减小则进一步加快了模型的收敛,并且使模型性能有了一定提升,最终将C语料训练的模型的语义召回率从0.61提升至0.7。此后,再怎么调整模型语义召回率也难以超越0.7。

所以,数据不好是深度学习模型训练的硬伤,虽然可以在算法设计层面进行一定优化,但这种优化是有限的,治标不治本,要想从根本上解决问题,仍需提升数据质量。


阅读过本文的人还看了以下文章:

TensorFlow 2.0深度学习案例实战

基于40万表格数据集TableBank,用MaskRCNN做表格检测

《基于深度学习的自然语言处理》中/英PDF

Deep Learning 中文版初版-周志华团队

【全套视频课】最全的目标检测算法系列讲解,通俗易懂!

《美团机器学习实践》_美团算法团队.pdf

《深度学习入门:基于Python的理论与实现》高清中文PDF+源码

特征提取与图像处理(第二版).pdf

python就业班学习视频,从入门到实战项目

2019最新《PyTorch自然语言处理》英、中文版PDF+源码

《21个项目玩转深度学习:基于TensorFlow的实践详解》完整版PDF+附书代码

《深度学习之pytorch》pdf+附书源码

PyTorch深度学习快速实战入门《pytorch-handbook》

【下载】豆瓣评分8.1,《机器学习实战:基于Scikit-Learn和TensorFlow》

《Python数据分析与挖掘实战》PDF+完整源码

汽车行业完整知识图谱项目实战视频(全23课)

李沐大神开源《动手学深度学习》,加州伯克利深度学习(2019春)教材

笔记、代码清晰易懂!李航《统计学习方法》最新资源全套!

《神经网络与深度学习》最新2018版中英PDF+源码

将机器学习模型部署为REST API

FashionAI服装属性标签图像识别Top1-5方案分享

重要开源!CNN-RNN-CTC 实现手写汉字识别

yolo3 检测出图像中的不规则汉字

同样是机器学习算法工程师,你的面试为什么过不了?

前海征信大数据算法:风险概率预测

【Keras】完整实现‘交通标志’分类、‘票据’分类两个项目,让你掌握深度学习图像分类

VGG16迁移学习,实现医学图像识别分类工程项目

特征工程(一)

特征工程(二) :文本数据的展开、过滤和分块

特征工程(三):特征缩放,从词袋到 TF-IDF

特征工程(四): 类别特征

特征工程(五): PCA 降维

特征工程(六): 非线性特征提取和模型堆叠

特征工程(七):图像特征提取和深度学习

如何利用全新的决策树集成级联结构gcForest做特征工程并打分?

Machine Learning Yearning 中文翻译稿

蚂蚁金服2018秋招-算法工程师(共四面)通过

全球AI挑战-场景分类的比赛源码(多模型融合)

斯坦福CS230官方指南:CNN、RNN及使用技巧速查(打印收藏)

python+flask搭建CNN在线识别手写中文网站

中科院Kaggle全球文本匹配竞赛华人第1名团队-深度学习与特征工程

不断更新资源

深度学习、机器学习、数据分析、python

 搜索公众号添加: datayx  


机大数据技术与机器学习工程

 搜索公众号添加: datanlp

长按图片,识别二维码

深度学习模型训练的一般方法(以DSSM为例)相关推荐

  1. AI佳作解读系列(一)——深度学习模型训练痛点及解决方法

    AI佳作解读系列(一)--深度学习模型训练痛点及解决方法 参考文章: (1)AI佳作解读系列(一)--深度学习模型训练痛点及解决方法 (2)https://www.cnblogs.com/carson ...

  2. 深度学习模型训练的结果及改进方法

    深度学习模型训练的结果及改进方法 模型在训练集上误差较大: 解决方法:1. 选择新的激活函数2. 使用自适应的学习率 在训练集上表现很好,但在测试集上表现很差(过拟合): 解决方法:1. 减少迭代次数 ...

  3. 深度学习模型训练过程

    深度学习模型训练过程 一.数据准备 基本原则: 1)数据标注前的标签体系设定要合理 2)用于标注的数据集需要无偏.全面.尽可能均衡 3)标注过程要审核 整理数据集 1)将各个标签的数据放于不同的文件夹 ...

  4. 收藏 | PyTorch深度学习模型训练加速指南2021

    点上方蓝字计算机视觉联盟获取更多干货 在右上方 ··· 设为星标 ★,与你不见不散 仅作学术分享,不代表本公众号立场,侵权联系删除 转载于:作者:LORENZ KUHN 编译:ronghuaiyang ...

  5. 笔记 | 百度飞浆AI达人创造营:深度学习模型训练和关键参数调优详解

    笔记 | 百度飞浆AI达人创造营:深度学习模型训练和关键参数调优详解 针对特定场景任务从模型选择.模型训练.超参优化.效果展示这四个方面进行模型开发. 一.模型选择 从任务类型出发,选择最合适的模型. ...

  6. 深度学习模型训练和关键参数调优详解

    深度学习模型训练和关键参数调优详解 一.模型选择 1.回归任务 人脸关键点检测 2.分类任务 图像分类 3.场景任务 目标检测 人像分割 文字识别 二.模型训练 1.基于高层API训练模型 加载数据集 ...

  7. 基于linux火焰识别算法,一种基于深度学习模型的火焰识别方法与流程

    本发明属于通信领域,具体涉及一种基于深度学习模型的火焰识别方法. 背景技术: 随着我国工业化与城镇水平的不断提高,现代设施大型公共建筑朝着空间大.进深广功能复杂的多元化方向发展,这对于防烟火朝着空间大 ...

  8. dcm格式的文件里有什么,哪些对于深度学习模型训练有用

    DCM格式的文件通常包含医学图像,如X射线.CT或MRI扫描.这些图像可以用来辅助医生诊断疾病,并且对于深度学习模型训练也非常有用.在医学图像分析方面,深度学习模型可以用来做图像分割.疾病诊断.肿瘤检 ...

  9. 深度学习100问之提高深度学习模型训练效果(调参经验)

    声明 1)本文仅供学术交流,非商用.所以每一部分具体的参考资料并没有详细对应.如果某部分不小心侵犯了大家的利益,还望海涵,并联系博主删除. 2)博主才疏学浅,文中如有不当之处,请各位指出,共同进步,谢 ...

最新文章

  1. 人员梯度培养_关键人才的梯队培养
  2. websocket在.net4.5中实现的简单demo
  3. vue 新窗口打开外链接
  4. SpringBoot2.0 基础案例(10):整合Mybatis框架,集成分页助手插件
  5. centos6.8 安装python3.6
  6. Linux实战 | Centos6.8安装matlab的mount挂载问题的解决方法_3
  7. glibc版本查看_[译] 写一个简单的内存分配器(替换glibc中的malloc函数)
  8. java awt jar_【Java学习笔记】操作JAR文件
  9. 相似矩阵对角化 | 找到一个可逆矩阵 P 使得 P^(-1)AP 成为一个对角矩阵
  10. 网约车源码 打车APP 同城打车代驾小程序源码
  11. 二进制,十进制,十六进制转化
  12. 二进制文件是什么?到底二进制文件和纯文本文件的区别是什么?为什么图像、音频是二进制文件?
  13. 计算机毕业设计Java车辆调度管理系统(源码+系统+mysql数据库+lw文档
  14. Slot-Gated Modeling for Joint Slot Filling and Intent Prediction论文笔记
  15. cisco packet tracer配置网络路由
  16. android 多闹钟实现代码,Android编程实现闹钟的方法详解
  17. 不小心清空了回收站怎么恢复,回收站删除的东西可以恢复吗
  18. SC16IS750芯片SPI转串口
  19. C++学习记录---(6)类和对象-----友元和运算符重载
  20. python的两种退出方式

热门文章

  1. Jquery第二篇【选择器、DOM相关API、事件API】
  2. 20个开发人员非常有用的Java功能代码(二)
  3. 使用手机模拟器与android操作系统
  4. Zygo保存zxg(Zemax File)文件(光学领域知道Zygo的一定要看)
  5. 创建界面_《魔兽世界》智慧烈风buff延长 9.0版本角色创建界面改动
  6. python os模块打开文件_Python 文件操作之OS模块
  7. java转动的风扇课程设计,课程设计—智能风扇设计报告
  8. 长沙android工程师,长沙安卓工程师辅导
  9. char flag[20]c语言,C语言试卷
  10. idea运行springboot出现 Disconnected from the target VM, address: ‘127.0.0.1:xxxx‘, transport: ‘socket‘