文章目录

  • 0. 定义
  • 1. 均方误差
  • 2. 交叉熵误差
  • 3. mini-batch中的损失函数
  • 4. 损失函数选择方法

神经网络的学习通过某个指标表示现在的状态,然后以这个指标为基准,寻找最优权重参数,这个指标就是损失函数(loss function)。

如上介绍,神经网络损失函数(loss function)也叫目标函数(objective function)的作用:衡量神经网络的输出与预期值之间的距离,以便控制、调节参数。这个损失函数可以使用任意函数,但一般使用均方误差交叉熵误差

0. 定义

后续公式中需要使用到的定义:

  • y k y_k yk​ 表示神经网络的输出,即 output
  • t k t_k tk​ 表示监督数据,也就是label
  • k k k 表示数据的维数

1. 均方误差

  • 均方误差公式:
    E = 1 2 ∑ k ( y k − t k ) 2 E=\frac{1}{2} \sum_{k}\left(y_{k}-t_{k}\right)^{2} E=21​k∑​(yk​−tk​)2

  • 代码实现:

    import numpy as np
    def mean_squared_error(y, t):return 0.5 * np.sum((y-t) ** 2)
    
  • 举例
    MNIST手写数字识别程序中中,一张图片的label为2,进行one-hot编码后,可以获得t,其中t = [0, 0, 1, 0, 0, 0, 0, 0, 0, 0],假设该图像经过cnn的输出为y:y = [0.1, 0.05, 0.6, 0.0, 0.05, 0.1, 0.0, 0.1, 0.0, 0.0]

    使用mean_squared_error(np.array(y), np.array(t)),经过计算,可以获得:loss = 0.09750000000000003。

2. 交叉熵误差

  • 公式
    E = − ∑ k t k log ⁡ y k E=-\sum_{k} t_{k} \log y_{k} E=−k∑​tk​logyk​
    其中log表示以e为底数的自然对数( l o g e log_e loge​, 即 l n ln ln), l n x ln x lnx函数图像为:

    • 通过上图可以看出来:当x介于0 ~ 1时,对应的函数值为负数。在神经网络中,输出值 y k y_k yk​介于 0 ~ 1之间,所以公式中有一个负号,使得loss为正。
def cross_entropy_error(y, t):delta = 1e-7 return -np.sum(t * np.log(y + delta))

上面实现代码,加上了一个微小的delta,是因为当出现np.log(0)时,np.log(0)会变为负无穷大,添加一个微小值可以防止负无穷大的发生。

使用「均方误差」中的y、t作为参数,调用 cross_entropy_error(np.array(y), np.array(t)),获得结果 0.510825457099338。

3. mini-batch中的损失函数

为了提高训练效率一般都会在每次迭代中使用小批量进行训练,因此计算损失函数时必须将所有的训练数据作为对象。即,如果训练数据有100个,我们就需要将这100个损失函数的总和作为学习的指标

计算公式为: E = − 1 N ∑ n ∑ k t n k log ⁡ y n k E=-\frac{1}{N} \sum_{n} \sum_{k} t_{n k} \log y_{n k} E=−N1​∑n​∑k​tnk​logynk​,也就是把每一个输出loss进行加总求和。

4. 损失函数选择方法

Python深度学习-u3.1:神经网络入门-理论 已经介绍了常见问题中损失函数的选择方法,现摘录如下:

  • 对于二分类问题,可以使用二元交叉熵(binary crossentropy)损失函数;
  • 对于多分类问题,可以用分类交叉熵(categorical crossentropy)损失函数;
  • 对于回归问题, 可以用均方误差(mean-squared error)损失函数;
  • 对于序列学习问题,可以用联结主义时序分类(CTC,connectionist temporal classification)损失函数,等等。
  • eg:imdb电影评论二分类问题
    面对的是一个二分类问题, 网络输出是一个概率值(网络最后一层使用 sigmoid 激活函数, 仅包含一个单元),那么最好使用 binary_ crossentropy(二元交叉熵)损失。 这并不是唯一可行的选择, 比如你还可以使用 mean_ squared_error(均方误差)。但对于输出概率值的模型,交叉熵(crossentropy)往往是最好的选择。
  • 对于分类、回归、序列预测等常见问题,你可以遵循一些简单的指导原则来选择正确的损失函数。参考方案如下:

    注:

    • 具有多个输出的神经网络可能具有多个损失函数(每个输出对应一个损失函数)。但是,梯度下降过程必须基于单个标量损失值。因此,对于具有多个损失函数的网络,需要将所有损失函数取平均,变为一个标量值
    • 只有在面对真正全新的研究问题时,你才需要自主开发目标函数。

