如何用最小代价重新训练Google Inception V3神经网络模型,以用于新的分类

郑灵翔

2017.06.03

运行本文的代码要求已安装TensorFlow,若没安装的请参考TensorFlow官方参考指南(EN)

或者我的翻译TensorFlow--和-TFlearn Windows安装指南 (仅提供windows翻译,Linux和MAC的翻译已经很多了)

人类视觉看起来是轻而易举的任务。人们不费吹灰之力就能区分一只狮子和一只美洲虎,看一个标志,或认出一个人的脸。但对计算机来说,这些是一个难题:他们看起来很容易是因为我们的大脑能非常好地理解图像。

近几年机器学习领域在解决这些困难问题上取得了巨大进步。尤其是深度卷积神经网络的模型可以在困难的视觉识别任务上达到不错的性能 - 在某些领域接近或超越人类表现。

研究人员已经用ImageNet(一个计算机视觉的基准测试数据集)证明了计算机视觉的稳步进展。并创造出许多很不错的模型,例如:QuocNet,AlexNet,Inception(GoogLeNet),BN-Inception-v2,描述这些模型的论文尽管都已发表,但结果仍难以重现。而Google开源了他们最新的Inception-v3图像识别模型的代码和训练好的模型,这使得我们可以轻松的利用它来完成许多图像分类的任务。

Inception-v3使用2012年的ImageNet数据进行了训练。这是计算机视觉的一项标准任务,它将所有图像分为1000个类,如“斑马”,“斑点狗”和“洗碗机”等。

使用Python API运行Inception-v3

classify_image.py 将从tensorflow.org下载训练好的模型。当程序第一次运行的时候,你将需要差不多200M的磁盘空间。

