1.训练图片分类模型的三种方法

(1).从无到有,先确定好算法框架,准备好需要训练的数据集,从头开始训练,参数一开始也是初始化的随机值,一个批次一个批次地进行训练。

(2).准备好已经训练好的模型,权值参数也都已经确定,只训练最后一层,因为前面的参数都是经过大量图片的训练来的,所以参数都比较好,比如卷积层主要的作用的对图像特征的提取,我们要做自己的分类模型的话也得对图像进行特征提取,做特征提取的话直接使用训练好的权值也行。

(3).跟2差不多,不同的是都之前的参数也做微调。

2.retrain图片分类模型

(1). https://github.com/tensorflow/tensorflow 下载官方包 有一些官方提供的案例,里面有后面要用到的retrain.py文件。

(2).下载图片集

网址:http://www.robots.ox.ac.uk/~vgg/data/

(3)然后写批处理文件

activate py3 ^
python E:/graduate_student/deep_learning/tensorflow-master/tensorflow-master/tensorflow/examples/image_retraining/retrain.py ^
--bottleneck_dir bottleneck ^
--how_many_training_steps 200 ^
--model_dir D:/software/mycodes/python35/py3/inception_model/ ^
--output_graph output_graph.pb ^
--output_labels output_labels.txt ^
--image_dir imagedata/
pause

bottleneck 瓶颈,图片预处理的时候把这个值算出来

output_graph 输出训练好的模型到当前文件夹

output_labels 输出训练好的标签到当前文件夹

image的格式需要里面文件夹的名字代表分类类型。(里面图片不能由大写字母也不能由中文!!)

运行批处理文件,得到:

(4)测试训练好的模型:

代码:

# coding: utf-8# In[1]:import tensorflow as tf
import os
import numpy as np
import re
from PIL import Image
import matplotlib.pyplot as plt# In[2]:lines = tf.gfile.GFile('E:/graduate_student/deep_learning/a-tensorflow/9/retain/output_labels.txt').readlines()
uid_to_human = {}
#一行一行读取数据
for uid,line in enumerate(lines) :#去掉换行符line=line.strip('\n')uid_to_human[uid] = linedef id_to_string(node_id):if node_id not in uid_to_human:return ''return uid_to_human[node_id]#创建一个图来存放google训练好的模型
with tf.gfile.FastGFile('E:/graduate_student/deep_learning/a-tensorflow/9/retain/output_graph.pb', 'rb') as f:graph_def = tf.GraphDef()graph_def.ParseFromString(f.read())tf.import_graph_def(graph_def, name='')with tf.Session() as sess:softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')#遍历目录for root,dirs,files in os.walk('E:/graduate_student/deep_learning/a-tensorflow/9/retain/images/'):for file in files:#载入图片image_data = tf.gfile.FastGFile(os.path.join(root,file), 'rb').read()predictions = sess.run(softmax_tensor,{'DecodeJpeg/contents:0': image_data})#图片格式是jpg格式predictions = np.squeeze(predictions)#把结果转为1维数据#打印图片路径及名称image_path = os.path.join(root,file)print(image_path)#显示图片img=Image.open(image_path)plt.imshow(img)plt.axis('off')plt.show()#排序top_k = predictions.argsort()[::-1]print(top_k)for node_id in top_k:     #获取分类名称human_string = id_to_string(node_id)#获取该分类的置信度score = predictions[node_id]print('%s (score = %.5f)' % (human_string, score))print()

结果:

使用finetune的优点:1.训练速度快,计算量少,只计算最后一层。2.迭代周期少,因为训练的权值少。3.需要使用到图片的数据量比较少。

3.重头开始训练图片识别模型

(1)tensorflow官方包里下载:https://github.com/tensorflow/models

(2)准备好分类图片

(3)图像预处理,生成tfrecord文件。

程序:

# coding: utf-8# In[2]:import tensorflow as tf
import os
import random
import math
import sys# In[3]:#验证集数量
_NUM_TEST = 500
#随机种子
_RANDOM_SEED = 0
#数据块
_NUM_SHARDS = 5
#数据集路径
DATASET_DIR = "E:/graduate_student/deep_learning/a-tensorflow\9/retain/imagedata/"
#标签文件名字
LABELS_FILENAME = "E:/graduate_student/deep_learning/a-tensorflow/9/retain/labels.txt"#定义tfrecord文件的路径+名字
def _get_dataset_filename(dataset_dir, split_name, shard_id):output_filename = 'image_%s_%05d-of-%05d.tfrecord' % (split_name, shard_id, _NUM_SHARDS)return os.path.join(dataset_dir, output_filename)#判断tfrecord文件是否存在
def _dataset_exists(dataset_dir):for split_name in ['train', 'test']:for shard_id in range(_NUM_SHARDS):#定义tfrecord文件的路径+名字output_filename = _get_dataset_filename(dataset_dir, split_name, shard_id)if not tf.gfile.Exists(output_filename):return Falsereturn True#获取所有文件以及分类
def _get_filenames_and_classes(dataset_dir):#数据目录directories = []#分类名称class_names = []for filename in os.listdir(dataset_dir):#合并文件路径path = os.path.join(dataset_dir, filename)#判断该路径是否为目录if os.path.isdir(path):#加入数据目录directories.append(path)#加入类别名称class_names.append(filename)photo_filenames = []#循环每个分类的文件夹for directory in directories:for filename in os.listdir(directory):path = os.path.join(directory, filename)#把图片加入图片列表photo_filenames.append(path)return photo_filenames, class_namesdef int64_feature(values):if not isinstance(values, (tuple, list)):values = [values]return tf.train.Feature(int64_list=tf.train.Int64List(value=values))def bytes_feature(values):return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))def image_to_tfexample(image_data, image_format, class_id):#Abstract base class for protocol messages.return tf.train.Example(features=tf.train.Features(feature={'image/encoded': bytes_feature(image_data),'image/format': bytes_feature(image_format),'image/class/label': int64_feature(class_id),}))def write_label_file(labels_to_class_names, dataset_dir,filename=LABELS_FILENAME):labels_filename = os.path.join(dataset_dir, filename)with tf.gfile.Open(labels_filename, 'w') as f:for label in labels_to_class_names:class_name = labels_to_class_names[label]f.write('%d:%s\n' % (label, class_name))#把数据转为TFRecord格式
def _convert_dataset(split_name, filenames, class_names_to_ids, dataset_dir):assert split_name in ['train', 'test']#计算每个数据块有多少数据num_per_shard = int(len(filenames) / _NUM_SHARDS)with tf.Graph().as_default():with tf.Session() as sess:for shard_id in range(_NUM_SHARDS):#定义tfrecord文件的路径+名字output_filename = _get_dataset_filename(dataset_dir, split_name, shard_id)with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:#每一个数据块开始的位置start_ndx = shard_id * num_per_shard#每一个数据块最后的位置end_ndx = min((shard_id+1) * num_per_shard, len(filenames))for i in range(start_ndx, end_ndx):try:sys.stdout.write('\r>> Converting image %d/%d shard %d' % (i+1, len(filenames), shard_id))sys.stdout.flush()#读取图片image_data = tf.gfile.FastGFile(filenames[i], 'r').read()#获得图片的类别名称class_name = os.path.basename(os.path.dirname(filenames[i]))#找到类别名称对应的idclass_id = class_names_to_ids[class_name]#生成tfrecord文件example = image_to_tfexample(image_data, b'jpg', class_id)tfrecord_writer.write(example.SerializeToString())except IOError as e:print("Could not read:",filenames[i])print("Error:",e)print("Skip it\n")sys.stdout.write('\n')sys.stdout.flush()if __name__ == '__main__':#判断tfrecord文件是否存在if _dataset_exists(DATASET_DIR):print('tfcecord文件已存在')else:#获得所有图片以及分类photo_filenames, class_names = _get_filenames_and_classes(DATASET_DIR)#把分类转为字典格式,类似于{'house': 3, 'flower': 1, 'plane': 4, 'guitar': 2, 'animal': 0}class_names_to_ids = dict(zip(class_names, range(len(class_names))))#把数据切分为训练集和测试集random.seed(_RANDOM_SEED)random.shuffle(photo_filenames)training_filenames = photo_filenames[_NUM_TEST:]testing_filenames = photo_filenames[:_NUM_TEST]#数据转换_convert_dataset('train', training_filenames, class_names_to_ids, DATASET_DIR)_convert_dataset('test', testing_filenames, class_names_to_ids, DATASET_DIR)#输出labels文件labels_to_class_names = dict(zip(range(len(class_names)), class_names))write_label_file(labels_to_class_names, DATASET_DIR)

