一、【深度学习图像识别课程】tensorflow迁移学习系列:VGG16花朵分类

转自:https://blog.csdn.net/weixin_41770169/article/details/80330581

花朵数据库介绍

种类5种:daisy雏菊,dandelion蒲公英,rose玫瑰,sunflower向日葵,tulips郁金香

数量:     633,            898,                   641,          699,                799

总数量:3670

实战:VGGNet实现花朵分类

1、读入VGG16模型
[python] view plaincopy
  1. from urllib.request import urlretrieve
  2. from os.path import isfile, isdir
  3. from tqdm import tqdm
  4. vgg_dir = 'tensorflow_vgg/'
  5. # Make sure vgg exists
  6. if not isdir(vgg_dir):
  7. raise Exception("VGG directory doesn't exist!")
  8. class DLProgress(tqdm):
  9. last_block = 0
  10. def hook(self, block_num=1, block_size=1, total_size=None):
  11. self.total = total_size
  12. self.update((block_num - self.last_block) * block_size)
  13. self.last_block = block_num
  14. if not isfile(vgg_dir + "vgg16.npy"):
  15. with DLProgress(unit='B', unit_scale=True, miniters=1, desc='VGG16 Parameters') as pbar:
  16. urlretrieve(
  17. 'https://s3.amazonaws.com/content.udacity-data.com/nd101/vgg16.npy',
  18. vgg_dir + 'vgg16.npy',
  19. pbar.hook)
  20. else:
  21. print("Parameter file already exists!")

下载了如下标亮文件:vgg16.npy

2、读入图像库
[python] view plaincopy
  1. import tarfile
  2. dataset_folder_path = 'flower_photos'
  3. class DLProgress(tqdm):
  4. last_block = 0
  5. def hook(self, block_num=1, block_size=1, total_size=None):
  6. self.total = total_size
  7. self.update((block_num - self.last_block) * block_size)
  8. self.last_block = block_num
  9. if not isfile('flower_photos.tar.gz'):
  10. with DLProgress(unit='B', unit_scale=True, miniters=1, desc='Flowers Dataset') as pbar:
  11. urlretrieve(
  12. 'http://download.tensorflow.org/example_images/flower_photos.tgz',
  13. 'flower_photos.tar.gz',
  14. pbar.hook)
  15. if not isdir(dataset_folder_path):
  16. with tarfile.open('flower_photos.tar.gz') as tar:
  17. tar.extractall()
  18. tar.close()

下载如下高亮文件:flower_photos.tar.gz

3、卷积代码

参考的源码:[html] view plain cop

[html] view plaincopy
  1. self.conv1_1 = self.conv_layer(bgr, "conv1_1")
  2. self.conv1_2 = self.conv_layer(self.conv1_1, "conv1_2")
  3. self.pool1 = self.max_pool(self.conv1_2, 'pool1')
  4. self.conv2_1 = self.conv_layer(self.pool1, "conv2_1")
  5. self.conv2_2 = self.conv_layer(self.conv2_1, "conv2_2")
  6. self.pool2 = self.max_pool(self.conv2_2, 'pool2')
  7. self.conv3_1 = self.conv_layer(self.pool2, "conv3_1")
  8. self.conv3_2 = self.conv_layer(self.conv3_1, "conv3_2")
  9. self.conv3_3 = self.conv_layer(self.conv3_2, "conv3_3")
  10. self.pool3 = self.max_pool(self.conv3_3, 'pool3')
  11. self.conv4_1 = self.conv_layer(self.pool3, "conv4_1")
  12. self.conv4_2 = self.conv_layer(self.conv4_1, "conv4_2")
  13. self.conv4_3 = self.conv_layer(self.conv4_2, "conv4_3")
  14. self.pool4 = self.max_pool(self.conv4_3, 'pool4')
  15. self.conv5_1 = self.conv_layer(self.pool4, "conv5_1")
  16. self.conv5_2 = self.conv_layer(self.conv5_1, "conv5_2")
  17. self.conv5_3 = self.conv_layer(self.conv5_2, "conv5_3")
  18. self.pool5 = self.max_pool(self.conv5_3, 'pool5')
  19. self.fc6 = self.fc_layer(self.pool5, "fc6")
  20. self.relu6 = tf.nn.relu(self.fc6)
  21. with tf.Session() as sess:
  22. vgg = vgg16.Vgg16()
  23. input_ = tf.placeholder(tf.float32, [None, 224, 224, 3])
  24. with tf.name_scope("content_vgg"):
  25. vgg.build(input_)
  26. feed_dict = {input_: images}
  27. codes = sess.run(vgg.relu6, feed_dict=feed_dict)

