个人网站:红色石头的机器学习之路
CSDN博客:红色石头的专栏
知乎:红色石头
微博:RedstoneWill的微博
GitHub:RedstoneWill的GitHub
微信公众号:AI有道(ID:redstonewill)

之前我用了六篇文章来详细介绍了支持向量机SVM的算法理论和模型,链接如下:

1. 线性支持向量机LSVM

2. 对偶支持向量机DSVM

3. 核支持向量机KSVM

4. 软间隔支持向量机

5. 核逻辑回归KLR

6. 支持向量回归SVR

实际上,支持向量机SVM确实是机器学习中一个非常重要也是非常复杂的模型。关于SVM的详细理论和推导,本文不再阐述,读者可以直接阅读上面的六篇文章。

学习完了复杂的理论知识,很多朋友可能非常想通过一个实际的例子,动手编写出一个SVM程序,应用到实际中。那么本文就将带领大家动手写出自己的SVM程序,并且应用到图像的分类问题中。我们将在经典的CIFAR10图像数据集上进行SVM程序验证。

话不多说,正式开始!

1. SVM的基本思想

简单来说,支持向量机SVM就是在特征空间中找到一条最佳的分类超平面,能够让正、负样本距离该超平面的间隔(margin)最大化。

以二维平面为例,确定一条直线对正负样本进行分类,如下图所示:

很明显,虽然分类线H1、H2、H3都能够将正负样本完全分开,但是毫无疑问H3更好一些。原因是正负样本距离H3都足够远,即间隔「margin」最大。这就是SVM的基本思想:尽量让所有样本距离分类超平面越远越好。

2. 线性分类与得分函数

在线性分类器算法中,输入为x,输出为y,令权重系数为W,常数项系数为b。我们定义得分函数s为:

s=Wx+bs=Wx+b

s=Wx+b

这是线性分类器的一般形式,得分函数s所属类别值越大,表示预测该类别的概率越大。

以图像识别为例,共有3个类别「cat,dog,ship」。令输入x的特征维度为4「即包含4个像素值」,W的维度是3x4,b的维度是3x1。在W和b确定后,得到各个类别的得分函数s为:

由上图可知,因为总有3个类别,得分函数s是3x1的向量。其中,cat score=-96.8,dog score=437.9,ship score=61.95。从s的值来说,dog score最高,cat score最低,则预测为狗的概率更大一些。而该图片真实标签是一只猫,显然,从得分函数s上来看,该线性分类器的预测结果是错误的。

通常为了简化计算,我们直接将W和b整合成一个矩阵,同时将x额外增加一个全为1的维度。这样,得分函数s的表达式得到了简化:

W:=[W  b]W:=[Wb]

W:=[W\ \ b]

x:=[x; 1]x:=[x;1]

x:=[x;\ 1]

s=Wxs=Wx

s=Wx

示例图如下:

3. 优化策略与损失函数

通常来说,SVM的优化策略是样本到分类超平面的距离最大化。也就是说尽量让正负样本距离分类超平面有足够宽的间隔,这是基于距离的衡量优化方式。针对上文提到的例子,图片真实标签是一只猫,但是得到的s值却是最低的,显然这不是我们希望看到的。最好的情况应该是cat score最高。这样才能保证预测cat的概率更大。此时,利用SVM的间隔最大化的思想,就要求cat score不仅仅要大于其它类别的s值,而且要达到一定的程度,可以说有个最低阈值。

因此,这种新的SVM优化策略可以这样理解:正确类别对应的得分函数s应该比其它类别的得分函数s大一个阈值 ΔΔ\Delta:

syi≥sj+Δsyi≥sj+Δ

s_{y_i}\geq s_j+\Delta

接下来,我们就可以根据这种思想定义SVM的损失函数:

Li=∑j≠yimax(0,sj−syi+Δ)Li=∑j≠yimax(0,sj−syi+Δ)

L_i=\sum_{j\neq y_i}max(0,s_j-s_{y_i}+\Delta)

其中,yiyiy_i表示正确的类别,j表示错误类别。从LiLiL_i的表达式可以看出,只有当syisyis_{y_i}比sjsjs_j大超过阈值 ΔΔ\Delta 时,LiLiL_i才为零,否则LiLiL_i大于零。这种策略类似于距离最大化策略。

举个例子来解释LiLiL_i的计算过程:例如得分函数s=[-1, 5, 4],y1y1y_1是真实样本,令Δ=3Δ=3\Delta=3,则:

Li=max(0,−1−5+3)+max(0,4−5+3)=0+2=2Li=max(0,−1−5+3)+max(0,4−5+3)=0+2=2

L_i=max(0,-1-5+3)+max(0,4-5+3)=0+2=2

该损失函数由两部分组成:y1y1y_1与y0y0y_0,y1y1y_1与y2y2y_2。由于y1y1y_1与y0y0y_0的差值大于阈值 ΔΔ\Delta,则其损失函数为0;虽然y1y1y_1比y2y2y_2大,但差值小于阈值 ΔΔ\Delta,则计算得到其损失函数为2。总的损失函数即为2。

这类损失函数的表达式一般称作合页损失函数「Hinge Loss Function」:

显然,只有当sj−syi+Δ<0sj−syi+Δ<0s_j-s_{y_i} + \Delta 时,损失函数才为零。

这种合页损失函数的优点是体现了SVM距离最大化的思想;而且,损失函数大于零时,是线性函数,便于梯度下降算法求导。

除了这种线性hinge loss SVM之外,还有squared hinge loss SVM,即采用平方的形式:

Li=∑j≠yimax(0,sj−syi+Δ)2Li=∑j≠yimax(0,sj−syi+Δ)2

L_i=\sum_{j\neq y_i}max(0,s_j-s_{y_i}+\Delta)^2

这种squared hinge loss SVM与linear hinge loss SVM相比较,特点是对违背间隔阈值要求的点加重惩罚,违背的越大,惩罚越大。某些实际应用中,squared hinge loss SVM的效果更好一些。具体使用哪个,可以根据实际问题,进行交叉验证再确定。

对于超参数阈值 ΔΔ\Delta,一般设置 Δ=1Δ=1\Delta=1。因为,权重系数W是可伸缩的,直接影响着得分函数s的大小。所以说,Δ=1Δ=1\Delta=1 或 Δ=10Δ=10\Delta=10,实际上没有差别,对W的伸缩完全可以抵消掉 ΔΔ\Delta 的数值影响。因此,通常把 ΔΔ\Delta 设置为1即可。此时的损失函数为:

Li=∑j≠yimax(0,sj−syi+1)Li=∑j≠yimax(0,sj−syi+1)

L_i=\sum_{j\neq y_i}max(0,s_j-s_{y_i}+1)

SVM中,为了防止模型过拟合,可以使用正则化「Regularization」方法。例如使用L2正则化:

R(W)=∑k∑lw2k,lR(W)=∑k∑lwk,l2

R(W)=\sum_k\sum_lw_{k,l}^2

引入正则化项之后的损失函数为:

L=1NLi+λR(W)L=1NLi+λR(W)

L=\frac1NL_i+\lambda R(W)

其中,N是训练样本个数,λλ\lambda 是正则化参数,可调。一般来说,λλ\lambda 越大,对权重W的惩罚越大;λλ\lambda 越小,对权重W的惩罚越小。λλ\lambda 实际上是权衡损失函数第一项和第二项之间的关系:λλ\lambda 越大,对W的惩罚更大,牺牲正负样本之间的间隔,可能造成欠拟合「underfit」;λλ\lambda 越小,得到的正负样本间隔更大,但是W数值会变大,可能造成过拟合「overfit」。实际应用中,可通过交叉验证,选择合适的正则化参数λλ\lambda。

常数项b是否需要正则化?其实一般b是否正则化对模型的影响很小。可以对b进行正则化,也可以选择不。实际应用中,通常只对权重系数W进行正则化。

4. 线性SVM实战

首先,简单介绍一下我们将要用到的经典数据集:CIFAR-10。

CIFAR-10数据集由60000张3×32×32的 RGB 彩色图片构成,共10个分类。50000张训练,10000张测试(交叉验证)。这个数据集最大的特点在于将识别迁移到了普适物体,而且应用于多分类,是非常经典和常用的数据集。

这个数据集网上可以下载,我直接给大家下好了,放在云盘里,需要的自行领取。

链接:https://pan.baidu.com/s/1iZPwt72j-EpVUbLKgEpYMQ

密码:vy1e

下面的代码是随机选择每种类别下的5张图片并显示:

# Visualize some examples from the dataset.
# We show a few examples of training images from each class.
classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
num_classes = len(classes)
samples_per_class = 7
for y, cls in enumerate(classes):idxs = np.flatnonzero(y_train == y)idxs = np.random.choice(idxs, samples_per_class, replace=False)for i, idx in enumerate(idxs):plt_idx = i * num_classes + y + 1plt.subplot(samples_per_class, num_classes, plt_idx)plt.imshow(X_train[idx].astype('uint8'))plt.axis('off')if i == 0:plt.title(cls)
plt.show()

