摘要

本人在学习《Python机器学习基础教程》时的一些小实验。

一、认识鸢尾花数据

python的机器学习库scikit-learn中保存了大量的经典的实验数据集,在学习阶段没有办法搜集大数据的情况下,可以用这些数据进行学习。
首先从sklearn.datasets中导入鸢尾花(iris)的数据集

from sklearn.datasets import load_iris
iris = load_iris()

对象iris是一个Bunch对象,与字典很类似,可以用iris.keys()查看所有键。这里只介绍需要用得到的键

print(iris['DESCR'], '\n') #鸢尾花数据集的摘要
print("The target name of iris: {}".format(iris['target_names'])) #标签名字
print("The target by number: \n{}".format(iris['target'])) #标签
.. _iris_dataset:Iris plants dataset
--------------------**Data Set Characteristics:**:Number of Instances: 150 (50 in each of three classes):Number of Attributes: 4 numeric, predictive attributes and the class:Attribute Information:- sepal length in cm- sepal width in cm- petal length in cm- petal width in cm- class:- Iris-Setosa- Iris-Versicolour- Iris-Virginica:Summary Statistics:============== ==== ==== ======= ===== ====================Min  Max   Mean    SD   Class Correlation============== ==== ==== ======= ===== ====================sepal length:   4.3  7.9   5.84   0.83    0.7826sepal width:    2.0  4.4   3.05   0.43   -0.4194petal length:   1.0  6.9   3.76   1.76    0.9490  (high!)petal width:    0.1  2.5   1.20   0.76    0.9565  (high!)============== ==== ==== ======= ===== ====================:Missing Attribute Values: None:Class Distribution: 33.3% for each of 3 classes.:Creator: R.A. Fisher:Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov):Date: July, 1988The famous Iris database, first used by Sir R.A. Fisher. The dataset is taken
from Fisher's paper. Note that it's the same as in R, but not as in the UCI
Machine Learning Repository, which has two wrong data points.This is perhaps the best known database to be found in the
pattern recognition literature.  Fisher's paper is a classic in the field and
is referenced frequently to this day.  (See Duda & Hart, for example.)  The
data set contains 3 classes of 50 instances each, where each class refers to a
type of iris plant.  One class is linearly separable from the other 2; the
latter are NOT linearly separable from each other... topic:: References- Fisher, R.A. "The use of multiple measurements in taxonomic problems"Annual Eugenics, 7, Part II, 179-188 (1936); also in "Contributions toMathematical Statistics" (John Wiley, NY, 1950).- Duda, R.O., & Hart, P.E. (1973) Pattern Classification and Scene Analysis.(Q327.D83) John Wiley & Sons.  ISBN 0-471-22361-1.  See page 218.- Dasarathy, B.V. (1980) "Nosing Around the Neighborhood: A New SystemStructure and Classification Rule for Recognition in Partially ExposedEnvironments".  IEEE Transactions on Pattern Analysis and MachineIntelligence, Vol. PAMI-2, No. 1, 67-71.- Gates, G.W. (1972) "The Reduced Nearest Neighbor Rule".  IEEE Transactionson Information Theory, May 1972, 431-433.- See also: 1988 MLC Proceedings, 54-64.  Cheeseman et al"s AUTOCLASS IIconceptual clustering system finds 3 classes in the data.- Many, many more ... The target name of iris: ['setosa' 'versicolor' 'virginica']
The target by number:
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 00 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 11 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 22 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 22 2]

说明:

  1. feature_name是特征名,一共有四种,分别代表花萼长,花萼宽,花瓣长,花瓣宽
:Attribute Information:- sepal length in cm- sepal width in cm- petal length in cm- petal width in cm
  1. target_name是标签(类别)名,一共分为三类
- Class:- Iris-Setosa- Iris-Versicolour- Iris-Virginica
  1. target是标签的代号,如下
{'Iris-Setosa':0, 'Iris-Versicolour':1, 'Iris-Virginica':2}
  1. data是鸢尾花的样本数据集,一共有150个样本点,每个样本点提供了4个特征和分类的数据.
    我们用pandas的DataFrame来展现数据集(限于篇幅只展示部分)
