回归定义

Regression 就是找到一个函数 function ,通过输入特征 x,输出一个数值 Scalar。

应用举例

  • 股市预测(Stock market forecast)

    • 输入:过去10年股票的变动、新闻咨询、公司并购咨询等
    • 输出:预测股市明天的平均值
  • 自动驾驶(Self-driving Car)
    • 输入:无人车上的各个sensor的数据,例如路况、测出的车距等
    • 输出:方向盘的角度
  • 商品推荐(Recommendation)
    • 输入:商品A的特性,商品B的特性
    • 输出:购买商品B的可能性
  • Pokemon精灵攻击力预测(Combat Power of a pokemon):
    • 输入:进化前的CP值、物种(Bulbasaur)、血量(HP)、重量(Weight)、高度(Height)
    • 输出:进化后的CP值

模型步骤

  • step1:模型假设,选择模型框架(线性模型)
  • step2:模型评估,如何判断众多模型的好坏(损失函数)
  • step3:模型优化,如何筛选最优的模型(梯度下降)

二、利用回归进行预测

预测宝可梦的CP值

2.1 背景

操作:找一个function,输入为一只宝可梦x 的各种属性值,输出为进化后的CP值y ,如图所示。

一只宝可梦记作 x ,其属性值如下:

标记 解释
x c p x_{cp}xcp​ 进化前的CP值(14)
x s x_sxs​ 物种(妙蛙种子)
x h p x_{hp}xhp​ 生命值(10)
x w x_wxw​ 重量(11.62kg)
x h x_hxh​ 身高(0.88m)

2.2 如何解决这个问题呢?

上一节我们知道有三个步骤:

  1. 建模:选一个模型(function set)
  2. 评估:评价模型中function的好坏
  3. 择优:找一个最好的function

详细步骤

2.2.1 找模型(使用一次函数)

先找一个简单的 一次函数

即进化后的CP值y等于某一个常数项b,加上某一个数值w乘以输入的宝可梦x在进化前的CP值。

w和b是参数,为任意值,取不同的数值就得到不同的function

如下图所示:

f3​ 显然不太可能是正确的,因为进化后CP值变负数了。 这就需要训练集来告诉我们哪些function才是合理的function

  • 将无穷多个w i ∗ x ii​相加之和再加上b,就得到了一个 线性模型
  • x i 中抽取出来的各种属性值称作 “特征”(feature),当做function的输入
  • w i 为权重(weight),b \ 为偏置(bias)

2.2.2 评价映射的好坏

1. 收集训练集

先收集训练集,才能找这个function。这是一个监督学习的模型,需要手动给出输入和输出,本例中,它们都是数值。

举例来说,杰尼龟 能进化成 卡咪龟,用x 1 x^1x1表示杰尼龟的各种特征,用^y 1 y^1y1表示进化后的CP值979:x 1 → y 1 x^1→y^1x1→y1。其中这个979是我们实际观察到的正确数值。

继续上述操作,收集更多的x i → y i x^i→y^ixi→yi。首先我们通过少量数据来进行训练,这里收集了10只宝可梦进化后的CP值,将输入作为横轴,输出作为y轴,标记在图像上:

2. 定义映射的好坏

定义另一个函数,用以评价Model中function的糟糕程度——损失函数 L LL(Loss function):

  • 输入:一个function
  • 输出:糟糕程度
  • 定义:L ( f ) = L ( w , b ) L(f) = L(w, b)L(f)=L(w,b),此处使用 最小二乘法

f由w和b决定,因此损失函数实际上是用来衡量一组参数的好坏

用真正的数值减去预测的数值再取平方,就是估测的误差。再将它们相加取和就得到损失函数:

损失函数值越大,即误差越大,function效果越差。

2.2.3 找出最好的映射

