基于TensorFlow训练花朵识别模型的源码和Demo

转发来源: https://blog.csdn.net/Anymake_ren/article/details/80550684

下面就通过对现有的 Google Inception-V3 模型进行 retrain ,对 5 种花朵样本数据的进行训练,来完成一个可以识别五种花朵的模型,并将新训练的模型进行测试部属,让大家体验一下完整的流程。

安装 TensorFlow (Mac 为例)

其他平台可以直接参考官网说明:Installing TensorFlow

首先检查系统是否安装了 Python

要安装 TensorFlow ,你的系统必须依据安装了以下任一 Python 版本:

  • Python 2.7
  • Python 3.3+

如果做数据处理较多的话,建议安装Anaconda, Anaconda 是一种Python语言的免费增值开源发行版 ,用于进行大规模数据处理, 预测分析, 和科学计算, 致力于简化包的管理和部署。Anaconda使用软件包管理系统Conda进行包管理。安装完成后输入shell下输入python即可查看Anaconda对应的Python 版本,我使用的是Python 2.7.14:

➜  ~ python
Python 2.7.14 |Anaconda, Inc.| (default, Dec  7 2017, 11:07:58)
[GCC 4.2.1 Compatible Clang 4.0.1 (tags/RELEASE_401/final)] on darwin
Type "help", "copyright", "credits" or "license" for more information.

如果你的系统还没有安装符合以上版本的 Python,现在安装。

通过 pip 安装 TensorFlow

# Python 2
➜ pip install tensorflow
# Python 3
➜ pip3 install tensorflow

通过官方样例测试 TensorFlow 是否正常安装

进入 Python 环境后输入以下代码,当出现 “Hello, TensorFlow!” 时表明已经安装成功,可正常使用 TensorFlow 了。

➜ python
import tensorflow as tf
hello = tf.constant('Hello, TensorFlow!')
sess = tf.Session()
print(sess.run(hello))
Hello, TensorFlow!

准备训练样本

现在我们要训练花朵的识别模型,这是 Google 在TensorFlow里面提供的一个例子,其中包含了5类花朵的训练图片。可以新建个flower_demo文件夹,用于存放数据和训练的模型。

下载并解压得到训练样本

cd flower_demo
# 下载和解压花朵训练数据
curl -O http://download.tensorflow.org/example_images/flower_photos.tgz
tar xzf flower_photos.tgz

打开训练样本文件夹 flower_photos ,里面有 5 种类别的花:daisy(雏菊), dandelion(蒲公英), roses(玫瑰), sunflowers(向日葵) , tulips(郁金香),总共3672张,每个类别的大概有 600-900 张训练样本图片,具体如下:

cd flower_photos
for dir in `find ./ -maxdepth 1 -type d`;do echo -n -e "$dir\t";find $dir -type f|wc -l ;done;
./      3672
.//roses         641
.//sunflowers        699
.//daisy         633
.//dandelion         898
.//tulips        799

开始训练

下载训练模型使用的 retrain 脚本 
该脚本会自动下载 google Inception v3 模型相关文件,retrain.py 是 Google 提供的以ImageNet图片分类模型为基础模型,利用flower_photos数据迁移训练花朵识别模型的脚本。

 cd flower_democurl -O https://raw.githubusercontent.com/tensorflow/tensorflow/r1.1/tensorflow/examples/image_retraining/retrain.py

启动训练脚本,开始训练模型

在运行 retrain.py 脚本时,需要配置一些运行命令参数,比如指定模型输入输出相关名称和其他训练要求的配置。其中--how_many_training_steps=4000配置代表训练迭代次数,默认值为4000,如果机器较差,可以适当减少这个值。

➜ cd flower_demo
➜ python3 retrain.py \--bottleneck_dir=bottlenecks \--how_many_training_steps=4000 \--model_dir=inception \--summaries_dir=training_summaries/basic \--output_graph=retrained_graph.pb \--output_labels=retrained_labels.txt \--image_dir=flower_photos

这里我们训练4000steps,时间不是很久,我在配备4.2 GHz Intel Core i7处理器的iMac上,不适用GPU大概就5分钟就能训练完成。模型训练完成后,可以看到测试集上Final test accuracy = 92.1%,也就是说我们训练的5类花朵识别模型,在测试集上已经有92%的识别准确率了。其中生成的 retrained_labels.txt 和 retrained_graph.pb 这两个是模型相关文件。