接下来,就是对SVM计算hinge loss,包含L2正则化,代码如下:

scores = X.dot(W)
correct_class_score = scores[range(num_train), list(y)].reshape(-1,1) # (N,1)
margin = np.maximum(0, scores - correct_class_score + 1)
margin[range(num_train), list(y)] = 0
loss = np.sum(margin) / num_train + 0.5 * reg * np.sum(W * W)

计算W梯度的代码如下:

num_classes = W.shape[1]
inter_mat = np.zeros((num_train, num_classes))
inter_mat[margin > 0] = 1
inter_mat[range(num_train), list(y)] = 0
inter_mat[range(num_train), list(y)] = -np.sum(inter_mat, axis=1)dW = (X.T).dot(inter_mat)
dW = dW/num_train + reg*W

根据SGD算法,每次迭代后更新W:

W -=  learning_rate * dW

训练过程中,使用交叉验证的方法选择最佳的学习因子 learning_rate 和正则化参数 reg,代码如下:

learning_rates = [1.4e-7, 1.5e-7, 1.6e-7]
regularization_strengths = [8000.0, 9000.0, 10000.0, 11000.0, 18000.0, 19000.0, 20000.0, 21000.0]results = {}
best_lr = None
best_reg = None
best_val = -1   # The highest validation accuracy that we have seen so far.
best_svm = None # The LinearSVM object that achieved the highest validation rate.for lr in learning_rates:for reg in regularization_strengths:svm = LinearSVM()loss_history = svm.train(X_train, y_train, learning_rate = lr, reg = reg, num_iters = 2000)y_train_pred = svm.predict(X_train)accuracy_train = np.mean(y_train_pred == y_train)y_val_pred = svm.predict(X_val)accuracy_val = np.mean(y_val_pred == y_val)if accuracy_val > best_val:best_lr = lrbest_reg = regbest_val = accuracy_valbest_svm = svmresults[(lr, reg)] = accuracy_train, accuracy_valprint('lr: %e reg: %e train accuracy: %f val accuracy: %f' %(lr, reg, results[(lr, reg)][0], results[(lr, reg)][1]))
print('Best validation accuracy during cross-validation:\nlr = %e, reg = %e, best_val = %f' %(best_lr, best_reg, best_val))

训练结束后,选择最佳的学习因子 learning_rate 和正则化参数 reg,在测试图片集上进行验证,代码如下:

# Evaluate the best svm on test set
y_test_pred = best_svm.predict(X_test)
test_accuracy = np.mean(y_test == y_test_pred)
print('linear SVM on raw pixels final test set accuracy: %f' % test_accuracy)

linear SVM on raw pixels final test set accuracy: 0.384000

最后,有个比较好玩的操作,我们可以将训练好的权重W可视化:

# Visualize the learned weights for each class.
# Depending on your choice of learning rate and regularization strength, these may
# or may not be nice to look at.
w = best_svm.W[:-1,:] # strip out the bias
w = w.reshape(32, 32, 3, 10)
w_min, w_max = np.min(w), np.max(w)
classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
for i in range(10):plt.subplot(2, 5, i + 1)# Rescale the weights to be between 0 and 255wimg = 255.0 * (w[:, :, :, i].squeeze() - w_min) / (w_max - w_min)plt.imshow(wimg.astype('uint8'))plt.axis('off')plt.title(classes[i])

可以明显看出,由W重构的图片具有所属样本类别相似的地方,这正是线性SVM学习到的东西。

5. 总结

本文讲述的线性SVM利用距离间隔最大的思想,利用hinge loss的优化策略,来构建一个机器学习模型,并将这个简单模型应用到CIFAR-10图片集中进行训练和测试。实际测试的准确率在40%左右。准确率虽然不是很高,但是此SVM是线性模型,没有引入核函数构建非线性模型,也没有使用AlexNet,VGG,GoogLeNet,ResNet等卷积网络。测试结果比随机猜测10%要好很多,是一个不错的可实操的有趣模型。

完整代码,点击「源码」获取。

源码



参考资料:

http://cs231n.github.io/linear-classify/

