目录

1.原文完整代码

1.1 模型运行参数总结

1.2模型训练效果

​编辑2.模型的保存

3.读取模型model

4.使用模型进行图片预测

5.补充 如何查看保存模型参数

5.1 model_weights

5.2 optimizer_weights


使用之前一篇代码:

原文链接:Tensorflow2 图像分类-Flowers数据及分类代码详解

这篇文章中,经常有人问到怎么保存模型?怎么读取和应用模型进行数据预测?这里做一下详细说明。

1.原文完整代码

完整代码如下,做了少量修改:

import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
import tensorflow as tffrom tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential"""flower_photo/daisy/dandelion/roses/sunflowers/tulips/"""
import pathlibdataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
print(data_dir)
print(type(data_dir))
data_dir = pathlib.Path(data_dir)
print(data_dir)
print(type(data_dir))image_count = len(list(data_dir.glob('*/*.jpg')))
print(image_count)
roses = list(data_dir.glob('roses/*'))
img0 = PIL.Image.open(str(roses[0]))
plt.imshow(img0)
plt.show()batch_size = 32
img_height = 180
img_width = 180# It's good practice to use a validation split when developing your model.
# Let's use 80% of the images for training, and 20% for validation.
train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="training",seed=123,image_size=(img_height, img_width),batch_size=batch_size)val_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="validation",seed=123,image_size=(img_height, img_width),batch_size=batch_size)class_names = train_ds.class_names
print(class_names)# 图片可视化
plt.figure(figsize=(10, 10))
for images, labels in train_ds.take(1):for i in range(30):ax = plt.subplot(3, 10, i + 1)plt.imshow(images[i].numpy().astype("uint8"))plt.title(class_names[labels[i]])plt.axis("off")
plt.show()for image_batch, labels_batch in train_ds:print(image_batch.shape)print(labels_batch.shape)breakAUTOTUNE = tf.data.AUTOTUNEtrain_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)normalization_layer = layers.experimental.preprocessing.Rescaling(1. / 255)
normalized_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))
image_batch, labels_batch = next(iter(normalized_ds))
first_image = image_batch[0]
# Notice the pixels values are now in `[0,1]`.
print(np.min(first_image), np.max(first_image))data_augmentation = keras.Sequential([layers.experimental.preprocessing.RandomFlip("horizontal",input_shape=(img_height,img_width,3)),layers.experimental.preprocessing.RandomRotation(0.1),layers.experimental.preprocessing.RandomZoom(0.1),]
)
plt.figure(figsize=(10, 10))
for images, _ in train_ds.take(1):for i in range(9):augmented_images = data_augmentation(images)ax = plt.subplot(3, 3, i + 1)plt.imshow(augmented_images[0].numpy().astype("uint8"))plt.axis("off")num_classes = 5
model = Sequential([data_augmentation,layers.experimental.preprocessing.Rescaling(1. / 255),layers.Conv2D(16, 3, padding='same', activation='relu'),layers.MaxPooling2D(),layers.Conv2D(32, 3, padding='same', activation='relu'),layers.MaxPooling2D(),layers.Conv2D(64, 3, padding='same', activation='relu'),layers.MaxPooling2D(),layers.Conv2D(128, 3, padding='same', activation='relu'),layers.MaxPooling2D(),layers.Dropout(0.15),layers.Flatten(),layers.Dense(128, activation='relu'),layers.Dense(num_classes)
])model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])
model.summary()epochs = 15
history = model.fit(train_ds,validation_data=val_ds,epochs=epochs
)model.save("./model/Flowers_1227.h5")   #保存模型#读取并调用模型
pre_model = tf.keras.models.load_model("./model/Flowers_1227.h5")acc = history.history['accuracy']
val_acc = history.history['val_accuracy']loss = history.history['loss']
val_loss = history.history['val_loss']epochs_range = range(epochs)plt.figure(figsize=(8, 8))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()sunflower_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/592px-Red_sunflower.jpg"
sunflower_path = tf.keras.utils.get_file('Red_sunflower', origin=sunflower_url)img = keras.preprocessing.image.load_img(sunflower_path, target_size=(img_height, img_width)
)
img_array = keras.preprocessing.image.img_to_array(img)
img_array = tf.expand_dims(img_array, 0)  # Create a batchpredictions = pre_model.predict(img_array)
score = tf.nn.softmax(predictions[0])print("This image most likely belongs to {} with a {:.2f} percent confidence.".format(class_names[np.argmax(score)], 100 * np.max(score))
)

