参考博客:【TensorFlow】迁移学习(使用Inception-v3),非常感谢这个博主的这篇博客,我这篇博客的框架来自于这位博主,然后我针对评论区的问题以及自己的实践增加了一些内容以及解答。

github:代码

知识储备

  • 迁移学习是将一个数据集上训练好的网络模型快速转移到另外一个数据集上,可以保留训练好的模型中倒数第一层之前的所有参数,替换最后一层即可,在最后层之前的网络层称之为瓶颈层。
  • 迁移学习,首先尝试了Inception-V3,直接使用pool_3层的输出,接上一个全连接的分类层,使用softmax进行分类,使用Inception-V3的默认输入。

一、准备工作

1、数据集下载
2、Inception-v3模型下载

  • 官方下载地址
    【科学上网】

  • 百度网盘
    提取码:zmrl

  • 数据集解压后的目录:

flower_photos/daisy/dandelion/roses/sunflowers/tulips/

数据集文件夹包含5个子文件,每一个子文件夹的名称为一种花的名称,代表了不同的类别。平均每一种花有734张图片,每一张图片都是RGB色彩模式,大小也不相同,程序将直接处理没有整理过的图像数据。

  • 模型解压后的目录:
imagenet_comp_graph_label_strings.txt
tensorflow_inception_graph.pb

3、目录结构

  • 需要自行创建transfer-learning/data/tmp/bottlenec/model/train.py/eval.py文件。
transfer-learning/data/flower_photos/......tmp/bottleneck/......model/imagenet_comp_graph_label_strings.txttensorflow_inception_graph.pbtrain.pyeval.py

二、代码实现

需要写两个文件:1、train.py 2、eval.py

1、train.py

