Tensorflow2.x.x最基础的神经网络(ANN)
Tensorflow2.x.x最基础的神经网络(ANN)
本章节主要使用Tensorflow2.x.x来搭建ANN神经网络。
ANN原理
这里直接放上小伙伴ANN的原理博客~
实现
使用ANN实现对MNIST数据集的分类。
import tensorflow as tf
# mnist数据集
from tensorflow.keras.datasets import mnist
# Adam优化器
from tensorflow.keras.optimizers import Adam
# 交叉熵损失函数,一般用于多分类
from tensorflow.keras.losses import CategoricalCrossentropy
# 模型和网络层
from tensorflow.keras import Model, layers# 批次大小
BATCH_SIZE = 128
# 迭代次数
EPOCHS = 10
# 加载mnist的训练、测试数据集
train, test = mnist.load_data()
# 数据集的预处理
@tf.function
def preprocess(x, y):# 将x一维数据转为3维灰度图x = tf.reshape(x, [28, 28, 1])# 将x的范围由[0, 255]为[0, 1]x = tf.image.convert_image_dtype(x, tf.float32)# 将y数字标签进行独热编码y = tf.one_hot(y, 10)# 返回处理后的x和yreturn x, y# 使用Dataset来减少内存的使用
train = tf.data.Dataset.from_tensor_slices(train)
# 对数据进行预处理并且给定BATCH_SIZE
train = train.map(preprocess).batch(BATCH_SIZE)# test数据集同理
test = tf.data.Dataset.from_tensor_slices(test)
test = test.map(preprocess).batch(BATCH_SIZE)# 搭建模型(只是其中的一种搭建方式而已)
x = layers.Input(shape=(28, 28, 1)) # 输入为x, 大小为 28*28*1
y = layers.Flatten()(x) # 将高维数据扁平化
y = layers.Dense(1024, activation='relu')(y) # 输出1024个神经元的全网络层
y = layers.Dense(512, activation='relu')(y) # 输出512个神经元的全网络层
y = layers.Dense(256, activation='relu')(y) # 输出256个神经元的全网络层
y = layers.Dense(128, activation='relu')(y) # 输出128个神经元的全网络层
y = layers.Dense(64, activation='relu')(y) # 输出64个神经元的全网络层
y = layers.Dense(32, activation='relu')(y) # 输出32个神经元的全网络层
y = layers.Dense(10, activation='softmax')(y) # 输出10个神经元的全网络层,最后一层使用了softmax进行激活,原因是我们希望提前[0, 1]之间的概率# 创建模型
ann = Model(x, y)
# 编译模型,选择优化器、评估标准、损失函数
ann.compile(optimizer=Adam(), metrics=['acc'], loss=CategoricalCrossentropy())
# 进行模型训练
history = ann.fit(train, epochs=EPOCHS)
# 测试集的评估
score = ann.evaluate(test)
# 打印评估成绩
print('loss: {0}, acc: {1}'.format(score[0], score[1])) # loss: 0.11106619730560828, acc: 0.9769999980926514# 绘制训练过程中每个epoch的loss和acc的折线图
import matplotlib.pyplot as plt
# history对象中有history字典, 字典中存储着“损失”和“评估标准”
epochs = range(EPOCHS)
fig = plt.figure(figsize=(15, 5), dpi=100)ax1 = fig.add_subplot(1, 2, 1)
ax1.plot(epochs, history.history['loss'])
ax1.set_title('loss graph')
ax1.set_xlabel('epochs')
ax1.set_ylabel('loss val')ax2 = fig.add_subplot(1, 2, 2)
ax2.plot(epochs, history.history['acc'])
ax2.set_title('acc graph')
ax2.set_xlabel('epochs')
ax2.set_ylabel('acc val')fig.show()
输出结果如下:
Tensorflow2.x.x最基础的神经网络(ANN)相关推荐
- 人工神经网络ANN建模基础须知
链接文章:机器学习基础须知.神经网络建模实践,其他博文 人工神经网络ANN 0.感知机:包括输入节点.输出节点两部分,输入节点和输出节点用一个表示权重的值连接.感知机的输出值是计算输入节点的加权和,减 ...
- TensorFlow2.0(九)--Keras实现基础卷积神经网络
Keras实现基础卷积神经网络 1. 卷积神经网络基础 2. Keras实现卷积神经网络 2.1 导入相应的库 2.2 数据集的加载与处理 2.3 构建模型 2.4 模型的编译与训练 2.5 学习曲线 ...
- DL之ANN/DNN: 人工神经网络ANN/DNN深度神经网络算法的简介、应用、经典案例之详细攻略
DL之ANN/DNN: 人工神经网络ANN/DNN深度神经网络算法的简介.应用.经典案例之详细攻略 相关文章 DL:深度学习(神经网络)的简介.基础知识(神经元/感知机.训练策略.预测原理).算法分类 ...
- ann matlab,人工神经网络ann及其matlab仿真.ppt
人工神经网络ann及其matlab仿真 人工神经网络 的研究方法及应用刘 长 安2004. 12. 31 引 言 利用机器模仿人类的智能是长期以来人们认识自然.改造自然和认识自身的理想. 研究ANN目 ...
- 人工神经网络 ANN
卷积神经网络CNN图解 本文参考人工神经网络ANN 神经网络是一门重要的机器学习技术.它是深度学习的基础. 神经网络是一种模拟人脑的神经网络以期望能够实现人工智能的机器学习技术.人脑中的神经网络是一个 ...
- 福利 | 从生物学到神经元:人工神经网络 ( ANN ) 简介
文末有数据派THU福利哦 [ 导读 ] 我们从鸟类那里得到启发,学会了飞翔,从牛蒡那里得到启发,发明了魔术贴,还有很多其他的发明都是被自然所启发.这么说来看看大脑的组成,并期望因此而得到启发来构建智能 ...
- Python基于PyTorch实现BP神经网络ANN回归模型项目实战
说明:这是一个机器学习实战项目(附带数据+代码+文档+视频讲解),如需数据+代码+文档+视频讲解可以直接到文章最后获取. 1.项目背景 在人工神经网络的发展历史上,感知机(Multilayer Per ...
- Python实现BP神经网络ANN单隐层回归模型项目实战
说明:这是一个机器学习实战项目(附带数据+代码+文档+视频讲解),如需数据+代码+文档+视频讲解可以直接到文章最后获取. 1.项目背景 20世纪80年代中期,David Runelhart.Geoff ...
- Python实现BP神经网络ANN单隐层分类模型项目实战
说明:这是一个机器学习实战项目(附带数据+代码+文档+视频讲解),如需数据+代码+文档+视频讲解可以直接到文章最后获取. 1.项目背景 BP(back propagation)神经网络是1986年由R ...
最新文章
- Mac是大脑,iPad是四肢 如何实现的呢?右键而已
- principle导出html5,让Principle成为生产力工具(二)单页面中的联动
- 用计算机计算板书,用计算器计算教案板书设计
- MT6580热设计要求
- 【通信原理课程设计】8PSK调制解调技术的设计与仿真(MATLAB)
- STM32F1笔记(二)GPIO输入
- 于变局中开新局!《2021中国SaaS市场研究报告》报告发布
- C/C++ OpenCV读取视频与调用摄像头
- __FILE__, __LINE__, __FUNCTION__
- Hibernate的学习详解(4)
- 简易java电子词典_使用Android简单实现有道电子词典
- 3.Python 进阶知识
- 计算机考研复试题目大全
- MV88DE3010之ffmpeg与m3u8
- 进阶实验5-3.2 新浪微博热门话题 (30 分)
- Python(2)模块和数据类型
- ctf之7z文件爆破
- 正睿OIday4总结
- sybase 珍藏(二)
- 【LTE】Qualcomm LTE Packets log 分析(一)LTE Access Stratum Log Analysis 1_PSS 2_RACH