鸢尾花分类问题是机器学习领域一个非常经典的问题,本文将利用神经网络来实现鸢尾花分类

实验环境:Windows10、TensorFlow2.0、Spyder

参考资料:人工智能实践:TensorFlow笔记第一讲

1、鸢尾花分类问题描述

根据鸢尾花的花萼、花瓣的长度和宽度可以将鸢尾花分成三个品种

我们可以使用以下代码读取鸢尾花数据集

from sklearn.datasets import load_iris
x_data = load_iris().data
y_data = load_iris().target

该数据集含有150个样本,每个样本由四个特征和一个标签组成,四个特征分别为:

  • 花萼长度
  • 花萼宽度
  • 花瓣长度
  • 花瓣宽度

标签值为:

标签值 0 1 2
鸢尾花品种 山鸢尾 变色鸢尾 维吉尼亚鸢尾

2、基于神经网络的解决方法

本文搭建的神经网络由一个输入层(包含4个输入节点)、一个输出层(包含3个输出节点)组成。

神经网络

单个神经元的结构如下

本方法略去了激活函数,直接将加权结果作为输出,加权公式为:
y = x ∗ w + b y=x*w+b y=x∗w+b
其中 x x x为 1 × 4 1\times 4 1×4矩阵, y y y为 1 × 3 1\times 3 1×3矩阵, w w w为 4 × 3 4\times 3 4×3的矩阵, b b b为 1 × 3 1\times 3 1×3矩阵

损失函数

用来描述预测值 y y y与真实标签 y . ‾ y\underline{.} y.​的差距,本文方法使用均方误差来描述损失函数。
M S E ( y , y . ‾ ) = ∑ k = 0 n ( y − y . ‾ ) 2 n MSE(y,y\underline{.})=\frac{\sum_{k=0}^n(y-y\underline{.})^2}{n} MSE(y,y.​)=n∑k=0n​(y−y.​)2​

参数优化

为了找到一组参数 w w w和 b b b使损失函数最小,本文使用梯度下降法进行参数优化

梯度下降法:沿损失函数梯度下降的方向,寻找损失函数的最小值,得到最优参数的方法。

梯度下降法即对损失函数中的各个参量求偏导,得到的结果即为损失函数梯度下降的方向。公式如下
w t + 1 = w t − l r ⋅ ∂ l o s s ∂ w t b t + 1 = b t − l r ⋅ ∂ l o s s ∂ b t w t + 1 ⋅ x + b t + 1 → y w_{t+1}=w_t-lr\cdot \frac{\partial loss}{\partial w_t}\\ b_{t+1}=b_t-lr\cdot \frac{\partial loss}{\partial b_t}\\ w_{t+1}\cdot x+b_{t+1}\to y wt+1​=wt​−lr⋅∂wt​∂loss​bt+1​=bt​−lr⋅∂bt​∂loss​wt+1​⋅x+bt+1​→y
其中 l r lr lr表示学习率,不同的学习率会对参数更新造成不同的影响,如学习率过小,会造成参数更新过慢;学习率过大,会造成损失函数震荡。

3、程序实现

完整程序如下:

# -*- coding: utf-8 -*-
"""
Created on Thu Apr  9 11:01:13 2020
"""
# 鸢尾花分类
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt# 读入数据集
from sklearn.datasets import load_iris
x_data = load_iris().data
y_data = load_iris().target# 打乱数据集
np.random.seed(116)
np.random.shuffle(x_data)
np.random.seed(116)
np.random.shuffle(y_data)# 选择倒数第30之前的数据作为训练集
x_train = x_data[:-30]
y_train = y_data[:-30]
# 选择倒数第30之后的数据作为测试集
x_test  = x_data[-30:]
y_test  = y_data[-30:]x_train = tf.cast(x_train, tf.float32)
x_test = tf.cast(x_test, tf.float32)# 分批处理
train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(32)
test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)# 随机初始化待更新的参数
w1 = tf.Variable(tf.random.truncated_normal([4, 3], stddev=0.1))
b1 = tf.Variable(tf.random.truncated_normal([3], stddev=0.1))lr = 0.2 # 学习率/步长
epoch = 300 # 迭代总次数
loss_all = 0 # 每次迭代的损失
loss_list = [] # 存储每一次迭代的损失
acc_list = [] # 存储每一次迭代结果的准确率
for epoch in range(epoch):# 训练 # 更新权重for step, (x_train, y_train) in enumerate(train_db):with tf.GradientTape() as tape:# 前向传播得到当前权值下的推理结果 y = x * w1 + b1y = tf.matmul(x_train, w1) + b1;# 使用softmax将推理结果转换到[0, 1]之间y = tf.nn.softmax(y)# 将标签转换为独热码,即0:0 0 1, 1:0 1 0, 2:1 0 0y_ = tf.one_hot(y_train, depth=3)# 求均方误差loss = tf.reduce_mean(tf.square(y_ - y))loss_all += loss.numpy()# 分别对损失函数的w1、b1求偏导grads = tape.gradient(loss, [w1, b1])# 更新w1、b1 w1 = w1 - lr * w1_grad  b1 = b1 - lr * b1_gradw1.assign_sub(lr * grads[0])b1.assign_sub(lr * grads[1])# 打印此次迭代的损失print("Ecoph:{}, Loss:{}".format(epoch, loss_all / 4))loss_list.append(loss_all / 4)loss_all = 0# 测试# 计算此次迭代结果的正确率# 在真实训练时可以略过,这里只是为了画出正确率曲线total_correct, total_number = 0, 0for x_test, y_test in test_db:# 前向传播得到当前权值下的推理结果 y = x * w1 + b1y = tf.matmul(x_test, w1) + b1# 使用softmax将预测结果转换到[0, 1]之间y = tf.nn.softmax(y)# 找到最大值的索引pred = tf.argmax(y, axis=1)pred = tf.cast(pred, dtype=y_test.dtype)# 将预测结果与真实标签对比correct = tf.cast(tf.equal(pred, y_test), dtype=tf.int32)correct = tf.reduce_sum(correct)total_correct += int(correct)total_number += x_test.shape[0]acc = total_correct / total_numberacc_list.append(acc)print("Acc:", acc)print("-------------------")# 画出损失函数曲线
plt.plot(loss_list)
plt.title("Loss Curve")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.show()
# 画出正确率曲线
plt.plot(acc_list)
plt.title("Acc Curve")
plt.xlabel("Epoch")
plt.ylabel("Acc")
plt.show()

4、运行结果

  • 损失函数曲线如下,可以看到损失函数随着迭代次数的增加逐渐减小
  • 正确率曲线

如有谬误,敬请指正!

