#SVM 回顾一下之前的SVM,找到一个间隔最大的函数,使得正负样本离该函数是最远的,是否最远不是看哪个点离函数最远,而是找到一个离函数最近的点看他是不是和该分割函数离的最近的。

使用large margin来regularization。 之前讲SVM的算法:www.jianshu.com/p/8fd28df73… #线性分类 线性SVM就是一种线性分类的方法。输入,输出,每一个样本的权重是,偏置项bias是。得分函数$$s = wx +b$$ 算出这么多个类别,哪一个类别的分数高,那就是哪个类别。比如要做的图像识别有三个类别,假设这张图片有4个像素,拉伸成单列: 得到的结果很明显是dog分数最大,cat的分数最低,但是图片很明显是猫,什么分类器是错误的。 一般来说习惯会把w和b合并了,x加上一个全为1的列,于是有$$W=[w;b];X = [x;1]$$#损失函数 之前的SVM是把正负样本离分割函数有足够的空间,虽然正确的是猫,但是猫的得分是最低的,常规方法是将猫的分数提高,这样才可以提高猫的正确率。但是SVM里面是要求一个间隔最大化,提到这里来说,其实就是cat score不仅仅是要大于其他的分数,而且是要有一个最低阈值,cat score不能低于这个分数。 所以正确的分类score应该是要大于其他的分类score一个阈值:$$s_{y_i} >= s_j + \triangle$$ 就是正确分类的分数,就是其他分类的分数。所以,这个损失函数就是:$$Loss_{y_i} = \sum_{j != y_i}max(0, s_j - s_{y_i}+\triangle)$$只有正确的分数比其他的都大于一个阈值才为0,否则都是有损失的。只有损失函数才是0的。这种损失函数称为合页损失函数,用的就是SVM间隔最大化的思想解决,如果损失函数为0,那么不用求解了,如果损失函数不为0,就可以用梯度下降求解。max求解梯度下降有点不现实,所以自然就有了square的合页损失函数。


这种squared hinge loss SVM与linear hinge loss SVM相比较,特点是对违背间隔阈值要求的点加重惩罚,违背的越大,惩罚越大。某些实际应用中,squared hinge loss SVM的效果更好一些。具体使用哪个,可以根据实际问题,进行交叉验证再确定。 对于的设置,之前SVM其实讨论过,对于一个平面是可以随意伸缩的,只需要增大w和b就可以随意把增大,所以把它定为1,也就是设置。因为w的增长或缩小完全可以抵消的影响。这个时候损失函数就是:


最后还要增加的就是过拟合,regularization的限制了。L2正则化:


加上正则化之后就是:


N是训练样本的个数,取平均损失函数,就是惩罚的力度了,可以小也可以大,如果大了可能w不足以抵消正负样本之间的间隔,可能会欠拟合,因为是在w可以自由伸缩达到的条件,如果w太小,可能就不足以增长到1了。如果小了,可能就会造成overfit。对于参数b就没有这么讲究了。 #代码实现 首先是对CIFAR10的数据读取:


def load_pickle(f):version = platform.python_version_tuple()if version[0] == '2':return pickle.load(f)elif version[0] == '3':return pickle.load(f, encoding='latin1')raise ValueError("invalid python version: {}".format(version))def loadCIFAR_batch(filename):with open(filename, 'rb') as f:datadict = load_pickle(f)x = datadict['data']y = datadict['labels']x = x.reshape(10000, 3, 32, 32).transpose(0, 3, 2, 1).astype('float')y = np.array(y)return x, ydef loadCIFAR10(root):xs = []ys = []for b in range(1, 6):f = os.path.join(root, 'data_batch_%d' % (b, ))x, y = loadCIFAR_batch(f)xs.append(x)ys.append(y)X = np.concatenate(xs)Y = np.concatenate(ys)x_test, y_test = loadCIFAR_batch(os.path.join(root, 'test_batch'))return X, Y, x_test, y_test
复制代码

首先要读入每一个文件的数据,先用load_pickle把文件读成字典形式,取出来。因为常规的图片都是(数量,高,宽,RGB颜色),在loadCIFAR_batch要用transpose来把维度调换一下。最后把每一个文件的数据都集合起来。 之后就是数据的格式调整了:

def data_validation(x_train, y_train, x_test, y_test):num_training = 49000num_validation = 1000num_test = 1000num_dev = 500mean_image = np.mean(x_train, axis=0)x_train -= mean_imagemask = range(num_training, num_training + num_validation)X_val = x_train[mask]Y_val = y_train[mask]mask = range(num_training)X_train = x_train[mask]Y_train = y_train[mask]mask = np.random.choice(num_training, num_dev, replace=False)X_dev = x_train[mask]Y_dev = y_train[mask]mask = range(num_test)X_test = x_test[mask]Y_test = y_test[mask]X_train = np.reshape(X_train, (X_train.shape[0], -1))X_val = np.reshape(X_val, (X_val.shape[0], -1))X_test = np.reshape(X_test, (X_test.shape[0], -1))X_dev = np.reshape(X_dev, (X_dev.shape[0], -1))X_train = np.hstack([X_train, np.ones((X_train.shape[0], 1))])X_val = np.hstack([X_val, np.ones((X_val.shape[0], 1))])X_test = np.hstack([X_test, np.ones((X_test.shape[0], 1))])X_dev = np.hstack([X_dev, np.ones((X_dev.shape[0], 1))])return X_val, Y_val, X_train, Y_train, X_dev, Y_dev, X_test, Y_testpass
复制代码

数据要变成一个长条。 先看看数据长啥样:

def showPicture(x_train, y_train):classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']num_classes = len(classes)samples_per_classes = 7for y, cls in enumerate(classes):idxs = np.flatnonzero(y_train == y)idxs = np.random.choice(idxs, samples_per_classes, replace=False)for i, idx in enumerate(idxs):plt_index = i*num_classes +y + 1plt.subplot(samples_per_classes, num_classes, plt_index)plt.imshow(x_train[idx].astype('uint8'))plt.axis('off')if i == 0:plt.title(cls)plt.show()
复制代码

然后就是使用谷歌的公式了:

    def loss(self, x, y, reg):loss = 0.0dw = np.zeros(self.W.shape)num_train = x.shape[0]scores = x.dot(self.W)correct_class_score = scores[range(num_train), list(y)].reshape(-1, 1)margin = np.maximum(0, scores - correct_class_score + 1)margin[range(num_train), list(y)] = 0loss = np.sum(margin)/num_train + 0.5 * reg * np.sum(self.W*self.W)num_classes = self.W.shape[1]inter_mat = np.zeros((num_train, num_classes))inter_mat[margin > 0] = 1inter_mat[range(num_train), list(y)] = 0inter_mat[range(num_train), list(y)] = -np.sum(inter_mat, axis=1)dW = (x.T).dot(inter_mat)dW = dW/num_train + reg*self.Wreturn loss, dWpass
复制代码

操作都是常规操作,算出score然后求loss最后SGD求梯度更新W。

    def train(self, X, y, learning_rate=1e-3, reg=1e-5, num_iters=100,batch_size=200, verbose=False):num_train, dim = X.shapenum_classes = np.max(y) + 1if self.W is None:self.W = 0.001 * np.random.randn(dim, num_classes)# Run stochastic gradient descent to optimize Wloss_history = []for it in range(num_iters):X_batch = Noney_batch = Noneidx_batch = np.random.choice(num_train, batch_size, replace = True)X_batch = X[idx_batch]y_batch = y[idx_batch]# evaluate loss and gradientloss, grad = self.loss(X_batch, y_batch, reg)loss_history.append(loss)self.W -=  learning_rate * gradif verbose and it % 100 == 0:print('iteration %d / %d: loss %f' % (it, num_iters, loss))return loss_historypass
复制代码

预测:

    def predict(self, X):y_pred = np.zeros(X.shape[0])scores = X.dot(self.W)y_pred = np.argmax(scores, axis = 1)return y_pred
复制代码

最后运行函数:

 svm = LinearSVM()tic = time.time()cifar10_name = '../Data/cifar-10-batches-py'x_train, y_train, x_test, y_test = loadCIFAR10(cifar10_name)X_val, Y_val, X_train, Y_train, X_dev, Y_dev, X_test, Y_test = data_validation(x_train, y_train, x_test, y_test)loss_hist = svm.train(X_train, Y_train, learning_rate=1e-7, reg=2.5e4,num_iters=3000, verbose=True)toc = time.time()print('That took %fs' % (toc - tic))plt.plot(loss_hist)plt.xlabel('Iteration number')plt.ylabel('Loss value')plt.show()y_test_pred = svm.predict(X_test)test_accuracy = np.mean(Y_test == y_test_pred)print('accuracy: %f' % test_accuracy)w = svm.W[:-1, :]  # strip out the biasw = w.reshape(32, 32, 3, 10)w_min, w_max = np.min(w), np.max(w)classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']for i in range(10):plt.subplot(2, 5, i + 1)wimg = 255.0 * (w[:, :, :, i].squeeze() - w_min) / (w_max - w_min)plt.imshow(wimg.astype('uint8'))plt.axis('off')plt.title(classes[i])plt.show()复制代码

首先是画出整个loss函数趋势:

最后再可视化一下w权值,看看每一个种类提取处理的特征是什么样子的:

转载于:https://juejin.im/post/5c321164e51d45524975d0d1