根据损失函数的定义可得:找到 使损失函数值最小 的function即为最好的function。可穷举w和b,代入使得损失函数最小,但这个非常消耗时间
,显然不可接受。因此就需要一种方法来较快地寻找 —— 梯度下降法(Gradient Descent)。

梯度下降法

在高等数学中我们学过,梯度就是可微函数 f 在各个方向上求偏导数的向量( f x ′ , f y ′ , f z ′ , . . . ) (f'_x, f'_y, f'_z, ...)(fx′​,fy′​,fz′​,...),表示某一函数在该点处的方向导数沿着该方向取得最大值。

前提:函数 L ( w ) L(w)L(w) 可微分。
做法:

  • 随机选取一个初始点 w 0 w^0w0
  • 做 L LL 对 w ww 在 w 0 w_0w0​ 处的微分,即切线的斜率
    • 斜率为负↘,则增大w:w ww 右移 → L ( w ) L(w)L(w) 减小
    • 斜率为正↗,则减小w:w ww 左移 → L ( w ) L(w)L(w) 减小

那么问题来了,参数值如何增加呢,该增加多少呢?

w向右走的步伐大小取决于两个条件:

  1. 现在的微分值大小:微分值越大,越陡峭,步伐越大;反之越小。
  2. 学习率(Learning rate)η:事先定好的数值,η越大,步伐越大,参数更新的幅度越大,学习的效率更快;反之越小。

注意是减号,微分正,减小w。

  • 重复上述步骤的迭代,直到w移动到 局部最小值

梯度下降法不能找到全局最小值,但是在 线性回归问题 中并没有局部最小,因为线性回归模型的损失函数是 凸函数(convex),局部极值必为全局极值。
【若函数可二阶导(或二阶可偏导),则可利用二阶导数(或偏导数)大于零证明】

以上是一个参数的问题,现在由一个参数推广到两个参数:

类似地,将b利用w的方法,分别计算出L对w和b的偏导数,反复迭代,求出最小的损失函数。

梯度下降法参数调整在等高线中展示效果如下:

2.3 测试训练结果


可以看到曲线拟合得并不完美,在某些点处有很大的误差,或许拟合曲线并不是一条直线。那么我们还能做的更好吗?当然,可以考虑换个更复杂的模型。

2.3.1 换模型(二次函数)

y = b + w 1 ∗ x c p + w 2 ∗ ( x c p ) 2 y=b + w_1*x_{cp}+w_2*(x_{cp})^2y=b+w1​∗xcp​+w2​∗(xcp​)2
可以用上述方法找到一个最好的 function:

  • b = − 10.3 b = -10.3b=−10.3
  • w 1 = 1.0 , w 2 = 2.7 × 1 0 − 3 w_1=1.0, w_2=2.7×10^{-3}w1​=1.0,w2​=2.7×10−3

在训练集上的图像如下:

看起来更加合理了,且平均误差为15.4。那么在测试集上表现如何呢:

平均误差为18.4!有没有可能做得更好呢?


上图分别为三次、四次、五次函数。可以发现,到五次函数的时候已经将训练集拟合得很好了,但是测试集上误差反而更大!这种现象称为 “过拟合”(Overfitting)

2.3.2 过拟合

如何解释这种现象呢?

如上图所示,次数越高的函数,它的解空间包含了低次数函数的解空间。因此越复杂的模型,它包含越多的函数,理论上就可以找出一个函数使得训练误差越来越低,前提是梯度下降法能真正帮我们找到最好的函数,不能出现局部最优。

但是在测试集上的结果与训练集的结果是不一样的:

在到第三个式子为止,测试集的误差一直在降低;但是越往后走,误差就暴增了!越复杂的模型并不一定得到越好的结果,这就是 过拟合。因此我们需要选择一个合适的模型。

2.4 收集大量数据来训练测试

现在我们收集60只宝可梦,把原来和进化后的CP值作图(如下图所示),会发现右上角几个点被某种“隐藏力量”影响了。这到底是什么呢?答案就是物种。

隐藏因素

我们将不同的物种用不同的颜色表示

只考虑进化前的CP值是远远不够的,还要考虑物种对进化的影响,因此需要重新设计模型:不同的物种就代入不同的线性函数。

但是把 if 放在 function 里面,这样不就不是一个线性模型了吗?后续还能对损失函数进行微分、用梯度下降法吗?

可以把上述式子改写成 一个 线性函数:

通过控制δ函数的值,来控制函数的结构,具体如下:

那么对于更换后的模型,它的结果怎么样呢?


对于不同的物种,线的颜色不一样。当分不同的种类来考虑时,所预测的误差比原来的模型误差更小。

但这里还是有不能拟合得很好的点,还有其他可能的因素影响着模型预测结果。

可以把所有因素加入模型,看看结果如何:

可以看到又是过拟合了。怎么办呢?使用 正则化 (Regularization)

我们重新定义损失函数,加入一些辅助的额外项λ,让我们找到比较好的function。其中,λ项越趋于0越好:

这里不需要考虑 bais 这一项。因为我们要找一个平滑的function,调整b的大小只是将函数图像上下移动而已,和平滑程度并无关系。

当我们加上λ项时,就说这个function是比较 平滑的(smooth) ——当输入有变化时,输出对输入的变化是比较 不敏感 的。

假设模型为一次函数模型,在某一个 x i x_ixi​ 加上一个 Δ x i \Delta x_iΔxi​,即:
y = b + Σ w i x i + Δ x i y = b + {\Sigma}w_ix_i + \Delta x_iy=b+Σwi​xi​+Δxi​
这个时候输出的变化:
y → y + w i Δ x i y→y+w_i \Delta x_iy→y+wi​Δxi​
当w i w_iwi​越接近0,输出的变化就越小,也即输出对输入的变化不敏感。

为什么我们喜欢比较平滑的 function ?

  • 假设我们的输入在测试的时候被噪声干扰了,那么一个比较平滑的function就能受到比较小的影响,从而给出一个好的结果。

实验结果

  • λ值越大,代表考虑平滑项的影响力越大,找到的function就越平滑。
  • 当λ值越大时,在测试集上的误差就越大,因为这时就越倾向于考虑 w ww 的数值,而减小考虑训练的误差。
  • 有趣的是,在训练集上得到的误差越大,在测试集上得到的误差是可能比较小的。我们喜欢平滑的function,但不喜欢太平滑的function。

如何找出function有多平滑呢?

  • 这就需要我们调λ的值来决定平滑程度
  • 如上图,选择λ = 100,在测试集上的误差最小。

总结

《深度学习》李宏毅 -- task2 回归相关推荐

  1. 深度学习导论(2)深度学习案例:回归问题

    深度学习导论(2)深度学习案例:回归问题 问题分析 优化方法 代码 采样数据 计算误差 计算梯度 梯度更新 main函数 结果输出 这篇文章将介绍深度学习的小案例:回归问题的问题分析.优化以及实现代码 ...

  2. 深度学习基础--SOFTMAX回归(单层神经网络)

    深度学习基础–SOFTMAX回归(单层神经网络) 最近在阅读一本书籍–Dive-into-DL-Pytorch(动手学深度学习),链接:https://github.com/newmonkey/Div ...

  3. 深度学习原理-----逻辑回归算法

    系列文章目录 深度学习原理-----线性回归+梯度下降法 深度学习原理-----逻辑回归算法 深度学习原理-----全连接神经网络 深度学习原理-----卷积神经网络 深度学习原理-----循环神经网 ...

  4. 深度学习pytorch--线性回归(三)

    线性回归pytorch框架实现 线性回归的简洁实现 生成数据集 读取数据 定义模型 初始化模型参数 定义损失函数 定义优化算法 训练模型 小结 完整代码: 线性回归的简洁实现 随着深度学习框架的发展, ...

  5. 深度学习pytorch--线性回归(一)

    线性回归 线性回归案例 提出问题 模型定义 模型训练 (1) 训练数据 (2) 损失函数 (3) 优化算法 模型预测 线性回归的表示方法 神经网络图 矢量计算 小结 线性回归案例 线性回归输出是一个连 ...

  6. 深度学习——李宏毅第一课2020

    李宏毅深度学习课程 预测宝可梦的战斗力 Regression Market Forecast--预测明天股价如何? self-driving car--预测方向盘角度 Recommendation-- ...

  7. 深度学习-李宏毅PPT总结

    前言: 深度学习话题十分火热,网上的资料也非常多,这的确很头疼,太容易迷失.个人认为寻找大牛的授课ppt作为入门方式就可以,跟随大牛的脚步先画出一条直线,再补充骨肉.Anyway,这篇文章十分适合机器 ...

  8. 【深度学习】logistic回归模型

    目录 神经网络 数据.符号准备: logistic回归: 损失函数和代价函数: 梯度下降法: 向量化: 神经网络 我们学习深度学习的目的就是用于去训练神经网络,而神经网络是什么呢?我们先来看下面一个基 ...

  9. 深度学习:Softmax回归

    在前面,我们介绍了线性回归模型的原理及实现.线性回归适合于预测连续值,而对于分类问题的离散值则束手无策.因此引出了本文所要介绍的softmax回归模型,该模型是针对多分类问题所提出的.下面我们将从so ...

  10. 【动手学深度学习】Softmax 回归 + 损失函数 + 图片分类数据集

    学习资料: 09 Softmax 回归 + 损失函数 + 图片分类数据集[动手学深度学习v2]_哔哩哔哩_bilibili torchvision.transforms.ToTensor详解 | 使用 ...

最新文章

  1. 为实现流行病预测:联邦政府在疫情暴发建模方面的努力和机遇
  2. “绳索”与“链接”:《死亡搁浅》的玩法解构
  3. 路由器配置——OSPF协议(2)
  4. 楼层效果_1一28高楼最好最吉利的楼层是哪层?选楼层要注意什么?
  5. linux cpu负载巡检,linux服务器巡检报告.doc
  6. Comet:基于 HTTP 长连接的“服务器推”技术 (实例)
  7. Storm入门学习随记
  8. [转载] python中numpy包使用方法总结
  9. OpenStack混合云的集成问题如何克服?
  10. 有关古文的C语言编程题,文言文考试也编程,文言语言!!!(附c/c++自译)
  11. zip和unzip命令使用
  12. VTK学习笔记(二十三)vtk空间几何变换
  13. (bug更正)利用KVC和associative特性在NSObject中存储键值
  14. 8421码转16进制的c语言,16进制数转换成8421BCD编码函数
  15. html制作3d动画效果,【分享】HTML5的Canvas制作3D动画效果分享
  16. Android攻城狮Dialog
  17. 计算机网络原理学习资源——相关书籍推荐
  18. [Cue]emulator unknown skin name 'WVGA800'
  19. 中国计算机设计大赛来啦!用飞桨驱动智慧救援机器狗
  20. ECharts圆环图(详细示例——满满的注释)

热门文章

  1. echarts+php+mysql 绘图实例
  2. 老李推荐:第8章2节《MonkeyRunner源码剖析》MonkeyRunner启动运行过程-解析处理命令行参数...
  3. CentOS下NTP安装配置
  4. 一起学java【5】---原生态数据类型使用陷阱
  5. 【深度学习】ImageDataGenerator的使用--读书笔记
  6. onenote 不能同步的原因及解决方法(教训总结)
  7. OpenCV Error: Unsupported format or combination of formats (Unsupported combination of input and out
  8. Remote System Explorer Operation卡死Eclipse解决方案
  9. SQL Server 2012 - 数据表的操作
  10. const char *p;和char * const p的区别