2018-06-02 15:47:00.266119: Step 3950: Train accuracy = 94.0%
2018-06-02 15:47:00.266159: Step 3950: Cross entropy = 0.135385
2018-06-02 15:47:00.327843: Step 3950: Validation accuracy = 93.0% (N=100)
2018-06-02 15:47:00.976543: Step 3960: Train accuracy = 94.0%
2018-06-02 15:47:00.976591: Step 3960: Cross entropy = 0.234760
2018-06-02 15:47:01.038559: Step 3960: Validation accuracy = 91.0% (N=100)
2018-06-02 15:47:01.667255: Step 3970: Train accuracy = 97.0%
2018-06-02 15:47:01.667372: Step 3970: Cross entropy = 0.167394
2018-06-02 15:47:01.731935: Step 3970: Validation accuracy = 87.0% (N=100)
2018-06-02 15:47:02.355780: Step 3980: Train accuracy = 96.0%
2018-06-02 15:47:02.355818: Step 3980: Cross entropy = 0.151201
2018-06-02 15:47:02.418314: Step 3980: Validation accuracy = 91.0% (N=100)
2018-06-02 15:47:03.042364: Step 3990: Train accuracy = 99.0%
2018-06-02 15:47:03.042402: Step 3990: Cross entropy = 0.094383
2018-06-02 15:47:03.103718: Step 3990: Validation accuracy = 91.0% (N=100)
2018-06-02 15:47:03.667861: Step 3999: Train accuracy = 99.0%
2018-06-02 15:47:03.667899: Step 3999: Cross entropy = 0.106797
2018-06-02 15:47:03.729215: Step 3999: Validation accuracy = 94.0% (N=100)
Final test accuracy = 92.1% (N=353)

测试训练完成后的模型

同样的,我们先下载测试模型的脚本 label_image.py,然后从flower_photos/daisy/文件夹下选择图片488202750_c420cbce61.jpg,测试我们训练后的模型的识别准确率,当然你也可以百度搜索一张5类花朵的任意一张图测试识别效果,从下图可以看出,我们训练的算法模型认为这张图属于daisy的概率高达98.9%.

➜ cd flower_demo
➜ curl -L https://goo.gl/3lTKZs > label_image.py
➜ python label_image.py flower_photos/daisy/488202750_c420cbce61.jpgdaisy (score = 0.98921)
sunflowers (score = 0.00948)
dandelion (score = 0.00088)
tulips (score = 0.00038)
roses (score = 0.00005)


有人说label_image.py无法下载,代码如下:

import os, sys
import tensorflow as tf
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'# change this as you see fit
image_path = sys.argv[1]# Read in the image_data
image_data = tf.gfile.FastGFile(image_path, 'rb').read()# Loads label file, strips off carriage return
label_lines = [line.rstrip() for line in tf.gfile.GFile("retrained_labels.txt")]# Unpersists graph from file
with tf.gfile.FastGFile("retrained_graph.pb", 'rb') as f:graph_def = tf.GraphDef()graph_def.ParseFromString(f.read())tf.import_graph_def(graph_def, name='')with tf.Session() as sess:# Feed the image_data as input to the graph and get first predictionsoftmax_tensor = sess.graph.get_tensor_by_name('final_result:0')predictions = sess.run(softmax_tensor, {'DecodeJpeg/contents:0': image_data})# Sort to show labels of first prediction in order of confidencetop_k = predictions[0].argsort()[-len(predictions[0]):][::-1]for node_id in top_k:human_string = label_lines[node_id]score = predictions[0][node_id]print('%s (score = %.5f)' % (human_string, score))

我们随便从百度搜索一张蒲公英(dandelion)的图,保存到test/WechatIMG383.jpg,测试结果显示属于蒲公英的概率为99.59%.

python label_image.py test/WechatIMG383.jpgdandelion (score = 0.99592)
sunflowers (score = 0.00359)
daisy (score = 0.00042)
tulips (score = 0.00005)
roses (score = 0.00001)

以上基本是模型训练和测试的全部过程,希望能让大家对深度学习的完整项目有个大致的了解。

启动 TensorBoard 
TensorBoard 是 TensorFlow 自带的训练效果可视化的分析工具,我们可以利用此工具检测和分析模型的收敛情况,比如查看loss的下降、acc的提升和查看可视化的网络结构图等。在我们建的工程目录下,启动tensorboard的具体命令如下:

➜ cd flower_demo
➜ tensorboard --logdir training_summaries

启动 TensorBoard 会占用系统 6006 端口 ,再启动一个新的 TensorBoard 之前,必须要 kill 已在运行的 TensorBoard 任务。

➜ pkill -f "tensorboard

启动浏览器查看 TensorBoard

启动TensorBoard后,可以启动浏览器,在地址栏中输入 localhost:6006 来查看训练进度以及loss和准确度的变化,分析模型等。

