混合神经网络中全连接层的一些技巧

我们做深度学习时往往要结合多种神经网络来构建模型,这种模型构建方式称之为混合神经网络模型,不管是CNN、RNN、BERT、RoBERTa还是混合神经网络模型,最后的最后一定要接一个全连接层来学习提取到的特征并转换为我们输出的维度(此处不讨论实现自注意力机制的全连接和Transformer中的全连接层)。全连接层的设置将最大化的挖掘模型的性能,怎么设置全连接层成了一个重点和难点。
结合自身经验,以下是个人认为的tricks:

  1. 批归一化:nn.BatchNorm1d(dim)
    批归一化操作是为了将输入中所有特征约束在(0~1)之间,防止强特征数值过大导致忽略弱特征,不过视具体情况而定吧,有的时候输入本身就有强弱之分,这时你归一化反而消除了强特征,多试试。

  2. 隐藏层层数:视情况而定,和模型的复杂性成反比。
    我们都知道全连接层一般越多越好,但是必须有非线性激活函数和Dropout,否则再多的线性层也等价于一层,但是实际上层数越多,模型要学习的东西就越复杂,这个时候模型难拟合或者过拟合,尤其是对于复杂模型来说,全连接层往往不需要过深,设置1-2层即可。

  3. 神经元个数:第一层适当增加,后续逐层递减。
    根据一些论文,有如下tricks:
    1、隐藏单元的数量不应该超过输入层中单元的两倍[1]
    2、隐藏单元的大小应该介于输入单元和输出单元之间[2]
    3、神经元的数量应捕获输入数据集方差的70~90%[3]
    根据我的实践,第一层适当增加比如100个神经元会取得不错的效果,后续则逐层递减,最后一层就是我们要输出的特征个数。

  4. 激活函数:RELU、GELU、ELU。
    传统的激活函数Sigmoid、Tanh把输入控制在(0,1)、(-1,1)之间,容易产生梯度消失和梯度爆炸问题,所以往往要结合RELU:max(0,x)及其变种来防止这个问题,而GELU是预训练模型BERT[4]采用的激活函数,是一种更好更快的激活函数。
    其形式:
    GELU(x)=0.5x(1+tanh[2/π(x+0.044715x3)])GELU(x)=0.5x(1+tanh[\sqrt{2/π}(x+0.044715x^3)])GELU(x)=0.5x(1+tanh[2/π​(x+0.044715x3)])
    其代码:

def gelu(input_tensor):cdf = 0.5 * (1.0 + tf.erf(input_tensor / tf.sqrt(2.0)))return input_tesnsor*cdf
  1. Dropout:0.1~0.5
    Dropout最早由Hinton[5]提出,用于防止过拟合,一般为0.5,但是Transformer[6]中使用的是0.1,根据经验,这个也和模型的复杂度成反比,太大模型难以拟合,太小过拟合,视情况而定。
  2. 权重衰减和学习率衰减:AdamW、get_linear_schedule_with_warmup
    学习率是一个非常重要的参数,他决定了模型的上限,学习率一般采用预热的方式,即线性地先增后减[7]:

    最大学习率根据 Learning Rate Finder 来确定,最小值则可以取最大值的十分之一。

[1] M. J. A. Berry and G. S. Linoff, Data Mining Techniques: For Marketing, Sales, and Customer Support, New York: John Wiley & Sons, 1997.
[2] A. Blum, Neural Networks in C++, New York: Wiley, 1992.
[3] Z. Boger and H. Guterman, “Knowledge extraction from artificial neural network models,” in Systems, Man, and Cybernetics, 1997. Computational Cybernetics and Simulation., 1997 IEEE International Conference, Orlando, FL, 1997.
[4] Devlin J, Chang M-W, Lee K, Toutanova K. BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding. arXiv preprint, arXiv:1810.04805 [cs.CL] 2018.
[5] Hinton G E, Srivastava N, Krizhevsky A, et al. Improving neural networks by preventing co-adaptation of feature detectors. arXiv preprint, arXiv:1207.0580 [cs.NE] 2012
[6] Vaswani A, Attention Is All You Need, arXiv preprint, arXiv : 1706.03762 [cs.CL] 2017.
[7] Leslie N. Smith. A disciplined approach to neural network hyper-parameters: Part 1 – learning rate, batch size, momentum, and weight decay. arXiv preprint, arXiv : 1803.09820[cs.LG] 2018.

