训练一个好的卷积神经网络模型进行图像分类不仅需要计算资源还需要很长的时间。特别是模型比较复杂和数据量比较大的时候。普通的电脑动不动就需要训练几天的时间。为了能够快速地训练好自己的花朵图片分类器,我们可以使用别人已经训练好的模型参数,在此基础之上训练我们的模型。这个便属于迁移学习。本文提供训练数据集和代码下载。

原理:卷积神经网络模型总体上可以分为两部分,前面的卷积层和后面的全连接层。卷积层的作用是图片特征的提取,全连接层作用是特征的分类。我们的思路便是在inception-v3网络模型上,修改全连接层,保留卷积层。卷积层的参数使用的是别人已经训练好的,全连接层的参数需要我们初始化并使用我们自己的数据来训练和学习。

上面inception-v3模型图红色箭头前面部分是卷积层,后面是全连接层。我们需要修改修改全连接层,同时把模型的最终输出改为5。

由于这里使用了tensorflow框架,所以,我们需要获取上图红色箭头所在位置的张量BOTTLENECK_TENSOR_NAME(最后一个卷积层激活函数的输出值,个数为2048)以及模型最开始的输入数据的张量JPEG_DATA_TENSOR_NAME。获取这两个张量的作用是,图片训练数据通过JPEG_DATA_TENSOR_NAME张量输入模型,通过BOTTLENECK_TENSOR_NAME张量获取通过卷积层之后的图片特征。

BOTTLENECK_TENSOR_SIZE = 2048
BOTTLENECK_TENSOR_NAME = 'pool_3/_reshape:0'
JPEG_DATA_TENSOR_NAME = 'DecodeJpeg/contents:0'

通过下面的链接下载inception-v3模型,其中包含已经训练好的参数。

模型下载链接:地址

or https://pan.baidu.com/s/1LxBK5annrmiWSXE_jajOJQ

训练数据花朵图片下载:地址

通过下面的代码加载模型,同时获取上面所述的两个张量。

   # 读取已经训练好的Inception-v3模型。with gfile.FastGFile(os.path.join(MODEL_DIR, MODEL_FILE), 'rb') as f:graph_def = tf.GraphDef()graph_def.ParseFromString(f.read())bottleneck_tensor, jpeg_data_tensor = tf.import_graph_def(graph_def, return_elements=[BOTTLENECK_TENSOR_NAME, JPEG_DATA_TENSOR_NAME])

由于我们模型的功能是对五种花进行分类,所以,我们需要修改全连接层,这里,我们只增加一个全连接层。全连接层的输入数据便是BOTTLENECK_TENSOR_NAME张量。

 # 定义一层全链接层with tf.name_scope('final_training_ops'):weights = tf.Variable(tf.truncated_normal([BOTTLENECK_TENSOR_SIZE, n_classes], stddev=0.001))biases = tf.Variable(tf.zeros([n_classes]))logits = tf.matmul(bottleneck_input, weights) + biasesfinal_tensor = tf.nn.softmax(logits)

最后便是定义交叉熵损失函数。模型使用反向传播训练,而训练的参数并不是模型的所有参数,仅仅是全连接层的参数,卷积层的参数是不变的。

    # 定义交叉熵损失函数。cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=ground_truth_input)cross_entropy_mean = tf.reduce_mean(cross_entropy)train_step = tf.train.GradientDescentOptimizer(LEARNING_RATE).minimize(cross_entropy_mean)

那么接下来的是如何给我们的模型输入数据了,这里提供了几个操作数据的函数。由于训练数据集比较小,先把所有的图片通过JPEG_DATA_TENSOR_NAME张量输入模型,然后获取BOTTLENECK_TENSOR_NAME张量的值并保存到硬盘中。在模型训练的时候,从硬盘中读取所保存的BOTTLENECK_TENSOR_NAME张量的值作为全连接层的输入数据。因为一张图片可能会被使用多次。

# 输入图片并获取`BOTTLENECK_TENSOR_NAME`张量的值
def get_or_create_bottleneck(sess, image_lists, label_name, index, category, jpeg_data_tensor, bottleneck_tensor)# 从硬盘中读取`BOTTLENECK_TENSOR_NAME`张量的值,用于训练
def get_or_create_bottleneck(sess, image_lists, label_name, index, category, jpeg_data_tensor, bottleneck_tensor):# 从硬盘中读取`BOTTLENECK_TENSOR_NAME`张量的值,用于测试。
def get_test_bottlenecks(sess, image_lists, n_classes, jpeg_data_tensor, bottleneck_tensor)

不到5分钟就可以训练好我们的模型,精确度还蛮高的。下图是本人运行的结果。

源码地址:https://github.com/liangyihuai/my_tensorflow/tree/master/com/huai/converlution/transfer_learning

