文章目录

fashion_mnist相比于mnist数据集,要复杂的多,识别衣服鞋子等物品

import osos.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics# 构造预处理函数
def preprocess(x, y):x = tf.cast(x, dtype=tf.float32) / 255.y = tf.cast(y, dtype=tf.int32)return x, y# x:[60k,28,28] x_test:[10k,28,28]
(x, y), (x_test, y_test) = datasets.fashion_mnist.load_data()
print(x.shape, y.shape)
print(x_test.shape, y_test.shape)batchsz = 128
# 构造数据集
db = tf.data.Dataset.from_tensor_slices((x, y))
# 预处理,只需要传入函数而无需要传入函数的调用方式
db = db.map(preprocess).shuffle(10000).batch(batchsz)db_test = tf.data.Dataset.from_tensor_slices((x_test, y_test))
# 预处理,只需要传入函数而无需要传入函数的调用方式
db_test = db_test.map(preprocess).batch(batchsz)db_iter = iter(db)
sample = next(db_iter)
print('batch:', sample[0].shape, sample[1].shape)# 构建多层网络-最后一层一般不需要激活函数
model = Sequential([layers.Dense(256, activation=tf.nn.relu),  # [b,784] => [b,256]layers.Dense(128, activation=tf.nn.relu),  # [b,256] => [b,128]layers.Dense(64, activation=tf.nn.relu),  # [b,128] => [b,64]layers.Dense(32, activation=tf.nn.relu),  # [b,64] => [b,32]layers.Dense(10),  # [b,32] => [b,10]   330 = 32*10 + 10
])# 构建输入维度
model.build(input_shape=[None, 28 * 28])
model.summary()# 构建优化器
# 更新权值- w = w -lr*grad
optimizer = optimizers.Adam(lr=1e-3)def main():for epoch in range(30):for step, (x, y) in enumerate(db):# x[b,28,28] => x[b,784]x = tf.reshape(x, [-1, 28 * 28])with tf.GradientTape() as tape:# 构建前向传播# [b,784] => [b,10]logits = model(x)y = tf.one_hot(y, depth=10)loss_mse = tf.reduce_mean(tf.losses.MSE(y, logits))loss_ce = tf.losses.categorical_crossentropy(y, logits, from_logits=True)loss_ce = tf.reduce_mean(loss_ce)grads = tape.gradient(loss_ce, model.trainable_variables)# 利用优化器统一原地更新optimizer.apply_gradients(zip(grads, model.trainable_variables))if step % 100 == 0:print(epoch, step, 'loss:', float(loss_ce), float(loss_mse))# testtotal_correct = 0total_num = 0for x, y in db_test:x = tf.reshape(x, [-1, 28 * 28])# [b,784] => [b,10]logits = model(x)prob = tf.nn.softmax(logits, axis=1)# [b,10] => [b]pred = tf.argmax(prob, axis=1)pred = tf.cast(pred, dtype=tf.int32)# pred:[b]  y:[b]correct = tf.equal(pred, y)correct = tf.reduce_sum(tf.cast(correct, dtype=tf.int32))# tensor => numpytotal_correct += int(correct)total_num += x.shape[0]acc = total_correct / total_numprint(epoch, 'test acc:', acc)if __name__ == '__main__':main()
(60000, 28, 28) (60000,)
(10000, 28, 28) (10000,)
batch: (128, 28, 28) (128,)
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
dense (Dense)                multiple                  200960
_________________________________________________________________
dense_1 (Dense)              multiple                  32896
_________________________________________________________________
dense_2 (Dense)              multiple                  8256
_________________________________________________________________
dense_3 (Dense)              multiple                  2080
_________________________________________________________________
dense_4 (Dense)              multiple                  330
=================================================================
Total params: 244,522
Trainable params: 244,522
Non-trainable params: 0
_________________________________________________________________
0 0 loss: 2.303929090499878 0.242684006690979
0 100 loss: 0.5860539674758911 22.32927703857422
0 200 loss: 0.5546977519989014 27.239099502563477
0 300 loss: 0.3866368532180786 25.4809627532959
0 400 loss: 0.3865927457809448 28.153907775878906
0 test acc: 0.8482
1 0 loss: 0.4001665711402893 26.901527404785156
1 100 loss: 0.523040771484375 32.702239990234375
1 200 loss: 0.2962474822998047 37.23288345336914
1 300 loss: 0.5228046774864197 29.088302612304688
1 400 loss: 0.4941650331020355 31.387100219726562
1 test acc: 0.853
2 0 loss: 0.3662468492984772 31.163776397705078
2 100 loss: 0.2790403962135315 37.754085540771484
2 200 loss: 0.3591204583644867 36.052513122558594
2 300 loss: 0.30174970626831055 35.038909912109375
2 400 loss: 0.45119309425354004 33.016700744628906
2 test acc: 0.8629
3 0 loss: 0.3489121198654175 36.01388931274414
3 100 loss: 0.32831019163131714 46.060543060302734
3 200 loss: 0.22170956432819366 41.36198425292969
3 300 loss: 0.284996896982193 34.6197624206543
3 400 loss: 0.3549075722694397 48.20280838012695
3 test acc: 0.8682...29 0 loss: 0.1280389428138733 230.05148315429688
29 100 loss: 0.10448945313692093 262.2647705078125
29 200 loss: 0.11969716101884842 239.57110595703125
29 300 loss: 0.10605543851852417 246.7599639892578
29 400 loss: 0.09774138778448105 235.69046020507812
29 test acc: 0.8915