python3 train.pyimport glob
import os.path
import random
import numpy as np
import tensorflow as tf
from tensorflow.python.platform import gfile# 数据参数
MODEL_DIR = 'model/'  # inception-v3模型的文件夹
MODEL_FILE = 'tensorflow_inception_graph.pb'  # inception-v3模型文件名
CACHE_DIR = 'data/tmp/bottleneck'  # 图像的特征向量保存地址
INPUT_DATA = 'data/flower_photos'  # 图片数据文件夹
VALIDATION_PERCENTAGE = 10  # 验证数据的百分比
TEST_PERCENTAGE = 10  # 测试数据的百分比# inception-v3模型参数
BOTTLENECK_TENSOR_SIZE = 2048  # inception-v3模型瓶颈层的节点个数
BOTTLENECK_TENSOR_NAME = 'pool_3/_reshape:0'  # inception-v3模型中代表瓶颈层结果的张量名称
JPEG_DATA_TENSOR_NAME = 'DecodeJpeg/contents:0'  # 图像输入张量对应的名称# 神经网络的训练参数
LEARNING_RATE = 0.01
STEPS = 1000
BATCH = 100
CHECKPOINT_EVERY = 100
NUM_CHECKPOINTS = 5# 从数据文件夹中读取所有的图片列表并按训练、验证、测试分开
def create_image_lists(validation_percentage, test_percentage):result = {}  # 保存所有图像。key为类别名称。value也是字典,存储了所有的图片名称sub_dirs = [x[0] for x in os.walk(INPUT_DATA)]  # 获取所有子目录is_root_dir = True  # 第一个目录为当前目录,需要忽略# 分别对每个子目录进行操作for sub_dir in sub_dirs:if is_root_dir:is_root_dir = Falsecontinue# 获取当前目录下的所有有效图片extensions = {'jpg', 'jpeg', 'JPG', 'JPEG'}file_list = []  # 存储所有图像dir_name = os.path.basename(sub_dir)  # 获取路径的最后一个目录名字for extension in extensions:file_glob = os.path.join(INPUT_DATA, dir_name, '*.' + extension)file_list.extend(glob.glob(file_glob))if not file_list:continue# 将当前类别的图片随机分为训练数据集、测试数据集、验证数据集label_name = dir_name.lower()  # 通过目录名获取类别的名称training_images = []testing_images = []validation_images = []for file_name in file_list:base_name = os.path.basename(file_name)  # 获取该图片的名称chance = np.random.randint(100)  # 随机产生100个数代表百分比if chance < validation_percentage:validation_images.append(base_name)elif chance < (validation_percentage + test_percentage):testing_images.append(base_name)else:training_images.append(base_name)# 将当前类别的数据集放入结果字典result[label_name] = {'dir': dir_name,'training': training_images,'testing': testing_images,'validation': validation_images}# 返回整理好的所有数据return result# 通过类别名称、所属数据集、图片编号获取一张图片的地址
def get_image_path(image_lists, image_dir, label_name, index, category):label_lists = image_lists[label_name]  # 获取给定类别中的所有图片category_list = label_lists[category]  # 根据所属数据集的名称获取该集合中的全部图片mod_index = index % len(category_list)  # 规范图片的索引base_name = category_list[mod_index]  # 获取图片的文件名sub_dir = label_lists['dir']  # 获取当前类别的目录名full_path = os.path.join(image_dir, sub_dir, base_name)  # 图片的绝对路径return full_path# 通过类别名称、所属数据集、图片编号获取特征向量值的地址
def get_bottleneck_path(image_lists, label_name, index, category):return get_image_path(image_lists, CACHE_DIR, label_name, index,category) + '.txt'# 使用inception-v3处理图片获取特征向量
def run_bottleneck_on_image(sess, image_data, image_data_tensor,bottleneck_tensor):bottleneck_values = sess.run(bottleneck_tensor,{image_data_tensor: image_data})bottleneck_values = np.squeeze(bottleneck_values)  # 将四维数组压缩成一维数组return bottleneck_values# 获取一张图片经过inception-v3模型处理后的特征向量
def get_or_create_bottleneck(sess, image_lists, label_name, index, category,jpeg_data_tensor, bottleneck_tensor):# 获取一张图片对应的特征向量文件的路径label_lists = image_lists[label_name]sub_dir = label_lists['dir']sub_dir_path = os.path.join(CACHE_DIR, sub_dir)if not os.path.exists(sub_dir_path):os.makedirs(sub_dir_path)bottleneck_path = get_bottleneck_path(image_lists, label_name, index,category)# 如果该特征向量文件不存在,则通过inception-v3模型计算并保存if not os.path.exists(bottleneck_path):image_path = get_image_path(image_lists, INPUT_DATA, label_name, index,category)  # 获取图片原始路径image_data = gfile.FastGFile(image_path, 'rb').read()  # 获取图片内容bottleneck_values = run_bottleneck_on_image(sess, image_data, jpeg_data_tensor,bottleneck_tensor)  # 通过inception-v3计算特征向量# 将特征向量存入文件bottleneck_string = ','.join(str(x) for x in bottleneck_values)with open(bottleneck_path, 'w') as bottleneck_file:bottleneck_file.write(bottleneck_string)else:# 否则直接从文件中获取图片的特征向量with open(bottleneck_path, 'r') as bottleneck_file:bottleneck_string = bottleneck_file.read()bottleneck_values = [float(x) for x in bottleneck_string.split(',')]# 返回得到的特征向量return bottleneck_values# 随机获取一个batch图片作为训练数据
def get_random_cached_bottlenecks(sess, n_classes, image_lists, how_many,category, jpeg_data_tensor,bottleneck_tensor):bottlenecks = []ground_truths = []for _ in range(how_many):# 随机一个类别和图片编号加入当前的训练数据label_index = random.randrange(n_classes)label_name = list(image_lists.keys())[label_index]image_index = random.randrange(65535)bottleneck = get_or_create_bottleneck(sess, image_lists, label_name, image_index, category,jpeg_data_tensor, bottleneck_tensor)ground_truth = np.zeros(n_classes, dtype=np.float32)ground_truth[label_index] = 1.0bottlenecks.append(bottleneck)ground_truths.append(ground_truth)return bottlenecks, ground_truths# 获取全部的测试数据
def get_test_bottlenecks(sess, image_lists, n_classes, jpeg_data_tensor,bottleneck_tensor):bottlenecks = []ground_truths = []label_name_list = list(image_lists.keys())# 枚举所有的类别和每个类别中的测试图片for label_index, label_name in enumerate(label_name_list):category = 'testing'for index, unused_base_name in enumerate(image_lists[label_name][category]):bottleneck = get_or_create_bottleneck(sess, image_lists, label_name, index, category,jpeg_data_tensor, bottleneck_tensor)ground_truth = np.zeros(n_classes, dtype=np.float32)ground_truth[label_index] = 1.0bottlenecks.append(bottleneck)ground_truths.append(ground_truth)return bottlenecks, ground_truthsdef main(_):# 读取所有的图片image_lists = create_image_lists(VALIDATION_PERCENTAGE, TEST_PERCENTAGE)n_classes = len(image_lists.keys())with tf.Graph().as_default() as graph:# 读取训练好的inception-v3模型with tf.gfile.GFile(os.path.join(MODEL_DIR, MODEL_FILE), 'rb') as f:graph_def = tf.GraphDef()graph_def.ParseFromString(f.read())# 加载inception-v3模型,并返回数据输入张量和瓶颈层输出张量bottleneck_tensor, jpeg_data_tensor = tf.import_graph_def(graph_def,return_elements=[BOTTLENECK_TENSOR_NAME, JPEG_DATA_TENSOR_NAME])# 定义新的神经网络输入bottleneck_input = tf.placeholder(tf.float32, [None, BOTTLENECK_TENSOR_SIZE],name='BottleneckInputPlaceholder')# 定义新的标准答案输入ground_truth_input = tf.placeholder(tf.float32, [None, n_classes], name='GroundTruthInput')# 定义一层全连接层解决新的图片分类问题with tf.name_scope('final_training_ops'):weights = tf.Variable(tf.truncated_normal([BOTTLENECK_TENSOR_SIZE, n_classes], stddev=0.1))biases = tf.Variable(tf.zeros([n_classes]))logits = tf.matmul(bottleneck_input, weights) + biasesfinal_tensor = tf.nn.softmax(logits)# 定义交叉熵损失函数cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=ground_truth_input)cross_entropy_mean = tf.reduce_mean(cross_entropy)train_step = tf.train.GradientDescentOptimizer(LEARNING_RATE).minimize(cross_entropy_mean)# 计算正确率with tf.name_scope('evaluation'):correct_prediction = tf.equal(tf.argmax(final_tensor, 1), tf.argmax(ground_truth_input, 1))evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))# 训练过程with tf.Session(graph=graph) as sess:init = tf.global_variables_initializer().run()# 模型和摘要的保存目录import timetimestamp = str(int(time.time()))out_dir = os.path.abspath(os.path.join(os.path.curdir, 'runs', timestamp))print('\nWriting to {}\n'.format(out_dir))# 损失值和正确率的摘要loss_summary = tf.summary.scalar('loss', cross_entropy_mean)acc_summary = tf.summary.scalar('accuracy', evaluation_step)# 训练摘要train_summary_op = tf.summary.merge([loss_summary, acc_summary])train_summary_dir = os.path.join(out_dir, 'summaries', 'train')train_summary_writer = tf.summary.FileWriter(train_summary_dir,sess.graph)# 开发摘要dev_summary_op = tf.summary.merge([loss_summary, acc_summary])dev_summary_dir = os.path.join(out_dir, 'summaries', 'dev')dev_summary_writer = tf.summary.FileWriter(dev_summary_dir, sess.graph)# 保存检查点checkpoint_dir = os.path.abspath(os.path.join(out_dir, 'checkpoints'))checkpoint_prefix = os.path.join(checkpoint_dir, 'model')if not os.path.exists(checkpoint_dir):os.makedirs(checkpoint_dir)saver = tf.train.Saver(tf.global_variables(), max_to_keep=NUM_CHECKPOINTS)for i in range(STEPS):# 每次获取一个batch的训练数据train_bottlenecks, train_ground_truth = get_random_cached_bottlenecks(sess, n_classes, image_lists, BATCH, 'training',jpeg_data_tensor, bottleneck_tensor)_, train_summaries = sess.run([train_step, train_summary_op],feed_dict={bottleneck_input: train_bottlenecks,ground_truth_input: train_ground_truth})# 保存每步的摘要train_summary_writer.add_summary(train_summaries, i)# 在验证集上测试正确率if i % 100 == 0 or i + 1 == STEPS:validation_bottlenecks, validation_ground_truth = get_random_cached_bottlenecks(sess, n_classes, image_lists, BATCH, 'validation',jpeg_data_tensor, bottleneck_tensor)validation_accuracy, dev_summaries = sess.run([evaluation_step, dev_summary_op],feed_dict={bottleneck_input: validation_bottlenecks,ground_truth_input: validation_ground_truth})print('Step %d : Validation accuracy on random sampled %d examples = %.1f%%'% (i, BATCH, validation_accuracy * 100))# 每隔checkpoint_every保存一次模型和测试摘要if i % CHECKPOINT_EVERY == 0:dev_summary_writer.add_summary(dev_summaries, i)path = saver.save(sess, checkpoint_prefix, global_step=i)print('Saved model checkpoint to {}\n'.format(path))# 最后在测试集上测试正确率test_bottlenecks, test_ground_truth = get_test_bottlenecks(sess, image_lists, n_classes, jpeg_data_tensor, bottleneck_tensor)test_accuracy = sess.run(evaluation_step,feed_dict={bottleneck_input: test_bottlenecks,ground_truth_input: test_ground_truth})print('Final test accuracy = %.1f%%' % (test_accuracy * 100))# 保存标签output_labels = os.path.join(out_dir, 'labels.txt')with tf.gfile.FastGFile(output_labels, 'w') as f:keys = list(image_lists.keys())for i in range(len(keys)):keys[i] = '%2d -> %s' % (i, keys[i])f.write('\n'.join(keys) + '\n')if __name__ == '__main__':tf.app.run()

