除了在Matlab中使用PRTools工具箱中的svm算法,Python中一样可以使用支持向量机做分类。因为Python中的sklearn库也集成了SVM算法,本文的运行环境是Pycharm。

一、导入sklearn算法包

skleran中集成了许多算法,其导入包的方式如下所示,

逻辑回归:from sklearn.linear_model import LogisticRegression

朴素贝叶斯:from sklearn.naive_bayes import GaussianNB

K-近邻:from sklearn.neighbors import KNeighborsClassifier

决策树:from sklearn.tree import DecisionTreeClassifier

支持向量机:from sklearn import svm

二、sklearn中svc的使用

(1)使用numpy中的loadtxt读入数据文件

loadtxt()的使用方法:

fname:文件路径。eg:C:/Dataset/iris.txt。

dtype:数据类型。eg:float、str等。

delimiter:分隔符。eg:‘,’。

converters:将数据列与转换函数进行映射的字典。eg:{1:fun},含义是将第2列对应转换函数进行转换。

usecols:选取数据的列。

以Iris兰花数据集为例子:

由于从UCI数据库中下载的Iris原始数据集的样子是这样的,前四列为特征列,第五列为类别列,分别有三种类别Iris-setosa, Iris-versicolor, Iris-virginica。

当使用numpy中的loadtxt函数导入该数据集时,假设数据类型dtype为浮点型,但是很明显第五列的数据类型并不是浮点型。

因此我们要额外做一个工作,即通过loadtxt()函数中的converters参数将第五列通过转换函数映射成浮点类型的数据。

首先,我们要写出一个转换函数:

1

2

3

def iris_type(s):

it= {'Iris-setosa':0,'Iris-versicolor':1,'Iris-virginica':2}

return it[s]

接下来读入数据,converters={4: iris_type}中“4”指的是第5列:

1

2

path= u'D:/f盘/python/学习/iris.data' # 数据文件路径

data= np.loadtxt(path, dtype=float, delimiter=',', converters={4: iris_type})

读入结果:

(2)将Iris分为训练集与测试集

1

2

3

x, y= np.split(data, (4,), axis=1)

x= x[:, :2]

x_train, x_test, y_train, y_test= train_test_split(x, y, random_state=1, train_size=0.6)

1. split(数据,分割位置,轴=1(水平分割) or 0(垂直分割))。

2. x = x[:, :2]是为方便后期画图更直观,故只取了前两列特征值向量训练。

3. sklearn.model_selection.train_test_split随机划分训练集与测试集。train_test_split(train_data,train_target,test_size=数字, random_state=0)

参数解释:

train_data:所要划分的样本特征集

train_target:所要划分的样本结果

test_size:样本占比,如果是整数的话就是样本的数量

random_state:是随机数的种子。

随机数种子:其实就是该组随机数的编号,在需要重复试验的时候,保证得到一组一样的随机数。比如你每次都填1,其他参数一样的情况下你得到的随机数组是一样的。但填0或不填,每次都会不一样。随机数的产生取决于种子,随机数和种子之间的关系遵从以下两个规则:种子不同,产生不同的随机数;种子相同,即使实例不同也产生相同的随机数。

(3)训练svm分类器

1

2

3

# clf = svm.SVC(C=0.1, kernel='linear', decision_function_shape='ovr')

clf= svm.SVC(C=0.8, kernel='rbf', gamma=20, decision_function_shape='ovr')

clf.fit(x_train, y_train.ravel())

kernel='linear'时,为线性核,C越大分类效果越好,但有可能会过拟合(defaul C=1)。

kernel='rbf'时(default),为高斯核,gamma值越小,分类界面越连续;gamma值越大,分类界面越“散”,分类效果越好,但有可能会过拟合。

decision_function_shape='ovr'时,为one v rest,即一个类别与其他类别进行划分,

decision_function_shape='ovo'时,为one v one,即将类别两两之间进行划分,用二分类的方法模拟多分类的结果。

(4)计算svc分类器的准确率

1

2

3

4

5

6

print clf.score(x_train, y_train)# 精度

y_hat= clf.predict(x_train)

show_accuracy(y_hat, y_train,'训练集')