深度学习2.0-18.随机梯度下降之手写数字问题实战(层)相关推荐

  1. 深度学习(32)随机梯度下降十: 手写数字识别问题(层)

    深度学习(32)随机梯度下降十: 手写数字识别问题(层) 1. 数据集 2. 网络层 3. 网络模型 4. 网络训练 本节将利用前面介绍的多层全连接网络的梯度推导结果,直接利用Python循环计算每一 ...

  2. 深度学习(33)随机梯度下降十一: TensorBoard可视化

    深度学习(33)随机梯度下降十一: TensorBoard可视化 Step1. run listener Step2. build summary Step3.1 fed scalar(监听标量) S ...

  3. 深度学习(31)随机梯度下降九: Himmelblau函数优化实战

    深度学习(31)随机梯度下降九: Himmelblau函数优化实战 1. Himmelblau函数 2. 函数优化实战 1. Himmelblau函数 Himmelblau函数是用来测试后话算法的常用 ...

  4. 深度学习(30)随机梯度下降七: 多层感知机梯度(反向传播算法)

    深度学习(30)随机梯度下降八: 多层感知机梯度(反向传播算法) 1. 多层感知机模型 2. 多层感知机梯度 3. 传播规律小结 tens Recap Chain Rule Multi-output ...

  5. 深度学习(28)随机梯度下降六: 多输出感知机梯度

    深度学习(28)随机梯度下降六: 多输出感知机梯度 1. Multi-output Perceptron 2. Derivative 3. 代码 Perceptron 单输出感知机梯度 ∂E∂wj0= ...

  6. 深度学习(27)随机梯度下降五: 单输出感知机梯度

    深度学习(27)随机梯度下降五: 单输出感知机梯度 1. Perceptrnon with Sigmoid + MSE 2. Derivative 3. 代码 Recap y=XW+by=XW+by= ...

  7. 深度学习(26)随机梯度下降四: 损失函数的梯度

    深度学习(26)随机梯度下降四: 损失函数的梯度 1. Mean Squared Error(MSE) 2. Cross Entropy Loss CrossEntropy 3. Softmax (1 ...

  8. 深度学习(25)随机梯度下降三: 激活函数的梯度

    深度学习(25)随机梯度下降三: 激活函数的梯度 1. Activation Functions 2. Deriative 3. Sigmoid/Logistic (1) Derivative (2) ...

  9. 深度学习(23)随机梯度下降一: 随机梯度下降简介

    深度学习(23)随机梯度下降一: 随机梯度下降简介 1. What's Gradient? 2. What does it mean? 3. How to search? 4. For instanc ...

  10. 深度学习(24)随机梯度下降二: 常见函数的梯度

    深度学习(24)随机梯度下降二: 常见函数的梯度 Common Functions 1. y=xw+by=xw+by=xw+b 2. y=xw2+b2y=xw^2+b^2y=xw2+b2 3. y=x ...

最新文章

  1. python中init和setup有什么区别_python – 为什么setup.py在安装期间运行模块__init__.py?...
  2. mysql5.7 rmp_linux MySQL5.7 rpm安装(转)
  3. JavaScript常用函数
  4. Swift URL含有中文的处理
  5. java nio 追加写文件_Java NIO在文件末尾追加数据
  6. 恒生电子实施怎么样_蓝思科技今年来涨幅超166%,消费电子主题基金如何挑选?...
  7. Java 泛形通配符 ?
  8. 大学数学不好是一种什么体验?
  9. 表单-图片浏览上传-单选框(二)
  10. 电子商务专业实习总结
  11. 如何实现XA式、非XA式Spring分布式事务
  12. const数据成员的初始化
  13. 如何使用飞秋FeiQ实现两电脑通信(或传输文件)
  14. 显示器的 VGA、HDMI、DVI 和DisplayPort接口有什么区别?
  15. 互联网日报 | 华为西南地区首家旗舰店开业;高德打车企业版入驻飞书;马蜂窝发布“北极星攻略”品牌...
  16. 笔记:扩展一个数字的位表示 无符号数的零扩展 补码数的符号扩展
  17. 扫雷游戏 P2670 [NOIP2015 普及组]
  18. 地震勘探原理(四)之频谱分析概述
  19. saiku连mysql 使用_Saiku的基本使用介绍(三)
  20. 关闭Java11中即将移除Nashorn引擎的警告Warning: Nashorn engine is planned to be removed from a future JDK release

热门文章

  1. oracle的热备份和冷备份
  2. [MATLAB]MATLAB中SIMULINK常用命令表
  3. 上周Asp.net源码(11.5-11.10)免费下载列表
  4. Echo团队Alpha冲刺随笔 - 第八天
  5. sql 分页查询 (每次6行 )
  6. J2EE学习笔记-第二章(Web应用初步)
  7. 黑马程序员——java基础---IO(input output)流字符流
  8. ESP8266-01/01S配对阿里云生活物联网教程(超详细)
  9. TX2开发板Ubuntu16.04安装中文输入法
  10. 基准对象object中的基础类型----元组 (五)