一、数据集

import tensorflow as tf
tf.__version__
'2.6.0'
# 导入数据集
mnist = tf.keras.datasets.mnist
(train_images,train_labels),(test_images,test_labels) = mnist.load_data()
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step
11501568/11490434 [==============================] - 0s 0us/step
train_images.shape,test_images.shape,train_labels.shape
((60000, 28, 28), (10000, 28, 28), (60000,))
# 展示图片
import matplotlib.pyplot as pltdef plot_image(image):plt.imshow(image,cmap='binary')plt.show()plot_image(train_images[0])

# 划分数据集
total_num = len(train_images)
split_valid = 0.2
train_num = int((1 - split_valid) * total_num)# 训练集
train_x = train_images[:train_num]
train_y = train_labels[:train_num]
# 验证集
valid_x = train_images[train_num:]
valid_y = train_labels[train_num:]
# 测试集
test_x = test_images
test_y = test_labels
# 数据塑形+归一化
train_x = tf.cast(train_x.reshape(-1,784)/255.0,dtype=tf.float32)
valid_x = tf.cast(valid_x.reshape(-1,784)/255.0,dtype=tf.float32)
test_x = tf.cast(test_x.reshape(-1,784)/255.0,dtype=tf.float32)# 标签进行独热编码
train_y = tf.one_hot(train_y,10)
valid_y = tf.one_hot(valid_y,10)
test_y = tf.one_hot(test_y,10)

二、模型

# 构建模型d
# 定义第一层隐藏层权重和偏执项变量
Input_Dim = 784
H1_NN = 64
W1 = tf.Variable(tf.random.normal(shape=(Input_Dim,H1_NN)),dtype=tf.float32)
B1 = tf.Variable(tf.zeros(H1_NN),dtype=tf.float32)
# 定义输出层权重和偏执项变量
Output_Dim = 10
W2 = tf.Variable(tf.random.normal(shape=(H1_NN,Output_Dim)),dtype=tf.float32)
B2 = tf.Variable(tf.zeros(Output_Dim),dtype=tf.float32)
# 待优化列表
W = [W1,W2]
B = [B1,B2]
# 定义模型的前向计算
def model(w,x,b):x = tf.matmul(x,w[0]) + b[0]x = tf.nn.relu(x)x = tf.matmul(x,w[1]) + b[1]return tf.nn.softmax(x)
# 损失函数
def loss(w,x,y,b):pred = model(w,x,b)loss_ = tf.keras.losses.categorical_crossentropy(y_true=y,y_pred=pred)return tf.reduce_mean(loss_)
# 准确率
def accuracy(w,x,y,b):pred = model(w,x,b)acc = tf.equal(tf.argmax(pred,axis=1),tf.argmax(y,axis=1))return tf.reduce_mean(tf.cast(acc,dtype=tf.float32))
# 计算梯度
def grad(w,x,y,b):with tf.GradientTape() as tape:loss_ = loss(w,x,y,b)return tape.gradient(loss_,[w[0],b[0],w[1],b[1]])

三、训练