全连接层调参tricks相关推荐

  1. 深度学习调参tricks总结!

    ↑↑↑关注后"星标"Datawhale 每日干货 & 每月组队学习,不错过 Datawhale干货 作者:山竹小果,来源:NewBeeNLP 寻找合适的学习率(learni ...

  2. 深度学习调参tricks总结

    点击上方"3D视觉工坊",选择"星标" 干货第一时间送达 作者丨山竹小果 来源丨NewBeeNLP 编辑丨极市平台 导读 本文总结了一系列深度学习工作中的调参策 ...

  3. 网络骨架:Backbone(神经网络基本组成——BN层、全连接层)

    BN层 为了追求更高的性能,卷积网络被设计得越来越深,然而网络却变得难以训练收敛与调参.原因在于,浅层参数的微弱变化经过多层线性变化与激活函数后会被放大,改变了每一层的输入分布,造成深层的网络需要不断 ...

  4. Lesson 16.1016.1116.1216.13 卷积层的参数量计算,1x1卷积核分组卷积与深度可分离卷积全连接层 nn.Sequential全局平均池化,NiN网络复现

    二 架构对参数量/计算量的影响 在自建架构的时候,除了模型效果之外,我们还需要关注模型整体的计算效率.深度学习模型天生就需要大量数据进行训练,因此每次训练中的参数量和计算量就格外关键,因此在设计卷积网 ...

  5. 卷积神经网络CNN要点:CNN结构、采样层、全连接层、Zero-padding、激活函数及Dropout

    CNN结构: 卷积层:特征提取: 采样层:特征选择: 全连接层:根据特征进行分类. 采样层(pooling): max-pooling:克服卷积层权值参数误差: average-pooling:克服卷 ...

  6. 卷积核和全连接层的区别_「动手学计算机视觉」第十六讲:卷积神经网络之AlexNet...

    前言 前文详细介绍了卷积神经网络的开山之作LeNet,虽然近几年卷积神经网络非常热门,但是在LeNet出现后的十几年里,在目标识别领域卷积神经网络一直被传统目标识别算法(特征提取+分类器)所压制,直到 ...

  7. 2 RepMLP:卷积重参数化为全连接层进行图像识别 (Arxiv)

    论文地址: RepMLP: Re-parameterizing Convolutions into Fully-connected Layers for Image Recognition​arxiv ...

  8. 卷积层和全连接层的区别_CNN卷积层、全连接层的参数量、计算量

    我们以VGG-16为例,来探讨一下如何计算卷积层.全连接层的参数量.计算量.为了简单.直观地理解,以下讨论中我们都会忽略偏置项,实践中必须考虑偏置项. [卷积层的参数量] 什么是卷积层的参数? 卷积层 ...

  9. “重参数宇宙”再添新成员:RepMLP,清华大学旷视科技提出将重参数卷积嵌入到全连接层

    编辑:Happy 首发:AIWalker paper: https://arxiv.org/abs/2105.01883 code: https://github.com/DingXiaoH/RepM ...

最新文章

  1. pytorch 模型可视化_高效使用Pytorch的6个技巧:为你的训练Pipeline提供强大动力
  2. 新入公司 问问题 ,快速了解代码的方法
  3. 文计笔记2: 计算机硬件知识
  4. ubuntu 定时执行任务at
  5. php取消mysql警告_mysql登录警告问题的解决方法
  6. java 生成器 设计模式_Java中的生成器设计模式
  7. Jquery Datatable 数据填充报错:requested unknown parameter ‘XXX‘ for row xx, column xx 解决方法
  8. python变量名区分大小写_python变量名要不要区分大小写
  9. tcpdump 命令快速实用参考手册
  10. PyTorch 1.0 中文文档:torch.Storage
  11. flash制作文字笔顺_用FLASH制作汉字笔顺动画
  12. 型机器人同人本子_唯美的人×机器人漫画《純情愛玩生化女友》
  13. linux 终端翻译,linux下终端使用有道翻译
  14. 【论文笔记】ARBITRAR: User-Guided API Misuse Detection
  15. 建筑企业收并购系列二:股转与吸收合并
  16. 手游平台系统搭建sdk服务端接口文档
  17. ACCU天气API以及Okhttp、Retrofit、RxJava的使用
  18. 如何快速了解一个系统
  19. ubuntu linux下直观的网络流量监控
  20. usb接口驱动_UART串行总线舵机转接板规格、接线说明 amp; 驱动安装

热门文章

  1. OpenWrt共享打印机关键问题
  2. 【消费战略方法论】认识消费者的恒常原理(三):消费者刺激反馈原理
  3. opencv4.3 Stitcher图像拼接方法——学习笔记1
  4. android 图片比例计算器,Algeo图形计算器
  5. 喜讯 | 美格智能荣获2022“物联之星”年度榜单之中国物联网企业100强
  6. 高校计算机考试准考证号
  7. rocketmq源码分析
  8. Unity5.X导入FBX文件,播放动画时位置变动的解决方法
  9. 自动输送线图纸输送机链板流水线倍速链皮带线SW机械设计
  10. 我眼中的Linux系统和红帽RHCE认证