基于SVM的思想做CIFAR 10图像分类相关推荐

  1. 基于svm图像分类C语言,基于SVM的图像分类算法与实现.PDF

    , ( ) 计算机工程与应用 40 ComputerEngineeringandApplications 基于SVM的图像分类算法与实现 张淑雅 赵一鸣 李均利 , , , , ZHANGShu-ya ...

  2. 【opencv机器学习】基于SVM和神经网络的车牌识别

    基于SVM和神经网络的车牌识别 深入理解OpenCV:实用计算机视觉项目解析 本文用来学习的项目来自书籍<实用计算机视觉项目解析>第5章Number Plate Recognition 提 ...

  3. 基于机器学习的车牌识别系统(Python实现基于SVM支持向量机的车牌分类)

    基于机器学习的车牌识别系统(Python实现基于SVM支持向量机的车牌分类) 一.数据集说明 训练样本来自于github上的EasyPR的c++版本,包含一万三千多张数字及大写字母的图片以及三千多张中 ...

  4. 基于SVM技术的手写数字识别

    老师常说,在人工智能未发展起来之前,SVM技术是一统江湖的,SVM常常听到,但究竟是什么呢?最近研究了一下基于SVM技术的手写数字识别.你没有看错,又是手写数字识别,就是喜欢这个手写数字识别,没办法( ...

  5. Python 基于SVM和KNN算法的红酒分类

    Python 机器学习之红酒分类问题 文章目录 Python 机器学习之红酒分类问题 前言 一.问题和目标是什么 1.原题 2.题目分析 二.算法简介 三.代码实现 1.算法流程框架 2.第三方库调用 ...

  6. 从原理到实现,详解基于朴素ML思想的协同过滤推荐算法

    作者丨gongyouliu 编辑丨Zandy 来源 | 大数据与人工智能(ID: ai-big-data) 作者在<协同过滤推荐算法>.<矩阵分解推荐算法>这两篇文章中介绍了几 ...

  7. matlab图像分类器,Matlab 基于svm的图像物体分类

    Matlab 基于svm的图像物体分类 发布时间:2018-05-16 20:27, 浏览次数:1623 , 标签: Matlab svm 本周工作日志,老师布置了一个小作业,让我们使用matlab实 ...

  8. 基于SVM的猫咪图片识别器

    分享一下我老师大神的人工智能教程!零基础,通俗易懂!http://blog.csdn.net/jiangjunshow 也欢迎大家转载本篇文章.分享知识,造福人民,实现我们中华民族伟大复兴! 基于SV ...

  9. 基于SVM算法的男女生分类器

    基于SVM算法的男女生分类器 题目:采用SVM设计男女生分类器.采用的特征包含身高.体重.鞋码.50m成绩.肺活量.是否喜欢运动共六个特征.要求:采用平台提供的软件包进行分类器的设计以及测试,尝试不同 ...

最新文章

  1. 116. Leetcode 1143. 最长公共子序列 (动态规划-子序列问题)
  2. SilverLigth的Chart不要图例(Legend)的方法
  3. 了解JavaScript中的prototype (实例)
  4. C语言判断点是否在矩阵内
  5. linux终奌站 信息 格式 更改 /etc/bashrc
  6. C#开源项目一览表[转](包含国内和国外)
  7. [VB]使用ADO Recordset对象导入Excel
  8. 波士顿动力新机器人登场!
  9. python2.7.5 怎么装redis_python中Redis的简要介绍以及Redis的安装,配置
  10. rust 手动关闭子线程_Rust入坑指南:齐头并进(上)
  11. debian 安装五笔输入法
  12. html自动选择省市,jQuery中国省市区地址三级联动插件Distpicker
  13. Scrum敏捷开发模式
  14. 开源无国界!CSDN 董事长蒋涛、GitHub 副总裁 Thomas Dohmke 对话实录
  15. OpenCV—直线拟合fitLine
  16. spss多因素方差分析
  17. 【SPIE独立出版|往届已检索、湘潭大学主办】第二届绿色通信、网络与物联网国际学术会议 (CNIoT 2022)
  18. 前端播放rtmp协议的视频流文件
  19. 神经网络方法研究及应用,基于神经网络的控制
  20. 【解决方案】根据当前系统时钟或签名文件中的时间戳验证时要求的证书不在有效期内

热门文章

  1. 那些年移动互联网行业曾经走过的弯路
  2. Hive UDAF开发
  3. 从市场角度看服务器虚拟化
  4. 在c#中使用WINDOWS API(转)
  5. Java 工程师成神之路 | 2019正式版
  6. APP签名MD5获取
  7. linux很容易忽略的rz上传、sz下载命令
  8. 一维数组和二维数组创建,输出,Arrays.fill()替换
  9. Nginx实战基础篇一 源码包编译安装部署web服务器
  10. MySQL之InnoDB索引的一些问题