首先从GitHub克隆[TensorFlow models 仓库](https://github.com/tensorflow/models)。 运行以下命令:

cd models/tutorials/image/imagenet

python classify_image.py

上述命令将对所大熊猫的图像进行分类。

如果模型运行正常,它将输出以下结果:

giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca (score = 0.88493)

indri, indris, Indri indri, Indri brevicaudatus (score = 0.00878)

lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens (score = 0.00317)

custard apple (score = 0.00149)

earthstar (score = 0.00127)

你也可以用--image_file让它对指定的JPEG图像进行分类。该模型默认输入为299x299像素的RGB图像,对于其他图像可以用input_width

和input_height参数来指定图像大小。

用迁移学习将Inception-v3用于自定义图像分类

尽管Inception-v3能够区分1000多类的图像,但是这些分类并不完全,对于某些具体任务它的结果可能不是太好,例如让它识别下面的玫瑰花。

python classify_image.py --model_dir=inception --image_file=flower_photos\roses\110472418_87b6a3aa98_m.jpg

velvet (score = 0.19078)

candle, taper, wax light (score = 0.07208)

quilt, comforter, comfort, puff (score = 0.07156)

jack-o'-lantern (score = 0.05729)

red wine (score = 0.04733)

为了让它能够认识这个玫瑰花,我们需要对模型重新进行训练。现代深度神经网络目标识别模型具有数百万个参数,可能需要几周才能完成训练。 这需要巨大的计算能力,对于大多数人来说成本太高,如果我们能够对现有的模型进行少量修改,就使其能适应新的分类任务将大大减少工作量。迁移学习就是这样一种利用现有模型快速完成训练的技术,它在一个已经完全训练,能识别已有分类(如ImageNet)的模型基础上,重新进行少量训练,使其适应新的分类需求。 在这个例子中,我们将从头开始重新培训Google Inception V3模型的最后一层的神经网络权重,同时保持模型中所有其他层的参数不变。 有关该方法的更多信息,您可以阅读Decaf的论文 。

虽然它不如完全重新训练效果好,但对于许多应用来说,这是种非常高效的方法,它让你可以在笔记本电脑上训练几分钟(训练时间与CPU性能有关),而无需GPU。本教程将向您展示如何在自己的图像集上运行示例脚本,形成适用你自己图像集的新的分类型,并将介绍一些帮助控制训练过程的选项。

新的花卉分类器

构造图像训练集

在开始训练之前,您需要一组图像来教神经网络了解您想要识别的新类别。 稍后部分将介绍如何准备自己的图片,但为了方便我们,使用Tensorflow提供的花卉照片。

http://download.tensorflow.org/example_images/flower_photos.tgz

将这些照片解压后,您现在应该可以在工作目录中找到可用的照片副本。

这个数据集包括5种类型的花卉,总大小212M,为了节省时间,我们将数据集的照片减少一些,每种类型的照片保留300张左右,来训练分类器

重新训练模型

在开始之前可以先启动tensorboard

tensorboard --logdir training_summaries &

python retrain.py --bottleneck_dir=bottlenecks --how_many_training_steps=500 --model_dir=inception --summaries_dir=training_summaries/basic --output_graph=retrained_graph.pb --output_labels=retrained_labels.txt --image_dir=flower_photos

一旦TensorBoard运行,浏览网页localhost:6006查看TensorBoard。

脚本将默认将TensorBoard摘要记录到/tmp/retrain_logs 。 您可以使用--summaries_dir标志更改--summaries_dir 。

TensorBoard README有更多关于TensorBoard使用情况的信息,包括提示和技巧以及调试信息。

重新训练Inception模型

接下来,我们将重新训练google 的Inception v3神经网络模型。

Inception是一个巨大的图像分类模型,具有数百万个参数,可以区分大量类型的图像。 我们只重新训练该网络的最后一层,所以训练将在合理的时间内结束。

接下来使用一个大命令开始图像重新训练:

python /tensorflow/tensorflow/examples/image_retraining/retrain.py --bottleneck_dir=/tf_files/bottlenecks --how_many_training_steps 500 --model_dir=/tf_files/inception --output_graph=/tf_files/retrained_graph.pb --output_labels=/tf_files/retrained_labels.txt --image_dir /tf_files/flower_photos

此脚本加载预训练的Inception v3模型,删除旧的最后一层,并在我们下载的花朵照片上训练一个新的模型。

ImageNet最初没有任何这些花卉的数据。 然而,Inception v3通过ImageNet数据进行训练,能够区分1000个类的信息,也有助于区分其他对象。 通过使用这个预训练的网络,我们使用该信息作为区分我们的花类的最终分类层的输入。

使用重新训练的模型

Retrain.py脚本将输出一个包含最后一层重新训练为识别我们给定类别的新的Inception v3网络模型,output_graph.pb和一个包含output_labels.txt的标签的文本文件。

这些文件都是C ++和Python图像分类示例可以使用的格式,因此您可以立即开始使用您的新模型。

你现在有两个选择:Python或C ++。

使用Python

(编译TensorFlow在C ++中可能是一个漫长的过程。)

下面是Python加载你的新模型文件和预测脚本。

label_image.py

import tensorflow as tf

# change this as you see fit

image_path = sys.argv[1]

# Read in the image_data

image_data = tf.gfile.FastGFile(image_path, 'rb').read()

# Loads label file, strips off carriage return

label_lines = [line.rstrip() for line

in tf.gfile.GFile("/tf_files/retrained_labels.txt")]

# Unpersists graph from file

with tf.gfile.FastGFile("/tf_files/retrained_graph.pb", 'rb') as f:

graph_def = tf.GraphDef()

graph_def.ParseFromString(f.read())

_ = tf.import_graph_def(graph_def, name='')

with tf.Session() as sess:

# Feed the image_data as input to the graph and get first prediction

softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')

predictions = sess.run(softmax_tensor, \

{'DecodeJpeg/contents:0': image_data})

# Sort to show labels of first prediction in order of confidence

top_k = predictions[0].argsort()[-len(predictions[0]):][::-1]

for node_id in top_k:

human_string = label_lines[node_id]

score = predictions[0][node_id]

print('%s (score = %.5f)' % (human_string, score))`

接下来,我们进行一些测试

菊花

# python /tf_files/label_image.py /tf_files/flower_photos/daisy/21652746_cc379e0eea_m.jpg

玖瑰

# python /tf_files/label_image.py /tf_files/flower_photos/roses/2414954629_3708a1a04d.jpg

关于Bottlenecks

第一阶段分析磁盘上的所有图像,并计算每个图像的Bottlenecks。什么是Bottlenecks?

Inception v3模型由许多层叠在一起的顶层组成(参见本文中的图片)。这些层是预先训练的,并且在找到和总结将有助于分类大多数图像的信息方面已经非常有价值。对于这个codelab,你只训练最后一层;先前的层保持它们已经训练的状态。

因此,“Bottlenecks”是我们经常用于刚好在实际进行分类的最终输出层之前的层的非正式术语。

每个图像在训练期间重复使用多次。计算每个图像的Bottlenecks后面的层需要大量的时间。通过将较低层的输出缓存在磁盘上,它们不必重复地重新计算。默认情况下,它们存储在/ tmp / bottleneck目录中。如果您重新运行脚本,它们将被重用,因此您不必再次等待此部分。

关于训练

你会看到Bottlenecks的训练,完成之后,将开始网络的最后一层的实际训练。

您将看到一系列步骤输出,每个步骤输出显示训练准确性,验证准确性和交叉熵:

训练精度表示在当前训练批次中使用的,用正确类别标记的图像的百分比。

验证精度:验证精度是来自不同集合的随机选择的图像组的精度(正确标记的图像的百分比)。

交叉熵是一个损失函数,它可以了解学习过程的进展情况。 (较低的数字在这里更好。)

网络性能的真实度量是测量其在不在训练数据中的数据集上的性能。此性能是使用验证准确度测量的。如果训练准确度高但验证准确性保持低,这意味着网络过度拟合,并且网络正在记忆训练图像中的特定特征,这些特征不能帮助其更一般地分类图像。

训练的目的是使交叉熵尽可能小,所以你可以通过观察损失是否保持向下趋势,忽略短期噪声来判断学习是否有效。

默认情况下,此脚本运行4000个训练步骤。每个步骤从训练集中随机选择10个图像,从缓存中找到它们的Bottlenecks,并将它们馈送到最后一层以获得预测。然后将那些预测与实际标签进行比较,以通过反向传播过程更新最终层的权重。

随着过程的继续,您应该看到报告的精度提高。在所有训练步骤完成之后,脚本对与训练和验证图片分离的一组图像运行最终测试精度评估。该测试评估提供了训练模型将如何对分类任务执行的最佳估计。

您应该看到的准确度值在85%和99%之间,虽然确切的值会随运行而变化,因为在训练过程中存在随机性。 (如果你只训练两个类,你应该期望更高的准确性。)此数值表示在模型完全训练后给出正确标签的测试集中的图像的百分比。

本文是在以下参考文章基础上修改而成:

神经网络python分类识别图片_教现有的神经网络图像模型识别新的图像类别相关推荐

  1. python人脸识别神器_教你用Python人脸识别自动开机,值得收藏

    这里将告诉您教你用Python人脸识别自动开机,值得收藏,具体操作方法:是不是厌烦了每次回家都要点击按钮打开电脑的操作? 你如果有看过我以前的推送,是不是厌烦了每次回家都要喊 "echo,t ...

  2. DL之AlexNet:利用卷积神经网络类AlexNet实现猫狗分类识别(图片数据增强→保存h5模型)

    DL之AlexNet:利用卷积神经网络类AlexNet实现猫狗分类识别(图片数据增强→保存h5模型) 目录 利用卷积神经网络类AlexNet实现猫狗分类识别(图片数据增强→保存h5模型) 设计思路 处 ...

  3. python人工智能文字识别软件_怎么用Python人工智能识别图片-百度AI文字识别使用方法分享 - Iefans...

    如果你是个Python初学者,那么你可以试着做做看这个教程,毕竟编程技能都是在实战中成长的,这篇教程是教你如何用Python来进行人工智能识别图片,可以帮助你解决日常办公时遇到的图片转换文字的问题. ...

  4. DL之VGG16:基于VGG16迁移技术实现猫狗分类识别(图片数据量调整→保存h5模型)

    DL之VGG16:基于VGG16迁移技术实现猫狗分类识别(图片数据量调整→保存h5模型) 目录 基于VGG16迁移技术实现猫狗分类识别(图片数据量调整→保存h5模型) 设计思路 输出结果 1488/1 ...

  5. python 制定识别图片的某些区域_python批量识别图片指定区域文字内容

    Python批量识别图片指定区域文字内容,供大家参考,具体内容如下 简介 对于一张图片,需求识别指定区域的内容 1.截取原始图上的指定图片当做模板 2.根据模板相似度去再原始图片上识别准确坐标 3.根 ...

  6. python新手入门教程思路-Python新手入门教程_教你怎么用Python做数据分析

    Python新手入门教程_教你怎么用Python做数据分析 跟大家讲了这么多期的Python教程,有小伙伴在学Python新手教程的时候说学Python比较复杂的地方就是资料太多了,比较复杂.很多网上 ...

  7. Python+Tesseract-OCR识别图片文字并保存到word文档

    目录 使用Python+Tesseract-OCR识别图片文字并保存到word文档 安装Tesseract-OCR 配置Tesseract-OCR 通过CMD验证Tesseract-OCR工作 安装p ...

  8. CV之FR(H+k机器学习):基于每人几张人脸图片训练H+k模型实现(国内外明星)新人脸图像的姓名预测(准确度高达100%)

    CV之FR(H+k机器学习):基于每人几张人脸图片训练H+k模型实现(国内外明星)新人脸图像的姓名预测(准确度高达100%) 目录 输出结果 设计思路 输出结果 设计思路

  9. python 文字识别 准确率_关于OCR图片文本检测、推荐一个 基于深度学习的Python 库!...

    大家好,我是 zeroing~ 1,前言 之前谈到图片文本 OCR 识别时,写过一篇文章介绍了一个 Python 包 pytesseract ,具体内容可参考 介绍一个Python 包 ,几行代码可实 ...

最新文章

  1. 心中有火,前方有光,致敬所有智能车后浪
  2. c++怎么输入带有空格的字符串_CCF CSP-J/S中常用的输入输出 总结(三)
  3. STM32 HAL库--串口的DMA(发送、接收)和esp8266 wifi模组发送和接收封装函数
  4. Python特殊语法:filter、map、reduce、lambda
  5. mysql upgrade 失败_`mysql_upgrade`失败,没有给出真正的理由
  6. 设计模式--程序猿必备面向对象设计原则
  7. linux 编译hadoop,linux centos 安装编译hadoop2.7.1
  8. 第 2-3 课:抽象类和接口 + 面试题
  9. Linux tshark发送抓取的数据到kafka
  10. oracle10.2.0.4 dbca,在rhel5上oracle 10.2.0.4 上dbca silent删除数据库
  11. mysql8默认字符编码_mysql默认字符编码问题
  12. Windows Azure Virtual Machine (33) Azure虚拟机删除重建
  13. 全国夜间灯光指数数据、GDP密度分布、人口密度分布、土地利用数据、降雨量数据
  14. Python如何运行单个.py文件而不是unittest
  15. 沙盘模拟软件_ERP企业经营模拟第一次培训
  16. 手把手教你App推广时如何能找到100个以上渠道!
  17. 弘辽科技:如何写出自带流量的标题
  18. 设置行与行的间隔(行间距)
  19. MFC用户名和密码的登录界面设计
  20. 什么是数据科学?数据科学的基本内容

热门文章

  1. 小小数学家(python)
  2. Make a difference with Dragon Board410c(1)
  3. 光流 — Optical Flow
  4. 谷歌“不支持”CDMA:实为ASOP ROM变动
  5. UDP 错误 10054 : 远程主机强迫关闭了一个现有的连接
  6. 浏览器指纹是什么?浏览器指纹伪装如何才有效果?
  7. C#为图片添加水印,生成缩略图
  8. 解决“尝试执行未经授权的操作”问题
  9. del服务器的型号,del服务器
  10. Java简单项目实例---统计部门员工的平均工资