import pandas as pd
from IPython.display import displaydata = {'sl':iris['data'][:,0],'sw':iris['data'][:,1],'pl':iris['data'][:,2],'pw':iris['data'][:,3],'target':iris['target']}
data_pandas = pd.DataFrame(data)display(data_pandas[data_pandas.target==0])
display(data_pandas[data_pandas.target==1])
display(data_pandas[data_pandas.target==2])

二、划分训练数据和测试数据

由于没有别的鸢尾花数据,我们只能把鸢尾花数据分为训练数据和测试数据,在训练数据上跑模型,在测试数据上检验我们的模型。
测试数据和训练数据的比例为1:3

sklearn.model_selection中可以导入函数train_test_split,它的主要作用是根据传入的X,Y,随机数种子来随机划分数据。

from sklearn.model_selection import train_test_split as tts
X_train, X_test, y_train, y_test = tts(iris['data'], iris['target'],\random_state=1)
print("X_train shape: {}".format(X_train.shape))
print("X_test shape: {}".format(X_test.shape))
print("y_train shape: {}".format(y_train.shape))
print("y_test shape: {}".format(y_test.shape))>>>
X_train shape: (112, 4)
X_test shape: (38, 4)
y_train shape: (112,)
y_test shape: (38,)

此时我们再展示一下训练数据都包含哪些样本点(部分)

data_train = {'sl':X_train[:,0],'sw':X_train[:,1],'pl':X_train[:,2],'pw':X_train[:,3],'target':y_train}
data_pandas = pd.DataFrame(data_train)
display(data_pandas)

三、建立神经网络(多层感知机)模型

(这里不介绍神经网络的实现方法)

sklearn.neural_network中导入多层感知机分类器MLPClassifier,它有几个重要的参数

  1. solver:建模的方法,有{'lbfgs', 'sgd', 'adam'}, default='adam
  2. random_state:随机数种子,用于权重的初始化
  3. hidden_layer_sizes:隐层数目和隐层节点数目,例如[10,100]表示两个隐层,第一个有10个节点,第二个有100个节点
  4. max_iter:最大迭代次数
  5. activation : 隐层激活函数,有{'identity', 'logistic', 'tanh', 'relu'}, default='relu'
  6. epsilon : 精度,默认为1e-8
  7. 还有一些参数比如正则化参数alpha,学习率learning_rate等等,有需要用时再查找。
from sklearn.neural_network import MLPClassifier as MLP
mlp = MLP(solver='lbfgs', random_state=1, \hidden_layer_sizes=[10], max_iter=1000)
mlp.fit(X_train, y_train)print("Accuracy on training set: {:.3f}".format(mlp.score(X_train, y_train)))
print("Accuracy on testing  set: {:.3f}".format(mlp.score(X_test, y_test)))>>>
Accuracy on training set: 0.982
Accuracy on testing  set: 1.000

一个隐层,10个隐层节点可以保持在测试机上100%的精度,下面看看每个输入对隐层节点的权重ωij\omega_{ij}ωij​的可视化

plt.figure(figsize=(20, 5))
plt.imshow(mlp.coefs_[0], interpolation='none', cmap='viridis')
plt.yticks(range(4), iris.feature_names)
plt.xlabel("Columns in weight matrix")
plt.ylabel("Input feature")
plt.colorbar()>>>
<matplotlib.colorbar.Colorbar at 0x2234ae09c88>


颜色越浅表示权重越高,这里面有sw特征对隐层第4个节点,pl特征对隐层第10个节点的权重较高,可以初步判断这两个特征对于判别一个鸢尾花的品种比较重要。

四、调参

1. 单层隐层

通过调整参数比如hidden_layer_sizes来分析精度随着它的变化的情况。
这里选用一层隐层,节点个数从1变化到100

