学习目标:

TensorFlow由谷歌开源的机器学习框架,其对常见机器学习算法的包装性好,“开箱即用”,让开发者能够轻松地构建和部署各类机器学习模型,并可直接用于生产系统。TensorBoard是TensorFlow配套的一个可视化工具,它可以用来展示网络图、指标变化、参数分布情况等。特别是在训练网络的时候,我们可以设置不同的参数(比如:权重、偏置、卷积层数、全连接层数等),使用TensorBoard可以很直观的帮我们进行参数的选择。它通过运行一个本地服务器,监听6006端口。在浏览器发出请求时,分析训练时记录的数据,绘制训练过程中的图像。 本文将进行一个简单的使用演示。搭建一个全连接层的机器学习模型,来预测图片的分类。 本次演示的数据集采用Fashion-MNIST,Fashion-MNIST包含了10个类别的图像,分别是:t-shirt(T恤),trouser(牛仔裤),pullover(套衫),dress(裙子),coat(外套),sandal(凉鞋),shirt(衬衫),sneaker(运动鞋),bag(包),ankle boot(短靴)。

了解TensorFlow 2 模型构建方法,掌握keras。同时结合tensorboard图像化展示,进而进行模型的调优等。


模型目标:预测图片分类

预测任务:
预测像素点为(28,28)的灰度照片的分类。
数据集为 tf.keras.datasets.fashion_mnist的分类数据


学习内容1:构建模型

构建方法:采用keras中的layer,一层层堆叠,然后compile
代码:

import tensorflow as tffashion_mnist = tf.keras.datasets.fashion_mnist
(train_images,train_labels),(test_images,test_labels) = fashion_mnist.load_data()print('训练数据的输入特征维度是:',train_images.shape)
print('训练数据的标签维度是:',train_labels.shape)
print('测试数据的输入特征维度是:',test_images.shape)
print('测试数据的标签维度是:',test_labels.shape)class_names = ['T_shirt(T恤)','Trouser(裤子)','Pullover(套衫)','Dress(裙子)','Coat(外套)','Sandal(凉鞋)','Shirt(汗衫)','Sneaker(运动鞋)','Bag(包)','Ankle_boot(踝靴)']import matplotlib.pyplot as plt
plt.figure(figsize=(10,10))
i = 0
for (image,label) in zip(test_images,test_labels):# image = image.reshape((28,28))plt.subplot(5,5,i+1)plt.xticks([])plt.yticks([])plt.grid(False)plt.imshow(image,cmap=plt.cm.binary)plt.rcParams['font.sans-serif'] = ['SimHei']plt.rcParams['axes.unicode_minus'] = Falseplt.xlabel(class_names[label])i+=1if i==25:break
plt.show()input_xs = tf.keras.Input(shape=(28,28))
flat = tf.keras.layers.Flatten()(input_xs)dense_1 = tf.keras.layers.Dense(256,activation='relu',name='dense_1')(flat)
dense_2 = tf.keras.layers.Dense(128,activation='relu',name='dense_2')(dense_1)
dense_3 = tf.keras.layers.Dense(512,activation='relu',name='dense_3')(dense_2)
logits = tf.keras.layers.Dense(10,activation='softmax',name='predict')(dense_3)model = tf.keras.Model(inputs = input_xs,outputs=logits)
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy','mse'])
print(model.summary())
tf.keras.utils.plot_model(model)
tensorboard = tf.keras.callbacks.TensorBoard(histogram_freq=1)
model.fit(x=train_images,y=train_labels,epochs=500,batch_size=128,callbacks=[tensorboard])
score = model.evaluate(x=test_images,y=test_labels)
print('last loss is :' + str(score[0]))
print('last accuracy is :' + str(score[1]))
print('last mse is :' + str(score[2]))

运行结果:

可以看到采用3层全连接层,最终的准确度达到88.32%


tensorboard图像展示:

根据上述的内容,产生的tensorboard如下:
模型可以看到了3个全连接层。

损失函数随着迭代次数的减小情况:

准确度:


对比CNN模型:

CNN模型构建如下:

import tensorflow as tffashion_mnist = tf.keras.datasets.fashion_mnist
(train_images,train_labels),(test_images,test_labels) = fashion_mnist.load_data()print('训练数据的输入特征维度是:',train_images.shape)
print('训练数据的标签维度是:',train_labels.shape)
print('测试数据的输入特征维度是:',test_images.shape)
print('测试数据的标签维度是:',test_labels.shape)class_names = ['T_shirt(T恤)','Trouser(裤子)','Pullover(套衫)','Dress(裙子)','Coat(外套)','Sandal(凉鞋)','Shirt(汗衫)','Sneaker(运动鞋)','Bag(包)','Ankle_boot(踝靴)']train_images = tf.expand_dims(train_images,axis=3)
test_images = tf.expand_dims(test_images,axis=3)
print('修正后训练数据的输入特征维度是:',train_images.shape)
print('修正后测试数据的输入特征维度是:',test_images.shape)input_xs = tf.keras.Input(shape=(28,28,1))
cov1 = tf.keras.layers.Conv2D(32,kernel_size=(3,3),activation='relu',padding='SAME')(input_xs)
pool1 = tf.keras.layers.MaxPool2D(strides=[2,2])(cov1)
norm1 = tf.keras.layers.BatchNormalization()(pool1)cov2 = tf.keras.layers.Conv2D(64,kernel_size=(3,3),activation='relu',padding='SAME')(norm1)
pool2 = tf.keras.layers.MaxPool2D(strides=[2,2])(cov2)
norm2 = tf.keras.layers.BatchNormalization()(pool2)cov3 = tf.keras.layers.Conv2D(128,kernel_size=(3,3),activation='relu',padding='SAME')(norm2)
pool3 = tf.keras.layers.MaxPool2D(strides=[2,2])(cov3)
norm3 = tf.keras.layers.BatchNormalization()(pool3)flatten = tf.keras.layers.Flatten()(norm3)
dense_1 = tf.keras.layers.Dense(32,activation='relu',name='dense_1')(flatten)
dense_2 = tf.keras.layers.Dropout(rate=0.2)(dense_1)
dense_2 = tf.keras.layers.Dense(64,activation='relu',name='dense_2')(dense_1)
dense_2 = tf.keras.layers.Dropout(rate=0.5)(dense_2)
logits = tf.keras.layers.Dense(10,activation='softmax',name='predict')(dense_2)model = tf.keras.Model(inputs = input_xs,outputs=logits)
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy','mse'])
print(model.summary())
tf.keras.utils.plot_model(model)tensorboard = tf.keras.callbacks.TensorBoard(histogram_freq=1)model.fit(x=train_images,y=train_labels,epochs=50,batch_size=256,verbose=2,callbacks=[tensorboard])
score = model.evaluate(x=test_images,y=test_labels)
print('last loss is :' + str(score[0]))
print('last accuracy is :' + str(score[1]))
print('last mse is :' + str(score[2]))

tensorboard图像展示:

根据上述的内容,产生的tensorboard如下:

模型可以看到了3个卷积层和2个全连接层。

运行结果如下:
损失函数随着迭代次数的减小情况:

精确度:


准确率达到91.53%。

从tensorboard的显示的趋势走向来看,模型是有效的,包括学习率的选取等。
而且,CNN比传统的全连接神经网络的准确度高3%。简直是一次重大的突破。

作为对比,我们训练一个异常模型,也就是某些参数设置不合理,导致模型无法收敛的。对比看模型参数的变化。

我们可以看到模型失效时,accuracy的精度在随机跳动,不收敛,模型参数也没有得到有效的更新。
本文只是简单的演示TensorBoard的使用方法,大概介绍其原理和初步印象,如果要深入研究,还有很多可以研究的点,总之,有了TensorBoard可以帮我更加直观的看清模型结构,判断模型训练参数是否有效,并进行可视化展示。

