本文主要讲解在现有常用模型基础上,如何微调模型,减少训练时间,同时保持模型检测精度。

首先介绍下Slim这个Google公布的图像分类工具包,可在github链接:modules and examples built with tensorflow 中找到slim包。

上面这个链接目录下主要包含:

official models(这个是用Tensorflow高层API做的例子模型集,建议初学者可尝试);

research models(这个是很多研究者利用tensorflow做的模型集,这个不是官方提供的,是研究者个人在维护的);

samples folder (包含代码片段和小的模型用以表述tensorflow特性,包含以博客形式存在的代码呈现);

而我说的slim工具包就在research文件夹下。


Slim库结构

不仅定义了很多接口,还提供了很多ImageNet数据集上常用的网络结构和预训练模型(包括Alexnet,CycleGAN,DCGAN,VGG16,VGG19,Inception V1~V4,ResNet 50, ResNet 101,MobileNet V1等)。


下面用slim工具包中的文件来对自己的数据集做训练,训练可分为利用已有的模型架构(如常见的VGG,Inception等的卷积,池化这些结构)来全新训练权重文件或是微调权重文件。由于很多已有的imagenet图像数据覆盖面已经很广,基于此训练的网络权重已经能提取大致的目标特征(从低微像素到高维的结构特征),所以可使用fine-tune只训练框架中某些层的权重,当然根据自己数据集做全部权重重新训练的检测效果理论会更好些,需要权衡时间成本和检测精度的需求了;

下面会依据成熟网络结构Incvption V3分别做权重文件的全部重新训练部分重新训练(即fine-tune)来介绍;

(前提是你将slim工具库下载下来,安装了必要的tensorflow等框架;并且根据训练图像制作完成了tfrecord文件)

有关tfrecord训练文件的制作请参考:将图像制作成tfrecord

step1:定义新的datasets数据集文件

在slim/datasets/文件夹下 添加一个python文件,直接复制一份flowers.py,重命名为“satellite.py”(这个名字可根据你实际的数据集名字来更改,我用的是何大神的航拍图数据集)

需要对赋值生成后的satellite.py内容做如下修改:

_FILE_PATTERN = 'flowers_%s_*.tfrecord'

更改为

_FILE_PATTERN = 'satellite_%s_*.tfrecord'      #这个主要是根据你之前制作的tfrecord文件名来改的,我制作的训练文件为satellite_train_00000-of-00002.tfrecord和satellite_train_00001-of-00002.tfrecord,验证文件为satellite_validation_00000-of-00002.tfrecord,satellite_validation_00001-of-00002.tfrecord。

SPLITS_TO_SIZES = {'train': 3320, 'validation': 350}

更改为

SPLITS_TO_SIZES = {'train': 4800, 'validation': 1200}  #这个根据自己训练和验证样本数量来改,我的训练数据是800张图/类,共6类,验证集时200张/类,共6类;

_NUM_CLASSES = 5

更改为

_NUM_CLASSES = 6       #实际训练类别为6类;

还需要对satellite.py文件中的'image/format': tf.FixedLenFeature((), tf.string, default_value='png'),这行代码做更改,由于用的数据集源文件都是XXXX.jpg格式,因此将默认的图像格式转为jpg,更改后为'image/format': tf.FixedLenFeature((), tf.string, default_value='jpg'), 至此,对satellite.py文件完成制作与更改(其源码如下):

satellite.py

# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Provides data for the flowers dataset.The dataset scripts used to create the dataset can be found at:
tensorflow/models/slim/datasets/download_and_convert_flowers.py
"""from __future__ import absolute_import
from __future__ import division
from __future__ import print_functionimport os
import tensorflow as tffrom datasets import dataset_utilsslim = tf.contrib.slim_FILE_PATTERN = 'satellite_%s_*.tfrecord'SPLITS_TO_SIZES = {'train': 4800, 'validation': 1200}_NUM_CLASSES = 6_ITEMS_TO_DESCRIPTIONS = {'image': 'A color image of varying size.','label': 'A single integer between 0 and 4',
}def get_split(split_name, dataset_dir, file_pattern=None, reader=None):"""Gets a dataset tuple with instructions for reading flowers.Args:split_name: A train/validation split name.dataset_dir: The base directory of the dataset sources.file_pattern: The file pattern to use when matching the dataset sources.It is assumed that the pattern contains a '%s' string so that the splitname can be inserted.reader: The TensorFlow reader type.Returns:A `Dataset` namedtuple.Raises:ValueError: if `split_name` is not a valid train/validation split."""if split_name not in SPLITS_TO_SIZES:raise ValueError('split name %s was not recognized.' % split_name)if not file_pattern:file_pattern = _FILE_PATTERNfile_pattern = os.path.join(dataset_dir, file_pattern % split_name)# Allowing None in the signature so that dataset_factory can use the default.if reader is None:reader = tf.TFRecordReaderkeys_to_features = {'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),'image/format': tf.FixedLenFeature((), tf.string, default_value='jpg'),'image/class/label': tf.FixedLenFeature([], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),}items_to_handlers = {'image': slim.tfexample_decoder.Image(),'label': slim.tfexample_decoder.Tensor('image/class/label'),}decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)labels_to_names = Noneif dataset_utils.has_labels(dataset_dir):labels_to_names = dataset_utils.read_label_file(dataset_dir)return slim.dataset.Dataset(data_sources=file_pattern,reader=reader,decoder=decoder,num_samples=SPLITS_TO_SIZES[split_name],items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,num_classes=_NUM_CLASSES,labels_to_names=labels_to_names)

step2:注册数据库

接下来对slim/datasets/dataset_factory.py文件做更改,注册下satellite数据库;修改之处如下(添加了两行红色字体代码):

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from datasets import cifar10
from datasets import flowers
from datasets import imagenet
from datasets import mnist
from datasets import satellite

datasets_map = {
    'cifar10': cifar10,
    'flowers': flowers,
    'imagenet': imagenet,
    'mnist': mnist,
    'satellite': satellite,
    
}

step3:准备训练文件夹

在slim文件夹下新建如下目录文件夹,并将对应的文件放在相应目录下

slim/
    satellite/
              data/
                   satellite_train_00000-of-00002.tfrecord
                   satellite_train_00001-of-00002.tfrecord
                   satellite_validation_00000-of-00002.tfrecord
                   satellite_validation_00001-of-00002.tfrecord
                   label.txt
              pretrained/
                   inception_v3.ckpt
              train_dir/

data文件夹下存放你制作的tfrecord训练测试文件和标签名;

pretrained文件夹下存放官网训练的权重文件;下载地址:http:/!download. tensorflow .org/models/inception _ v3_2016 _ 08 _ 28.tar.gz      

train_dir文件夹下存放你训练得到的模型和日志;

step4-1:在现有模型结构上fine-tune

开始训练,在slim文件夹下,运行如下指令可开始训练(主要是训练逻辑层):

python train_image_classifier.py \--train_dir=satellite/train_dir \--dataset_name=satellite \--dataset_split_name=train \--dataset_dir=satellite/data \--model_name=inception_v3 \--checkpoint_path=satellite/pretrained/inception_v3.ckpt \--checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \--trainable_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \--max_number_of_steps=100000 \--batch_size=32 \--learning_rate=0.001 \--learning_rate_decay_type=fixed \--save_interval_secs=300 \--save_summaries_secs=2 \--log_every_n_steps=10 \--optimizer=rmsprop \--weight_decay=0.00004

命令参数解析如下:

--trainable_ scopes=Inception V3/Logits,InceptionV3/ AuxLogits :首先来解 释参数trainable_scopes 的作用,因为非常重要。 trainable_scopes 规定了在模型中fine-tune变量的范围 。 这里的设定表示只对 InceptionV3/Logits, Inception V3/ AuxLogits 两个变量进行微调,其他变量都保持不动 。 Inception V3/Logits,Inception V3/ AuxLogits 就相当于在网络中的 fc8 ,它们是 Inception V3的“末端层” 。 如果不设定 trainable_scopes , 就会对模型中所有的参数进行训练。

• --train_dir=satellite/train_dir:表明会在 satellite/train_dir目录下保存日志和checkpoint。

--dataset_name=satellite、 --dataset_split_ name=train: 指定训练的数据集。

--dataset_dit=satellite/data:指定训练数据集保存的位置。

--model_ name=inception _ v3 :使用的模型名称。

--checkpoint_path=satellite/pretrained/inception_v3.ckpt:预训练模型的保存位置。

--checkpoint_exclude_scopes=Inception V3/Logits,InceptionV3/ AuxLogits : 在恢复预训练模型时,不恢复这两层。正如之前所说,这两层是 Inception V3 模型的末端层,对应着 ImageNet 数据集的 1000 类,和相当前的数据集不符,因此不要去恢复它。

--max_number_of_steps 100000:最大的执行步数。

--batch_size=32:每步使用的 batch 数量。

--learning_rate=0.001 : 学习率。

• --learning_rate_decay_type=fixed:学习率是否自动下降,此处使用固定的学习率。

• --save_interval_secs=300:每隔 300s,程序会把当前模型保存到train_dir中。 此处就是目录 satellite/train_dir。

• --save_summaries_secs=2:每隔 2s,就会将日志写入到 train_dir 中。可以用 TensorBoard 查看该日志。此处为了方便观察,设定的时间间隔较多,实际训练时,为了性能考虑,可以设定较长的时间间隔。

• --log_every_n_steps=10:每隔10步,就会在屏上打出训练信息。

--optimizer=msprop:表示选定的优化器。

• --weight_decay=0.00004:选定的 weight_decay 值。 即模型中所高参数的 二次正则化超参数。

以上命令是只训练末端层 InceptionV3/Logits,Inception V3/ AuxLogits ,还 可以使用以下命令对所高层进行训练:

step4-2:训练整个模型权重数据

使用以下命令对所有层进行训练:
去掉 了--trainable_scopes 参数

python train_image_classifier.py \--train_dir=satellite/train_dir \--dataset_name=satellite \--dataset_split_name=train \--dataset_dir=satellite/data \--model_name=inception_v3 \--checkpoint_path=satellite/pretrained/inception_v3.ckpt \--checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \--max_number_of_steps=100000 \--batch_size=32 \--learning_rate=0.001 \--learning_rate_decay_type=fixed \--save_interval_secs=300 \--save_summaries_secs=2 \--log_every_n_steps=10 \--optimizer=rmsprop \--weight_decay=0.00004

当train_image_classifier.py程序启动后,如果训练文件夹(即satellite/train_dir)里没再已经保存的模型,就会加载 checkpoint_path 中的预训练模型,紧接着,程序会把初始模型保存到 train_dir中 ,命名为 model.ckpt-0, 0 表示第 0 步。 这之后,每隔 5min (参数一save_interval_secs=300 指定了每隔 300s 保存一次,即 5min )。 程序还会把当前模型保存到同样的文件夹中 , 命名恪式和第一次保存的格式一样。 因为模型比较大,程序只会保留最新的 5 个模型。
此外,如果中断了程序并再次运行,程序会首先检查 train_dir 中有无已经保存的模型,如果有,就不会去加载 checkpoint_path 中的预训练模型, 而是直接加载 train_dir 中已经训练好的模型,并以此为起点进行训练。 Slim 之所以这样设计,是为了在微调网络的时候,可以方便地按阶段手动调整学习率等参数。

至此用slim工具包做fine-tune或重新训练的步骤就完成了。


相似文章参考:https://blog.csdn.net/chaipp0607/article/details/74139895

【深度学习-微调模型】使用Tensorflow Slim fine-tune(微调)模型相关推荐

  1. 福利 | Python、深度学习、机器学习、TensorFlow 好书推荐

    在上次的送书活动中,营长做了个调查问卷,结果显示大家更喜欢深度学习.Python以及TensorFlow方面的书,所以这期送书活动一并满足大家.本期图书选自人民邮电出版社图书,包括:近期AI圈儿比较流 ...

  2. 【深度学习】Keras和Tensorflow框架使用区别辨析

    [深度学习]Keras和Tensorflow框架使用区别辨析 文章目录 1 概述 2 Keras简介 3 Tensorflow简介 4 使用tensorflow的几个小例子 5 Keras搭建CNN ...

  3. DL:深度学习框架Pytorch、 Tensorflow各种角度对比

    DL:深度学习框架Pytorch. Tensorflow各种角度对比 目录 先看两个框架实现同样功能的代码 1.Pytorch.Tensorflow代码比较 2.Tensorflow(数据即是代码,代 ...

  4. 深度学习(16)TensorFlow高阶操作五: 张量限幅

    深度学习(16)TensorFlow高阶操作五: 张量限幅 1. clip_by_value 2. relu 3. clip_by_norm 4. Gradient clipping 5. 梯度爆炸实 ...

  5. 深度学习(6)TensorFlow基础操作二: 创建Tensor

    深度学习(6)TensorFlow基础操作二: 创建Tensor 一. 创建方式 1. From Numpy,List 2. zeros,ones (1) tf.zeros() (2) tf.zero ...

  6. 深度学习与 Spark 和 TensorFlow

    2019独角兽企业重金招聘Python工程师标准>>> 深度学习与 Spark 和 TensorFlow 在过去几年中,神经网络领域的发展非常迅猛,也是现在图像识别和自动翻译领域中最 ...

  7. 深度学习 第三章 tensorflow手写数字识别

    深度学习入门视频-唐宇迪 (笔记加自我整理) 深度学习 第三章 tensorflow手写数字识别 1.tensorflow常见操作 这里使用的是tensorflow1.x版本,tensorflow基本 ...

  8. CV:Win10下深度学习框架安装之Tensorflow/tensorflow_gpu+Cuda+Cudnn(最清楚/最快捷)之详细攻略(图文教程)

    CV:Win10下深度学习框架安装之Tensorflow/tensorflow_gpu+Cuda+Cudnn(最清楚/最快捷)之详细攻略(图文教程) 导读 本人在Win10下安装深度学习框架Tenso ...

  9. 深度学习框架Caffe, MXNet, TensorFlow, Torch, CNTK性能测试报告

    香港浸会大学对于深度学习框架Caffe, MXNet, TensorFlow, Torch, CNTK性能测试报告 http://dlbench.comp.hkbu.edu.hk/

  10. 深度学习(17)TensorFlow高阶操作六: 高阶OP

    深度学习(17)TensorFlow高阶操作六: 高阶OP 1. Where(tensor) 2. where(cond, A, B) 3. 1-D scatter_nd 4. 2-D scatter ...

最新文章

  1. 互联网协议 — IP 网络的 QoS 服务模型
  2. 将RGB值转换为灰度值的简单算法(转)
  3. java jar包图片_jar包的图片不显示 求解
  4. 【Unity3D技巧】一个简单的Unity-UI框架的实现
  5. sqlserver 查询中使用Union或Union All
  6. AttributeError: module 're' has no attribute 'sub'
  7. linux内存分配堆栈数据段代码段,linux – LD_PRELOAD堆栈和数据段内存分配
  8. 惠普服务器eth0的位置,HPUX下定位网卡位置
  9. 户口所在地代码查询_毕业生如何查询档案存放地及存档问题?
  10. 如何将SQL Profiler Trace读入到SQL的表中?
  11. sola病毒doc变exe批量恢复方法
  12. 强化学习实战(九) Linux下配置星际争霸Ⅱ环境
  13. 选中物体高亮显示(MR开发日志)
  14. 又学一招——Chrome 插件安装技巧
  15. 输入一个整数,判断它是几位数
  16. Datawhale组队学习21期_学术前沿趋势分析Task2_论文作者统计
  17. python与金融工程的区别_科研进阶 | 纽约大学 | 金融工程、量化金融、商业分析:Python金融工程分析...
  18. 开放经济中的货币-中国视角下的宏观经济
  19. linux igmp 属于那层协议,igmp协议属于哪一层
  20. 播放器android版最新官方版下载安装,万能播放器安卓版下载

热门文章

  1. 软考证书有效期多久?
  2. 《图解机器学习-杉山将著》读书笔记---CH3
  3. 让 wls 拥有可视化功能
  4. 加密芯片具体是要保护什么
  5. 会员情况下,如果购物总金额大于等于200;则享受会员价75折,小于200,打八折;如果不是会员,,购物金额大于等于100,享受打九折优惠
  6. Yandex – 俄罗斯无限免费空间、免费相册、免费邮箱、免费网盘
  7. IDEA 使用VUE框架
  8. Python操作Redis
  9. python难学吗-python难学吗
  10. 朝鲜国家黑客组织Lazarus 被指攻击IT供应链