得到:

(4)批处理文件:

1.在slim里面加入image文件夹,里面放入图片tfrecord文件

2.在slim文件夹里的datasets文件夹里新建

程序如下:

"""Provides data for the flowers dataset.The dataset scripts used to create the dataset can be found at:
tensorflow/models/slim/datasets/download_and_convert_flowers.py
"""from __future__ import absolute_import
from __future__ import division
from __future__ import print_functionimport os
import tensorflow as tffrom datasets import dataset_utilsslim = tf.contrib.slim_FILE_PATTERN = 'image_%s_*.tfrecord'SPLITS_TO_SIZES = {'train': 1000, 'test': 500}_NUM_CLASSES = 5_ITEMS_TO_DESCRIPTIONS = {'image': 'A color image of varying size.','label': 'A single integer between 0 and 4',
}def get_split(split_name, dataset_dir, file_pattern=None, reader=None):"""Gets a dataset tuple with instructions for reading flowers.Args:split_name: A train/validation split name.dataset_dir: The base directory of the dataset sources.file_pattern: The file pattern to use when matching the dataset sources.It is assumed that the pattern contains a '%s' string so that the splitname can be inserted.reader: The TensorFlow reader type.Returns:A `Dataset` namedtuple.Raises:ValueError: if `split_name` is not a valid train/validation split."""if split_name not in SPLITS_TO_SIZES:raise ValueError('split name %s was not recognized.' % split_name)if not file_pattern:file_pattern = _FILE_PATTERNfile_pattern = os.path.join(dataset_dir, file_pattern % split_name)# Allowing None in the signature so that dataset_factory can use the default.if reader is None:reader = tf.TFRecordReaderkeys_to_features = {'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),'image/format': tf.FixedLenFeature((), tf.string, default_value='png'),'image/class/label': tf.FixedLenFeature([], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),}items_to_handlers = {'image': slim.tfexample_decoder.Image(),'label': slim.tfexample_decoder.Tensor('image/class/label'),}decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)labels_to_names = Noneif dataset_utils.has_labels(dataset_dir):labels_to_names = dataset_utils.read_label_file(dataset_dir)return slim.dataset.Dataset(data_sources=file_pattern,reader=reader,decoder=decoder,num_samples=SPLITS_TO_SIZES[split_name],items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,num_classes=_NUM_CLASSES,labels_to_names=labels_to_names)

然后在datasets文件夹里的里修改:

添加myimages。

3.在slim目录下新建批处理文件。 其中,train_image_classifier.py是在(1)里下载的文件里的程序,在slim文件夹里面,slim文件夹需要拷贝到当前目录。

D:\Anaconda2\envs\PY3\python E:/graduate_student/deep_learning/models-master/models-master/research/slim/train_image_classifier.py ^
--train_dir=D:/software/mycodes/python3/python35/captcha/model/ ^
--dataset_name=myimages ^
--dataset_split_name=train ^
--dataset_dir=D:/software/mycodes/python3/python35/captcha/image ^
--batch_size=10 ^
--max_number_of_steps=10000 ^
--model_name=inception_v3 ^
pause

然后运行批处理文件。缓慢运行。