#调整隐层节点个数
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as nphidden_lst = []
train_score = []
test_score = []
for hidden in range(1, 101):mlp = MLP(solver='lbfgs', random_state=1, hidden_layer_sizes=[hidden],\max_iter=1000)mlp.fit(X_train, y_train)hidden_lst.append(hidden)train_score.append(mlp.score(X_train, y_train))test_score.append(mlp.score(X_test, y_test))plt.figure(figsize=(20, 5))
plt.plot(hidden_lst, train_score,label="train_score")
plt.plot(hidden_lst, test_score,label="test_score ")
plt.ylabel("Accuracy")
plt.xlabel("hidden_num")
plt.legend()
print("Max accuracy of train set: {0}, min accuracy: {1}, mean accuracy: {2}".format(max(train_score), min(train_score),\np.mean(train_score)))
print("Max accuracy of test  set: {0}, min accuracy: {1}, mean accuracy: {2}".format(max(test_score), min(test_score),\np.mean(test_score)))>>>
Max accuracy of train set: 1.0, min accuracy: 0.36607142857142855, mean accuracy: 0.9594642857142859
Max accuracy of test  set: 1.0, min accuracy: 0.23684210526315788, mean accuracy: 0.9578947368421052

可以发现,训练集和测试集上的最大精度达到100%,最小精度分别为0.37, 0.24,平均精度都在0.95左右。

精度曲线

所以,实际上将隐层节点设置在10个左右就可以得到很好的精度,模型也简单。

2. 多层隐层

首先是双层隐层,每个隐层的节点个数为10

mlp = MLP(solver='lbfgs', random_state=1, hidden_layer_sizes=[10,10], max_iter=1000)
mlp.fit(X_train, y_train)print("Accuracy on training set: {:.3f}".format(mlp.score(X_train, y_train)))
print("Accuracy on testing  set: {:.3f}".format(mlp.score(X_test, y_test)))>>>
Accuracy on training set: 0.991
Accuracy on testing  set: 1.000

然后我们将第二个隐层的节点数调整为100

mlp = MLP(solver='lbfgs', random_state=1, hidden_layer_sizes=[10,100], max_iter=1000)
mlp.fit(X_train, y_train)print("Accuracy on training set: {:.3f}".format(mlp.score(X_train, y_train)))
print("Accuracy on testing  set: {:.3f}".format(mlp.score(X_test, y_test)))
>>>
Accuracy on training set: 0.982
Accuracy on testing  set: 1.000

精度有所下降,如果我们再将层数改为三层,每层10个节点的话

mlp = MLP(solver='lbfgs', random_state=1, hidden_layer_sizes=[10,10,10], max_iter=1000)
mlp.fit(X_train, y_train)print("Accuracy on training set: {:.3f}".format(mlp.score(X_train, y_train)))
print("Accuracy on testing  set: {:.3f}".format(mlp.score(X_test, y_test)))>>>
Accuracy on training set: 0.795
Accuracy on testing  set: 0.605

发现精度大大下降,层数变多,计算的次数变多,拟合更加困难。

3. 调整隐层个数

我们通过调整隐层个数从1到50,来分析精度的变化

hidden_lst_up = []
train_score_up = []
test_score_up = []
h = [10]
for hidden in range(51):mlp = MLP(solver='lbfgs', random_state=1, hidden_layer_sizes=h, max_iter=5000)mlp.fit(X_train, y_train)hidden_lst_up.append(hidden)train_score_up.append(mlp.score(X_train, y_train))test_score_up.append(mlp.score(X_test, y_test))h.append(10)
plt.figure(figsize=(10, 5))
plt.plot(hidden_lst_up, train_score_up,label="train_score")
plt.plot(hidden_lst_up, test_score_up,label="test_score ")
plt.ylabel("Accuracy")
plt.xlabel("hidden_num")
plt.legend()
print("Max accuracy of train set: {0}, min accuracy: {1}, mean accuracy: {2}".format(max(train_score_up), min(train_score_up),\np.mean(train_score_up)))
print("Max accuracy of test  set: {0}, min accuracy: {1}, mean accuracy: {2}".format(max(test_score_up), min(test_score_up),\np.mean(test_score_up)))>>>
Max accuracy of train set: 0.9910714285714286, min accuracy: 0.30357142857142855, mean accuracy: 0.5087535014005601
Max accuracy of test  set: 1.0, min accuracy: 0.23684210526315788, mean accuracy: 0.43292053663570695


当隐层个数增大是,精度下降的十分明显,但也有意外之处,当隐层个数为26时,精度意外的好。