tensorflow中vgg_16采用的上述结构。本项目代码如下:

[python] view plaincopy
  1. import os
  2. import numpy as np
  3. import tensorflow as tf
  4. from tensorflow_vgg import vgg16
  5. from tensorflow_vgg import utils
[python] view plaincopy
  1. data_dir = 'flower_photos/'
  2. contents = os.listdir(data_dir)
  3. classes = [each for each in contents if os.path.isdir(data_dir + each)]

将图像批量batches通过VGG模型,将输出作为新的输入:

[python] view plaincopy
  1. # Set the batch size higher if you can fit in in your GPU memory
  2. batch_size = 10
  3. codes_list = []
  4. labels = []
  5. batch = []
  6. codes = None
  7. with tf.Session() as sess:
  8. vgg = vgg16.Vgg16()
  9. input_ = tf.placeholder(tf.float32, [None, 224, 224, 3])
  10. with tf.name_scope("content_vgg"):
  11. vgg.build(input_)
  12. for each in classes:
  13. print("Starting {} images".format(each))
  14. class_path = data_dir + each
  15. files = os.listdir(class_path)
  16. for ii, file in enumerate(files, 1):
  17. # Add images to the current batch
  18. # utils.load_image crops the input images for us, from the center
  19. img = utils.load_image(os.path.join(class_path, file))
  20. batch.append(img.reshape((1, 224, 224, 3)))
  21. labels.append(each)
  22. # Running the batch through the network to get the codes
  23. if ii % batch_size == 0 or ii == len(files):
  24. images = np.concatenate(batch)
  25. feed_dict = {input_: images}
  26. codes_batch = sess.run(vgg.relu6, feed_dict=feed_dict)
  27. # Here I'm building an array of the codes
  28. if codes is None:
  29. codes = codes_batch
  30. else:
  31. codes = np.concatenate((codes, codes_batch))
  32. # Reset to start building the next batch
  33. batch = []
  34. print('{} images processed'.format(ii))

4、模型建立和测试

图像处理代码和标签:

[python] view plaincopy
  1. # read codes and labels from file
  2. import csv
  3. with open('labels') as f:
  4. reader = csv.reader(f, delimiter='\n')
  5. labels = np.array([each for each in reader if len(each) > 0]).squeeze()
  6. with open('codes') as f:
  7. codes = np.fromfile(f, dtype=np.float32)
  8. codes = codes.reshape((len(labels), -1))
4.1 图像预处理
[python] view plaincopy
  1. from sklearn.preprocessing import LabelBinarizer
  2. lb = LabelBinarizer()
  3. lb.fit(labels)
  4. labels_vecs = lb.transform(labels)

对标签进行one-hot编码:daisy雏菊  dandelion蒲公英  rose玫瑰  sunflower向日葵 tulips郁金香

daisy雏菊        1                0                        0                 0                     0

dandelion蒲公英        0                1                        0                 0                     0

rose玫瑰        0                0                        1                 0                     0

sunflower向日葵        0                0                        0                 1                     0

tulips郁金香        0                0                        0                 0                     1

随机拆分数据集(之前那种直接把集中的部分图像拿出来验证/测试不管用,这里的数据集是每个种类集中放的,如果直接拿出其中的一部分,会导致验证集或者测试集是同一种花)。scikit-learn中的函数StratifiedShuffleSplit可以做到。我们这里,随机拿出20%的图像用来验证和测试,然后验证集和测试集再各占一半。

[python] view plaincopy
  1. from sklearn.model_selection import StratifiedShuffleSplit
  2. ss = StratifiedShuffleSplit(n_splits=1, test_size=0.2)
  3. train_idx, val_idx = next(ss.split(codes, labels))
  4. half_val_len = int(len(val_idx)/2)
  5. val_idx, test_idx = val_idx[:half_val_len], val_idx[half_val_len:]
  6. train_x, train_y = codes[train_idx], labels_vecs[train_idx]
  7. val_x, val_y = codes[val_idx], labels_vecs[val_idx]
  8. test_x, test_y = codes[test_idx], labels_vecs[test_idx]
[python] view plaincopy
  1. print("Train shapes (x, y):", train_x.shape, train_y.shape)
  2. print("Validation shapes (x, y):", val_x.shape, val_y.shape)
  3. print("Test shapes (x, y):", test_x.shape, test_y.shape)

总数量:3670,则训练图像:3670*0.8=2936,验证图像:3670*0.2*0.5=367,测试图像:3670*0.2*0.5=367。

4.2 层

在上述vgg的基础上,增加一个256个元素的全连接层,最后加上一个softmax层,计算交叉熵进行最后的分类。