【tensorflow 深度学习】8.训练图片分类模型相关推荐

  1. [深度学习] 自然语言处理 --- 文本分类模型总结

    文本分类 包括基于word2vec预训练的文本分类,与及基于最新的预训练模型(ELMO,BERT等)的文本分类 fastText 模型 textCNN 模型 charCNN 模型 Bi-LSTM 模型 ...

  2. 动手学深度学习--课堂笔记图片分类数据集

    softmax是一个非线性函数,但softmax回归是一个线性模型(linear model):是不是线性的是由决策面是否是线性函数决定的,不是由拟合的数据分布决定的.softmax只是对数据分布做了 ...

  3. 基于tensorflow深度学习的猫狗分类识别

  4. 转:tensorflow深度学习实战笔记(二):把训练好的模型进行固化

    原文地址:https://blog.csdn.net/chenyuping333/article/details/82106863 目录 一.导出前向传播图 二.对模型进行固化 三.pb文件转tfli ...

  5. python 训练识别验证码_python使用tensorflow深度学习识别验证码

    本文介绍了python使用tensorflow深度学习识别验证码 ,分享给大家,具体如下: 除了传统的PIL包处理图片,然后用pytessert+OCR识别意外,还可以使用tessorflow训练来识 ...

  6. 《深度学习案例精粹:基于TensorFlow与Keras》深度学习常用训练案例合集

    #好书推荐##好书奇遇季#<深度学习案例精粹:基于TensorFlow与Keras>,京东当当天猫都有发售.本书配套示例源码.PPT课件.思维导图.数据集.开发环境与答疑服务. <深 ...

  7. 深度学习如何训练出好的模型

    深度学习在近年来得到了广泛的应用,从图像识别.语音识别到自然语言处理等领域都有了卓越的表现.但是,要训练出一个高效准确的深度学习模型并不容易.不仅需要有高质量的数据.合适的模型和足够的计算资源,还需要 ...

  8. 张量模型并行详解 | 深度学习分布式训练专题

    随着模型规模的扩大,单卡显存容量无法满足大规模模型训练的需求.张量模型并行是解决该问题的一种有效手段.本文以Transformer结构为例,介绍张量模型并行的基本原理. 模型并行的动机和现状 我们在上 ...

  9. [TensorFlow深度学习入门]实战七·简便方法实现TensorFlow模型参数保存与加载(ckpt方式)

    [TensorFlow深度学习入门]实战七·简便方法实现TensorFlow模型参数保存与加载(ckpt方式) 个人网站–> http://www.yansongsong.cn TensorFl ...

最新文章

  1. iResearch_2008年度中国互联网经济及核心行业核心数据发布
  2. .NET中的枚举(Enum)
  3. java中的异常种类和区别以及处理机制和区别
  4. pmp最近5题(2022年3月23日)
  5. oracle复合索引介绍(多字段索引)
  6. Chrome让人失望,是时候转到Firefox或Edge?
  7. C++学习笔记(11) 重载流插入运算符和流提取运算符,以及自动类型转换
  8. java gzip 文件夹_Java GZip 基于磁盘实现压缩和解压的方法
  9. [Book]《云计算核心技术剖析》读书笔记
  10. c语言单项选择题标准化考试系统,c语言课程设计(单项选择题标准化考试系统)分析报告.doc...
  11. kron matlab_使用kron来实现repmat, repelem的功能
  12. Ubuntu 18.04双系统卸载教程,不借助第三方软件(UEFI)
  13. 企业的五种组织架构模式
  14. 人工智能之模式识别(一)
  15. 浅谈微信小程序和微信公众平台
  16. 前端实习工作找不到,怎么增长实战经验
  17. android手机配什么蓝牙耳机,Airpods Pro搭配安卓手机+Windows电脑服用指南(避坑指南)...
  18. 如何快速搭建手游平台?
  19. # MASA MAUI Plugin (十)iOS消息推送(原生APNS方式)
  20. linux sd卡驱动流程图,SD卡驱动(详细介绍,不明白的人可以仔细看看了.有流程图)-转-OpenEdv-开源电子网...

热门文章

  1. win10 bat脚本设置软件的开机自启动
  2. Arduino 串口
  3. ssm框架中利用pagehelper分页,完成模糊查询与select条件查询
  4. 为什么GSM下行的频率要比上行的频率高呢
  5. 计算机bootmgr丢失,电脑bootmgr is missing怎么办_bootmgr is missing修复-太平洋IT百科
  6. wayos计费系统easyradius使用小记
  7. 又一个Mac特洛伊木马被发现!苹果用户要警惕
  8. 一种新的具有多尺度结构感知的全色锐化方法
  9. 如何屏蔽掉电脑上因下载软件捆绑的广告(烦人的广告让人十分尴尬)
  10. 【Mac使用技巧】QuickTime Player 如何让声音和视频同步加速播放