1鸢尾花数据
回顾鸢尾花数据集,其提供了150组鸢尾花数据,每组包括鸢尾花的花萼长、花萼宽、花瓣长、花瓣宽 4个输入特征,同时还给出了这一组特征对应的鸢尾花类别。类别包括狗尾鸢尾、杂色鸢尾、弗吉尼亚鸢尾三类, 分别用数字0、1、2表示。可以使用sklearn来导入其数据。
2. 流程分析
① 准备数据。数据集读入;数据集乱序;生成训练集和测试集;配对。
② 搭建网络。定义神经网络中所有可训练的参数。
③ 参数优化。嵌套循环迭代,with结构更新参数,显示当前loss。
④ 测试效果。计算当前参数前向传播后的准确率,显示当前acc
3. 代码实现(手撕神经网络)

import tensorflow as tf
import numpy as np
import pandas as pd
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from matplotlib import pyplot as plt
def demo16():"""神经网络实现鸢尾花分类:return:"""# 数据集读入data_iris = load_iris()# 获取鸢尾花数据集的特征矩阵x_data = data_iris.data# 获取鸢尾花数据集的目标值y_data = data_iris.target# 数据集乱序np.random.seed(116)  # 使用相同的seed,使特征/标签一一对应。np.random.shuffle(x_data)np.random.seed(116)np.random.shuffle(y_data)tf.random.set_seed(116)# 分出训练集和测试集# x_train, x_test, y_train, y_test = train_test_split(x_data,y_data,test_size=0.2,random_state=116)x_train = x_data[:-30]y_train = y_data[:-30]x_test = x_data[-30:]y_test = y_data[-30:]# 转换x的数据类型,否则后面矩阵相乘会因数据类型不一致而报错x_train = tf.cast(x_train, tf.float32)x_test = tf.cast(x_test, tf.float32)# 特征值和目标值配对  每次喂入一小batchtrain_db = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(32)test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)# 生成神经网络的参数,4个输入特征,故输入层为4个输入节点;因为3分类,故输出层为3个神经元w1 = tf.Variable(tf.random.truncated_normal([4, 3], stddev=0.1, seed=1))b1 = tf.Variable(tf.random.truncated_normal([3], stddev=0.1, seed=1))lr = 0.1  # 学习率为0.1train_loss_results = []  # 将每轮的loss记录在此列表中,方便后续画图test_acc = []  # 将每轮的acc记录在此列表中,方便后续画图epoch = 500  # 循环500轮loss_all = 0  # 每轮分4个step,loss_all记录着4个step生成的4个loss的和# 训练部分for epoch in range(epoch):  # 数据集级别的循环for step, (x_train, y_train) in enumerate(train_db):  # batch级别的循环with tf.GradientTape() as tape:  # 记录梯度信息y = tf.matmul(x_train, w1) + b1  # 神经网络的乘加运算y = tf.nn.softmax(y)  # 使y符合概率分布y_ = tf.one_hot(y_train, depth=3)loss = tf.reduce_mean(tf.square(tf.subtract(y_, y)))  # 采用均方差损失函数loss_all += loss.numpy()  # 将每一步计算的loss累加# 计算loss对各个参数的梯度grads = tape.gradient(loss, [w1, b1])# 实现梯度更新w1.assign_sub(lr * grads[0])b1.assign_sub(lr * grads[1])# 每个epoch,打印loss信息print("Epoch{},loss:{}".format(epoch, loss_all / 4))train_loss_results.append(loss_all / 4)  # 将4step的loss求的平均记录在此变量中loss_all = 0  # 归零# 测试部分total_corrent, total_number = 0, 0for x_test, y_test in test_db:  # 遍历batch# 使用更新后的参数进行预测y = tf.matmul(x_test, w1) + b1y = tf.nn.softmax(y)pred = tf.argmax(y, axis=1)  # 返回y中最大值的索引,即预测的分类# 将pred转换为y_test数据类型pred = tf.cast(pred, dtype=y_test.dtype)correct = tf.cast(tf.equal(pred, y_test), dtype=tf.int32)# 将每个batch的correct数加起来correct = tf.reduce_sum(correct)total_corrent += int(correct)total_number += x_test.shape[0]acc = total_corrent / total_numbertest_acc.append(acc)print("Test_acc", acc)print("-----------------------------")# 绘制loss曲线plt.title("loss function curve")plt.xlabel("epoch")plt.ylabel("loss")plt.plot(train_loss_results, label="$loss$")plt.legend()plt.show()# 绘制acc曲线plt.title("acc curve")plt.xlabel("epoch")plt.ylabel("acc")plt.plot(test_acc, label="$acc$")plt.legend()plt.show()if __name__ == '__main__':demo16()

结果如下:

准确率如图所示:

损失函数值如下图所示:

神经网络实现鸢尾花分类(Tensorflow2.0)相关推荐

  1. Tensorflow2.x框架-神经网络实现鸢尾花分类

    神经网络实现鸢尾花分类 一.数据准备     1.数据集读入 2.数据集乱序 3.生成训练集和测试集(即 x_train / y_train,x_test / y_test) 4.配成(输入特征,标签 ...

  2. [python] 深度学习基础------人工神经网络实现鸢尾花分类(一)

    ​​​​​​​人工神经网络实现鸢尾花分类(一) 人工神经网络实现鸢尾花分类(二) 人工神经网络实现鸢尾花分类(三) 人工神经网络实现鸢尾花分类(四) 人工神经网络实现鸢尾花分类(五) 目录 人工智能主 ...

  3. 神经网络实现鸢尾花分类

    神经网络实现鸢尾花分类 一.数据集介绍 共有数据150组,每组包括花萼长.花萼宽.花瓣长.花瓣宽4个输入特征. 同时给出了,这一组特征对应的鸢尾花类别.类别包括Setosa Iris(狗尾草 鸢尾), ...

  4. 猿创征文|深度学习基于前馈神经网络完成鸢尾花分类

    大家我是猿童学!这次给大家带来的是基于前馈神经网络完成鸢尾花分类! 在本实验中,我们使用的损失函数为交叉熵损失:优化器为随机梯度下降法:评价指标为准确率. 一.小批量梯度下降法 在梯度下降法中,目标函 ...

  5. 机器学习 | 使用TensorFlow搭建神经网络实现鸢尾花分类

    鸢尾花分类问题是机器学习领域一个非常经典的问题,本文将利用神经网络来实现鸢尾花分类 实验环境:Windows10.TensorFlow2.0.Spyder 参考资料:人工智能实践:TensorFlow ...

  6. Python 基于BP神经网络的鸢尾花分类

    本文用Python实现了BP神经网络分类算法,根据鸢尾花的4个特征,实现3种鸢尾花的分类. 算法参考文章:纯Python实现鸢尾属植物数据集神经网络模型 2020.07.21更新: 增加了分类结果可视 ...

  7. 最简单的单层神经网络实现鸢尾花分类

    一,知识背景 鸢尾花的分类由四个数据定义,分别是花萼长.花萼宽.花瓣长.花瓣宽.我们把这样的一组数据称为是一组特征,根据特征可以分为三类鸢尾花. 二,神经元模型 神经元采用最简单的简化MP(麦卡洛克- ...

  8. 神经网络与深度学习——TensorFlow2.0实战(笔记)(二)(安装TensorFlow2.0)

    创建环境并激活 conda create --name tensorflow2.0 python==3.7 activate tensorflow2.0 安装相关软件包(conda命令或pip命令2选 ...

  9. 神经网络与深度学习理论,tensorflow2.0教程,cnn

    *免责声明: 1\此方法仅提供参考 2\搬了其他博主的操作方法,以贴上路径. 3* 场景一:神经网络与深度学习理论 场景二:tensorflow的安装 场景三:numpy包介绍 场景四:机器学习基础 ...

最新文章

  1. Java-----applet小程序简介
  2. h3c 3600 交换机配置Telnet登录
  3. python下载文件到指定目录-Python获取指定文件夹下的文件名的方法
  4. Windows10安装Ubuntu子系统+docker教程说明
  5. ubuntu下面的git服务器搭建
  6. 十年架构师:我是这样手写Spring的,用300行代码体现优雅之道
  7. FreeEIM 与飞鸽传书的区别
  8. 算法竞赛入门经典题解目录
  9. 如何解决电脑横屏问题
  10. 深度探索c++对象模型(5):ctor、dtor、copy
  11. 数据库学习2 排序检索数据
  12. java broken pipe_java.net.SocketException: Broken pipe问题解决
  13. API入门系列之五 -一个正儿八经的SDK程序
  14. 在贵州大数据峰会上,马云再次语出惊人!
  15. 安装TortoiseGit后别忘了这一步
  16. 计算数的三次方根(Java)
  17. 第一章、安装、登录CentOS7
  18. 企业如何做好业务监控​?
  19. 无法配置在此计算机的硬件上运行6,“Windows安装程序无法将Windows配置未在此计算机的硬件上运行”解决方案 | 秋收稻田...
  20. 实施CMMI具体要做什么——点评

热门文章

  1. mv单位是什么意思_ayawawa经常说的pu MV是什么意思 怎么mv是什么意思算
  2. 修改STM32CuBeMX生成文件
  3. mysql8 2058_SQLyog连接MySQL8.0报2058错误的解决方案
  4. 微分恒等式(助于找到均值、方差和其他矩)
  5. Unity3D 对于在VR中普通摄像头和VR摄像头同时存在——分屏
  6. java 累加函数_请你编写一个方法(函数),功能要求从参数x累加到y,并返回累加后的整数结果。...
  7. angularjs1-3,工具方法,bootstrap,多个module,引入jquery
  8. 优化屏蔽广告.提高浏览体验
  9. 当使用广告拦截器时,有些页面无法查看,应该怎样解决?
  10. 贝壳找房APP安装包瘦身