实例描述

有一组照片,每个文件夹的名称为具体的年龄,里面放的是该年纪的人物图片。

微调 TF-Hub 库,让模型学习这些样本,找到其中的规律,可以根据具体人物的图片来评估人物的年龄。

即便是通过人眼来观察他人的外表,也不能准确判断出被观察人的性别和年纪。所以在应用中,模型的准确度应该与用人眼的估计值来比对,并不能与被测目标的真实值来比对。

一、准备样本

本实例所用的样本来自于 IMDB-WIKI 数据集。IMDB-WIKI 数据集中包含与年龄匹配应的人物图片。

因为该数据集相对粗糙(有些年纪对应的图片特别少),所以需要在该数据集的基础上做一些简单的调整:

  • 补充了一些与年龄匹配的人物图片。
  • 删掉了若干不合格的样本。

整理后的图片一共有 105500 张。

读者可以直接使用本书配套的数据集,将该数据集(IMBD-WIKI 文件夹)放到当前代码的本地同级文件夹下即可使用。

二、下载 TF-Hub 库中的模型

安装 TF-Hub 库后,可以按照以下步骤进行操作。

1. 找到 TF-Hub 库中的模型下载链接

在 GitHub 网站中找到 TF-Hub 库中所提供的模型及下载地址,具体网址如下(国内可能访问不了,请读者自行想办法): https://tfhub.dev/

打开该网页后,可以看到在列表中有很多模型及下载链接,如图 1 所示。


图 1 预训练模型列表

在图 1 可以分为 3 部分,具体如下:

  • 最顶端是搜索框。可以通过该搜索框搜索想要下载的预训练模型。
  • 左侧是模型的分类目录。将 TF-Hub 库中的预训练模型按照文本、图像、视频、发布者进行分类。
  • 右侧是具体的模型列表。其中列出每个模型的具体说明和下载链接。

因为本例需要图像方面的预训练模型,所以重点介绍左侧分类目录中 image 下的内容。在 image 分类下方还有 4 个子菜单,具体含义如下:

  • Classification:是一个分类器模型的分类。该类模型可以直接输出图片的预测结果。用于端到端的使用场景。
  • Feature_vector:一个特征向量模型的分类。该类模型是在分类器模型基础上去掉了最后两个网络层,只输出图片的向量特征,以便在预训练时使用。
  • Generator:一个生成器模型的分类。该类别的模型可以完成合成图片相关的任务。
  • Other:一个有关图像模型的其他分类。

2. 在 TF-Hub 库中搜索预训练模型

在图 1 中的搜索框里输入“mobilenet”并按 Enter 键,即可显示出与 MobileNet 相关的模型,如图 2 所示。


图 2 搜索 MobileNet 预训练模型

在图 2 右侧的列表部分,可以找到 MobileNet 模型。以 MobileNet_v2_100_224 模型为例(图 2 右侧列表中的最下方 2 行),该模型有两个版本:classification 与 feature_vector。

单击图 2 右侧列表中的最后下面一行,进入 MobileNet_v2_100_224 模型 classification 版本的详细说明页面,如图 3 所示。


图 3 NASNet_Mobile 模型 feature_vector 版本的详细说明页

在如图 3 所示的页面中,可以看到该网页介绍了 MobileNet_v2_100_224 模型的来源、训练、使用、微调,以及历史日志等方面的内容。在页面的右上角有一个“Copy URL”按钮,该按钮可以复制模型的下载,方便下载使用。

3. 在 TF-Hub 库中下载 MobileNet_V2 模型

下载 TF-Hub 库中的模型方法有两种:自动下载和手动下载。

  • 自动下载:单击图 3 中的“Copy URL”按钮,复制下载的 URL 地址,并将该地址填入调用 TF-Hub 库时的参数中。具体做法见 5.5.3 小节。
  • 手动下载:从图 3 所示页面中复制的 URL 地址不能直接使用,需要将其前半部分的“ https://tfhub.dev ”换成“ https://storage.googleapis.com/tfhub-modules”,并在 URL 后加上“.tar.gz”。