print clf.score(x_test, y_test)

y_hat= clf.predict(x_test)

show_accuracy(y_hat, y_test,'测试集')

结果为:

如果想查看决策函数,可以通过decision_function()实现

1

2

print 'decision_function:\n', clf.decision_function(x_train)

print '\npredict:\n', clf.predict(x_train)

结果为:

decision_function中每一列的值代表距离各类别的距离。

(5)绘制图像

1.确定坐标轴范围,x,y轴分别表示两个特征

1

2

3

4

5

x1_min, x1_max= x[:,0].min(), x[:,0].max()# 第0列的范围

x2_min, x2_max= x[:,1].min(), x[:,1].max()# 第1列的范围

x1, x2= np.mgrid[x1_min:x1_max:200j, x2_min:x2_max:200j]# 生成网格采样点

grid_test= np.stack((x1.flat, x2.flat), axis=1)# 测试点

# print 'grid_test = \n', grid_testgrid_hat = clf.predict(grid_test) # 预测分类值grid_hat = grid_hat.reshape(x1.shape) # 使之与输入的形状相同

这里用到了mgrid()函数,该函数的作用这里简单介绍一下:

假设假设目标函数F(x,y)=x+y。x轴范围1~3,y轴范围4~6,当绘制图像时主要分四步进行:

【step1:x扩展】(朝右扩展):

[1 1 1]

[2 2 2]

[3 3 3]

【step2:y扩展】(朝下扩展):

[4 5 6]

[4 5 6]

[4 5 6]

【step3:定位(xi,yi)】:

[(1,4) (1,5) (1,6)]

[(2,4) (2,5) (2,6)]

[(3,4) (3,5) (3,6)]

【step4:将(xi,yi)代入F(x,y)=x+y】

因此这里x1, x2 = np.mgrid[x1_min:x1_max:200j, x2_min:x2_max:200j]后的结果为:

再通过stack()函数,axis=1,生成测试点

2.指定默认字体

1

2

mpl.rcParams['font.sans-serif']= [u'SimHei']

mpl.rcParams['axes.unicode_minus']= False

3.绘制

1

2

3

4

5

6

7

8

9

10

11

12

cm_light= mpl.colors.ListedColormap(['#A0FFA0','#FFA0A0','#A0A0FF'])

cm_dark= mpl.colors.ListedColormap(['g','r','b'])

plt.pcolormesh(x1, x2, grid_hat, cmap=cm_light)

plt.scatter(x[:,0], x[:,1], c=y, edgecolors='k', s=50, cmap=cm_dark)# 样本

plt.scatter(x_test[:,0], x_test[:,1], s=120, facecolors='none', zorder=10)# 圈中测试集样本

plt.xlabel(u'花萼长度', fontsize=13)

plt.ylabel(u'花萼宽度', fontsize=13)

plt.xlim(x1_min, x1_max)

plt.ylim(x2_min, x2_max)

plt.title(u'鸢尾花SVM二特征分类', fontsize=15)

# plt.grid()

plt.show()

pcolormesh(x,y,z,cmap)这里参数代入x1,x2,grid_hat,cmap=cm_light绘制的是背景。

scatter中edgecolors是指描绘点的边缘色彩,s指描绘点的大小,cmap指点的颜色。

xlim指图的边界。

最终结果为:

