手工实现:SVM with Stochastic Gradient Descent

  • 引入
  • 实际问题
  • 理论知识
    • SVM
      • 直观认识
      • 什么是分的好?
        • 1、是不是只要都分对了就是分的好?
        • 2、是不是只要训练集分的好就是分的好?
    • Gradient Descent
      • cost 最小的模型
      • Gradient Descent 是啥
    • SGD(Stochastic Gradient Descent)
  • 实现
    • 选择
    • 思路
    • 结果
  • 后记
  • Reference

引入

之前的两篇我们一直在搞Naive Bayes Classifier,今天我们来换一个分类器试试。
SVM(support vector machine),支持向量机,虽让从名字上看不出来,但是它确实是个分类器。如果搜索引擎搜一下SVM的话,其实已经有很多资料很详细的介绍了SVM,但是正如之前两篇一样,本篇文章会着重于直观的解释相关知识点,并以解决实际问题为导向手工实现该算法。

实际问题

我们能否基于人口普查数据建立分类器,并对收入情况进行预测?
本次试验中,我们将使用UCI的 “Adult Data Set”。该数据集包含了被调查者的年龄、受教育程度、 工作岗位、性别等信息,并且标出了被调查者的收入分类(>50K,<=50K)。

在实现过程中,我们将在训练集上建立 SVM, 在测试集上评估模型的精度。

理论知识

SVM

直观认识

首先让我们举个栗子,直观认识一下SVM。
假设我们现在有这样一个训练集:
维度有x(1)x_{(1)}x(1)​、x(2)x_{(2)}x(2)​两个;类有正类、负类两个,正类由红色表示,负类由蓝色表示。

如果我们用一个SVM分类,它将这样做:
1、使用某个函数 h(x)h(x)h(x),图例中为aTX+ba^TX+baTX+b,对所有的点进行处理,将它们投到一个新的维度(γ\gammaγ)当中。γ\gammaγ > 0 的点被预测为正类,γ\gammaγ < 0 的点被预测为负类。
2、通过一定方法找到步骤1中分的最好的函数(cost最低的),直观上来讲,这个函数对应的那个灰色矩形平面就是我们训练出来的模型。

3、若是要预测新的点的类别,只需要用同样函数处理新来的点(粉),γ\gammaγ > 0 则被预测为正类,γ\gammaγ < 0 则被预测为负类。

什么是分的好?

作为一只有追求的分类器就想要尽量分类分的好,但是究竟什么是分的好呢?

1、是不是只要都分对了就是分的好?

并不是~ 因为除了分对以外,我们还想要分的开。

如果我们只是以分对为标准,那么 l1l_1l1​ 和 l2l_2l2​ 就没有什么差别。

但是实际情况并非如此。如果我们在训练集上分的不够开,那么将来就很有可能出现像紫点这样的情况。

所以我们在训练时就要考虑尽量分开这件事,即:在交线附近划一个区域,如果有离得太近的点(e1e_1e1​)存在,模型也要付出一丢丢代价。(当然如果本身都已经分错了的话,代价就会更高,而且错的越离谱代价越大(离分界越远,代价越大))
体现在数学上,就是我们使用 Hinge Loss 作为 cost function。在这个例子中就是:
C(y,γ)=max(0,1−yγ)C(y, \gamma) = max(0, 1-y \gamma)C(y,γ)=max(0,1−yγ)
其中:
(1)yyy 为这个点实际的类别,为了方便计算,这里需要提前统一把正类的 yyy 赋值为1,负类的赋值为-1。
(2)γ\gammaγ 为 aTx+ba^Tx+baTx+b 的结果。
PS:这个函数之所以叫 Hinge Loss 是因为这个函数长得像 hinge。 感兴趣的盆友可以自己画画看。

2、是不是只要训练集分的好就是分的好?