修改的代码包含:(1)修改了模型,增加了一个卷积层;(2)增加模型保存代码;(3)增加模型读取代码,并使用读取到的模型预测图片

1.1 模型运行参数总结

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
sequential (Sequential)      (None, 180, 180, 3)       0
_________________________________________________________________
rescaling_1 (Rescaling)      (None, 180, 180, 3)       0
_________________________________________________________________
conv2d (Conv2D)              (None, 180, 180, 16)      448
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 90, 90, 16)        0
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 90, 90, 32)        4640
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 45, 45, 32)        0
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 45, 45, 64)        18496
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 22, 22, 64)        0
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 22, 22, 128)       73856
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 11, 11, 128)       0
_________________________________________________________________
dropout (Dropout)            (None, 11, 11, 128)       0
_________________________________________________________________
flatten (Flatten)            (None, 15488)             0
_________________________________________________________________
dense (Dense)                (None, 128)               1982592
_________________________________________________________________
dense_1 (Dense)              (None, 5)                 645
=================================================================
Total params: 2,080,677
Trainable params: 2,080,677
Non-trainable params: 0

1.2模型训练效果

15次epoch有75.61%的精度,增加训练次数应该还有一定提升空间。

Epoch 1/15
92/92 [==============================] - 99s 1s/step - loss: 1.3126 - accuracy: 0.4087 - val_loss: 1.0708 - val_accuracy: 0.5477
Epoch 2/15
92/92 [==============================] - 88s 957ms/step - loss: 1.0561 - accuracy: 0.5562 - val_loss: 0.9844 - val_accuracy: 0.5872
Epoch 3/15
92/92 [==============================] - 89s 966ms/step - loss: 0.9517 - accuracy: 0.6117 - val_loss: 1.0068 - val_accuracy: 0.6035
Epoch 4/15
92/92 [==============================] - 84s 913ms/step - loss: 0.8743 - accuracy: 0.6550 - val_loss: 0.8538 - val_accuracy: 0.6580
Epoch 5/15
92/92 [==============================] - 82s 891ms/step - loss: 0.8065 - accuracy: 0.6809 - val_loss: 0.8371 - val_accuracy: 0.6703
Epoch 6/15
92/92 [==============================] - 82s 892ms/step - loss: 0.7623 - accuracy: 0.7115 - val_loss: 0.8203 - val_accuracy: 0.7016
Epoch 7/15
92/92 [==============================] - 94s 1s/step - loss: 0.7309 - accuracy: 0.7245 - val_loss: 0.7539 - val_accuracy: 0.7057
Epoch 8/15
92/92 [==============================] - 90s 982ms/step - loss: 0.6928 - accuracy: 0.7262 - val_loss: 0.7811 - val_accuracy: 0.7166
Epoch 9/15
92/92 [==============================] - 88s 955ms/step - loss: 0.6840 - accuracy: 0.7333 - val_loss: 0.8314 - val_accuracy: 0.6703
Epoch 10/15
92/92 [==============================] - 81s 877ms/step - loss: 0.6591 - accuracy: 0.7565 - val_loss: 0.7585 - val_accuracy: 0.7153
Epoch 11/15
92/92 [==============================] - 83s 899ms/step - loss: 0.6195 - accuracy: 0.7633 - val_loss: 0.7600 - val_accuracy: 0.7125
Epoch 12/15
92/92 [==============================] - 86s 934ms/step - loss: 0.6006 - accuracy: 0.7657 - val_loss: 0.6871 - val_accuracy: 0.7262
Epoch 13/15
92/92 [==============================] - 86s 934ms/step - loss: 0.5736 - accuracy: 0.7762 - val_loss: 0.6955 - val_accuracy: 0.7452
Epoch 14/15
92/92 [==============================] - 82s 897ms/step - loss: 0.5523 - accuracy: 0.7871 - val_loss: 0.7513 - val_accuracy: 0.7234
Epoch 15/15
92/92 [==============================] - 86s 935ms/step - loss: 0.5379 - accuracy: 0.7956 - val_loss: 0.6591 - val_accuracy: 0.7561

  图片数据增强后的效果图:

2.模型的保存