# 定义超参数
train_epochs = 20
learning_rate = 0.01
batch_size = 50
total_steps = train_num // batch_size
train_loss_list = []
valid_loss_list = []
trian_acc_list = []
valide_acc_list = []
# 优化器
optimizer = tf.optimizers.Adam(learning_rate=learning_rate)
for epoch in range(train_epochs):for step in range(total_steps):xs = train_x[step*batch_size:(step+1)*batch_size]ys = train_y[step*batch_size:(step+1)*batch_size]grads = grad(W,xs,ys,B)optimizer.apply_gradients(zip(grads,[W[0],B[0],W[1],B[1]]))trian_loss = loss(W,train_x,train_y,B).numpy()valid_loss = loss(W,valid_x,valid_y,B).numpy()train_accuracy = accuracy(W,train_x,train_y,B).numpy()valid_accuracy = accuracy(W,valid_x,valid_y,B).numpy()trian_acc_list.append(train_accuracy)valide_acc_list.append(valid_accuracy)train_loss_list.append(trian_loss)valid_loss_list.append(valid_loss)print(f'{epoch+1}:trian_loss:{trian_loss}valid_loss:{valid_loss}train_accuracy:{train_accuracy}valid_accuracy:{valid_accuracy}')
1:trian_loss:4.090484142303467valid_loss:4.0961079597473145train_accuracy:0.7324583530426025valid_accuracy:0.731083333492279
2:trian_loss:3.873914957046509valid_loss:3.8966963291168213train_accuracy:0.7461875081062317valid_accuracy:0.7425000071525574
3:trian_loss:3.698087215423584valid_loss:3.7547082901000977train_accuracy:0.7597083449363708valid_accuracy:0.7545833587646484
4:trian_loss:2.0992202758789062valid_loss:2.1797149181365967train_accuracy:0.8577708601951599valid_accuracy:0.8530833125114441
5:trian_loss:2.0091030597686768valid_loss:2.1187283992767334train_accuracy:0.8645208477973938valid_accuracy:0.8534166812896729
6:trian_loss:2.05008864402771valid_loss:2.162834405899048train_accuracy:0.8585000038146973valid_accuracy:0.8494166731834412
7:trian_loss:1.9510189294815063valid_loss:2.0553224086761475train_accuracy:0.8664166927337646valid_accuracy:0.8576666712760925
8:trian_loss:1.9326006174087524valid_loss:2.050128221511841train_accuracy:0.8680833578109741valid_accuracy:0.8569999933242798
9:trian_loss:1.9068089723587036valid_loss:2.024397850036621train_accuracy:0.8706041574478149valid_accuracy:0.8599166870117188
10:trian_loss:0.4595804512500763valid_loss:0.5628429651260376train_accuracy:0.9586874842643738valid_accuracy:0.949999988079071
11:trian_loss:0.3590681552886963valid_loss:0.5005843043327332train_accuracy:0.9663333296775818valid_accuracy:0.9556666612625122
12:trian_loss:0.29265761375427246valid_loss:0.46133357286453247train_accuracy:0.9728958606719971valid_accuracy:0.9575833082199097
13:trian_loss:0.3250505030155182valid_loss:0.49780264496803284train_accuracy:0.9699791669845581valid_accuracy:0.9567499756813049
14:trian_loss:0.329074889421463valid_loss:0.4836892783641815train_accuracy:0.9683958292007446valid_accuracy:0.9536666870117188
15:trian_loss:0.2734844386577606valid_loss:0.46817922592163086train_accuracy:0.9743750095367432valid_accuracy:0.9578333497047424
16:trian_loss:0.3187606930732727valid_loss:0.5206401944160461train_accuracy:0.9695624709129333valid_accuracy:0.952750027179718
17:trian_loss:0.23391176760196686valid_loss:0.46213391423225403train_accuracy:0.9774166941642761valid_accuracy:0.9605000019073486
18:trian_loss:0.2218097299337387valid_loss:0.41849949955940247train_accuracy:0.9789999723434448valid_accuracy:0.9635000228881836
19:trian_loss:0.2505856156349182valid_loss:0.45410531759262085train_accuracy:0.9771875143051147valid_accuracy:0.9606666564941406
20:trian_loss:0.2279120683670044valid_loss:0.45335933566093445train_accuracy:0.9788125157356262valid_accuracy:0.9618333578109741
accuracy(W,test_x,test_y,B).numpy()
0.959
# 损失图像
plt.plot(train_loss_list,'r')
plt.plot(valid_loss_list,'b')
[<matplotlib.lines.Line2D at 0x7f78f32a60d0>]

# 准确率图像
plt.plot(trian_acc_list,'r')
plt.plot(valide_acc_list,'b')
[<matplotlib.lines.Line2D at 0x7f78f328cfd0>]

四、预测

def predict(x,w,b):pred = model(w,x,b)pred_ = tf.argmax(pred,axis=1)return pred_
import numpy as np
id = np.random.randint(0,len(test_x)) # 随机生成一个验证id
# 预测值
pred = predict(test_x,W,B)[id]
# 真实值
true = test_labels[id]
print(true,pred.numpy())
1 1
import sklearn.metrics as sm
print(f'r2:{sm.r2_score(test_y,model(W,test_x,B))}')
r2:0.9126431934513113