[python] view plaincopy
  1. inputs_ = tf.placeholder(tf.float32, shape=[None, codes.shape[1]])
  2. labels_ = tf.placeholder(tf.int64, shape=[None, labels_vecs.shape[1]])
  3. fc = tf.contrib.layers.fully_connected(inputs_, 256)
  4. logits = tf.contrib.layers.fully_connected(fc, labels_vecs.shape[1], activation_fn=None)
  5. cross_entropy = tf.nn.softmax_cross_entropy_with_logits(labels=labels_, logits=logits)
  6. cost = tf.reduce_mean(cross_entropy)
  7. optimizer = tf.train.AdamOptimizer().minimize(cost)
  8. predicted = tf.nn.softmax(logits)
  9. correct_pred = tf.equal(tf.argmax(predicted, 1), tf.argmax(labels_, 1))
  10. accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
4.3 训练:batches和epoches
[python] view plaincopy
  1. def get_batches(x, y, n_batches=10):
  2. """ Return a generator that yields batches from arrays x and y. """
  3. batch_size = len(x)//n_batches
  4. for ii in range(0, n_batches*batch_size, batch_size):
  5. # If we're not on the last batch, grab data with size batch_size
  6. if ii != (n_batches-1)*batch_size:
  7. X, Y = x[ii: ii+batch_size], y[ii: ii+batch_size]
  8. # On the last batch, grab the rest of the data
  9. else:
  10. X, Y = x[ii:], y[ii:]
  11. # I love generators
  12. yield X, Y
[python] view plaincopy
  1. epochs = 10
  2. iteration = 0
  3. saver = tf.train.Saver()
  4. with tf.Session() as sess:
  5. sess.run(tf.global_variables_initializer())
  6. for e in range(epochs):
  7. for x, y in get_batches(train_x, train_y):
  8. feed = {inputs_: x,
  9. labels_: y}
  10. loss, _ = sess.run([cost, optimizer], feed_dict=feed)
  11. print("Epoch: {}/{}".format(e+1, epochs),
  12. "Iteration: {}".format(iteration),
  13. "Training loss: {:.5f}".format(loss))
  14. iteration += 1
  15. if iteration % 5 == 0:
  16. feed = {inputs_: val_x,
  17. labels_: val_y}
  18. val_acc = sess.run(accuracy, feed_dict=feed)
  19. print("Epoch: {}/{}".format(e, epochs),
  20. "Iteration: {}".format(iteration),
  21. "Validation Acc: {:.4f}".format(val_acc))
  22. saver.save(sess, "checkpoints/flowers.ckpt")

验证集的正确率达到90%,很高了已经。

4.4 测试
[python] view plaincopy
  1. with tf.Session() as sess:
  2. saver.restore(sess, tf.train.latest_checkpoint('checkpoints'))
  3. feed = {inputs_: test_x,
  4. labels_: test_y}
  5. test_acc = sess.run(accuracy, feed_dict=feed)
  6. print("Test accuracy: {:.4f}".format(test_acc))

[python] view plaincopy
  1. %matplotlib inline
  2. import matplotlib.pyplot as plt
  3. from scipy.ndimage import imread
[python] view plaincopy
  1. test_img_path = 'flower_photos/roses/10894627425_ec76bbc757_n.jpg'
  2. test_img = imread(test_img_path)
  3. plt.imshow(test_img)

[python] view plaincopy
  1. with tf.Session() as sess:
  2. input_ = tf.placeholder(tf.float32, [None, 224, 224, 3])
  3. vgg = vgg16.Vgg16()
  4. vgg.build(input_)

[python] view plaincopy
  1. with tf.Session() as sess:
  2. img = utils.load_image(test_img_path)
  3. img = img.reshape((1, 224, 224, 3))
  4. feed_dict = {input_: img}
  5. code = sess.run(vgg.relu6, feed_dict=feed_dict)
  6. saver = tf.train.Saver()
  7. with tf.Session() as sess:
  8. saver.restore(sess, tf.train.latest_checkpoint('checkpoints'))
  9. feed = {inputs_: code}
  10. prediction = sess.run(predicted, feed_dict=feed).squeeze()
[python] view plaincopy
  1. plt.imshow(test_img)

[python] view plaincopy
  1. plt.barh(np.arange(5), prediction)
  2. _ = plt.yticks(np.arange(5), lb.classes_)

上图的花最有可能是Rose,有小概率是Tulips。