print(np.array(list(enumerate(train_score_up))))
print(np.array(list(enumerate(test_score_up))))>>>
[[ 0.          0.98214286][ 1.          0.99107143][ 2.          0.79464286][ 3.          0.97321429][ 4.          0.97321429][ 5.          0.99107143][ 6.          0.96428571][ 7.          0.41071429][ 8.          0.875     ][ 9.          0.33035714][10.          0.30357143][11.          0.875     ][12.          0.63392857][13.          0.33035714][14.          0.5625    ][15.          0.36607143][16.          0.36607143][17.          0.58035714][18.          0.69642857][19.          0.36607143][20.          0.33035714][21.          0.69642857][22.          0.69642857][23.          0.36607143][24.          0.36607143][25.          0.36607143][26.          0.97321429][27.          0.36607143][28.          0.36607143][29.          0.36607143][30.          0.36607143][31.          0.36607143][32.          0.36607143][33.          0.36607143][34.          0.36607143][35.          0.36607143][36.          0.36607143][37.          0.36607143][38.          0.36607143][39.          0.36607143][40.          0.36607143][41.          0.36607143][42.          0.36607143][43.          0.36607143][44.          0.36607143][45.          0.36607143][46.          0.36607143][47.          0.36607143][48.          0.36607143][49.          0.36607143][50.          0.36607143]]
>>>
[[ 0.          1.        ][ 1.          1.        ][ 2.          0.60526316][ 3.          1.        ][ 4.          1.        ][ 5.          1.        ][ 6.          1.        ][ 7.          0.28947368][ 8.          0.78947368][ 9.          0.34210526][10.          0.42105263][11.          1.        ][12.          0.76315789][13.          0.34210526][14.          0.71052632][15.          0.23684211][16.          0.23684211][17.          0.68421053][18.          0.57894737][19.          0.23684211][20.          0.34210526][21.          0.57894737][22.          0.57894737][23.          0.23684211][24.          0.23684211][25.          0.23684211][26.          0.94736842][27.          0.23684211][28.          0.23684211][29.          0.23684211][30.          0.23684211][31.          0.23684211][32.          0.23684211][33.          0.23684211][34.          0.23684211][35.          0.23684211][36.          0.23684211][37.          0.23684211][38.          0.23684211][39.          0.23684211][40.          0.23684211][41.          0.23684211][42.          0.23684211][43.          0.23684211][44.          0.23684211][45.          0.23684211][46.          0.23684211][47.          0.23684211][48.          0.23684211][49.          0.23684211][50.          0.23684211]]

小结

神经网络模型可以达到很高的精度,但这依赖于参数的调整,特别是隐层数目和隐层节点数目;
神经网络的非线性拟合功能很强,虽然是基于梯度下降法来获取系数权重,但是还是有点黑箱功能。