机器学习 | 使用TensorFlow搭建神经网络实现鸢尾花分类相关推荐

  1. Tensorflow搭建神经网络八股及实现鸢尾花数据集分类

    tensorflow笔记系列文章均参考自中国大学Mooc上北京大学软件与微电子学院曹建老师的<Tensorflow笔记2>课程.曹建老师讲的非常棒,受益良多,强烈建议tensorflow初 ...

  2. 神经网络实现鸢尾花分类(Tensorflow2.0)

    1鸢尾花数据 回顾鸢尾花数据集,其提供了150组鸢尾花数据,每组包括鸢尾花的花萼长.花萼宽.花瓣长.花瓣宽 4个输入特征,同时还给出了这一组特征对应的鸢尾花类别.类别包括狗尾鸢尾.杂色鸢尾.弗吉尼亚鸢 ...

  3. Tensorflow2.x框架-神经网络实现鸢尾花分类

    神经网络实现鸢尾花分类 一.数据准备     1.数据集读入 2.数据集乱序 3.生成训练集和测试集(即 x_train / y_train,x_test / y_test) 4.配成(输入特征,标签 ...

  4. 神经网络实现鸢尾花分类

    神经网络实现鸢尾花分类 一.数据集介绍 共有数据150组,每组包括花萼长.花萼宽.花瓣长.花瓣宽4个输入特征. 同时给出了,这一组特征对应的鸢尾花类别.类别包括Setosa Iris(狗尾草 鸢尾), ...

  5. [python] 深度学习基础------人工神经网络实现鸢尾花分类(一)

    ​​​​​​​人工神经网络实现鸢尾花分类(一) 人工神经网络实现鸢尾花分类(二) 人工神经网络实现鸢尾花分类(三) 人工神经网络实现鸢尾花分类(四) 人工神经网络实现鸢尾花分类(五) 目录 人工智能主 ...

  6. 猿创征文|深度学习基于前馈神经网络完成鸢尾花分类

    大家我是猿童学!这次给大家带来的是基于前馈神经网络完成鸢尾花分类! 在本实验中,我们使用的损失函数为交叉熵损失:优化器为随机梯度下降法:评价指标为准确率. 一.小批量梯度下降法 在梯度下降法中,目标函 ...

  7. 最简单的单层神经网络实现鸢尾花分类

    一,知识背景 鸢尾花的分类由四个数据定义,分别是花萼长.花萼宽.花瓣长.花瓣宽.我们把这样的一组数据称为是一组特征,根据特征可以分为三类鸢尾花. 二,神经元模型 神经元采用最简单的简化MP(麦卡洛克- ...

  8. 2、python机器学习基础教程——K近邻算法鸢尾花分类

    一.第一个K近邻算法应用:鸢尾花分类 import numpy as np from sklearn.datasets import load_iris from sklearn.model_sele ...

  9. Python每日一练(机器学习)——第43天:鸢尾花分类

    文章目录 1. 鸢尾花分类(1) 2. 鸢尾花分类_2 <100天精通Python>专栏推荐白嫖80g Python全栈视频 废话少说速度上号刷题卷起来 1. 鸢尾花分类(1) 描述: 请 ...

最新文章

  1. Java泛型使用需要小心
  2. ASP.NET 下载文件方式
  3. 计算机视觉开源库OpenCV之图像翻转
  4. 关于COPC后台配置的几个关键步骤及其事务代码
  5. 作为一名Java开发者应该掌握的基础知识汇总!
  6. qss样式表笔记大全(一):qss名词解析(包含相关示例)
  7. 如何用最短的时间学会C语言,并掌握C语言的精髓所在?
  8. dumpsys gfxinfo packacges计算帧率
  9. 如何编辑PDF文件?分享几种编辑PDF文件方法
  10. 计算机室管理员考核细则,宿舍管理员量化考核细则
  11. VIN码识别技术在移动端的应用
  12. 【2015NOIP模拟】【Ocd】【Mancity】【Captcha】10.31总结
  13. android 正则句子按照标点符号断句,正则Pattern;
  14. 了解ES6 The Dope Way第五部分:类,转译ES6代码和更多资源!
  15. 编辑重命名文件夹下多个文件名,一键操作技巧
  16. java项目业绩怎么写,GitHub已标星16k
  17. NOKIA N8 和 Nokia Qt SDK
  18. Android 直播 播放器 IJK播放器低延时120ms
  19. pyspark之dataframe当前行与上一行值求差
  20. win7_oracle11g_64位连接32位PLSQL_Develop

热门文章

  1. 图片查看器viewer
  2. MAC系统安装Hadoop
  3. 动态规划经典题目:最大连续子序列和、最大不连续子序列和
  4. STM32----IIC详解
  5. 测试工程师python常见面试题_测试人员python面试题
  6. 在线式测斜仪是一款新型的、智能的、适应多种行业应用的三轴智能测斜仪
  7. Visual Tracking Using Attention-Modulated Disintegration and Integration
  8. 身体知道LGG益生菌酸奶营养高不高?
  9. AT89C51+ULN2003A+中断=控制(跑马灯+步进电机)
  10. java如何解决写者优先问题_第二类读者写者问题(写者优先)的信号量及PV操作解决方案...