使用支持向量机训练mnist数据
1 # encoding: utf-8 2 import numpy as np 3 import matplotlib.pyplot as plt 4 import cPickle 5 import gzip 6 7 class SVC(object): 8 def __init__(self, c=1.0, delta=0.001): # 初始化 9 self.N = 0 10 self.delta = delta 11 self.X = None 12 self.y = None 13 self.w = None 14 self.wn = 0 15 self.K = np.zeros((self.N, self.N)) 16 self.a = np.zeros((self.N, 1)) 17 self.b = 0 18 self.C = c 19 self.stop=1 20 self.k=0 21 self.cls=0 22 self.train_result=[] 23 24 def kernel_function(self,x1, x2): # 核函数 25 return np.dot(x1, x2) 26 27 def kernel_matrix(self, x): # 核矩阵 28 for i in range(0, len(x)): 29 for j in range(i, len(x)): 30 self.K[j][i] = self.K[i][j] = self.kernel_function(self.X[i], self.X[j]) 31 32 def get_w(self): # 计算更新w 33 ay = self.a * self.y 34 w = np.zeros((1, self.wn)) 35 for i in range(0, self.N): 36 w += self.X[i] * ay[i] 37 return w 38 39 def get_b(self, a1, a2, a1_old, a2_old): # 计算更新B 40 y1 = self.y[a1] 41 y2 = self.y[a2] 42 a1_new = self.a[a1] 43 a2_new = self.a[a2] 44 b1_new = -self.E[a1] - y1 * self.K[a1][a1] * (a1_new - a1_old) - y2 * self.K[a2][a1] * ( 45 a2_new - a2_old) + self.b 46 b2_new = -self.E[a2] - y1 * self.K[a1][a2] * (a1_new - a1_old) - y2 * self.K[a2][a2] * ( 47 a2_new - a2_old) + self.b 48 if (0 < a1_new) and (a1_new < self.C) and (0 < a2_new) and (a2_new < self.C): 49 return b1_new[0] 50 else: 51 return (b1_new[0] + b2_new[0]) / 2.0 52 53 def gx(self, x): # 判别函数g(x) 54 return np.dot(self.w, x) + self.b 55 56 def satisfy_kkt(self, a): # 判断样本点是否满足kkt条件 57 index = a[1] 58 if a[0] == 0 and self.y[index] * self.gx(self.X[index]) > 1: 59 return 1 60 elif a[0] < self.C and self.y[index] * self.gx(self.X[index]) == 1: 61 return 1 62 elif a[0] == self.C and self.y[index] * self.gx(self.X[index]) < 1: 63 return 1 64 return 0 65 66 def clip_func(self, a_new, a1_old, a2_old, y1, y2): # 拉格朗日乘子的裁剪函数 67 if (y1 == y2): 68 L = max(0, a1_old + a2_old - self.C) 69 H = min(self.C, a1_old + a2_old) 70 else: 71 L = max(0, a2_old - a1_old) 72 H = min(self.C, self.C + a2_old - a1_old) 73 if a_new < L: 74 a_new = L 75 if a_new > H: 76 a_new = H 77 return a_new 78 79 def update_a(self, a1, a2): # 更新a1,a2 80 partial_a2 = self.K[a1][a1] + self.K[a2][a2] - 2 * self.K[a1][a2] 81 if partial_a2 <= 1e-9: 82 print "error:", partial_a2 83 a2_new_unc = self.a[a2] + (self.y[a2] * ((self.E[a1] - self.E[a2]) / partial_a2)) 84 a2_new = self.clip_func(a2_new_unc, self.a[a1], self.a[a2], self.y[a1], self.y[a2]) 85 a1_new = self.a[a1] + self.y[a1] * self.y[a2] * (self.a[a2] - a2_new) 86 if abs(a1_new - self.a[a1]) < self.delta: 87 return 0 88 self.a[a1] = a1_new 89 self.a[a2] = a2_new 90 self.is_update = 1 91 return 1 92 93 def update(self, first_a): # 更新拉格朗日乘子 94 for second_a in range(0, self.N): 95 if second_a == first_a: 96 continue 97 a1_old = self.a[first_a] 98 a2_old = self.a[second_a] 99 if self.update_a(first_a, second_a) == 0: 100 return 101 self.b= self.get_b(first_a, second_a, a1_old, a2_old) 102 self.w = self.get_w() 103 self.E = [self.gx(self.X[i]) - self.y[i] for i in range(0, self.N)] 104 self.stop=0 105 106 def train(self, x, y, max_iternum=100): # SMO算法 107 x_len = len(x) 108 self.X = x 109 self.N = x_len 110 self.wn = len(x[0]) 111 self.y = np.array(y).reshape((self.N, 1)) 112 self.K = np.zeros((self.N, self.N)) 113 self.kernel_matrix(self.X) 114 self.b = 0 115 self.a = np.zeros((self.N, 1)) 116 self.w = self.get_w() 117 self.E = [self.gx(self.X[i]) - self.y[i] for i in range(0, self.N)] 118 self.is_update = 0 119 for i in range(0, max_iternum): 120 self.stop=1 121 data_on_bound = [[x,y] for x,y in zip(self.a, range(0, len(self.a))) if x > 0 and x< self.C] 122 if len(data_on_bound) == 0: 123 data_on_bound = [[x,y] for x,y in zip(self.a, range(0, len(self.a)))] 124 for data in data_on_bound: 125 if self.satisfy_kkt(data) != 1: 126 self.update(data[1]) 127 if self.is_update == 0: 128 for data in [[x,y] for x,y in zip(self.a, range(0, len(self.a)))]: 129 if self.satisfy_kkt(data) != 1: 130 self.update(data[1]) 131 if self.stop: 132 break 133 return self.w, self.b 134 135 def fit(self,x, y): # 训练模型, 一对一法k(k-1)/2个SVM进行多类分类 136 self.cls, y = np.unique(y, return_inverse=True) 137 self.k=len(self.cls) 138 for i in range(self.k): 139 for j in range(i): 140 a,b=self.sub_data(x,y,i,j) 141 self.train_result.append([i,j,self.train(a,b)]) 142 143 def predict(self,x_new): # 预测 144 p=np.zeros(self.k) 145 for i,j,w in self.train_result: 146 self.w=w[0] 147 self.b=w[1] 148 if self.classfy(x_new)==1: 149 p[j]+=1 150 else: 151 p[i]+=1 152 return self.cls[np.argmax(p)] 153 154 def sub_data(self,x,y,i,j): # 数据分类 155 subx=[] 156 suby=[] 157 for a,b in zip(x,y): 158 if b==i: 159 subx.append(a) 160 suby.append(-1) 161 elif b==j: 162 subx.append(a) 163 suby.append(1) 164 return subx,suby 165 166 def classfy(self,x_new): # 预测 167 y_new=self.gx(x_new) 168 cl = int(np.sign(y_new)) 169 if cl == 0: 170 cl = 1 171 return cl 172 173 174 def load_data(): 175 f = gzip.open('../data/mnist.pkl.gz', 'rb') 176 training_data, validation_data, test_data = cPickle.load(f) 177 f.close() 178 return (training_data, validation_data, test_data) 179 180 if __name__ == "__main__": 181 svc = SVC() 182 np.random.seed(0) 183 l=1000 184 training_data, validation_data, test_data = load_data() 185 svc.fit(training_data[0][:l],training_data[1][:l]) 186 predictions = [svc.predict(a) for a in test_data[0][:l]] 187 num_correct = sum(int(a == y) for a, y in zip(predictions, test_data[1][:l])) 188 print "%s of %s values correct." % (num_correct, len(test_data[1][:l])) #72/100 #808/1000 #8194/10000(较慢)
转载于:https://www.cnblogs.com/qw12/p/5744302.html
使用支持向量机训练mnist数据相关推荐
- 【python】利用两层神经网络(网络必须用类)来训练mnist数据(要求准确率90%以上)
要求: 用python自建一个class类,不能使用其他高级库函数,如pytorch,tensorflow,含有两个隐含层,隐含层数量可以指定. 准确率达到90以上. 画出学习曲线:损失曲线核准确率曲 ...
- tensorflow学习笔记——使用TensorFlow操作MNIST数据(1)
续集请点击我:tensorflow学习笔记--使用TensorFlow操作MNIST数据(2) 本节开始学习使用tensorflow教程,当然从最简单的MNIST开始.这怎么说呢,就好比编程入门有He ...
- linux手写数字识别,OpenCV 3.0中的SVM训练 mnist 手写字体识别
前言: SVM(支持向量机)一种训练分类器的学习方法 mnist 是一个手写字体图像数据库,训练样本有60000个,测试样本有10000个 LibSVM 一个常用的SVM框架 OpenCV3.0 中的 ...
- 二、如何保存MNIST数据集中train和test的图片?
如何保存MNIST数据集中train和test的图片? 介绍一种非诚神奇的图片保存方法,尤其是利用字典-format-结合来用,创建保存路径,这是一种史上很难用到的一种方法,哈哈哈哈,有点吹牛皮,不说 ...
- MNIST数据集合在PaddlePaddle环境下使用简单神经网络识别效果
简 介: 通过PaddlePaddle构造的简单神经网络对于MNIST数据库进行实验,可以看到使用普通的稠密网络,便可以达到很高的识别效果.训练结果存在一定的随机性.这取决于训练其实的条件.由于在Pa ...
- 使用Tensorflow操作MNIST数据
MNIST是一个非常有名的手写体数字识别数据集,在很多资料中,这个数据集都会被用作深度学习的入门样例.而TensorFlow的封装让使用MNIST数据集变得更加方便.MNIST数据集是NIST数据集的 ...
- LeNet训练MNIST
jupyter notebook: https://github.com/Penn000/NN/blob/master/notebook/LeNet/LeNet.ipynb LeNet训练MNIST ...
- 训练MNIST数据集模型
1. 数据集准备 详细信息见: Caffe: LMDB 及其数据转换 mnist是一个手写数字库,由DL大牛Yan LeCun进行维护.mnist最初用于支票上的手写数字识别, 现在成了DL的入门练习 ...
- CAFFE学习笔记(一)Caffe_Example之训练mnist
CAFFE学习笔记(一)Caffe_Example之训练mnist 0.参考文献 [1]caffe官网<Training LeNet on MNIST with Caffe>; [ ...
最新文章
- Office 2003出现发送错误报告怎么办
- windows下qt5 kinect 2.0开发与环境配置
- windows2008下配置iis时出现错误“由于扩展配置问题而无法提供您请求的页面。如果该页面是脚本,请添加处理程序。如果应下载文件,请添加 MIME 映射。”...
- Protocol Buffers的应用
- mysql mycat 路由规则_Mycat分库路由规则
- 01 掌握运算符的分类 1204
- 对bmp文件内存压缩 与 解压缩
- Session的clear方法和flush方法
- 使用Websocket框架之GatewayWorker开发电商平台买家与卖家实时通讯
- 累计独立访客(UV)不低于 1000是什么意思?如何查看自己小程序的UV数量?
- SQL AlawaysOn 之三:SQL服务器加入域
- 内容领先地位无法撼动,腾讯音乐与环球续约将共建新厂牌
- iOS 获取指南针的数据
- 选择框,单选框,组合框,列表框
- 学习总结-《父与子的编程之旅》chapter 6
- Arcgis创建新色带
- 关于实习和秋招的准备
- photoclip / 移动端图片上传剪裁插件 /一款手势驱动的裁图插件
- Android自定义底部弹出窗-dialog(2种实现分析+源码)
- 37.创建自定义的指令的限制使用 通过restrict 设置
热门文章
- webService学习3:客户端生成webservice代码
- 计算机考研:计算机组成原理考点分析
- idea xml文件引入类提示_IntelliJ IDEA:引用XML模式和DTD
- javascript对页面简单的加密和解密
- lib和dll的区别、生成以及使用详解
- 《leetcode》search-insert-position
- TensorFlow学习笔记(四)自己动手求Weights和biases
- define,require的基本用法
- 聊聊高并发(三十五)Java内存模型那些事(三)理解内存屏障
- Elasticsearch技术解析与实战(二)文档的CRUD操作