训练结果:

Writing to D:\New Scenery-AI\detection_flowers\runs\1548057429Step 0 : Validation accuracy on random sampled 100 examples = 19.0%
Saved model checkpoint to D:\New Scenery-AI\detection_flowers\runs\1548057429\checkpoints\model-0Step 100 : Validation accuracy on random sampled 100 examples = 75.0%
Saved model checkpoint to D:\New Scenery-AI\detection_flowers\runs\1548057429\checkpoints\model-100Step 200 : Validation accuracy on random sampled 100 examples = 81.0%
Saved model checkpoint to D:\New Scenery-AI\detection_flowers\runs\1548057429\checkpoints\model-200Step 300 : Validation accuracy on random sampled 100 examples = 78.0%
Saved model checkpoint to D:\New Scenery-AI\detection_flowers\runs\1548057429\checkpoints\model-300Step 400 : Validation accuracy on random sampled 100 examples = 87.0%
Saved model checkpoint to D:\New Scenery-AI\detection_flowers\runs\1548057429\checkpoints\model-400Step 500 : Validation accuracy on random sampled 100 examples = 90.0%
Saved model checkpoint to D:\New Scenery-AI\detection_flowers\runs\1548057429\checkpoints\model-500Step 600 : Validation accuracy on random sampled 100 examples = 90.0%
Saved model checkpoint to D:\New Scenery-AI\detection_flowers\runs\1548057429\checkpoints\model-600Step 700 : Validation accuracy on random sampled 100 examples = 89.0%
Saved model checkpoint to D:\New Scenery-AI\detection_flowers\runs\1548057429\checkpoints\model-700Step 800 : Validation accuracy on random sampled 100 examples = 93.0%
Saved model checkpoint to D:\New Scenery-AI\detection_flowers\runs\1548057429\checkpoints\model-800Step 900 : Validation accuracy on random sampled 100 examples = 83.0%
Saved model checkpoint to D:\New Scenery-AI\detection_flowers\runs\1548057429\checkpoints\model-900Step 999 : Validation accuracy on random sampled 100 examples = 82.0%
Final test accuracy = 85.6%