以 MobileNet_v2_100_224(简称 MobileNet_V2)模型的 classification 版本为例,手动下载的步骤如下。

(1)单击 5-8 中的“Copy URL”按钮,所得到的 URL 地址如下:
https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/feature_vector/2

(2)将其改成正常下载的地址。具体如下:
https://storage.googleapis.com/tfhub-modules/google/imagenet/mobilenet_v2_100_224/feature_vector/2.tar.gz

(3)用下载工具按照(2)中的地址进行下载。

三、代码实现:测试 TF-Hub 库中的 MobileNet_V2 模型

为了验证 TF-Hub 库中的模型效果,本小节将使用与第 3 章类似的代码:将 3 张图片输入 MobileNet_V2 模型的 classification 版本中,观察其输出结果。

编写代码载入 MobileNet_V2 模型,具体代码如下:

代码 1 测试 TF-Hub 库中的 NASNet_Mobile 模型

from PIL import Image
from matplotlib import pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_hub as hubwith open('中文标签.csv','r+') as f:               #打开文件           labels =list( map(lambda x:x.replace(',',' '),list(f))  )print(len(labels),type(labels),labels[:5])        #显示输出中文标签sample_images = ['hy.jpg', 'ps.jpg','72.jpg']           #定义待测试图片路径#加载分类模型
module_spec = hub.load_module_spec("https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/classification/2")
#获得模型的输入图片尺寸
height, width = hub.get_expected_image_size(module_spec) input_imgs = tf.placeholder(tf.float32, [None, height,width,3])#定义占位符
images = 2 *( input_imgs / 255.0)-1.0                      #归一化图片module = hub.Module(module_spec)                     #将模型载入张量图logits = module(images)   #获得输出张量,其形状为 [batch_size, num_classes]y = tf.argmax(logits,axis = 1)                           #获得结果的输出节点
with tf.Session() as sess:sess.run(tf.global_variables_initializer())sess.run(tf.tables_initializer())def preimg(img):                                      #定义图片预处理函数return np.asarray(img.resize((height, width)), dtype=np.float32).reshape(height, width,3) #获得原始图片与预处理图片batchImg = [ preimg( Image.open(imgfilename) ) for imgfilename in sample_images ]orgImg = [  Image.open(imgfilename)  for imgfilename in sample_images ]#将样本输入模型yv,img_norm = sess.run([y,images], feed_dict={input_imgs: batchImg})    print(yv,np.shape(yv))                                    #显示输出结果         def showresult(yy,img_norm,img_org):                   #定义显示图片函数plt.figure()  p1 = plt.subplot(121)p2 = plt.subplot(122)p1.imshow(img_org)                                #显示图片p1.axis('off') p1.set_title("organization image")p2.imshow((img_norm * 255).astype(np.uint8))      #显示图片p2.axis('off') p2.set_title("input image")  plt.show()print(yy,labels[yy])for yy,img1,img2 in zip(yv,batchImg,orgImg):          #显示每条结果及图片showresult(yy,img1,img2)