基于TensorFlow训练花朵识别模型的源码和Demo相关推荐

  1. matlab和投影寻踪,基于遗传算法的投影寻踪模型Matlab源码

    基于遗传算法的投影寻踪模型Matlab源码 %% "投影寻踪+遗传算法优化"的主仿真程序 % GreenSim团队原创作品,转载请注明 % Email:greensim@http: ...

  2. 【邮政编码识别】基于计算机视觉实现邮政编码识别含Matlab源码

    1 简介 邮政包裹的自动分拣可以使邮政部门节省大量的人力物力,有效地提高邮政部门的邮件分拣效率,具有广阔的应用前景.该文对邮政包裹地址标签上的邮政编码识别进行了比较深入的研究,在简化模型的基础之上,详 ...

  3. python 深度学习源码_「深度学习」用TensorFlow实现人脸识别(附源码,快速get技能)...

    本文将会带你使用python码一个卷积神经网络模型,实现人脸识别,操作难度比较低,动手跟着做吧,让你的电脑认出你那帅气的脸. 由于代码篇幅较长,而且最重要的缩进都没了,建议直接打开源码或者点击分享-& ...

  4. 【水果识别】基于计算机视觉实现水果识别含Matlab源码

    1 简介 自"农业 4.0"时代的来临,以"互联网+"为驱动的农业技术已成为发展农业强有力的支撑.在果蔬业中,果蔬分类通常由经过训练的人员人工评估农产品或农作物 ...

  5. 推荐30个以上比较好的命名实体识别模型github源码?

    命名实体识别是自然语言处理中的一个重要任务,也是比较经典的应用.这里推荐几个比较流行的命名实体识别模型的GitHub源码: BERT-NER:基于BERT的命名实体识别模型,使用了CRF层来解码,在很 ...

  6. 【水果识别】基于形态学实现水果识别含Matlab源码

    1 简介 数学形态学操作可以分为二值形态学和灰度形态学,灰度形态学由二值形态学扩展而来.数学形态学有2个基本的运算,即腐蚀和膨胀,而腐蚀和膨胀通过结合又形成了开运算和闭运算. 开运算就是先腐蚀再膨胀, ...

  7. 【条形码识别】基于计算机视觉实现二维条形码识别含Matlab源码

    1 简介 在信息时代的今天,随着计算机技术的发展,条形码作为一种简单.方便.廉价.高速的信息保存和传输技术,在世界各地应用广泛,是商品进入国际市场的通行证. 本论文的研究基于一种全新的购物理念,即无需 ...

  8. 【人脸识别】基于KL变换人脸识别含Matlab源码

    1 简介 系统的设计是利用奇异值分解确定KL变换系数,并对人脸训练样本和待识别样本进行KL变换,对变换向量进行最小距离判别决策.对ORL人脸数据库的实验结果表明正确识别率随着变换系数维数的增加而增加, ...

  9. Opencv基于改进VGG19的表情识别系统(源码&Fer2013&教程)

    1.研究背景 在深度学习中,传统的卷积神经网络对面部表情特征的提取不充分以及计算参数量较大的问题,导致分类准确率偏低.因此,提出了一种基于改进的VGG19网络的人脸表情识别算法.首先,对数据进行增强如 ...

最新文章

  1. Java_中快速获取系统时间
  2. 搞懂Kafka的这个问题,你离大厂就不远了!
  3. j2se学习中的一些零碎知识点2之基础知识
  4. 数据库相关的系统巡检参考项
  5. Linux下CMake简明教程(三)同一目录下多个源文件
  6. ADO学习(九)如何阅读ADO文档
  7. mongodb php 扩展 linux,CentOS Linux 安装PHP的MongoDB扩展
  8. adt变频器故障代码ol2_误诊实例换来的变频器维修经验
  9. 区块链java语言,基于Java语言构建区块链(一)—— 基本原型
  10. php 多任务,PHP并行多任务研究(笔记)
  11. 导出手机缓存的B站视频或者在PC电脑端下载B站视频到本地
  12. PS批量修改文件大小及类型
  13. 学习虚幻4(一)U3D与UE4的比较
  14. Unable to open shape_predictor_68_face_landmarks.dat
  15. .htaccess rewrite 规则详细说明
  16. @18. 自幂数、水仙花数、四叶玫瑰数等等是什么?
  17. 解决国产电脑微信卡顿问题的脚本
  18. 3、乐趣国学—“色难”
  19. 服务器4个网口只显示2个,服务器4个网口的作用
  20. 【工具】复制别人的CSDN博客文章到本地

热门文章

  1. Numbers创建堆叠柱状图
  2. 解决react@v18的useEffect函数执行两遍的问题
  3. 使用gojs展示设备配置过程
  4. UE4官方课程笔记(1)——游戏设计师的蓝图与游戏玩法
  5. HTML页面如何添加ICO图标?
  6. 计算机丢失xmlLite.dll,电脑总是提醒应用程序错误或DLL C:windowssystem32xmllite.dll为无效的windows映像。。。这是为什么...
  7. Esp32cam通过巴法云订阅控制拍摄照片并上传图云
  8. 缅怀 Delphi,缅怀 Borland
  9. 用Python实现简易音乐播放器(mp3类型)3
  10. STM32单片机串口空闲中断接收不定长数据