并不是~
因为仔细观察我们现在的cost function可以发现,如果在训练集上分的全对且相对比较开的情况下,同比例放大所有的 aaa 并不会对cost产生任何变化,因为都是0。这也就意味着,我们可选的 aaa 并不是唯一的。这是否意味着我们可以随便挑一个 aaa 呢?
其实并不是。假设我们现在选了一个巨大无比的 aaa,这时数据点中又出现一个新的点时,那么只要它稍稍在分界线外正类的一方,我们就会把毫无惩罚的分到正类。但是这并不是我们期待的结果。因为我们希望对于那些虽然分对,但是离分界太近的也要有一些惩罚。因此我们需要对参数有一定的限制。
这里为了增加限制,我们改变了 cost function(之前了解过regularization的盆友应该可以明白,其实加的这个限制就是正则化):
1N∑max(0,1−yiγi)+λ2aTa{1 \over N}\sum max(0, 1-y_i \gamma_i) + {\lambda \over 2} a^TaN1​∑max(0,1−yi​γi​)+2λ​aTa
(1)yiy_iyi​ 表示第 iii 个点实际的类。
(2)γi\gamma_iγi​ 表示第 iii 个点经过 aTx+ba^Tx+baTx+b 以后,得到的结果
(3)这里的 λ\lambdaλ 是一个 超参数(hyperparameter),并不是由训练得到的参数,而是一个需要提前定下来的值。

Gradient Descent

上面 “什么是分的好?”这部分已经解决了合理表示 Cost function 的问题,下一步,我们就需要通过一定方法找到 cost 最小的模型了,那么这个方法是什么呢?

cost 最小的模型

说到这先暂停一下,让我们先想一下“找到 cost 最小的模型”究竟是怎样的一件事。
1、Cost:我们的cost是针对某一个特定的模型而计算的。是所有点的实际label,和某一特定模型下所有点的预测 label 对比产生的结果
2、模型:当我们没有任何限制的时候,其实会出现无限多个模型(即便在限定使用特定核函数时(eg:之前的 aTX+ba^TX+baTX+b)),因为核函数的参数有无限多种可能(eg:aTX+ba^TX+baTX+b 中的向量 aaa 和 数值 bbb 都可以有无限多个选择)。换句话说,限定只能用 aTX+ba^TX+baTX+b 这个形式的前提下,每个特定模型都对应了一组特定的参数 a,ba,ba,b。
Cost−模型−参数Cost - 模型 - 参数Cost−模型−参数
理到这里,盆友们应该能够想明白一点,某个cost值 和 某个模型(某组参数)之间是有对应关系的。而这种对应关系用数学的形式表达就是函数了。如果用之前的例子写就是:
cost=f(a,b)cost = f(a,b)cost=f(a,b)
因此,我们找cost最小的模型的问题,就成功转化成了求cost关于 a,ba,ba,b 的函数 f(a,b)f(a,b)f(a,b) 的最小值的问题啦。
说到求函数最小值,就要引出我们这一部分的主角Gradient Descent了.

Gradient Descent 是啥

Gradient Descent,梯度下降,放在这里用来找我们 cost function 的最小值。篇幅问题,就不展开说了,只说下整体思路,实在不了解的盆友可以参考吴恩达老师《机器学习》这门课第二章梯度下降部分的视频,讲的很清楚。
梯度下降的大致的思路就是,在寻找一个函数的 “最小值” 时,我们可以从这个函数的某个点出发,不断沿着梯度下降方向走,直到 “不再变化” 。

其中每次的变化为:
学习率(η\etaη)* 当前位置的导数
这里就涉及到一个学习率 / 步长(η\etaη)的问题,η\etaη 同之前提到的 λ\lambdaλ 一样,是一个超参数,并非由训练得出,而是需要提前选择好的。要注意,它的选择会极大的影响整个寻找的速度和结果哦。
PS:这个式子有很巧妙的一点,如果 η\etaη 选择合适,方向基本正确,就算再慢也一定会慢慢趋近极小值。这是因为随着位置的更新,导数绝对值本身就会逐渐变小,所以即使 η\etaη 是个常数,变化也会逐渐放缓。当然如果步子太大跨过去了就另当别论了。

SGD(Stochastic Gradient Descent)

