TensorFlow基础8-实现单层神经网络
记录TensorFlow听课笔记
文章目录
- 记录TensorFlow听课笔记
- 一,神经网络的设计
- 二,实现单层神经网络
一,神经网络的设计
二,实现单层神经网络
导入库
加载数据
数据预处理
设置超参数和显示间隔
设置模型参数初始值
训练模型
结果可视化
#单层神经网络
import tensorflow as tf
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
#下载鸢尾花数据集
TRAIN_URL="http://download.tensorflow.org/data/iris_training.csv"
train_path=tf.keras.utils.get_file(TRAIN_URL.split('/')[-1],TRAIN_URL)
TEST_URL="http://download.tensorflow.org/data/iris_test.csv"
test_path=tf.keras.utils.get_file(TRAIN_URL.split('/')[-1],TEST_URL)
#读取数据集
df_iris_train=pd.read_csv(train_path,header=0)
df_iris_test=pd.read_csv(test_path,header=0)
#将数据集变为np数组
iris_train=np.array(df_iris_train)
iris_test=np.array(df_iris_test)
#取出数据集前四列特征以及第五列标签
x_train=iris_train[:,0:4]
y_train=iris_train[:,4]
x_test=iris_test[:,0:4]
y_test=iris_test[:,4]
#归一化中心化
x_train=x_train-np.mean(x_train,axis=0)
x_test=x_test-np.mean(x_test,axis=0)
#转换数据类型并标签独热编码
X_train=tf.cast(x_train,tf.float32)
Y_train=tf.one_hot(tf.constant(y_train,dtype=tf.int32),3)
X_test=tf.cast(x_test,tf.float32)
Y_test=tf.one_hot(tf.constant(y_test,dtype=tf.int32),3)
#设置超参数和显示间隔
learn_rate=0.5
iter=50
display_step=10
#设置模型参数初始值
np.random.seed(612)
W=tf.Variable(np.random.randn(4,3),dtype=tf.float32)
B=tf.Variable(np.zeros([3]),dtype=tf.float32)
#训练模型
acc_train=[] #训练准确率
acc_test=[]
cce_train=[] #训练交叉熵
cce_test=[]
for i in range(0,iter+1):with tf.GradientTape() as tape:PRED_train=tf.nn.softmax(tf.matmul(X_train,W)+B) Loss_train=tf.reduce_mean(tf.keras.losses.categorical_crossentropy(y_true=Y_train,y_pred=PRED_train))PRED_test=tf.nn.softmax(tf.matmul(X_test,W)+B) Loss_test=tf.reduce_mean(tf.keras.losses.categorical_crossentropy(y_true=Y_test,y_pred=PRED_test))Accuracy_train=tf.reduce_mean(tf.cast(tf.equal(tf.argmax(PRED_train.numpy(),axis=1),y_train),tf.float32))Accuracy_test=tf.reduce_mean(tf.cast(tf.equal(tf.argmax(PRED_test.numpy(),axis=1),y_test),tf.float32))acc_train.append(Accuracy_train)acc_test.append(Accuracy_test)cce_train.append(Loss_train)cce_test.append(Loss_test)grads=tape.gradient(Loss_train,[W,B])W.assign_sub(learn_rate*grads[0])B.assign_sub(learn_rate*grads[1])if i % display_step == 0:print("i:%i, TrainAcc:%f,TrainLoss:%f, TestAcc:%f,TestLoss:%f"%(i,Accuracy_train,Loss_train,Accuracy_test,Loss_test))
#损失函数和准确率可视化
plt.figure(figsize=(10,3))\plt.subplot(121)
plt.plot(cce_train,color="blue",label="train")
plt.plot(cce_test,color="red",label="test")
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.legend()plt.subplot(122)
plt.plot(acc_train,color="blue",label="train")
plt.plot(acc_test,color="red",label="test")
plt.xlabel("Iteration")
plt.ylabel("Accuracy")
plt.legend()plt.show()
TensorFlow基础8-实现单层神经网络相关推荐
- TensorFlow随笔-多分类单层神经网络softmax
#!/usr/bin/env python2 # -*- coding: utf-8 -*-import tensorflow as tf from tensorflow.examples.tutor ...
- 译文 | 与TensorFlow的第一次接触 第四章:单层神经网络
北京 | 深度学习与人工智能研修 12月23-24日 再设经典课程 重温深度学习阅读全文> 正文共7865个字,27张图,预计阅读时间:20分钟. 在前言中,已经提到经常使用深度学习的领域就是模 ...
- 华南理工深度学习与神经网络期末考试_深度学习基础:单层神经网络之线性回归...
3.1 线性回归 线性回归输出是一个连续值,因此适用于回归问题.回归问题在实际中很常见,如预测房屋价格.气温.销售额等连续值的问题.与回归问题不同,分类问题中模型的最终输出是一个离散值.我们所说的图像 ...
- 深度学习基础--SOFTMAX回归(单层神经网络)
深度学习基础–SOFTMAX回归(单层神经网络) 最近在阅读一本书籍–Dive-into-DL-Pytorch(动手学深度学习),链接:https://github.com/newmonkey/Div ...
- 【神经网络与深度学习-TensorFlow实践】-中国大学MOOC课程(八)(TensorFlow基础))
[神经网络与深度学习-TensorFlow实践]-中国大学MOOC课程(八)(TensorFlow基础)) 8 TensorFlow基础 8.1 TensorFlow2.0特性 8.1.1 Tenso ...
- TensorFlow基础之模型建立与训练:线性回归、MLP多层感知机、卷积神经网络
TensorFlow基础之模型建立与训练 模型建立与训练:简单的线性回归 MLP多层感知机 数据获取.预处理 模型搭建 训练与评估 卷积神经网络 高效建模 Keras Sequential高效建模 F ...
- 飞桨深度学习零基础入门(一)——使用飞桨(Paddle)单层神经网络预测波士顿房价
系列文章往期回顾 飞桨深度学习零基础入门(序)--Python实现梯度下降 使用飞桨(Paddle)构建单层神经网络 系列文章往期回顾 一.导入相关依赖包 二.构建单层神经网络回归类 三.设置参数 四 ...
- MOOC网神经网络与深度学习TensorFlow实践3——数字图像处理、TensorFlow基础
数字图像处理 数字图像基本概念 pillow图像处理库 手写数字数据集MNIST TensorFlow基础 TensorFlow2.0特性 创建张量 维度变换 部分采样 张量运算
- TensorFlow基础剖析
TensorFlow基础剖析 一.概述 TensorFlow 是一个使用数据流图 (Dataflow Graph) 表达数值计算的开源软件库.它使用节点表示抽象的数学计算,并使用 OP 表达计算的逻辑 ...
最新文章
- 31 天重构学习笔记28. 为布尔方法命名
- centos7安装redis的正确姿势
- 【Java进阶】云存储-创建子模块作为第三方整合模块
- python字符编码使用_python – Numpy字符串编码
- python的三个特性_Python3.9的7个特性
- 360私有云平台Elasticsearch服务初探
- xps15u盘装linux,Dell XPS 15 9560 安装 Ubuntu 18.04
- index mysql_mysql 原理~ index的详解
- python3爬取青年文摘999篇精选文章
- 说说Android的广播(4) - 前台队列为什么比后台队列快?
- IDEA快捷键之搜索查询
- Openwrt安装transmission离线下载
- YDUI Touch InfiniteScroll无限加载数据测试
- Deepin和Windows10双系统,如何修改默认启动项
- php随机一句话,PHP简单实现一言 / 随机一句功能
- 节假日读取接口_节假日API接口,2018年,直接计算好的
- 【C#】分享一个可携带附加消息的增强消息框MessageBoxEx
- FFmpeg c++ 报错合集
- ㉓AW-H3 Linux驱动开发之mipi camera(CSI)驱动程序
- 瑞吉外卖第一篇(1):搭建环境之创建数据库