TensorFlow2快速模型构建及tensorboard初体验相关推荐

  1. Vue快速上手笔记1 - 使用初体验

    Vue快速上手笔记1 - 使用初体验 博主:李俊才 邮箱:291148484@163.com 若本文中存在的错误请告知博主更正 希望对大家有所帮助 专题目录:https://blog.csdn.net ...

  2. Tensorflow2.0模型构建与训练

    模型构建 class Encoder(layers.Layer):def __init__(self, latent_dim=32, intermediate_dim=64, name="e ...

  3. APICloud入门初体验

    APP快速开发平台APICloud之初体验 借助APICloud开发平台,使用web开发技术制作的APP,和我们使用原生开发的APP,几乎没有任何区别,这是因为除了使用HTML.JS这些web开发技术 ...

  4. 我的Go+语言初体验——Go+语言构建神经网络实战手写数字识别

    "我的Go+语言初体验" | 征文活动进行中- 我的Go+语言初体验--Go+语言构建神经网络实战手写数字识别 0. 前言 1. 神经网络相关概念 2. 构建神经网络实战手写数字识 ...

  5. 我的Go+语言初体验——初学者的快速入门

    "我的Go+语言初体验" | 征文活动进行中- 一.Go+ 是什么 听说Go+发布已经有几天了,之前有接触过Go.python.php.c等语言,所以看到Go+时感觉并不陌生. 那 ...

  6. 3.2 实战项目二(手工分析错误、错误标签及其修正、快速地构建一个简单的系统(快速原型模型)、训练集与验证集-来源不一致的情况(异源问题)、迁移学习、多任务学习、端到端学习)

    手工分析错误 手工分析错误的大多数是什么 猫猫识别,准确率90%,想提升,就继续猛加材料,猛调优?     --应该先做错误分析,再调优! 把识别出错的100张拿出来, 如果发现50%是"把 ...

  7. Keras与Tensorflow2.0入门(6)模型可视化与tensorboard的使用

    文章目录 1. 前言 1.1 Plot_model 1.2 History 1.3 自定义评估函数 PRF值的计算方法 AUC的计算方法 2. tensorboard 2.1 tensorboard是 ...

  8. 【虚幻引擎UE】UE5 AR初体验之静态动态模型加载

    UE5的AR初体验之静态动态两种模型加载 基于配置好AR环境(参考另一篇文章) 先## 标题简单了解一下它的项目结构 这里的brush就是我们的操作空间范围 官方模板可以实现平面识别,控制对象的旋转和 ...

  9. show-busy-java-threads脚本初体验,快速排查Java的CPU性能问题

    前言 之前写过一篇文章,通过top和jstack命令来排查CPU使用率高的问题,详见:https://blog.csdn.net/yougou_sully/article/details/842624 ...

最新文章

  1. 如何利用【百度地图API】,制作房产酒店地图?(下)——结合自己的数据库...
  2. error LNK2001: 无法解析的外部符号 “void __cdecl cv::cvtColor
  3. 老电脑安装matlab 2018卡不卡,软件装C盘会变卡?Windows系统要怎么用才不卡?
  4. 势头迅猛的儿童手表:恐陷下一个文曲星之地?
  5. 图片压缩但质量不减,这个工具很不错
  6. c oracle日志分析,oracle 日志分析
  7. python画图代码-Python实战小程序利用matplotlib模块画图代码分享
  8. 计算已知经纬度两点的距离_python
  9. PyQt5实现局域网聊天工具
  10. android app 后台运行,安卓APP锁定后台运行的方法
  11. CAXA 数控车编程视频教程 CAXA车床绘图教程
  12. 汽车营销与保险【1】
  13. oracle数据库实例改名,如何修改数据库实例及数据库名
  14. 计算机八进制 算法视频,八进制转二进制计算器
  15. PCB安规设计是怎样的?怎样设计高压电源的安规?ECM设计间距是怎样的,CAF设计间距要求是怎样的?电源的PCB间距设计指南,安规标准有哪些?380V电源安规设计
  16. web邮箱和客户端的区别
  17. 法制晚报记者采访王杰律师就“给女主播“添麻烦”方静亮相 间谍传言不攻自破 ”发表法律评论
  18. 在Windows系统上对hfds中的文件进行操作
  19. Spark 1.6 SparkSQL实践
  20. 使用Nodejs创建一个Web服务器应如何操作?以及路由相关知识了解

热门文章

  1. mysql5.7.12 my.ini文件_MySQL5.7缺少my.ini文件的解决方法
  2. linux oracle 运维_Oracle查询当前的crs/has自启动状态实例教程
  3. 划分VLAN,以及VLAN间通信
  4. creatdep oracle_Oracle数据库自带表
  5. mysql 中文 问号 utf8_[MySql] 设置了UTF8,中文存数据库中仍然出现问号
  6. Android网络请求开源框架retrofit的基本GET用法(2.4版本)
  7. php7.0 yield,PHP7中生成器的新特性 yield-from amp;amp; return-values
  8. php7.1 aes 加密解密,PHP7.1中AES加密解密方法 mcrypt_module_open()替换方案
  9. C#单例---饿汉式和懒汉式
  10. ?Web开发者需要知道的CSS Tricks