基于简单神经网络模型的鸢尾花分类问题相关推荐

  1. 深度学习修炼(五)——基于pytorch神经网络模型进行气温预测

    文章目录 5 基于pytorch神经网络模型进行气温预测 5.1 实现前的知识补充 5.1.1 神经网络的表示 5.1.2 隐藏层 5.1.3 线性模型出错 5.1.4 在网络中加入隐藏层 5.1.5 ...

  2. 图片2分类卷积神经网络模型训练、分类预测案例全过程(1)

    图片2分类卷积神经网络模型训练.分类预测案例全过程(1) 前言 (1)尽管目前有关卷积神经网络深度学习的相关材料较多,但深度学习牵涉到数据预处理.模型构建.模型调用等环节,我也是一个初学者,中间有很多 ...

  3. 基于循环神经网络模型(GRU)的新型冠状病毒肺炎流行趋势预测

    资源下载地址:https://download.csdn.net/download/sheziqiong/85639079 资源下载地址:https://download.csdn.net/downl ...

  4. 图片2分类卷积神经网络模型训练、分类预测案例全过程(2)

    上一篇博客内容讲述了卷积神经网络模型构建.训练以及模型的保存,包括训练样本数据的预处理和喂给网络. 本篇博客内容讲述训练好的模型的应用和实际图片数据的分类预测. 图片2分类卷积神经网络模型训练.分类预 ...

  5. 基于K-最近邻算法构建鸢尾花分类模型

    基于K-最近邻算法构建鸢尾花分类模型 一 任务描述 鸢尾花(Iris)数据集是机器学习中一个经典的数据集.假设有一名植物学爱好者收集了150朵鸢尾花的测量数据:花瓣的长度和宽度以及花萼的长度和宽度,这 ...

  6. 基于神经网络模型的二分类--以Creditcard DataSet数据集为例

    数据集特点:Creditcard数据集包含711个样本:29个特征:1个标签(0.1表示信用卡是否出现问题) 问题定义 单标签二分类问题 标签的取值只有两种,并且只有一个需要预测的标签 解决方法:构建 ...

  7. bp神经网络预测模型_基于BP神经网络模型的河南省严重精神障碍患者服药依从性影响因素分析...

    发表文章 文章发表背景1 严重精神障碍主要包括精神分裂症.分裂情感性障碍.偏执性精神病等6种疾病,目前此类患者的主要治疗方法是社区抗精神病药维持治疗,虽然各类抗精神病药不断问世,但是患者服药依从性情况 ...

  8. 基于Tensorflow的卷积神经网络模型实现水果分类识别(实践案例)

    前言 写这篇博客的目的,就是记录下实现Fruit Dataset Image Classification Network的过程,所以从头开始写.这里感谢下会飞的小咸鱼提供了思路,文章内容主要翻译自K ...

  9. 刚刚涉足神经网络,基于TensorFlow2.0以实现鸢尾花分类为例总结神经网络代码实现的几个步骤,附代码详细讲解

    前言 总体来看,一个简单的神经网络,在准备数据和参数定义后就已经被搭建起来了,这便是神经网络的骨架.我们后面补入的参数优化是为了让这个神经网络能够朝着我们希望的方向进行迭代,最后能获取到符合我们预期的 ...

最新文章

  1. linux shell 日期比较大小,在Shell中使用日期运算和比较详解
  2. 关于级数∑(x n-x n-1)一致收敛性的一点儿理解
  3. 配置token_Nginx常用的配置
  4. SpringBatch适配器详解
  5. 互联网人的恶梦是加班?不,是饥荒!
  6. Android 中的拿来主义(编译,反编译,AXMLPrinter2,smali,baksmali)!
  7. android+ndk+r9+x64下载,Win7 64位中文旗舰版上Cocos2d-x 3.0的Android开发调试环境架设
  8. 【JAVA基础】HashSet、LinkedHashSet、TreeSet使用区别
  9. cmd怎么使用post请求’_flutter中dio的post请求方式使用总结
  10. J2EE技术-Hibernate
  11. Mac中无法运行旧版本印象笔记:版本太旧 你的本地印象笔记数据是由新版印象笔记管理
  12. python中θ符号怎么打出来_Python打印特殊符号及对应编码解析
  13. TVS与ESD的区别
  14. 大学生bootstrap框架网页作业成品 bootstrap响应式网页制作模板 学生海贼王动漫bootstrap框架网站作品
  15. 帝国cms系统使用初级教程二(较全面)
  16. 本地网络出现了一个意外的情况,不能完成所有你在设置中所要求的更改?
  17. 学习编程需要什么英语基础?
  18. 3个月测试员自述:4个影响我职业生涯的重要技能
  19. drupal 6初始安装,中文汉化 简明教程
  20. 手机bootstrap搜索框_这些桌面小部件,Android 手机可不能错过

热门文章

  1. 关于TP5静态文件加载不出来
  2. 浙江农林大学第二十二届程序设计竞赛部分题解
  3. fpga.一些学习感悟以及细节方面
  4. 北京 春暖花开沙拂面
  5. 搜题公众号制作简单教学
  6. 树脂除杂在锂溶液中的应用、硫酸锂除钙镁方法
  7. 日常生活开支记账明细_教你记账管理家庭日常生活收入支出明细的实例
  8. 城市垃圾类毕业论文文献有哪些?
  9. 基于Matlab的Robotics Toolbox工具箱的机器人仿真函数介绍(空间位姿表示与动力学)
  10. 物联网技术周报第 126 期: 使用 Yocto 构建 Raspberry Pi 系统