2、eval.py

import tensorflow as tf
import numpy as np# 模型目录
CHECKPOINT_DIR = './runs/1548061861/checkpoints'
INCEPTION_MODEL_FILE = 'model/tensorflow_inception_graph.pb'# inception-v3模型参数
BOTTLENECK_TENSOR_NAME = 'pool_3/_reshape:0'  # inception-v3模型中代表瓶颈层结果的张量名称
JPEG_DATA_TENSOR_NAME = 'DecodeJpeg/contents:0'  # 图像输入张量对应的名称# 测试数据
file_path = './data/flower_photos/roses/295257304_de893fc94d.jpg'
# file_path = './data/flower_photos/roses/12240303_80d87f77a3_n.jpg'
# file_path = './data/flower_photos/dandelion/7355522_b66e5d3078_m.jpg'
# file_path = './data/flower_photos/dandelion/16159487_3a6615a565_n.jpg'
# file_path = './data/flower_photos/sunflowers/6953297_8576bf4ea3.jpg'
# file_path = './data/flower_photos/sunflowers/40410814_fba3837226_n.jpg'
# file_path = './data/flower_photos/tulips/11746367_d23a35b085_n.jpg'
y_test = [1]# 读取数据
image_data = tf.gfile.GFile(file_path, 'rb').read()# 评估
checkpoint_file = tf.train.latest_checkpoint(CHECKPOINT_DIR)
with tf.Graph().as_default() as graph:with tf.Session().as_default() as sess:# 读取训练好的inception-v3模型with tf.gfile.GFile(INCEPTION_MODEL_FILE, 'rb') as f:graph_def = tf.GraphDef()graph_def.ParseFromString(f.read())# 加载inception-v3模型,并返回数据输入张量和瓶颈层输出张量bottleneck_tensor, jpeg_data_tensor = tf.import_graph_def(graph_def,return_elements=[BOTTLENECK_TENSOR_NAME, JPEG_DATA_TENSOR_NAME])# 使用inception-v3处理图片获取特征向量bottleneck_values = sess.run(bottleneck_tensor,{jpeg_data_tensor: image_data})# 将四维数组压缩成一维数组,由于全连接层输入时有batch的维度,所以用列表作为输入bottleneck_values = [np.squeeze(bottleneck_values)]# 加载元图和变量saver = tf.train.import_meta_graph('{}.meta'.format(checkpoint_file))saver.restore(sess, checkpoint_file)# 通过名字从图中获取输入占位符input_x = graph.get_operation_by_name('BottleneckInputPlaceholder').outputs[0]# 我们想要评估的tensorspredictions = graph.get_operation_by_name('evaluation/ArgMax').outputs[0]# 收集预测值all_predictions = []all_predictions = sess.run(predictions, {input_x: bottleneck_values})# 如果提供了标签则打印正确率
if y_test is not None:correct_predictions = float(sum(all_predictions == y_test))print('\nTotal number of test examples: {}'.format(len(y_test)))print('Accuracy: {:g}'.format(correct_predictions / float(len(y_test))))