svm算法python实现_(转载)python应用svm算法过程相关推荐

  1. python 统计检验_[转载]Python替代SPSS进行各项统计检验

    采用python的scipy库完成常用的假设检验, 配合pandas库非常好用 正态性检验 检验数据样本是否具有高斯分布. from scipy.stats import shapiro data = ...

  2. 第一章 第一节:Python基础_认识Python

    Python基础入门(全套保姆级教程) 第一章 第一节:Python基础_认识Python 1. 什么是编程 通俗易懂,编程就是用代码编写程序,编写程序有很多种办法,像c语言,javaPython语言 ...

  3. python网格搜索核函数_(转载)Python机器学习笔记GridSearchCV(网格搜索)

    转载声明 介绍 在机器学习模型中,需要人工选择的参数称为超参数.比如随机森林中决策树的个数,人工神经网络模型中隐藏层层数和每层的节点个数,正则项中常数大小等等,他们都需要事先指定.超参数选择不恰当,就 ...

  4. 排序算法python实现_用Python,Java和C / C ++实现的选择排序算法

    排序算法python实现 The Selection Sort Algorithm sorts the elements of an array. In this article, we shall ...

  5. 随机森林python实例_用Python实现随机森林算法的示例

    这篇文章主要介绍了用Python实现随机森林算法,小编觉得挺不错的,现在分享给大家,也给大家做个参考. 拥有高方差使得决策树(secision tress)在处理特定训练数据集时其结果显得相对脆弱.b ...

  6. python蚁群算法路径规划_使用python实现蚁群算法

    此次使用python实现蚁群算法是仿照蚁群优化算法的JAVA实现中的蚁群算法实现方法,使用的也是其中的数据(此处为上传数据),如需更深一步了解蚁群算法原理和具体实现过程,请参考蚁群优化算法的JAVA实 ...

  7. format函数python的顺序_[转载] Python中format函数用法

    Python中format函数用法 format优点 format是python2.6新增的格式化字符串的方法,相对于老版的%格式方法,它有很多优点. 1.不需要理会数据类型的问题,在%方法中%s只能 ...

  8. louvian算法 缺点 优化_机器学习中的优化算法(1)-优化算法重要性,SGD,Momentum(附Python示例)...

    本系列文章已转至 机器学习的优化器​zhuanlan.zhihu.com 优化算法在机器学习中扮演着至关重要的角色,了解常用的优化算法对于机器学习爱好者和从业者有着重要的意义. 这系列文章先讲述优化算 ...

  9. python 字符识别_使用python进行光学字符识别入门

    python 字符识别 语言模型设计 (Language Model Designing) Optical Character Recognition is the conversion of 2-D ...

  10. 类的继承python事例_【Python五篇慢慢弹(5)】类的继承案例解析,python相关知识延伸...

    作者:白宁超 2016年10月10日22:36:57 摘要:继一文之后,笔者又将python官方文档认真学习下.官方给出的pythondoc入门资料包含了基本要点.本文是对文档常用核心要点进行梳理,简 ...

最新文章

  1. 安卓系列七(广播机制)
  2. 工作中不能学的6种人
  3. django Error: [Errno 10013]
  4. oracle 11查询sid,oracle 11g 更改sid和dbname
  5. XAML Namespace http://schemas.microsoft.com/expression/blend/2008 is not resolved
  6. [ARM-assembly]-ARM交叉编译器下编译的各个镜像的反汇编文件分析
  7. System.getProperty()的用途
  8. idea 创建多模块依赖Maven项目
  9. 为什么我认为现阶段HIDS处于攻防不对等的地位?(ids、nta、绕过)
  10. android 系统(143)---Android实现App版本自动更新
  11. 洛谷——P1657 选书
  12. NoticeBoard 一个仿原生UI的消息通知控件
  13. 连接mysql超过连接次数处理办法
  14. Linux 命令(55)—— netstat 命令
  15. xcode中使用正则表达式来搜索替换代码
  16. 【高效程序员系列】1、好马配好鞍——舒适的工作环境
  17. AxureUX中后台管理信息系统通用原型方案
  18. 【CSS】从熟悉到更熟悉
  19. arduino中u8g2汉字显示总结
  20. mysql和JDBC学习

热门文章

  1. Go:十进制转二进制算法(附完整源码)
  2. 支付宝小程序花呗分期插件
  3. Facebook发币,互联网与区块链的生死竞速开始了
  4. 基于OpenPose的人体姿态检测(非常好)
  5. mysql alert 新增字段_mysql 增加字段 sql 语句
  6. 简单为蒲公英在线教学系统进行优化-04
  7. unity shader ASE连线 海浪效果
  8. 论文抽检判定抄袭的标准?
  9. 大彩科技新一代HMI人机界面KM4.3寸新品发布!
  10. 手机html音乐播放器代码隐藏,教程方法;仿酷狗html5手机音乐播放器主要部分代码电脑技巧-琪琪词资源网...