常见损失函数 损失函数选择方法相关推荐

  1. 了解机器学习回归的3种最常见的损失函数

    机器学习中的损失函数是衡量你的ML模型的预测结果准确性的一个指标. 损失函数将以两项作为输入:模型的输出值和标准答案的期望值.损失函数的输出称为损失,它是衡量我们的模型在预测结果方面做得有多好. 损失 ...

  2. 经验 | 深度学习中常见的损失函数(loss function)总结

    点上方蓝字计算机视觉联盟获取更多干货 在右上方 ··· 设为星标 ★,与你不见不散 仅作分享,不代表本公众号立场,侵权联系删除 转载于:机器学习算法与自然语言处理出品    单位 | 哈工大SCIR实 ...

  3. Pytorch学习 - Task6 PyTorch常见的损失函数和优化器使用

    Pytorch学习 - Task6 PyTorch常见的损失函数和优化器使用 官方参考链接 1. 损失函数 (1)BCELoss 二分类 计算公式 小例子: (2) BCEWithLogitsLoss ...

  4. 机器学习-常见的损失函数比较

    在机器学习每一个算法中都会有一个目标函数,算法的求解过程是通过对这个目标函数优化的过程.在分类或者回归问题中,通常使用损失函数(代价函数)作为其目标函数.损失函数用来评价模型的预测值和真实值不一样的程 ...

  5. 深度学习中常见的损失函数

    文章来源于AI的那些事儿,作者黄鸿波 2018年我出版了<TensorFlow进阶指南 基础.算法与应用>这本书,今天我把这本书中关于常见的损失函数这一节的内容公开出来,希望能对大家有所帮 ...

  6. 机器学习之常见的损失函数(loss function)

    解决一个机器学习问题主要有两部分:数据和算法.而算法又有三个部分组成:假设函数.损失函数.算法优化.我们一般在看算法书或者视频教学时,更多的是去推算或者说参数估计出其假设函数,而往往不太注重损失函数, ...

  7. 机器学习中常见的损失函数_机器学习中最常见的损失函数

    机器学习中常见的损失函数 现实世界中的DS (DS IN THE REAL WORLD) In mathematical optimization and decision theory, a los ...

  8. 交叉熵损失函数优缺点_【损失函数】常见的损失函数(loss function)总结

    阅读大概需要7分钟 跟随小博主,每天进步一丢丢 机器学习算法与自然语言处理出品 @公众号原创专栏作者 yyHaker 单位 | 哈工大SCIR实验室 损失函数用来评价模型的预测值和真实值不一样的程度, ...

  9. 神经网络模型简介及常见的损失函数

    神经网络模型常见的损失函数 1.神经网络模型简介 神经网络模型一般包含输入层.隐含层和输出层,每一层都是由诸多神经元组成.输入层神经元的个数一般和输入模型的特征(单个样本的维数)有关,输出层神经元的个 ...

最新文章

  1. 【Netty】从 BIO、NIO 聊到 Netty
  2. android AVD运行chrome,contentshell,chromeshell失败解决方法
  3. Class.newInstance()与new、Constructor.newInstance()的区别
  4. CRMEB系统安装访问不了
  5. 使用vbs脚本检查网站是否使用asp.net
  6. 像科学家一样思考python 第二版 epub_Kindle Python教程 – 像计算机科学家一样思考python(第2版) epub,mobi...
  7. mysql索引有几种使用索引的好处_mysql索引的类型和优缺点
  8. SVG 和 CSS3 实现一个超酷爱心 Like 按钮
  9. 前端学习(130):HTML和CSS发展历史
  10. 【元胞自动机】基于matlab元胞自动机图像处理【含Matlab源码 234期 】
  11. 什么是CentOS系统?
  12. Flash CS6 新功能
  13. Mac 安装Yarn
  14. 十三经结业:《诗经》之《蒹葭》赏析
  15. sqlserver 18456登录错误处理
  16. Dubbo笔记 ⑤ : 服务发布流程 - Protocol#export
  17. cpu设计和实现(pc跳转和延迟槽)
  18. 以大多数人的努力程度之低,根本轮不到拼智商
  19. python画螺旋状图形教程_Python实现的绘制三维双螺旋线图形功能示例
  20. Java-SpringBoot-使用SNMP对交换机/服务器进行简单的数据采集

热门文章

  1. 网线属于计算机网络的哪一层,网线的种类分哪几种?
  2. 北京化工大学数据结构2022/10/27作业 题解
  3. Java HashMap底层实现
  4. PMP第13章知识点回顾,练习题
  5. 数据结构 — 浅析红黑树原理以及实现
  6. 通过keil使用汇编语言生成二进制文件,并使用vivado仿真cortexm0处理器
  7. 仿热血江湖帮战客方血帮战 准备记时器结束事件
  8. forward 和 redirect
  9. esmtp 源码 分析
  10. thinkphp图片拖动验证码