结果:

python3 eval.pyTotal number of test examples: 1
Accuracy: 1

注意:

问:更换了图片,为什么accuracy总是为0的问题?怎么解决?
答:\detection_flowers\runs\1548061861,在你自己的目录下去找一个label.txt的文件
0 -> tulips 1 -> roses 2 -> daisy 3 -> dandelion 4 -> sunflowers
这个标签对应的就是你要测试的图片的标签,如果你选用的是roses,那么你的标签就为1。y_test = [1],1是标签,意思就是测试的图片是这一类,也就是roses 。如果你的标签与你的图片不是对应的,那么就会出现一直是0的情况,意思就是测试错误。

结束。希望对你有用~

【Inception-v3模型】迁移学习 实战训练 花朵种类识别相关推荐

  1. 深度学习100例 | 第33天:迁移学习-实战案例教程

    在本教程中,你将学习如何使用迁移学习通过预训练网络对猫和狗的图像进行分类. 预训练模型是一个之前基于大型数据集(通常是大型图像分类任务)训练的已保存网络. 迁移学习通常应用在数据集过少以至于无法有效完 ...

  2. Deep Learning:基于pytorch搭建神经网络的花朵种类识别项目(内涵完整文件和代码)—超详细完整实战教程

    基于pytorch的深度学习花朵种类识别项目完整教程(内涵完整文件和代码) 相关链接:: 超详细--CNN卷积神经网络教程(零基础到实战) 大白话pytorch基本知识点及语法+项目实战 文章目录 基 ...

  3. 基于pytorch搭建神经网络的花朵种类识别(深度学习)

    基于pytorch搭建神经网络的花朵种类识别(深度学习) 文章目录 基于pytorch搭建神经网络的花朵种类识别(深度学习) 一.知识点 1.特征提取.神经元逐层判断 2.中间层(隐藏层) 3.学习权 ...

  4. Xception迁移学习:玉米叶片病害识别分类

    Xception迁移学习:玉米叶片病害识别分类 数据集:来自网上公开的PlantVillage数据集中的玉米叶片部分. 运行环境:Tensorflow深度学习开源框架,选用Python 3.6.12作 ...

  5. 【毕业设计_课程设计】基于深度学习网络模型训练的车型识别系统

    文章目录 0 项目说明 1 简介 2 模型训练精度 3 扫一扫识别功能 4 技术栈 5 模型训练 6 最后 0 项目说明 基于深度学习网络模型训练的车型识别系统 提示:适合用于课程设计或毕业设计,工作 ...

  6. 迁移学习实战 | 快速训练残差网络 ResNet-101,完成图像分类与预测,精度高达 98%!...

    作者 | AI 菌 出品 | CSDN博客 头图 | CSDN付费下载自视觉中国 前言 笔者在实现ResNet的过程中,由于电脑性能原因,不得不选择层数较少的ResNet-18进行训练.但是很快发现, ...

  7. 手动搭建的VGG16网络结构训练数据和使用ResNet50微调(迁移学习)训练数据对比(图像预测+前端页面显示)

    文章目录 1.VGG16训练结果: 2.微调ResNet50之后的训练结果: 3.结果分析: 4.实验效果: (1)VGG16模型预测的结果: (2)在ResNet50微调之后预测的效果: 5.相关代 ...

  8. EnforceLearning:迁移学习-监督训练与非监督训练

    前言 CNN刷分ImageNet以来,迁移学习已经得到广泛的应用,不过使用ImageNet预训练模型迁移到特定数据集是一个全集到子集的迁移,不是标准定义的迁移学习(模型迁移),而是"模型移动 ...

  9. 【pytorch】MobileNetV2迁移学习+可视化+训练数据保存

    一.前言 由于写论文,不单单需要可视化数据,最好能将训练过程的完整数据全部保存下来.所以,我又又又写了篇迁移学习的文章,主要的改变是增加了训练数据记录的模块,可以将训练全过程的数据记录为项目路径下的E ...