tensorflow微调vgg16 程序代码汇总相关推荐

  1. 20210112.使用字典来创建并分类汇总物品清单的python程序代码

    20210112.使用字典来创建并分类汇总物品清单的python程序代码 #这段代码使用字典来创建并分类汇总物品清单.为<python编程快速上手--让繁琐工作自动化>一书中的5.6.1实 ...

  2. ASP.NET程序中常用代码汇总-1[转]

    相关链接: ASP.NET程序中常用代码汇总-1 ASP.NET程序中常用代码汇总-2[转] ASP.NET程序中常用代码汇总-3[转] ASP.NET程序中常用代码汇总-4[转] ASP.NET程序 ...

  3. 迁移学习——使用Tensorflow和VGG16预训模型进行预测

    使用Tensorflow和VGG16预训模型进行预测 from:https://zhuanlan.zhihu.com/p/28997549 fast.ai的入门教程中使用了kaggle: dogs v ...

  4. CV Code | 本周新出计算机视觉开源代码汇总(含自动驾驶目标检测、医学图像分割、风格迁移、语义分割、目标跟踪等)...

    点击我爱计算机视觉标星,更快获取CVML新技术 刚刚过去的一周含五一假期,工作日第一天,CV君汇总了过去一周计算机视觉领域新出的开源代码,涉及到自动驾驶目标检测.医学图像分割.风格迁移.神经架构搜索. ...

  5. Tensorflow 入门手册(代码与原理释义)

    ·人工智能与深度学习 -人工智能={机器学习,........else} -机器学习={深度学习(表示学习),........else} ·神经网络 ·卷积神经网络(Convolutional Neu ...

  6. 2 万字全面测评深度学习框架 PaddlePaddle、TensorFlow 和 Keras | 程序员硬核评测

    [CSDN 编者按]人工智能想入门深度学习?却苦恼网上的入门教程太零碎,不知道用什么框架好?本文作者用两万字手分别从百度的PaddlePaddle深度学习框架.Google的TensorFlow深度学 ...

  7. 3-3 uniapp、HTML5+、Native.js 功能代码汇总

    3-3 uniapp.HTML5+.Native.js 功能代码汇总 本文只适用于 APP 代码汇总 Android平台 监听手机锁屏,解锁,开屏 var receiver; mui.plusRead ...

  8. Java 从虚拟机层面看程序代码是怎么运行起来的

    专栏原创出处:github-源笔记文件 ,github-源码 ,欢迎 Star,转载请附上原文出处链接和本声明. Java JVM-虚拟机专栏系列笔记,系统性学习可访问个人复盘笔记-技术博客 Java ...

  9. Window环境运行Tensorflow目标识别示例程序

    Tensorflow提供了目标识别的API来支持通过各种深度学习网络实现目标识别的功能.通过访问Github项目https://github.com/tensorflow/models 可以看到Ten ...

最新文章

  1. 如何将风险应用加入白名单_将微信服务器、API接口的IP列表加入宝塔防火墙IP白名单...
  2. 华为DUA-AL00真机android studio识别不出
  3. python清空字典保留变量_python彻底清除字典数据,clear方法使用
  4. 关于正则表达式匹配任意字符
  5. webview 修改html,使用自定义CSS在WebView中呈现HTML
  6. asp.net截取指定长度的字符串内容
  7. ubuntu创建wifi热点plasma-nm
  8. crontab自动执行任务,失败原因记录
  9. php兴趣爱好复选框如何取值,php checkbox 取值详细说明
  10. js 对象数组和对象的使用
  11. 高级排序之分割法(以某数为基准分割)
  12. 线上python课程一般多少钱-python培训班一般多少钱?一篇文章告诉你
  13. 数据传递型情景下事件机制与消息机制的架构设计剖析(目录)
  14. html标签的补充—— b,strong标签
  15. Objective-C中的单例模式
  16. java课程设计-简单学生签到系统-桌面小程序的实现
  17. 深入浅出MFC学习笔记(第三章:MFC六大关键技术之仿真:命令传递) .
  18. HPE Gen9 使用 RESTful API 管理服务器
  19. VS2015卸载不完全与安装问题
  20. java flag 用法_Java中一些常用的方法

热门文章

  1. 取消小米笔记本插入耳机后弹框
  2. 微信分享官方第三方接入(图片及文字)
  3. 技术平台与业务平台的区别
  4. 中国图书分类号-自动化_计算机
  5. 读稻盛和夫《活法》-现代人的修行之路
  6. dnsmasq.conf配置
  7. thinkcmf需要的php版本,升级指导 · ThinkCMF5开发手册 · 看云
  8. 更新微信 7.0,你后悔了吗?
  9. 锦城学院和锦江计算机,四川大学锦城学院怎么样_是几本?和四川大学锦江学院哪个更好?...
  10. 任鸟飞FPS类型游戏绘制和游戏安全,反外挂研究(二)