迁移学习CNN图像分类模型 - 花朵图片分类相关推荐

  1. Pytorch:利用迁移学习做图像分类

    **Pytorch:利用迁移学习做图像分类** 数据准备 数据扩充 数据加载 迁移学习 训练 验证 推理/分类 在这一篇文章中,我们描述了如何在 pytorch中进行图像分类.我们将使用Caltech ...

  2. 【转】[caffe]深度学习之图像分类模型AlexNet解读

    [caffe]深度学习之图像分类模型AlexNet解读 原文地址:http://blog.csdn.net/sunbaigui/article/details/39938097 本文章已收录于:  深 ...

  3. Xception迁移学习:玉米叶片病害识别分类

    Xception迁移学习:玉米叶片病害识别分类 数据集:来自网上公开的PlantVillage数据集中的玉米叶片部分. 运行环境:Tensorflow深度学习开源框架,选用Python 3.6.12作 ...

  4. 深度学习-第T2周——彩色图片分类

    深度学习-第T2周--彩色图片分类 深度学习-第P1周--实现mnist手写数字识别 一.前言 二.我的环境 三.前期工作 1.导入依赖项并设置GPU 2.导入数据集 3.归一化 4.可视化图片 四. ...

  5. [caffe]深度学习之图像分类模型VGG解读

    一.简单介绍 vgg和googlenet是2014年imagenet竞赛的双雄,这两类模型结构有一个共同特点是go deeper.跟googlenet不同的是.vgg继承了lenet以及alexnet ...

  6. Pytorch模型迁移和迁移学习,导入部分模型参数

    Pytorch模型迁移和迁移学习 目录 Pytorch模型迁移和迁移学习 1. 利用resnet18做迁移学习 2. 修改网络名称并迁移学习 3.去除原模型的某些模块 1. 利用resnet18做迁移 ...

  7. 不懂得如何优化CNN图像分类模型?这有一份综合设计指南请供查阅

    对于计算机视觉任务而言,图像分类是其中的主要任务之一,比如图像识别.目标检测等,这些任务都涉及到图像分类.而卷积神经网络(CNN)是计算机视觉任务中应用最为广泛且最为成功的网络之一.大多数深度学习研究 ...

  8. 迁移学习基础知识(一)——分类及应用

    适合入门:机器学习的明天--迁移学习 一.迁移学习分类: 按照目标域有无标签,迁移学习可以分为监督迁移学习,半监督迁移学习,无监督迁移学习: 按照学习方法分类,迁移学习可以分为基于样本的迁移学习方法( ...

  9. T5,一个探索迁移学习边界的模型

    作者 | Ajit Rajasekharan 译者 | 夕颜 出品 | AI科技大本营(ID:rgznai100) [导读]10月,Google 在<Exploring the Limits o ...

最新文章

  1. Bioinformatics: Assembling Genomes (week 1-2)
  2. 高精度矢量汉字的一种填充方法_使用PS中的钢笔工具制作一只蝴蝶矢量插画
  3. C语言 使用指针计算两个整数的和与差
  4. ORA-08176 错误的一个案例
  5. Java7任务并行执行神器:ForkJoin框架
  6. 移动端,fixed bottom问题
  7. 微软建议的ASP性能优化28条守则
  8. 用计算机打cf,CF能用的特殊符号有什么 CF特殊符号怎么打
  9. 去中心化 去区块链_基于区块链的去中心化应用的四种架构模式候选
  10. pythons实现信号分帧
  11. opencv-python人脸识别初探
  12. win10清理_win10系统怎么一键清理系统垃圾
  13. Win10 安装 Rational Rose
  14. html网页框架分割三部分,Dreamweaver用框架建立网站把浏览器的显示空间分割为几个部分...
  15. python多久可以入门_python自学要多久能学会
  16. IDEA使用教程(三) 功能面板
  17. 一文看懂:网址,URL,域名,IP地址,DNS,域名解析
  18. 多IP服务器怎么样?多IP服务器有什么优势?
  19. 通过PostMessage/SendMessage实现模拟键盘鼠标按键,发送不成功或出现重复按键的可参考本文
  20. 同事能力比你强怎么办

热门文章

  1. python 隐马尔科夫_隐马尔可夫模型原理和python实现
  2. mysql8.0卸载出现问题,Windows环境下MySQL 8.0 的安装、配置与卸载
  3. python扫描端口脚本_Python实现的端口扫描功能示例
  4. mysql json匹配key为数值_干货篇:一篇文章让你——《深入解析MySQL索引原理》
  5. Apache Beam的架构概览
  6. C#如何使用REST接口读写数据
  7. Flume Sinks官网剖析(博主推荐)
  8. Oracle-01033错误处理
  9. springmvc前台String转后台Date
  10. /proc/meminfo详解 = /nmon analysis --MEM