在代码第 14 行,用 TF-Hub 库中的 load_module_spec 函数加载 MobileNet_V2 模型。该步骤是通过将 TF-Hub 库中的模型链接(Module URL=“ https://tfhub.dev /google/imagenet/mobilenet_v2_100_224/classification/2”)传入函数 load_module_spec 中来完成的。

在链接里可以找到该模型文件的名字:mobilenet_v2_100_224。TF-Hub 库中的命名都非常规范,从名字上便可了解该模型的相关信息:

  • 模型是 MobileNet_V2。
  • 神经元节点是 100%(无裁剪)。
  • 输入的图片尺寸是 224。

得到模型之后,便将模型文件载入图中(见代码第 21 行),并获得输出张量(见代码第 23 行),然后通过会话(session)完成模型的输出结果。

运行代码后,显示以下结果:

在显示的结果中,可以分为两部分内容:

  • 第 1 行是标签内容。
  • 从第 2 行开始,所有以“INFO:”开头的信息都是模型加载具体参数时的日志信息。
    在每条信息中都能够看到一个相同的路径:“checkpoint b’C:\Users\ljh\AppData\Local\ Temp\tfhub_modules\bb6444e8248f8c581b7a320d5ff53061e4506c19”,这表示系统将 mobilenet_v2_100_224 模型下载到 C:\Users\ljh\AppData\Local\Temp\tfhub_modules\ bb6444e8248f8c581b7a320d5ff53061e4506c19 目录下。
    如果想要让模型缓存到指定的路径下,则需要在系统中设置环境变量 TFHUB_CACHE_DIR。例如,以下语句表示将模型下载到当前目录下的 my_module_cache 文件夹中。
TFHUB_CACHE_DIR=./my_module_cache
提示 1:
如果由于网络原因导致模型无法下载成功,还可以将本书的配套模型资源复制到当前代码同级目录下,并传入当前模型文件的路径。具体操作是,将代码第 14 行换为以下代码:
module_spec = hub.load_module_spec("mobilenet_v2_100_224")

提示 2:在最后一条的 INFO 信息之后便是模型的预测结果。

如果感觉输出的 INFO 内容太多,则可以在代码的最前面加上“tf.logging.set_verbosity (tf.logging.ERROR)”来关闭 info 信息输出。

四、用 TF-Hub 库微调 MobileNet_V2 模型

在 TF-Hub 库的 GitHub 网站上提供了微调模型的代码文件,运行该代码可以直接微调现有模型。该文件的地址如下:
https://github.com/tensorflow/hub/raw/master/examples/image_retraining/retrain.py

将代码文件下载后,直接用命令行的方式运行,便可以对模型进行微调。

1. 修改 TF-Hub 库中的代码 BUG

当前代码存在一个隐含的 BUG:在某一类的数据样本相对较少的情况下,运行时会产生错误。需要将其修改后才可以正常运行。

在“ retrain.py ”代码文件中的函数 get_random_cached_bottlenecks 里添加代码(见代码第 2 行,书中第 477 行),当程序在产生错误时,让其再去执行一次随机选取类别的操作(见代码第 15~25 行,书中第 515~525 行)。具体代码如下:

代码 retrain(片段)

…
def get_random_cached_bottlenecks(sess, image_lists, how_many, category,bottleneck_dir, image_dir, jpeg_data_tensor,decoded_image_tensor, resized_input_tensor,bottleneck_tensor, module_name):
……class_count = len(image_lists.keys())bottlenecks = []ground_truths = []filenames = []if how_many >= 0:# Retrieve a random sample of bottlenecks.for unused_i in range(how_many):IsErr = True             #添加检测异常标志while IsErr==True:       #如果出现异常就再运行一次try:label_index = random.randrange(class_count)label_name = list(image_lists.keys())[label_index]image_index = random.randrange(MAX_NUM_IMAGES_PER_CLASS + 1)image_name = get_image_path(image_lists, label_name, image_index,image_dir, category)IsErr = False  #没有异常except ZeroDivisionError: continue             #出现异常,再运行一次
…

2. 用命令行运行微调程序

将代码文件“ retrain.py ”与 5.5.1 小节准备的样本数据、5.5.2 小节下载的 MobileNet_V2 模型文件一起放到当前代码的同级目录下。在命令行窗口中输入以下命令:

python retrain.py     --image_dir ./IMBD-WIKI   --tfhub_module  mobilenet_v2_100_224_feature_vector

也可以输入以下命令,直接从网上下载 MobileNet_V2 模型,并进行微调。

python retrain.py --image_dir ./IMBD-WIKI --tfhub_module https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/feature_vector/2

程序运行之后,会显示如图 5 所示界面。


图 5 微调 MobileNet_V2 模型结束

从图 5 中可以看到,生成的模型被放在默认路径下(根目录下的 tmp 文件夹里)。来到该路径下(作者本地的路径是“G:\tmp”),可以看到微调模型程序所生成的文件,如图 6 所示。


图 6 微调 MobileNet_V2 模型后生成的文件

在图 6 中可以看到有两个文件夹。

  • bottleneck:用预训练模型 MobileNet_V2 将图片转化成的特征值文件。
  • retrain_logs:微调模型过程中的日志文件。该文件可以通过 TensorBoard 显示出来(TensorBoard 的使用方法见 13.3.2 小节)。

其他的文件是训练后生成的模型。每个模型文件的具体意义在第 6 章会有介绍。

提示:
本实例只是一个例子,重点在演示 TF-Hub 的使用。因为实例中所使用的数据集质量较低,所以训练效果并不是太理想。读者可以按照本实例的方法使用更优质的数据集训练出更好的模型。

3. 支持更多的命令行操作

代码文件“ retrain.py ”是一个很强大的训练脚本。在使用时,还可以通过修改参数实现更多的配置。

本实例只演示了部分参数的使用,其他的参数都用默认值,例如:迭代训练 4000 次,学习率为 0.01,批次大小为 100,训练集占比为 80%,测试集与验证集各占比 10% 等。

可以通过以下命令获得该脚本的全部参数说明。

python retrain.py -h

五、代码实现:用模型评估人物的年龄

用代码文件“ retrain.py ”微调后的模型是以扩展名为“pb”的文件存在的(在图 6 中,第 2 行的左数第 1 个)。该模型文件属于冻结图文件。冻结图的知识在第 13 章会详细讲解。

将冻结图格式的模型载入内存,便可以人评估物的年纪。

1. 找到模型中的输入、输出节点

冻结图文件中只有模型的具体参数。如果想使用它,则还需要知道与模型文件对应的输入和输出节点。

这两个节点都可以在代码文件“ retrain.py ”中找到。以输入节点为例,具体代码如下:

代码 retrain(片断)

…
def create_module_graph(module_spec):
…… height, width = hub.get_expected_image_size(module_spec)with tf.Graph().as_default() as graph:resized_input_tensor = tf.placeholder(tf.float32, [None, height, width, 3])m = hub.Module(module_spec)bottleneck_tensor = m(resized_input_tensor)wants_quantization = any(node.op in FAKE_QUANT_OPSfor node in graph.as_graph_def().node)return graph, bottleneck_tensor, resized_input_tensor, wants_quantization
…

从代码文件“ retrain.py ”的第 6 行(书中第 305 行)代码可以看到,输入节点的张量是一个占位符——placeholder。

提示:
直接使用 print( placeholder.name ) 和 print(final_result.name) 两行代码即可将输入节点和输出节点的名称打印出来。

将输入节点和输出节点的名称记下来,填入代码文件“5-6 用微调后的 mobilenet_v2 模型评估人物的年龄.py”中,便可以实现模型的使用。

更多有关张量的介绍可以参考《深度学习之 TensorFlow——入门、原理与进阶实战》的 4.4.2 小节。

2. 加载模型并评估结果

将本书的配套图片样例文件“22.jpg”和“tt2t.jpg”放到代码的同级目录下,用于测试模型。同时把生成的模型文件夹“tmp”也复制到本地代码的同级目录下。

这部分代码可以分为 3 部分。

  • 样本文件加载部分(见代码第 1~34 行):这部分重用了本书 4.7 节的代码。
  • 加载冻结图(见代码第 35~69 行):读者可以先有一个概念,在第 13 章还有详细讲解。
  • 图片结果显示部分(见代码第 70~94 行):这部分重用了本书 3.4 节中显示部分的代码。

完整的代码如下:
代码 2 用模型评估人物的年龄

from PIL import Image
from matplotlib import pyplot as plt
import numpy as np
import tensorflow as tffrom sklearn.utils import shuffle
import os def load_sample(sample_dir,shuffleflag = True):'''递归读取文件。只支持一级。返回文件名、数值标签、数值对应的标签名'''print ('loading sample  dataset..')lfilenames = []labelsnames = []for (dirpath, dirnames, filenames) in os.walk(sample_dir):       for filename in filenames:                           #遍历所有文件名#print(dirnames)filename_path = os.sep.join([dirpath, filename]) lfilenames.append(filename_path)                #添加文件名labelsnames.append( dirpath.split('\\')[-1] )#添加文件名对应的标签lab= list(sorted(set(labelsnames)))                #生成标签名称列表labdict=dict( zip( lab  ,list(range(len(lab)))  ))    #生成字典labels = [labdict[i] for i in labelsnames]if shuffleflag == True:return shuffle(np.asarray( lfilenames),np.asarray( labels)),np.asarray(lab)else:return (np.asarray( lfilenames),np.asarray( labels)),np.asarray(lab)#载入标签
data_dir = 'IMBD-WIKI\\'                                 #定义文件的路径
_,labels = load_sample(data_dir,False)                 #载入文件的名称与标签
print(labels)                           #输出 load_sample 返回的标签字符串sample_images = ['22.jpg', 'tt2t.jpg']                     #定义待测试图片的路径tf.logging.set_verbosity(tf.logging.ERROR)
tf.reset_default_graph()
#分类模型
thissavedir= 'tmp'
PATH_TO_CKPT = thissavedir +'/output_graph.pb'
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:serialized_graph = fid.read()od_graph_def.ParseFromString(serialized_graph)tf.import_graph_def(od_graph_def, name='')fenlei_graph = tf.get_default_graph()height,width = 224,224with tf.Session(graph=fenlei_graph) as sess:result = fenlei_graph.get_tensor_by_name('final_result:0')input_imgs = fenlei_graph.get_tensor_by_name('Placeholder:0')y = tf.argmax(result,axis = 1)   def preimg(img):                                            #定义图片的预处理函数reimg = np.asarray(img.resize((height, width)), dtype=np.float32).reshape(height, width,3) normimg = 2 *( reimg / 255.0)-1.0 return normimg#获得原始图片与预处理图片batchImg = [ preimg( Image.open(imgfilename) ) for imgfilename in sample_images ]orgImg = [  Image.open(imgfilename)  for imgfilename in sample_images ]yv = sess.run(y, feed_dict={input_imgs: batchImg})    #输入模型print(yv) print(yv,np.shape(yv))                                   #显示输出结果         def showresult(yy,img_norm,img_org):                   #定义显示图片的函数plt.figure()  p1 = plt.subplot(121)p2 = plt.subplot(122)p1.imshow(img_org)                               #显示图片p1.axis('off') p1.set_title("organization image")img = ((img_norm+1)/2)*255p2.imshow(  np.asarray(img,np.uint8)      )       #显示图片p2.axis('off') p2.set_title("input image")  plt.show()print(" 索引:",yy,","," 年纪:",labels[yy])for yy,img1,img2 in zip(yv,batchImg,orgImg):           #显示每条结果及图片showresult(yy,img1,img2)

代码第 41 行,指定了要加载的模型动态图文件。
代码第 53 行,指定了与模型文件对应的输入节点“final_result:0”。
代码第 54 行,指定了与模型文件对应的输出节点“Placeholder:0”。

代码运行后显示以下结果:

输出结果可以分为两部分:

  • 第 1 部分是标签的内容。
  • 第 2 部分是评估的结果。

在第 2 部分中,每张图片的下面都会显示这个图片的评估结果,其中包括:在模型中的标签索引、该索引对应的标签名称。

TensorFlow 工程实战(一):在TFhub中下载预训练的pb文件,并使用 TF-Hub 库微调模型评估人物年龄相关推荐

  1. Tensorflow【实战Google深度学习框架】预训练与微调含代码(看不懂你来打我)

    文章目录 1.前言 2.什么是预训练和微调 3.预训练和微调的作用 4.在一个新任务上微调一个预训练的模型代码实现 1.前言 预训练(pre-training/trained)和微调(fine tun ...

  2. PyTorch在NLP任务中使用预训练词向量

    在使用pytorch或tensorflow等神经网络框架进行nlp任务的处理时,可以通过对应的Embedding层做词向量的处理,更多的时候,使用预训练好的词向量会带来更优的性能.下面分别介绍使用ge ...

  3. Tensorflow基于pb模型进行预训练(pb模型转CKPT模型)

    Tensorflow基于pb模型进行预训练(pb模型转CKPT模型) 在网上看到很多教程都是tensorflow基于pb模型进行推理,而不是进行预训练.最近在在做项目的过程中发现之前的大哥只有一个pb ...

  4. 在Keras的Embedding层中使用预训练的word2vec词向量

    文章目录 1 准备工作 1.1 什么是词向量? 1.2 获取词向量 2 转化词向量为keras所需格式 2.1 获取所有词语word和词向量 2.2 构造"词语-词向量"字典 2. ...

  5. 神经网络 Embedding层理解; Embedding层中使用预训练词向量

    1.Embedding层理解 高维稀疏特征向量到低维稠密特征向量的转换:嵌入层将正整数(下标)转换为具有固定大小的向量:把一个one hot向量变为一个稠密向量 参考:https://zhuanlan ...

  6. Pytorch中更改预训练权重文件的下载位置

    目录 1. 参考链接 2. 更改方法 3. 一个小技巧 1. 参考链接 Pytorch更改预训练权重下载位置 pytorch---修改预训练模型下载路径 2. 更改方法 在线加载的预训练权重默认存放位 ...

  7. 使用C#把Tensorflow训练的.pb文件用在生产环境

    训练了很久的Tf模型,终于要到生产环境中去考验一番了.今天花费了一些时间去研究tf的模型如何在生产环境中去使用.大概整理了这些方法. 继续使用分步骤保存了的ckpt文件 这个貌似脱离不了tensorf ...

  8. 使用Chinese-Word-Vectors作为pytorch中的预训练向量

    如何在深度学习中使用开源Chinese Word Vectors 摘要:Chinese-Word-Vectors开源项目提供了100多种预训练模型,但在深度学习中使用时,加载预训练向量存在词表重复项问 ...

  9. 如何在深度学习过程中使用预训练的词表征(持续更新ing...)

    诸神缄默不语-个人CSDN博文目录 本文介绍在深度学习中如何应用预训练的词表征(word2vec等),应用到的框架包括numpy.PyTorch和TensorFlow 不同形式,见到了就补充总结一下. ...

最新文章

  1. mysql团队开发工具_最棒的10款MySQL GUI工具
  2. [BZOJ4259]残缺的字符串
  3. CLion 输出遇到乱码解决办法,GBK和utf-8的转换
  4. python怎么另起一行阅读答案_使用Python+Dlib构建人脸识别系统(在Nvidia Jetson Nano 2GB开发板上)...
  5. 工作分流是什么意思_【嘉陵特装要闻】重庆嘉陵召开持续推进职工分流安置工作布置会...
  6. GStreamer基础教程04 - 动态连接Pipeline
  7. 像数据科学家一样思考:12步指南(上) 1
  8. 微信小程序遇到的那些坑
  9. Applet类的方法
  10. Hyperledger Fabric MSP Identity Validity Rules——MSP身份验证规则
  11. 今年暑假不ac (c语言版)
  12. CentOs7下Zabbix安装教程——准备工作
  13. ERwin Data Modeler数据库建模工具使用纪要
  14. mysql五日均线_怎么设置五日均线?
  15. android毕业论文结论,毕业论文经典结束语
  16. web课程设计网页规划与设计~在线阅读小说网页共6个页面(HTML+CSS+JavaScript+Bootstrap)...
  17. Web前端(15)_input表单
  18. eNSP 路由器配置-静态路由和缺省路由
  19. Linux 设置开机自启动
  20. ubantu16.04下安装omnet5.4.1,inet3.6.4,veins4.7.1 和 sumo0.32.0

热门文章

  1. 零基础HTML教程(10)--写一个画龙点睛的标题
  2. 渗透测试 | 几款常用的CMS识别「Web指纹识别」扫描脚本工具(含下载地址)
  3. python外星人入侵游戏代码_黄哥Python:猜数字游戏代码
  4. Android MediaPlay的使用以及实现视频播放器
  5. python自动登录qq客户端_Python自动登录QQ的实现示例
  6. 【SHOI2007/BZOJ1933】书柜的尺寸 dp
  7. 如何将亚马逊广告添加到您的 WordPress 网站(3 种方法)
  8. 做个精致的电子工程师
  9. oracle 数据库truncate,详解Oracle DELETE和TRUNCATE 的区别
  10. 【操作系统】如何在linux系统下运行C程序