https://tensorflow.google.cn/tutorials/keras/classification

解决方案

#!usr/bin/env python
# -*- coding:utf-8 _*-
"""
@version: 0.0.1
author: ShenTuZhiGang
@time: 2021/01/25 16:33
@file: 12.py
@function:
@modify:
"""from tensorflow import keras
import tensorflow as tf
import mnist_reader
import numpy as np
import matplotlib.pyplot as plt
from tensorflow import summary
import datetimecurrent_time = str(datetime.datetime.now().timestamp())
train_log_dir = '/content/drive/My Drive/colab notebooks/output/tsboardx/train/' + current_time
test_log_dir = '/content/drive/My Drive/colab notebooks/output/tsboardx/test/' + current_time
val_log_dir = '/content/drive/My Drive/colab notebooks/output/tsboardx/val/' + current_time
train_summary_writer = summary.create_file_writer(train_log_dir)
val_summary_writer = summary.create_file_writer(val_log_dir)
test_summary_writer = summary.create_file_writer(test_log_dir)
(train_images, train_labels), (test_images, test_labels) = mnist_reader.load_data('../data/fashion')
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat','Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
train_images = train_images / 255.0test_images = test_images / 255.0plt.figure(figsize=(10,10))
for i in range(25):plt.subplot(5,5,i+1)plt.xticks([])plt.yticks([])plt.grid(False)plt.imshow(train_images[i], cmap=plt.cm.binary)plt.xlabel(class_names[train_labels[i]])
plt.show()class FashionMnistModel(keras.Model):def __init__(self, **kwargs):super().__init__(**kwargs)self.input_ = keras.layers.Flatten(input_shape=[28, 28])self.hidden1 = keras.layers.Dense(128, activation="relu")self.main_output = keras.layers.Dense(10)def call(self, inputs, **kwargs):input_a = self.input_(inputs)hidden1 = self.hidden1(input_a)output = self.main_output(hidden1)return outputmodel = FashionMnistModel()
model.build(input_shape=(0, 28, 28))
model.summary()
model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])
history = model.fit(train_images, train_labels, epochs=10)
test_loss, test_acc = model.evaluate(test_images,  test_labels, verbose=2)
with test_summary_writer.as_default():summary.scalar('loss', test_loss, 10)summary.scalar('accuracy', test_acc, 10)
print('\nTest accuracy:', test_acc)
probability_model = tf.keras.Sequential([model,tf.keras.layers.Softmax()])
predictions = probability_model.predict(test_images)
print(predictions[0])
print(np.argmax(predictions[0]))
print(test_labels[0])def plot_image(i, predictions_array, true_label, img):predictions_array, true_label, img = predictions_array, true_label[i], img[i]plt.grid(False)plt.xticks([])plt.yticks([])plt.imshow(img, cmap=plt.cm.binary)predicted_label = np.argmax(predictions_array)if predicted_label == true_label:color = 'blue'else:color = 'red'plt.xlabel("{} {:2.0f}% ({})".format(class_names[predicted_label],100*np.max(predictions_array),class_names[true_label]),color=color)def plot_value_array(i, predictions_array, true_label):predictions_array, true_label = predictions_array, true_label[i]plt.grid(False)plt.xticks(range(10))plt.yticks([])thisplot = plt.bar(range(10), predictions_array, color="#777777")plt.ylim([0, 1])predicted_label = np.argmax(predictions_array)thisplot[predicted_label].set_color('red')thisplot[true_label].set_color('blue')i = 0
plt.figure(figsize=(6, 3))
plt.subplot(1, 2, 1)
plot_image(i, predictions[i], test_labels, test_images)
plt.subplot(1, 2, 2)
plot_value_array(i, predictions[i], test_labels)
plt.show()i = 12
plt.figure(figsize=(6,3))
plt.subplot(1,2,1)
plot_image(i, predictions[i], test_labels, test_images)
plt.subplot(1,2,2)
plot_value_array(i, predictions[i],  test_labels)
plt.show()# Plot the first X test images, their predicted labels, and the true labels.
# Color correct predictions in blue and incorrect predictions in red.
num_rows = 5
num_cols = 3
num_images = num_rows*num_cols
plt.figure(figsize=(2*2*num_cols, 2*num_rows))
for i in range(num_images):plt.subplot(num_rows, 2*num_cols, 2*i+1)plot_image(i, predictions[i], test_labels, test_images)plt.subplot(num_rows, 2*num_cols, 2*i+2)plot_value_array(i, predictions[i], test_labels)
plt.tight_layout()
plt.show()# Grab an image from the test dataset.
img = test_images[1]print(img.shape)# Add the image to a batch where it's the only member.
img = (np.expand_dims(img,0))print(img.shape)predictions_single = probability_model.predict(img)print(predictions_single)plot_value_array(1, predictions_single[0], test_labels)
_ = plt.xticks(range(10), class_names, rotation=45)print(np.argmax(predictions_single[0]))

参考文章

TensorFlow——本地加载fashion-mnist数据集

TensorFlow 教程——基本分类:对服装图像进行分类