最新文章

  1. 20180829-Java多线程编程
  2. Python实训day09pm【Python处理Excel实际应用】
  3. 分布与并行计算—生命游戏(Java)
  4. my batis plus 小数没有0_大黄蜂3号Plus,妈咪保贝的强劲对手!
  5. html弹出层全覆盖滚动条,JS弹出层遮罩,隐藏背景页面滚动条细节优化分析
  6. JavaWeb JDBC初步连接和JDBC连接规范化
  7. CoreBluetooth Central模式 Swift版
  8. ⚡自组织映射(SOM)神经网络⚡Python实现 |Python技能树征题
  9. ubuntu报错:RuntimeError : unexcepted EOF, excepted 2599001 more bytes. The file might be corrupted.
  10. 智能优化算法:斑点鬣狗优化算法-附代码
  11. GitHub界面各个页签作用
  12. python存数据库c读数据库喷码加工_python图片文字识别
  13. 【小蜜蜂蓝桥杯笔记】DS18B20温度传感器的使用
  14. 化学公式编辑器怎么画聚合物?
  15. 腾讯,竞争力 和 用户体验
  16. 已知抛物线与直线相交两点和抛物线顶点,求抛物线和直线所围成的面积?
  17. 这才是真正的Git——Git内部原理了解
  18. navicat 链接mysql异常 2005 - Unknown MySQL server host ‘xxxxxxxxx‘(11001)
  19. 聊聊旷厂黑科技 | 室内视觉定位导航,补齐导航的最后一块拼图
  20. ccd自动检测机优点是什么?

热门文章

  1. GOF(五)-原型模式【推荐】
  2. java isterminated_Java线程池,isShutDown、isTerminated的作用与区别
  3. SwiftUI iPadOS如何实现快捷键功能 KeyboardShortcut (教程含源码)
  4. 电脑鼠标不能正常使用
  5. VT系列二:检测是否支持虚拟化
  6. mysql45讲学习笔记
  7. ODOO longpoll解析
  8. Idea工程中,找不到JDK的类
  9. 【2个月录用】网络安全、人工智能、物联网、数字研究领域,1区,SCI在检,正刊
  10. Java求ijk+kji=1534有几种方法(for循环用法巩固)(数据都是可以自行修改的,本题是以1534为例)