前面的说完,似乎Gradient Descent已经可以解决我们找 cost function 最小值的需求了,为什么又出来一个Stochastic Gradient Descent 呢?盆友憋着急,且听我慢慢说。
让我们来看一下之前被求导的函数:
Cost=1N∑max(0,1−yiγi)+λ2aTaCost = {1 \over N}\sum max(0,1-y_i\gamma_i) + {\lambda \over 2} a^TaCost=N1​∑max(0,1−yi​γi​)+2λ​aTa
其中包含了 ∑max(0,1−yiγi)\sum max(0,1-y_i\gamma_i)∑max(0,1−yi​γi​), 是要包含所有的数据点的。
盆友们请想象一下,当我们数据量比较大的情况下,计算并对上面这个式子求导是多么可怕的一个计算量。
所以为了不让我们的电脑起飞,我们需要考虑在算 gradient 的时候做出一点变化。SGD应运而生。
简单来说,SGD就是在计算gradient的时候,不用全部的数据点,而是只抽取一部分(batch)来算。

实现

选择

正式开始这部分前,得先把实验过程中的一些选择列清楚,这些东西可以自行根据需求改变,但是改变以后结果可能会发生比较大的变化,所以得提前列一下:
1、本实验中并未使用所有的attribute,只选用了一部分attribute。
2、本实验的 η\etaη 并非常数,而是 η=mk∗s+n\eta ={m \over k * s+n}η=k∗s+nm​其中:
(1)m,n,k 是类似 λ\lambdaλ 超参数的存在,但是 m,n,k 与 λ\lambdaλ在本实验中的处理方式并不相同,后面会详细展开。
(2)s是所处的season,后面展开。
3、超参数的选择:
(1)m,n,k:对于这三者,本次实验使用预先实验选取。简单来说就是针对每个参数,都选几个可选值放在那。给定一个 λ\lambdaλ 的情况下,采用三层 for 循环,把每个参数组合的表现记录下来,选择表现最好的那一组。
(2)λ\lambdaλ:由于想要把不同 λ\lambdaλ 下,Accuracy 和 Magnitude 的变化画出来,所以没有在实验中选择某个特定的 λ\lambdaλ,只是选了几个可选的,每个都跑了一遍。
4、season、step:整个训练过程分成了x个season,每个season y走步。分season 是考虑到学习率(η\etaη)其实可以随着训练的推进不断变化的,这里假定η=mk∗s+n\eta ={m \over k * s+n}η=k∗s+nm​。
5、batch:本次实验选了 batch = 1

思路

同样,这次还是不会直接贴代码,只写个思路。其实写到这,如果理论明白了,实现并没有很复杂,如果实在有需要的盆友请私聊。

第一轮,预实验:#确定 m, n, k给定一个lambda(本实验中给的0.01):for m in [...]:for n in [...]:for k in [...]:accuracy = 0for s in [...]: # season从原来的train,分出validation,剩下的作新的trainEta = m/(k*s+n)for j in range(...): # step选batch同步更新a,b(要注意,不同情况下,梯度是不同的式子)每个season算一个accuracy当前参数下 accuracy = 所有记录下来的accuracy取平均# 这里我没有取前几个season,因为随机的缘故,前几个season可能太不稳定找到平均accuracy最高的(m,n,k)
第二轮,正式实验:
Lambda = [...]
for i in Lambda:for s in range(...):分 validationEta = 1/(0.1*s+20)for j in range(...):每30步记录accuracy 和 magnitude,后面作图用同步更新a,b
作Accuracy 和 Magnitude的图

结果

第一轮预实验取 λ\lambdaλ =0.01,第二轮正式实验 η=1/(0.1∗s+20)\eta = 1/(0.1*s+20)η=1/(0.1∗s+20) 的实验结果:
1、Accuracy随训练过程的变化:

2、Magnitude随训练过程的变化:

后记

1、如果单纯为了解决这个收入预测问题,实现-选择部分很多地方都还有更好的选择,但鉴于本文的重点不在这,就不展开讨论了。但是很欢迎各位盆友讨论。
2、为了尽量保证认知逻辑的通顺,本文的行文顺序和标题名称并不像常规算法文章一样,所以在此加一个索引,方便各位盆友查找。
(1)SVM的直观介绍:理论知识 - SVM - 直观认识
(2)Cost function:理论知识 - SVM - 什么是分的好?
(3)求 Cost function 最小值:理论知识 - SVM - Gradient Descent;
理论知识 - SVM - SGD(Stochastic Gradient Descent)

Reference

1、Applied Machine Learning, D.A. Forsyth
2、Adult Data Set: https://archive.ics.uci.edu/ml/datasets/Adult

若需引用请注明出处。
若有错误欢迎指正、讨论。

