1.数据集的介绍

以鸢尾花数据集为例,共有150组,每组包括花萼长、花萼宽、花瓣长、花瓣宽4个输入特征。同时给出了这一组特征对应的鸢尾花的类别。类别包括狗尾草鸢尾、杂色鸢尾以及弗吉尼亚鸢尾,分别用0,1,2表示。

数据集读入:从sklearn包datasets读入数据集,如下:

from sklearn.datasets import load_irisx_data = load_iris().data  # 返回iris数据集所有输入特征
y_data = load_iris().target  # 返回iris数据集所有标签
from sklearn.datasets import load_iris
from pandas import DataFrame
import pandas as pdx_data = load_iris().data  # 返回iris数据集所有输入特征
y_data = load_iris().target  # 返回iris数据集所有标签
print("x_data from datasets:", x_data)
print("y_data from datasets", y_data)x_data = DataFrame(x_data, columns=['花萼长', '花萼宽', '花瓣长', '花瓣宽'])
pd.set_option('display.unicode.east_asian_width', True)  # 设置列名对齐
print(x_data)x_data['类别'] = y_data  # 新加一列,列标签‘类别’,数据为y_data
print("x_data add a column: \n", x_data)

2.鸢尾花分类

  1. 准备数据
  • 数据集读入
  • 数据集乱序
  • 生成永不相见的训练集和测试集(即x_train / y_train, x_test / y_test)
  • 配成(输入特征,标签)对每次读入一小撮(batch)

 2.搭建网络

定义神经网络中所有可训练参数

3. 参数优化

嵌套循环迭代,with结构更新参数,显示当前loss

    4. 测试效果

计算房前参数前向传播后的准确率,显示当前acc

    5.acc / loss可视化 

3.鸢尾花分类案例

完整程序:

from sklearn.datasets import load_iris
from pandas import DataFrame
import pandas as pd
import numpy as np
import tensorflow as tf
from matplotlib import pyplot as pltimport os
import PySide2
dirname = os.path.dirname(PySide2.__file__)
plugin_path = os.path.join(dirname, 'plugins', 'platforms')
os.environ['QT_QPA_PLATFORM_PLUGIN_PATH'] = plugin_path# 定义超参数和画图用的两个存数据的空列表
lr = 0.1
train_loss_results = []  # 将每轮的loss记录在此列表中,为后续画loss曲线提供数据
test_acc = []   # 将每轮的acc记录在此列表中,为后续画acc曲线提供数据
epoch = 300
loss_all = 0  # 每轮分为4个step(因为一共有120个训练数据,每个batch有32个样本,所以epoch迭代一次120个数据需要4个batch),loss_all记录四个step生成的4个loss的和# ____________________________数据准备______________________________
# 1.数据集的读入
x_data = load_iris().data  # 返回iris数据集所有输入特征
y_data = load_iris().target  # 返回iris数据集所有标签
# print("x_data from datasets:", x_data)
# print("y_data from datasets", y_data)# 2.数据集乱序
np.random.seed(116)  # 使用相同的种子seed,使得乱序后的数据特征和标签仍然可以对齐
np.random.shuffle(x_data)  # 打乱数据集
np.random.seed(116)
np.random.shuffle(y_data)
tf.random.set_seed(116)# 3.数据集分出永不相见的训练集和测试集
x_train = x_data[:-30]  # 前120个数据作为训练集
y_train = y_data[:-30]  # 前120个标签作为训练集标签
x_test = x_data[-30:]   # 后30个数据集作为测试集
y_test = y_data[-30:]# 转换x的数据类型,否则后面矩阵相乘时会因为数据类型不一致报错
x_train = tf.cast(x_train, tf.float32)
x_test = tf.cast(x_test, tf.float32)# 配成【输入特征, 标签】对,每次喂入一小撮(batch)(把数据集分为批次,每批次32组数据)
train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(32)
test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)# ____________________________定义神经网络______________________________
w1 = tf.Variable(tf.random.truncated_normal([4, 3], stddev=0.1, seed=1))  # 4表示输入的4的特征,3表示3分类
b1 = tf.Variable(tf.random.truncated_normal([3], stddev=0.1, seed=1))   # 3表示3分类# ____________________________训练部分:嵌套循环迭代_______________________
for epoch in range(epoch):  # 数据集级别迭代for step, (x_train, y_train) in enumerate(train_db):  # batch级别迭代with tf.GradientTape() as tape:  # 在with结构中计算前向传播y以及计算总损失lossy = tf.matmul(x_train, w1) + b1   # 神经网络乘加运算y = tf.nn.softmax(y)  # 使输出y符合概率分布(此操作后与独热码同量级,可以相减求loss)y_ = tf.one_hot(y_train, depth=3)  # 将标签值转换为独热码格式,方便计算loss和accloss = tf.reduce_mean(tf.square(y_ - y))  # 采用均值方差损失函数MSEloss_all += loss.numpy()  # 将每个step计算出loss累加,为后续求loss平均值提供数据# 计算loss对各个参数的梯度grads = tape.gradient(loss, [w1, b1])  # 损失函数loss分别对参数w1和b1计算偏导数# 实现梯度更新 w1 = w1 - lr * w1_grad    b = b - lr * b_gradw1.assign_sub(lr * grads[0])  # 参数w1自更新b1.assign_sub(lr * grads[1])  # 参数b1自更新# 求出每个epoch的平均损失print("Epoch {}, loss:{}".format(epoch, loss_all/4))train_loss_results.append(loss_all / 4)  # 将4个step的loss求平均记录在此变量中loss_all = 0  # loss_all归零为记录下一个epoch的loss做准备# ____________________________测试部分:识别准确率______________________________total_correct, total_number = 0, 0for x_test, y_test in test_db:y = tf.matmul(x_test, w1) + b1  # y为预测结果y = tf.nn.softmax(y)  # y符合概率分布pred = tf.argmax(y, axis=1)  # 返回y中最大值的索引,即预测的分类pred = tf.cast(pred, dtype=y_test.dtype)  # 调整数据类型与标签一致,即为把pred预测值转换为y_test数据类型correct = tf.cast(tf.equal(pred, y_test), dtype=tf.int32)  # 如果真实值与预测值相同,就正确correct = tf.reduce_sum(correct)  # 将每个batch的correct加起来total_correct += int(correct)  # 将所有batch中的correct数加起来total_number += x_test.shape[0]# 总的准确率等于total_correct / total_numberacc = total_correct / total_numbertest_acc.append(acc)print("test_acc", acc)print("__________________________")# ____________________________acc / loss 可视化___________________________
# 绘制loss曲线
plt.title("Loss Curve")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.plot(train_loss_results, label="$Loss$")  #逐点画出test_acc值并连线
plt.legend()
plt.show()#  绘制Accuracy曲线
plt.title("Acc Curve")
plt.xlabel("Epoch")
plt.ylabel("Acc")
plt.plot(test_acc, label="$Accuracy$")  #逐点画出test_acc值并连线
plt.legend()
plt.show()