TensorFlow——基于Keras子类API的fashion-mnist数据集图像分类相关推荐

  1. TensorFlow(Keras) 一步步实现Fashion MNIST衣服鞋子图片分类 (2) Coursera深度学习教程分享

    @[TOC](Coursera TensorFlow(Keras) 一步步手写体Fashion Mnist识别分类(2) Tensorflow和ML, DL 机器学习/深度学习Coursera教程分享 ...

  2. fashionmnist数据集_Keras实现Fashion MNIST数据集分类

    本篇用keras构建人工神经网路(ANN)和卷积神经网络(CNN)实现Fashion MNIST 数据集单个物品分类,并从模型预测的准确性方面对ANN和CNN进行简单比较. Fashion MNIST ...

  3. python cnn程序_python cnn训练(针对Fashion MNIST数据集)

    本文将和大家一起一步步尝试对Fashion MNIST数据集进行调参,看看每一步对模型精度的影响.(调参过程中,基础模型架构大致保持不变) 废话不多说,先上任务: 模型的主体框架如下(此为拿到的原始代 ...

  4. DL之DCGNN:基于TF利用DCGAN实现在MNIST数据集上训练生成新样本

    DL之DCGNN:基于TF利用DCGAN实现在MNIST数据集上训练生成新样本 目录 输出结果 设计思路 实现部分代码 说明:所有图片文件丢失 输出结果 更新-- 设计思路 更新-- 实现部分代码 更 ...

  5. 基于jupyter notebook的python编程-----MNIST数据集的的定义及相关处理学习

    基于jupyter notebook的python编程-----MNIST数据集的相关处理 一.MNIST定义 1.什么是MNIST数据集 2.python如何导入MNIST数据集并操作 3.接下来, ...

  6. Fashion MNIST数据集的处理——“...-idx3-ubyte”文件解析

    Fashion MNIST MNIST数据集可能是计算机视觉所接触的第一个图片数据集.而 Fashion MNIST 是在遵循 MNIST 的格式和大小的基础上,提升了一定的难度,在比较算法的性能时可 ...

  7. Tensorflow 笔记 XIII——“百无聊赖”:深挖 mnist 数据集与 fashion-mnist 数据集的读取原理,经典数据的读取你真的懂了吗?

    文章目录 数据集简介 Mnist 出门右转 Fashion-Mnist 数据集制作需求来源 写给专业的机器学习研究者 获取数据 类别标注 读取原理 原理获取 TRAINING SET LABEL FI ...

  8. 深度学习之利用TensorFlow实现简单的卷积神经网络(MNIST数据集)

    卷积神经网络(Convolutional Neural Networks, CNN)是一类包含卷积计算且具有深度结构的前馈神经网络(Feedforward Neural Networks),是深度学习 ...

  9. 基于Python实现的神经网络分类MNIST数据集

    神经网络分类MNIST数据集 目录 神经网络分类MNIST数据集 1 一 .问题背景 1 1.1 神经网络简介 1 前馈神经网络模型: 1 1.2 MINST 数据说明 4 1.3 TensorFlo ...

最新文章

  1. 如何优雅地使用pdpipe与Pandas构建管道?
  2. python之⾯向对象-继承
  3. mysql int和bigdecimal,mysql的 int 类型,刨析返回类型为BigDicemal 类型的奇怪现象
  4. c语言如何实现阶乘,求10000的阶乘(c语言代码实现)
  5. 网站发布问题及使用Web Deployment Projects
  6. Duo js 一个非常酷的前端打包工具
  7. Spring websocket 使用@Autowired 出现null
  8. soapui 测试soap_使用SoapUI调用不同的安全WCF SOAP服务-基本身份验证,第二部分
  9. linux实现自动互信,Linux 使用shell脚本实现自动SSH互信功能
  10. 一个神奇的测试_这4个在线黑科技工具拥有神奇的魔法,值得收藏!
  11. python语言继承6.3节例6-1中的person_第6.3节 Python动态执行之动态编译的compile函数...
  12. python3.7降级3.6_请问一下Mac python3.7.1怎么降低到3.6版本?
  13. windows 监控
  14. java基础学习(4)
  15. java 全双工串口,Java实现全双工串口通信
  16. 推荐丨全球主要城市TOD数据
  17. Java 对象排序完整版
  18. drools快速入门:简介、语法和结构
  19. carla学习笔记(八)
  20. 聚焦五大亮点,神策数据 A/B 测试功能全新发布!

热门文章

  1. Java时断时续之——正则表达式
  2. mac地址厂商对应表_网络工程师一分钟搞懂MAC地址表知识点全部内容,建议收藏...
  3. unity 批量导入模型工具_零基础的Unity图形学笔记3:使用多模型UV与优化模型导出...
  4. c++ 弧形面如何逆时针排序_环形导轨如何实现拐弯?
  5. linux一切对象皆文件,为什么说Linux下“一切皆文件”?
  6. 列表逆序排序_Python零基础入门学习05:容器数据类型:列表和元组
  7. python的datetime举例_Python datetime.timedelta()用法及代码示例
  8. 化工热力学重修补考第三章重点内容
  9. 四十五、Gtihub+Hexo+icarus搭建自己的博客
  10. 九十一、Python的GUI系列 | QT组件篇