手工实现:SVM with Stochastic Gradient Descent相关推荐

  1. UA MATH567 高维统计专题3 含L1-norm的凸优化6 Stochastic Gradient Descent简介

    UA MATH567 高维统计专题3 含L1-norm的凸优化6 Stochastic Gradient Descent简介 Stochastic Gradient Descent的思想 Varian ...

  2. 随机梯度下降(Stochastic gradient descent)和 批量梯度下降(Batch gradient descent )的公式对比、实现对比

     随机梯度下降(Stochastic gradient descent)和 批量梯度下降(Batch gradient descent )的公式对比.实现对比 标签: 梯度下降最优化迭代 2013 ...

  3. Minibatch Stochastic Gradient Descent

    Reference: https://d2l.ai/chapter_linear-networks/linear-regression.html https://d2l.ai/chapter_line ...

  4. 随机梯度下降(Stochastic gradient descent)

    总目录 一. 凸优化基础(Convex Optimization basics) 凸优化基础(Convex Optimization basics) 二. 一阶梯度方法(First-order met ...

  5. 【文献阅读】Federated Accelerated Stochastic Gradient Descent

    文章目录 1 Federated Accelerated Stochastic Gradient Descent (FedAc) 2 challenge 3 how to do 4 baseline ...

  6. 论文阅读------Stochastic Gradient Descent with Differentially Private updates

    论文阅读------Stochastic Gradient Descent with Differentially Private updates SGD 代价函数 正则化描述 训练过程 梯度偏导 参 ...

  7. 论文笔记——Asynchronous Decentralized Parallel Stochastic Gradient Descent

    论文笔记--Asynchronous Decentralized Parallel Stochastic Gradient Descent 改变了中心化的结构,使用了分布式的结构 算法过程 每个wor ...

  8. Kaggle(L3) - Stochastic Gradient Descent Notebook

    Use Keras and Tensorflow to train your first neural network. Introduction In the first two lessons, ...

  9. 随机梯度下降算法SGD(Stochastic gradient descent)

    SGD是什么 SGD是Stochastic Gradient Descent(随机梯度下降)的缩写,是深度学习中常用的优化算法之一.SGD是一种基于梯度的优化算法,用于更新深度神经网络的参数.它的基本 ...

最新文章

  1. win8计算机安全模式,安全模式,教您Win8怎么进入安全模式
  2. bat脚本交互输入_基于winserver操作系统的自动清理Oracle3天前归档日志脚本分享...
  3. Python数据类型一
  4. DataGridView的使用
  5. 人力资源管理4个过程及相关重点
  6. flink介绍:有界流和无界流
  7. kibana数据可视化
  8. java 类 request_java普通类得到request对象
  9. 【C++深度剖析教程14】经典问题解析三之关于赋值的疑问
  10. Linux 命令(37)—— free 命令
  11. M1 MacBook Air值不值得买?使用体验分享
  12. 阿里p7java什么水平_转头条:阿里p7架构师:三年经验应该具备什么样的技能?
  13. java中 什么意思?比如130
  14. Python基础阶段:体脂率计算练习
  15. 零知识证明 Zero Knowledge Proof 以及 Layer2、跨链介绍
  16. CSS漂亮盒子(下)
  17. Web 利用纯html和css画出一个android机器人
  18. haoi2008木棍分割解题报告
  19. bing词典案例分析
  20. 财商教育—百万富翁的生活习惯

热门文章

  1. 计算机按键变成音符怎么弄,电脑键盘的数字都变成了符号、怎么弄才能变成数字阿?...
  2. java校招笔试题目_Java校招笔试题
  3. 云端服务器跑python代码,断开后台运行
  4. 使用弗洛伊德算法(Floyd-Warshall)找到所有对最短路径长度
  5. 蓝奏云PHP解析接口,蓝奏云下载地址解析API[直链]
  6. 咕咚已上传服务器怎到不丁微信,微信运动怎样使用咕咚数据?
  7. 【2-SAT初学+模板题讲解】POJ3683 Priest John's Busiest Day
  8. 51单片机实验——按键外部中断实现四进制计数器
  9. 解决方案:惠普15-bc012tx笔记本电脑电池掉电快的检测及解决
  10. birt 报表与润乾报表对比