基于线性SVM的CIFAR-10图像集分类相关推荐

  1. Few-Shot Classification of Aerial Scene Images via Meta-Learning(基于元学习的航拍场景图像小样本分类)

    Abstract: 基于卷积神经网络(CNN)的方法近年来在航空场景分类领域占据主导地位.虽然取得了显著的成功,但基于cnn的方法存在过多的参数,并依赖于大量的训练数据.在本工作中,我们将小样本学习引 ...

  2. 基于迁移深度学习的遥感图像场景分类

    前述 根据语义特征对遥感图像场景进行分类是一项具有挑战性的任务.因为遥感图像场景的类内变化较大,而类间变化有时却较小.不同的物体会以不同的尺度和方向出现在同一类场景中,而同样的物体也可能出现在不同的场 ...

  3. 线性SVM与非线性SVM

    所谓线性SVM与非线性SVM是指其选用的核类型. 用于分类问题时,SVM可供选择的参数并不多,惩罚参数C,核函数及其参数选择.对于一个应用,是选择线性核,还是多项式核,还是高斯核?还是有一些规则的. ...

  4. [读论文]弱监督学习的精确 3D 人脸重建:从单个图像到图像集-Accurate 3D Face Reconstruction with Weakly-Supervised Learning:From

    论文地址:Accurate 3D Face Reconstruction with Weakly-Supervised Learning:From Single Image to Image Set ...

  5. 【支持向量机SVM系列教程1】线性SVM

    文章目录 1 线性SVM 1.1 优化的目标 1.2 直观展示 1.3 公式表达 1.3.1 约束条件 1.3.2 硬间隔形式 1.3.3 软间隔形式 1.4 sklearn中的线性SVM 1.4.1 ...

  6. 【手把手教你】搭建神经网络(CT扫描3D图像的分类)

    大家好,我是羽峰,今天要和大家分享的是一个基于tensorflow的CT扫描3D图像的分类.文章会把整个代码进行分割讲解,完整看完,相信你一定会有所收获. 欢迎关注"羽峰码字" 目 ...

  7. 基于SVM的思想做CIFAR 10图像分类

    #SVM 回顾一下之前的SVM,找到一个间隔最大的函数,使得正负样本离该函数是最远的,是否最远不是看哪个点离函数最远,而是找到一个离函数最近的点看他是不是和该分割函数离的最近的. 使用large ma ...

  8. Matlab 基于svm的图像物体分类

    matlab 图像分类 本周工作日志,老师布置了一个小作业,让我们使用matlab实现图像物体分类 目录 文章目录 matlab 图像分类 目录 1分类原理 2程序流程 补充 1分类原理 基于一个很朴 ...

  9. matlab图像分类器,Matlab 基于svm的图像物体分类

    Matlab 基于svm的图像物体分类 发布时间:2018-05-16 20:27, 浏览次数:1623 , 标签: Matlab svm 本周工作日志,老师布置了一个小作业,让我们使用matlab实 ...

最新文章

  1. 5.java String对象
  2. 企业 SpringBoot 教程(六)springboot整合mybatis
  3. 洛谷 P1897电梯里的爱情 题解
  4. 存储引擎 boltdb 的设计奥秘?
  5. img = img1*mask + img2*(1-mask) How do that ?
  6. pda连接电脑无法存取文件_手机无法连接电脑怎办
  7. Intel X86 CPU寄存器学习笔记
  8. Agilent RF fundamentals (2)- fundamental units of RF
  9. join丢失数据_15、Hive数据倾斜与解决方案
  10. 程序员必看—程序员如何高效提升自己?
  11. 希尔伯特变换与三瞬属性简介
  12. NAS HomeAssistant
  13. python数据中元素可以改变的是_下列Python数据中其元素可以改变的是( )。 (2.0分)_学小易找答案...
  14. 大数据Apache Druid(四):使用Imply进行Druid集群搭建
  15. 【企业架构】现代企业架构方法——第 1 章
  16. 性能测试工具ab和wrk
  17. unity 人物走动声音_Unity3D实现人物走动 教程
  18. 超级轻量级: KV存储引擎实现
  19. 西门子S7-1200PLC脉冲控制伺服程序案例
  20. 对软件公司财务管理方面的一些想法

热门文章

  1. 想要成为Linux大神,你应该和我一样这样做!
  2. 一次PostgreSQL行估算偏差导致的慢查询分析
  3. java动态代理(JDK和cglib)
  4. sql已经完成,生成表
  5. apache 和 nginx 301重定向配置方法
  6. 非技术(一)——从最近的股票市场看到的
  7. input标签加disabled属性后无法获得其value值
  8. VitrualBox、vagrant、homestead的关系
  9. Tomcat 9.0.6 HostManager页面 403 Access Denied 错误
  10. python小白-day4递归和算法基础