深度学习分类步骤——鸢尾花分类相关推荐

  1. 深度学习经典应用——鸢尾花分类(基于sklearn包)+三维成果可视化

    文章目录 任务描述 数据集 Step1. 数据准备 Step2. 模型搭建 Step3.模型训练 Step4. 模型评估 Step5. 模型使用 3D可视化模型 任务描述 构建一个模型,根据鸢尾花的花 ...

  2. ECG分析:基于深度学习的ECG心律失常分类入门(3)

    ECG分析:基于深度学习的ECG心律失常分类入门(3) 数据库的Python读取 本次读取数据,用的是一款专门读取MITAB数据的工具--WFDB-python,WFDB包下载 ,全称是 Python ...

  3. ECG分析:基于深度学习的ECG心律失常分类入门(4)

    ECG分析:基于深度学习的ECG心律失常分类入门(4) 在搭建模型之前,讲一下本次任务需要区分的类别,MITAB根据心拍类型划分了14个小类: 也可以用wfdb查看: wfdb.show_ann_la ...

  4. 【信号识别】基于matlab深度学习CNN信号调制分类【含Matlab源码 2066期】

    ⛄一.深度学习CNN信号调制分类概述 1 背景介绍 在通信信号处理领域, 特别是在非协作通信信号盲解调研究领域, 每时隙突发信号的调制方式不同, 必须进行信号的调制方式自动识别.信号的调制方式识别效果 ...

  5. Data Augmentation for Deep Learning-based Radio ModulationClassification解读(基于深度学习的无线电调制分类数据扩充)

    摘要:深度学习最近被应用于自动分类接收无线电信号的调制类别,而无需人工经验.然而,训练深度学习模型需要大量的数据.训练数据不足会导致严重的过度拟合问题,降低分类精度.为了处理小数据集,数据增强被广泛应 ...

  6. 三维深度学习中的目标分类与语义分割

    (转载的文章,仅供学习,如有侵权请随时联系删帖) 在过去的几年中,基于RGB的深度学习已经在目标分类与语义分割方面取得了非常好的效果,也促进了很多技术的发展,深度学习在现实生活中的应用也越来越多.但是 ...

  7. 利用深度学习(Keras)进行癫痫分类-Python案例

    目录 癫痫介绍 数据集 Keras深度学习案例 本分享为脑机学习者Rose整理发表于公众号:脑机接口社区 QQ交流群:903290195 癫痫介绍 癫痫,即俗称"羊癫风",是由多种 ...

  8. 用深度学习解决大规模文本分类问题

     用深度学习解决大规模文本分类问题 人工智能头条 2017-03-27 22:14:22 淘宝 阅读(228) 评论(0) 声明:本文由入驻搜狐公众平台的作者撰写,除搜狐官方账号外,观点仅代表作者 ...

  9. 【NLP】相当全面:各种深度学习模型在文本分类任务上的应用

    论文标题:Deep Learning Based Text Classification:A Comprehensive Review 论文链接:https://arxiv.org/pdf/2004. ...

最新文章

  1. 山景智能创始人黄勇:银行要从数据智能转向业务智能,今天的金融服务难以支撑未来 | MEET2021...
  2. Git详解之五 分布式Git
  3. 嵌入式SQL程序的VC+SQL server 2000实现的环境配置
  4. c#中的奇异递归模式
  5. 2018上半年软件设计师上午题参考答案
  6. python 判断该地址 文件创建时间2020年10月14日14时25分32秒 文件最后一次访问时间 文件最后一次修改时间
  7. Tomcat启动log:SLF4J: Class path contains multiple SLF4J bindings.
  8. 聊天窗口,新加的内容直接 往上顶
  9. Hadoop2 自己动手编译Hadoop的eclipse插件
  10. SpingMVC 注解@RequestMapping、@SuppressWarnings、@Scheduled 定时器
  11. 一分钟搞懂 分布式与集群
  12. Python 源代码代码打包成 whl 文件
  13. [常用工具] 搜索引擎的常用技巧总结
  14. 大量数据表的优化方案
  15. 将一个数组中的值按逆序重新存放。例如,原来顺序为8,6,5,4,1。要求改为1,4,5,6,8
  16. 软件工程——实体关系图 + 状态转换图 + 数据流图
  17. 窗口置顶工具v2.1.0
  18. linux学习篇 之 复制、黏贴、删除、撤销
  19. django项目中实现excel表数据导入
  20. git操作总结(1):常用操作流程之SSH、上传、下载和改名字

热门文章

  1. 上采样和下采样_首次采样带回 嫦娥五号为什么是中国航天史上最复杂任务?...
  2. C/C++ sleep函数使用方法
  3. Oracle编码格式为US7ASCII中文乱码如何解决
  4. spring框架漏洞整理(Spring Data漏洞)
  5. 服务器系统增加蓝牙功能,技术指导:怎么搭建简易蓝牙定位系统
  6. 天津公务员 计算机水平,天津公务员考试报考这6类职位,上岸几率更大!
  7. Adobe Acrobat Reader也能用书签了
  8. printf花式输出
  9. 流媒体网络协议 -- RTSP
  10. 使用Python的requests库爬取网页表情包