训练模型的保存实际上只需一行代码就行,在模型训练完成之后,我们将模型保存到指定的路径并给模型命名。模型保存的格式是.h5后缀的格式,这种文件是hdf5格式的数据,我们可以使用专门的软件打开查看模型相关参数。

在model.fit()训练完模型之后,保存模型到model文件夹下:

model.save("./model/Flowers_1227.h5")   #保存模型

运行完成之后在项目文件下可以看到model文件夹,文件中可以看到我们保存的模型:

模型大小有45.7M.

3.读取模型model

读取代码也只需要一行,如下:

#读取并调用模型
pre_model = tf.keras.models.load_model("./model/Flowers_1227.h5")

4.使用模型进行图片预测

根据上面读取到的模型直接进行图片预测。

可以省去前面的数据训练部分,直接使用后面的部分代码,读取模型然后进行图片预测。

继续运行上文中的代码后面部分,即最后面的部分是预测一张图片属于什么类型的。

运行结果是:

This image most likely belongs to sunflowers with a 98.13 percent confidence.

读取模型进行预测的代码如下:

import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
import tensorflow as tffrom tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential"""flower_photo/daisy/dandelion/roses/sunflowers/tulips/"""
import pathlibbatch_size = 32
img_height = 180
img_width = 180dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="training",seed=123,image_size=(img_height, img_width),batch_size=batch_size)
class_names = train_ds.class_names
print(class_names)#读取并调用模型
pre_model = tf.keras.models.load_model("./model/Flowers_1227.h5")sunflower_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/592px-Red_sunflower.jpg"
sunflower_path = tf.keras.utils.get_file('Red_sunflower', origin=sunflower_url)img = keras.preprocessing.image.load_img(sunflower_path, target_size=(img_height, img_width)
)
img_array = keras.preprocessing.image.img_to_array(img)
img_array = tf.expand_dims(img_array, 0)  # Create a batchpredictions = pre_model.predict(img_array)
score = tf.nn.softmax(predictions[0])
print(score)print("This image most likely belongs to {} with a {:.2f} percent confidence.".format(class_names[np.argmax(score)], 100 * np.max(score))
)

运行结果:

Found 3670 files belonging to 5 classes.
Using 2936 files for training.
2022-12-27 22:34:13.075000: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX AVX2
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
2022-12-27 22:34:14.205000: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:176] None of the MLIR Optimization Passes are enabled (registered 2)
tf.Tensor([1.1012822e-04 5.2587932e-04 4.4165729e-03 9.8130155e-01 1.3645952e-02], shape=(5,), dtype=float32)
This image most likely belongs to sunflowers with a 98.13 percent confidence.

我们可以看到预测结果中,score = tf.nn.softmax(predictions[0]),代表的是该图片属于每种类型的概率大小,第四个是最大的,即第四个对应的是sunflowers类别,因此可以认为预测的结果就是sunflowers。

这里预测的图片是使用在线下载的一张图片进行预测的,实际上我们可以读取我们本地路径下的文件进行大批量的预测,并将每张图片预测结果保存到文本文件中用于后续的分析。

5.补充 如何查看保存模型参数

使用HDFView软件查看.H5后缀的文件。下载链接:HDFViiew-2.11.0-win64.exe-桌面系统文档类资源-CSDN下载

网上其他地方也有免费下载的,之前是在国外网站下载的,有时间查找的同学可以花点时间去找一下。

可以看到,该模型主要有两部分,model_weights和optimizer_weights.即模型权重系数和优化器权重系数参数。

我们点击展开这两个文件夹,我们可以看到里面的文件层次和我们的模型层次是一致的。

卷积层中的参数可以查看到如下:

5.1 model_weights

model_weigths参数展开如下,从dropout开始是空的。

5.2 optimizer_weights

optimizer_weights参数如下:

Tensorflow2 图像分类-Flowers数据深度学习模型保存、读取、参数查看和图像预测相关推荐

  1. DL之模型调参:深度学习算法模型优化参数之对深度学习模型的超参数采用网格搜索进行模型调优(建议收藏)

    DL之模型调参:深度学习算法模型优化参数之对深度学习模型的超参数采用网格搜索进行模型调优(建议收藏) 目录 神经网络的参数调优 1.神经网络的通病-各种参数随机性 2.评估模型学习能力

  2. 深度学习模型保存_Web服务部署深度学习模型

    本文的目的是介绍如何使用Web服务快速部署深度学习模型,虽然TF有TFserving可以进行模型部署,但是对于Pytorch无能为力(如果要使用的话需要把torch模型进行转换,有些麻烦):因此,本文 ...

  3. 深度学习模型保存_解读计算机视觉的深度学习模型

    作者 | Dipanjan(DJ)Sarkar 来源 | Medium 编辑 | 代码医生团队 介绍 人工智能(AI)不再仅限于研究论文和学术界.业内不同领域的企业和组织正在构建由AI支持的大规模应用 ...

  4. 手工计算深度学习模型中的参数数量

    https://www.toutiao.com/a6649299295855968782/ 2019-01-22 20:37:14 计算深度学习模型的可训练参数的数量被认为太微不足道了,因为您的代码已 ...

  5. Tensorflow2 图像分类-Flowers数据及分类代码详解

    目录 1.基本步骤 2.1 环境 2.2 数据说明 2.3 数据查看代码 3.创建数据集 3.1 数据集创建函数 3.2 创建数据集 4.数据可视化 5.模型训练 5.1 配置数据集 5.2 数据标准 ...

  6. 深度学习模型中的参数数量(备忘)

    原文地址:huay' blog/模型中的参数数量(备忘) 记录模型参数数量的计算方法 最早使用 tensorflow 的时候没怎么注意这个问题: 后面高级 API 用的多了,有点忘记怎么计算模型的参数 ...

  7. [自动调参]深度学习模型的超参数自动化调优详解

    向AI转型的程序员都关注了这个号

  8. 基于web端和C++的两种深度学习模型部署方式

    深度学习Author:louwillMachine Learning Lab 本文对深度学习两种模型部署方式进行总结和梳理.一种是基于web服务端的模型部署,一种是基... 深度学习 Author:l ...

  9. 如何在Keras中检查深度学习模型(翻译)

    本文翻译自:How to Check-Point Deep Learning Models in Keras 深度学习模型可能需要数小时,数天甚至数周才能进行训练. 如果意外停止运行,则可能会丢失大量 ...

最新文章

  1. linux网络编程二十:socket选项:SO_RCVTIMEO和SO_SNDTIMEO
  2. python前端学习-------Flask框架基础(建议收藏)
  3. SAP 电商云 Spartacus UI 设置 delivery mode 在 3G 慢速网络下的排队效果
  4. 机器学习之支持向量机(SVM)总结
  5. arm shellcode 编写详析2
  6. CXF生成本地ws调用代码测试webservice
  7. ekf pose使用方法 ros_【百川小课堂】第13课—ROS学习(二)
  8. 用例驱动的需求过程实践
  9. Node:使用node-postgre时,使用async、await查询
  10. android4.0.3校准屏幕和隐藏statusbar
  11. 网页打开软件显示无法连接服务器,Safari 浏览器无法打开网页怎么办
  12. Perl中Data::Dumper模块用法
  13. mysql状态表 历史记录设计表_常见数据库设计(2)——历史数据问题之单记录变更...
  14. 知云文献翻译打不开_学用系列|自带翻译功能的PDF文献阅读器——知云文献翻译3.0...
  15. 6.数据仓库搭建之数据仓库设计
  16. *4-2 CCF 2014-12-2 Z字形扫描
  17. 原画 机器人总动员_《机器人瓦力》导演执导 科幻史诗巨制《火星上的约翰·卡特》首支震撼预告...
  18. RISC-V特权级寄存器及指令文档
  19. 深度学习:GAN优化方法-DCGAN案例
  20. 在画电路图时,想问下几种地之间的区别? power-GND singal-GND GND

热门文章

  1. Java编程基础语句
  2. 基于概率的循环地图 Unlit Shader
  3. main函数参数(关于argc与argv)
  4. C1任务03 北院314-张本龙
  5. Lenovo G480笔记本安装OS X Mavericks 10.9,升级10.9.1,驱动安装,DSDT修改全过程
  6. 在Firefly AIO-3399ProC搭建rknn环境
  7. 使用 BigDecimal 的正确方式
  8. 体验了基于ChatGPT的谷歌翻译插件后,我把其他翻译插件移除了
  9. Mybatis的SQL注入隐患操作复现防止SQL注入
  10. 身份证号码前六位所代表的省,市,区, 以及地区编码下载