TensorFlow 从入门到精通(5)—— 多层神经网络与应用相关推荐

  1. tensorflow从入门到精通100讲(七)-TensorFlow房价预估使用Keras快速构建模型

    前言 这篇文章承接上一篇tensorflow从入门到精通100讲(二)-IRIS数据集应用实战 https://wenyusuran.blog.csdn.net/article/details/107 ...

  2. Tensorflow系列 | Tensorflow从入门到精通(二):附代码实战

    作者 | AI小昕 编辑 | 安可 [导读]:本文讲了Tensorflow从入门到精通.欢迎大家点击上方蓝字关注我们的公众号:深度学习与计算机视觉. Tensor介绍 Tensor(张量)是Tenso ...

  3. TensorFlow 从入门到精通(14)—— 初识循环神经网络

    hello,大家好,我又回来了,如约,更新循环神经网络. 最近好像事情变少了,但是状态还是很差.新生班级要展示了,希望51班大哥们能拿个好名次.这篇博客,是用LSTM/RNN来对影评进行分析,这个网络 ...

  4. TensorFlow 从入门到精通(11)—— DeepDream(上)

    这节课,我翘了两天,原因是最近压力比较大. 大家可能对卷积云里雾里,这节课我们就可视化一下卷积层.通过噪声图像起点单层网络单通道/单层网络多通道/多层网络全通道 来生成几幅图像,让大家看一下卷积神经网 ...

  5. TensorFlow官方入门实操课程-卷积神经网络

    知识点 卷积:用原始像素数据与过滤器中的值相乘,以后加起来. 如下是增强水平特征的过滤器. MaxPooling:每次卷积结束以后用一个MaxPooling用来增强图像的特征. 可以看出经过MaxPo ...

  6. PyTorch实战福利从入门到精通之七——卷积神经网络(LeNet)

    卷积神经网络就是含卷积层的网络.介绍一个早期用来识别手写数字图像的卷积神经网络:LeNet [1].这个名字来源于LeNet论文的第一作者Yann LeCun.LeNet展示了通过梯度下降训练卷积神经 ...

  7. tensorflow从入门到精通100讲(三)-谈谈Estimator在Tensorflow中的应用

    前言 正如Tensorflow的官网所示:TensorFlow 提供一个包含多个 API 层的编程堆栈其架构图如下.用户可以任意选择不同级别的API进行自己模型的构建.而本篇文章就最高级别的API-- ...

  8. tensorflow从入门到精通100讲(六)-在TensorFlow Serving/Docker中做keras 模型部署

    前言 不知道大家研究过没有,tensorflow模型有三种保存方式: 训练时我们会一般会将模型保存成:checkpoint文件 为了方便python,C++或者其他语言部署你的模型,你可以将模型保存成 ...

  9. tensorflow从入门到精通100讲(二)-IRIS数据集应用实战

    前言 TensorFlow 2.0 即将问世,很多API该删的删,该改的改.在这篇文章中我就2.0 版本中以下两点更新,为大家做一下预热(注意:博主使用的是tensorflow1.9版). 在tens ...

最新文章

  1. java mysql乐观锁_java乐观锁使用
  2. Bullet 物理引擎 详细分析 Dbvt (4)
  3. 服务器2016修改时间,服务器时间错误
  4. MyBatis的概述及使用
  5. NYoj 蛇形填数
  6. 大家来找茬源码(微擎) -- 流量主
  7. 1GB等于2的多少次方
  8. codelite解决中文乱码问题
  9. matlab if嵌套函数,MATLAB嵌套函数的应用
  10. javascript如何获取html中的控件,Javascript-dom总结(获取页面控件)
  11. CAD 批量打印,输出pdf,plt的工具
  12. 机器学习之决策树实践:隐形眼镜类型预测
  13. python提取每个单词首字母_如何将字符串中每个单词的首字母大写(Python)?
  14. Oracle分区表管理
  15. 腾讯游戏业务竟然是这样利用低代码平台的 | ArchSummit
  16. 微型计算机标致寄存器实验报告,微机原理实验报告(2013).doc
  17. 用友从“新”出发:“新”在哪里?
  18. 移动导入表/导入表注入(注入导入表后EXE无法运行的BUG解决方案)
  19. 内存泄漏 内存溢出 踩内存 malloc底层实现原理
  20. windows10快速打开回收站(Recycle Bin)

热门文章

  1. 音乐APP攻防战:QQ、酷狗、酷我、网易云,谁会成为下一个虾米?
  2. QtCreator里添加外部第三库、头文件路径的方法(.pro文件)
  3. 了解Linux实时内核
  4. vc++ 自定义消息和WM_NOTIFY消息实现
  5. vs2010自定义消息与vc6.0不太一样
  6. (番茄插件)Visual AssistX 安装教程
  7. learning from the Trenches 12-16 用看板管理大型项目
  8. 0.5mm的焊锡丝能吃多大电流_貌似简单无奇的操作步骤 或许就能让你前功尽弃
  9. 数据库审计部署方式有哪些?哪种比较好?
  10. 桌面总是出现计算机内存不足,为什么会出现电脑内存不足?该如何处理?