深度学习分类步骤——鸢尾花分类
1.数据集的介绍
以鸢尾花数据集为例,共有150组,每组包括花萼长、花萼宽、花瓣长、花瓣宽4个输入特征。同时给出了这一组特征对应的鸢尾花的类别。类别包括狗尾草鸢尾、杂色鸢尾以及弗吉尼亚鸢尾,分别用0,1,2表示。
数据集读入:从sklearn包datasets读入数据集,如下:
from sklearn.datasets import load_irisx_data = load_iris().data # 返回iris数据集所有输入特征
y_data = load_iris().target # 返回iris数据集所有标签
from sklearn.datasets import load_iris
from pandas import DataFrame
import pandas as pdx_data = load_iris().data # 返回iris数据集所有输入特征
y_data = load_iris().target # 返回iris数据集所有标签
print("x_data from datasets:", x_data)
print("y_data from datasets", y_data)x_data = DataFrame(x_data, columns=['花萼长', '花萼宽', '花瓣长', '花瓣宽'])
pd.set_option('display.unicode.east_asian_width', True) # 设置列名对齐
print(x_data)x_data['类别'] = y_data # 新加一列,列标签‘类别’,数据为y_data
print("x_data add a column: \n", x_data)
2.鸢尾花分类
- 准备数据
- 数据集读入
- 数据集乱序
- 生成永不相见的训练集和测试集(即x_train / y_train, x_test / y_test)
- 配成(输入特征,标签)对每次读入一小撮(batch)
2.搭建网络
定义神经网络中所有可训练参数
3. 参数优化
嵌套循环迭代,with结构更新参数,显示当前loss
4. 测试效果
计算房前参数前向传播后的准确率,显示当前acc
5.acc / loss可视化
3.鸢尾花分类案例
完整程序:
from sklearn.datasets import load_iris
from pandas import DataFrame
import pandas as pd
import numpy as np
import tensorflow as tf
from matplotlib import pyplot as pltimport os
import PySide2
dirname = os.path.dirname(PySide2.__file__)
plugin_path = os.path.join(dirname, 'plugins', 'platforms')
os.environ['QT_QPA_PLATFORM_PLUGIN_PATH'] = plugin_path# 定义超参数和画图用的两个存数据的空列表
lr = 0.1
train_loss_results = [] # 将每轮的loss记录在此列表中,为后续画loss曲线提供数据
test_acc = [] # 将每轮的acc记录在此列表中,为后续画acc曲线提供数据
epoch = 300
loss_all = 0 # 每轮分为4个step(因为一共有120个训练数据,每个batch有32个样本,所以epoch迭代一次120个数据需要4个batch),loss_all记录四个step生成的4个loss的和# ____________________________数据准备______________________________
# 1.数据集的读入
x_data = load_iris().data # 返回iris数据集所有输入特征
y_data = load_iris().target # 返回iris数据集所有标签
# print("x_data from datasets:", x_data)
# print("y_data from datasets", y_data)# 2.数据集乱序
np.random.seed(116) # 使用相同的种子seed,使得乱序后的数据特征和标签仍然可以对齐
np.random.shuffle(x_data) # 打乱数据集
np.random.seed(116)
np.random.shuffle(y_data)
tf.random.set_seed(116)# 3.数据集分出永不相见的训练集和测试集
x_train = x_data[:-30] # 前120个数据作为训练集
y_train = y_data[:-30] # 前120个标签作为训练集标签
x_test = x_data[-30:] # 后30个数据集作为测试集
y_test = y_data[-30:]# 转换x的数据类型,否则后面矩阵相乘时会因为数据类型不一致报错
x_train = tf.cast(x_train, tf.float32)
x_test = tf.cast(x_test, tf.float32)# 配成【输入特征, 标签】对,每次喂入一小撮(batch)(把数据集分为批次,每批次32组数据)
train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(32)
test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)# ____________________________定义神经网络______________________________
w1 = tf.Variable(tf.random.truncated_normal([4, 3], stddev=0.1, seed=1)) # 4表示输入的4的特征,3表示3分类
b1 = tf.Variable(tf.random.truncated_normal([3], stddev=0.1, seed=1)) # 3表示3分类# ____________________________训练部分:嵌套循环迭代_______________________
for epoch in range(epoch): # 数据集级别迭代for step, (x_train, y_train) in enumerate(train_db): # batch级别迭代with tf.GradientTape() as tape: # 在with结构中计算前向传播y以及计算总损失lossy = tf.matmul(x_train, w1) + b1 # 神经网络乘加运算y = tf.nn.softmax(y) # 使输出y符合概率分布(此操作后与独热码同量级,可以相减求loss)y_ = tf.one_hot(y_train, depth=3) # 将标签值转换为独热码格式,方便计算loss和accloss = tf.reduce_mean(tf.square(y_ - y)) # 采用均值方差损失函数MSEloss_all += loss.numpy() # 将每个step计算出loss累加,为后续求loss平均值提供数据# 计算loss对各个参数的梯度grads = tape.gradient(loss, [w1, b1]) # 损失函数loss分别对参数w1和b1计算偏导数# 实现梯度更新 w1 = w1 - lr * w1_grad b = b - lr * b_gradw1.assign_sub(lr * grads[0]) # 参数w1自更新b1.assign_sub(lr * grads[1]) # 参数b1自更新# 求出每个epoch的平均损失print("Epoch {}, loss:{}".format(epoch, loss_all/4))train_loss_results.append(loss_all / 4) # 将4个step的loss求平均记录在此变量中loss_all = 0 # loss_all归零为记录下一个epoch的loss做准备# ____________________________测试部分:识别准确率______________________________total_correct, total_number = 0, 0for x_test, y_test in test_db:y = tf.matmul(x_test, w1) + b1 # y为预测结果y = tf.nn.softmax(y) # y符合概率分布pred = tf.argmax(y, axis=1) # 返回y中最大值的索引,即预测的分类pred = tf.cast(pred, dtype=y_test.dtype) # 调整数据类型与标签一致,即为把pred预测值转换为y_test数据类型correct = tf.cast(tf.equal(pred, y_test), dtype=tf.int32) # 如果真实值与预测值相同,就正确correct = tf.reduce_sum(correct) # 将每个batch的correct加起来total_correct += int(correct) # 将所有batch中的correct数加起来total_number += x_test.shape[0]# 总的准确率等于total_correct / total_numberacc = total_correct / total_numbertest_acc.append(acc)print("test_acc", acc)print("__________________________")# ____________________________acc / loss 可视化___________________________
# 绘制loss曲线
plt.title("Loss Curve")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.plot(train_loss_results, label="$Loss$") #逐点画出test_acc值并连线
plt.legend()
plt.show()# 绘制Accuracy曲线
plt.title("Acc Curve")
plt.xlabel("Epoch")
plt.ylabel("Acc")
plt.plot(test_acc, label="$Accuracy$") #逐点画出test_acc值并连线
plt.legend()
plt.show()
深度学习分类步骤——鸢尾花分类相关推荐
- 深度学习经典应用——鸢尾花分类(基于sklearn包)+三维成果可视化
文章目录 任务描述 数据集 Step1. 数据准备 Step2. 模型搭建 Step3.模型训练 Step4. 模型评估 Step5. 模型使用 3D可视化模型 任务描述 构建一个模型,根据鸢尾花的花 ...
- ECG分析:基于深度学习的ECG心律失常分类入门(3)
ECG分析:基于深度学习的ECG心律失常分类入门(3) 数据库的Python读取 本次读取数据,用的是一款专门读取MITAB数据的工具--WFDB-python,WFDB包下载 ,全称是 Python ...
- ECG分析:基于深度学习的ECG心律失常分类入门(4)
ECG分析:基于深度学习的ECG心律失常分类入门(4) 在搭建模型之前,讲一下本次任务需要区分的类别,MITAB根据心拍类型划分了14个小类: 也可以用wfdb查看: wfdb.show_ann_la ...
- 【信号识别】基于matlab深度学习CNN信号调制分类【含Matlab源码 2066期】
⛄一.深度学习CNN信号调制分类概述 1 背景介绍 在通信信号处理领域, 特别是在非协作通信信号盲解调研究领域, 每时隙突发信号的调制方式不同, 必须进行信号的调制方式自动识别.信号的调制方式识别效果 ...
- Data Augmentation for Deep Learning-based Radio ModulationClassification解读(基于深度学习的无线电调制分类数据扩充)
摘要:深度学习最近被应用于自动分类接收无线电信号的调制类别,而无需人工经验.然而,训练深度学习模型需要大量的数据.训练数据不足会导致严重的过度拟合问题,降低分类精度.为了处理小数据集,数据增强被广泛应 ...
- 三维深度学习中的目标分类与语义分割
(转载的文章,仅供学习,如有侵权请随时联系删帖) 在过去的几年中,基于RGB的深度学习已经在目标分类与语义分割方面取得了非常好的效果,也促进了很多技术的发展,深度学习在现实生活中的应用也越来越多.但是 ...
- 利用深度学习(Keras)进行癫痫分类-Python案例
目录 癫痫介绍 数据集 Keras深度学习案例 本分享为脑机学习者Rose整理发表于公众号:脑机接口社区 QQ交流群:903290195 癫痫介绍 癫痫,即俗称"羊癫风",是由多种 ...
- 用深度学习解决大规模文本分类问题
用深度学习解决大规模文本分类问题 人工智能头条 2017-03-27 22:14:22 淘宝 阅读(228) 评论(0) 声明:本文由入驻搜狐公众平台的作者撰写,除搜狐官方账号外,观点仅代表作者 ...
- 【NLP】相当全面:各种深度学习模型在文本分类任务上的应用
论文标题:Deep Learning Based Text Classification:A Comprehensive Review 论文链接:https://arxiv.org/pdf/2004. ...
最新文章
- 山景智能创始人黄勇:银行要从数据智能转向业务智能,今天的金融服务难以支撑未来 | MEET2021...
- Git详解之五 分布式Git
- 嵌入式SQL程序的VC+SQL server 2000实现的环境配置
- c#中的奇异递归模式
- 2018上半年软件设计师上午题参考答案
- python 判断该地址 文件创建时间2020年10月14日14时25分32秒 文件最后一次访问时间 文件最后一次修改时间
- Tomcat启动log:SLF4J: Class path contains multiple SLF4J bindings.
- 聊天窗口,新加的内容直接 往上顶
- Hadoop2 自己动手编译Hadoop的eclipse插件
- SpingMVC 注解@RequestMapping、@SuppressWarnings、@Scheduled 定时器
- 一分钟搞懂 分布式与集群
- Python 源代码代码打包成 whl 文件
- [常用工具] 搜索引擎的常用技巧总结
- 大量数据表的优化方案
- 将一个数组中的值按逆序重新存放。例如,原来顺序为8,6,5,4,1。要求改为1,4,5,6,8
- 软件工程——实体关系图 + 状态转换图 + 数据流图
- 窗口置顶工具v2.1.0
- linux学习篇 之 复制、黏贴、删除、撤销
- django项目中实现excel表数据导入
- git操作总结(1):常用操作流程之SSH、上传、下载和改名字
热门文章
- 上采样和下采样_首次采样带回 嫦娥五号为什么是中国航天史上最复杂任务?...
- C/C++ sleep函数使用方法
- Oracle编码格式为US7ASCII中文乱码如何解决
- spring框架漏洞整理(Spring Data漏洞)
- 服务器系统增加蓝牙功能,技术指导:怎么搭建简易蓝牙定位系统
- 天津公务员 计算机水平,天津公务员考试报考这6类职位,上岸几率更大!
- Adobe Acrobat Reader也能用书签了
- printf花式输出
- 流媒体网络协议 -- RTSP
- 使用Python的requests库爬取网页表情包