这次用slim搭个稍微大一点的网络VGG16,VGG16和VGG19实际上差不多,所以本例程的代码以VGG16来做5类花的分类任务。

VGG网络相比之前的LeNet,AlexNet引入如下几个特点:

1. 堆叠3×3的小卷积核替代了5×5,7×7的大卷积核。

虽然5×5的卷积核感受野大,但是参数多。2个3×3的卷积堆叠感受野等同于5×5,并且进行了2次非线性变换。总结一下:相比于大卷积核,小卷积核的堆叠一方面减少了参数; 另一方面进行了更多的非线性映射,增加了网络表达能力。

2.网络层数加深。我们先不谈深层网络难以训练又或者梯度弥散等缺点,在特征的抽象化或者网络的表达能力范畴上,深层网络比浅层网络更加能够拟合数据的分布。

3.VGG网络的原作还引入了数据增广,图像预处理等trick。

开始贴代码阶段,工程分为三个文件:

vgg.py: 搭建16层的VGG网络。

import tensorflow as tf
import tensorflow.contrib.slim as slimdef build_vgg(rgb, num_classes, keep_prob, train=True):with slim.arg_scope([slim.conv2d, slim.fully_connected], activation_fn=tf.nn.relu, normalizer_fn=slim.batch_norm):# block_1net = slim.repeat(rgb, 2, slim.conv2d, 64, [3, 3], padding='SAME', scope='conv1')net = slim.max_pool2d(net, [2, 2], scope='pool1')# block_2net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], padding='SAME', scope='conv2')net = slim.max_pool2d(net, [2, 2], scope='pool2')# block_3net = slim.repeat(net, 3, slim.conv2d, 256, [3, 3], padding='SAME', scope='conv3')net = slim.max_pool2d(net, [2, 2], scope='pool3')# block_4net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], padding='SAME', scope='conv4')net = slim.max_pool2d(net, [2, 2], scope='pool4')# block_5net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], padding='SAME', scope='conv5')net = slim.max_pool2d(net, [2, 2], scope='pool5')# flattenfeature_shape = net.get_shape()flattened_shape = feature_shape[1].value * feature_shape[2].value * feature_shape[3].valuepool5_flatten = tf.reshape(net, [-1, flattened_shape])# fc6net = slim.fully_connected(pool5_flatten, 4096, scope='fc6')if train:net = slim.dropout(net, keep_prob=keep_prob, scope='dropout6')# fc7net = slim.fully_connected(net, 4096, scope='fc7')if train:net = slim.dropout(net, keep_prob=keep_prob, scope='dropout7')# fc8net = slim.fully_connected(net, num_classes, activation_fn=tf.nn.softmax, scope='fc8')return net

tfrecords.py:用于数据的编码和解码,本例程不同与之前的文章采用feed_dict向网络喂数据,而是使用tensorflow自己的TFRecord结构编码数据集。

import tensorflow as tf
import numpy as np
import os
import glob
from PIL import Imagepath_tfrecord = '/home/danny/chenwei/CSDN_blog/VGG/tfrecords/'def convert_to_tfrecord(images, labels, filename):print("Converting data into %s ..." % filename)writer = tf.python_io.TFRecordWriter(path_tfrecord + filename)for index, img in enumerate(images):img_raw = Image.open(img)if img_raw.mode != "RGB":continueimg_raw = img_raw.resize((256, 256))img_raw = img_raw.tobytes()label = int(labels[index])example = tf.train.Example(features=tf.train.Features(feature={"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),"img_raw": tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),}))writer.write(example.SerializeToString())writer.close()def read_and_decode(filename, is_train=None):filename_queue = tf.train.string_input_producer([filename], num_epochs=400)reader = tf.TFRecordReader()_, serialized_example = reader.read(filename_queue)features = tf.parse_single_example(serialized_example,features={'label': tf.FixedLenFeature([], tf.int64),'img_raw': tf.FixedLenFeature([], tf.string),})img = tf.decode_raw(features['img_raw'], tf.uint8)img = tf.reshape(img, [256, 256, 3])img = tf.cast(img, tf.float32) * (1. / 255) - 0.5if is_train == True:img = tf.random_crop(img, [224, 224, 3])img = tf.image.random_flip_left_right(img)img = tf.image.random_brightness(img, max_delta=63)img = tf.image.random_contrast(img, lower=0.2, upper=1.8)img = tf.image.per_image_standardization(img)else:img = tf.image.resize_image_with_crop_or_pad(img, 224, 224)img = tf.image.per_image_standardization(img)label = tf.cast(features['label'], tf.int32)return img, labeldef get_file(path):cate = [path+x for x in os.listdir(path) if os.path.isdir(path+x)]images = []labels = []for idx, folder in enumerate(cate):for img in glob.glob(folder+'/*.jpg'):print('reading the images:%s' % (img))images.append(img)labels.append(idx)image_list = np.asarray(images, np.string_)label_list = np.asarray(labels, np.int32)# shufflenum_example = image_list.shape[0]arr = np.arange(num_example)np.random.shuffle(arr)image_list = image_list[arr]label_list = label_list[arr]# divide train_data and val_datanum_example = image_list.shape[0]split = np.int(num_example * 0.8)train_images = image_list[:split]train_labels = label_list[:split]val_images = image_list[split:]val_labels = label_list[split:]return train_images, train_labels, val_images, val_labelsif __name__ == '__main__':train_images, train_labels, val_images, val_labels = get_file('/home/danny/chenwei/CSDN_blog/VGG/datasets/')convert_to_tfrecord(images=train_images, labels=train_labels, filename="train.tfrecords")convert_to_tfrecord(images=val_images, labels=val_labels, filename="test.tfrecords")

train.py:用于训练的文件,与之间不同之处在于使用队列的方式多线程取数据进行训练。

# -*- coding: utf-8 -*-
import tensorflow as tf
from utils.tfrecords import *
from model.vgg import *tf.app.flags.DEFINE_integer('num_classes', 5, 'classification number.')
tf.app.flags.DEFINE_integer('crop_width', 256, 'width of input image.')
tf.app.flags.DEFINE_integer('crop_height', 256, 'height of input image.')
tf.app.flags.DEFINE_integer('channels', 3, 'channel number of image.')
tf.app.flags.DEFINE_integer('batch_size', 2, 'num of each batch')
tf.app.flags.DEFINE_integer('num_epochs', 400, 'number of epoch')
tf.app.flags.DEFINE_bool('continue_training', False, 'whether is continue training')
tf.app.flags.DEFINE_float('learning_rate', 0.01, 'learning rate')
tf.app.flags.DEFINE_string('dataset_path', './datasets/', 'path of dataset')
tf.app.flags.DEFINE_string('checkpoints', './checkpoints/model.ckpt', 'path of checkpoints')
tf.app.flags.DEFINE_string('train_tfrecords', '/home/danny/chenwei/CSDN_blog/VGG/tfrecords/train.tfrecords', 'train tfrecord')
tf.app.flags.DEFINE_string('test_tfrecords', '/home/danny/chenwei/CSDN_blog/VGG/tfrecords/test.tfrecords', 'test tfrecord')FLAGS = tf.app.flags.FLAGSdef main(_):# data processtrain_images, train_labels = read_and_decode(FLAGS.train_tfrecords, True)val_images, val_labels = read_and_decode(FLAGS.test_tfrecords, False)train_labels = tf.one_hot(indices=tf.cast(train_labels, tf.int32), depth=FLAGS.num_classes)train_images_batch, train_labels_batch = tf.train.shuffle_batch([train_images, train_labels], batch_size=FLAGS.batch_size, capacity=20000, min_after_dequeue=3000, num_threads=16)  # 这里设置线程数val_labels = tf.one_hot(indices=tf.cast(val_labels, tf.int32), depth=FLAGS.num_classes)val_images_batch, val_labels_batch = tf.train.shuffle_batch([val_images, val_labels], batch_size=FLAGS.batch_size, capacity=20000, min_after_dequeue=3000, num_threads=16)  # 这里设置线程数# define network inputinput = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, FLAGS.crop_height, FLAGS.crop_width, FLAGS.channels], name='input')output = tf.placeholder(tf.int32, shape=[FLAGS.batch_size, FLAGS.num_classes], name='output')# control GPU resource utilizationconfig = tf.ConfigProto(allow_soft_placement=True)config.gpu_options.allow_growth = Truesess = tf.Session(config=config)# build networklogits = build_vgg(input, FLAGS.num_classes, 0.5, True)# losscross_entropy_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=output))regularization_loss = tf.reduce_sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))loss = cross_entropy_loss + regularization_loss# optimizertrain_op = tf.train.AdamOptimizer(FLAGS.learning_rate).minimize(loss)# calculate correctcorrect_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(output, 1))accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))with sess.as_default():# init all paramterssaver = tf.train.Saver(max_to_keep=1000)sess.run(tf.local_variables_initializer())sess.run(tf.global_variables_initializer())# restore weightif FLAGS.continue_training:saver.restore(sess, FLAGS.checkpoints)# begin trainingcoord = tf.train.Coordinator()threads = tf.train.start_queue_runners(sess=sess, coord=coord)epoch = 0try:while not coord.should_stop():# begin trainingtrain_images, train_labels = sess.run([train_images_batch, train_labels_batch])_, err, acc = sess.run([train_op, loss, accuracy], feed_dict={input: train_images, output: train_labels})print("[Train] Step: %d, loss: %.4f, accuracy: %.4f%%" % (epoch, err, acc))epoch += 1if epoch % 10 == 0 or (epoch + 1) == FLAGS.num_epochs:val_images, val_labels = sess.run([val_images_batch, val_labels_batch])val_err, val_acc = sess.run([loss, accuracy], feed_dict={input:val_imagesh, output: val_labels})print("[validation] Step: %d, loss: %.4f, accuracy: %.4f%%" % (epoch, val_err, val_acc))if (epoch + 1) == FLAGS.num_epochs:checkpoint_path = FLAGS.checkpointssaver.save(sess, save_path=checkpoint_path, global_step=epoch)except tf.errors.OutOfRangeError:print('Done training -- epoch limited reached')finally:coord.request_stop()coord.join(threads)sess.close()if __name__ == '__main__':tf.app.run()

训练结果:大约在96%左右

[Train] Step: 19985, loss: 1.1098, accuracy: 1.0000%
[Train] Step: 19986, loss: 1.1302, accuracy: 1.0000%
[Train] Step: 19987, loss: 1.1232, accuracy: 1.0000%
[Train] Step: 19988, loss: 1.1299, accuracy: 1.0000%
[Train] Step: 19989, loss: 1.1220, accuracy: 1.0000%
[validation] Step: 19990, loss: 1.1634, accuracy: 0.9688%

tensorflow随笔——VGG网络相关推荐

  1. tensorflow随笔——LeNet网络

    最近总想写点东西,把以前的这些网络都翻出来自己实现一遍.计划上从经典的分类网络到目标检测再到目标分割的都过一下.这篇从最简单的LeNet写起. 先上一张经典的LeNet模型结果图: 该网络结构包含2个 ...

  2. 首次运行 tensorflow 项目之 vgg 网络

    首次运行 tensorflow 项目之 vgg 网络 文章目录 1. 下载所需文件 2. 在 pycharm 中打开项目 3. 为项目设置 python 编译器 4. 调试代码 1. train.py ...

  3. tensorflow 随笔-----------VGG网络的模型的复现

    推荐个课程北京大学的tensorflow 笔记 VGG网络是谷歌千分类用的,实现的是对图像的识别 由于模型过大,需要的可以私聊我(13072509383微信) vgg16.py 网络对参数的读取 im ...

  4. slim php dd model,第二十四节,TensorFlow下slim库函数的使用以及使用VGG网络进行预训练、迁移学习(附代码)...

    在介绍这一节之前,需要你对slim模型库有一些基本了解,具体可以参考第二十二节,TensorFlow中的图片分类模型库slim的使用.数据集处理,这一节我们会详细介绍slim模型库下面的一些函数的使用 ...

  5. 【转载】计算机视觉~VGG网络

    [转载:一文读懂VGG网络 - Amusi的文章 - 知乎 https://zhuanlan.zhihu.com/p/41423739] 目录 前言 VGG原理 VGG网络结构 VGG优缺点 VGG优 ...

  6. 使用PaddleFluid和TensorFlow实现图像分类网络SE_ResNeXt | 文末超大福利

    专栏介绍:Paddle Fluid 是用来让用户像 PyTorch 和 Tensorflow Eager Execution 一样执行程序.在这些系统中,不再有模型这个概念,应用也不再包含一个用于描述 ...

  7. TensorFlow载入VGG并可视化每层

    一.简介 VGG网络在2014年的 ILSVRC localization and classification 两个问题上分别取得了第一名和第二名.VGG网络非常深,通常有16-19层,如果自己训练 ...

  8. 【深度学习】使用tensorflow实现VGG19网络

    转载注明出处:http://blog.csdn.net/accepthjp/article/details/70170217 接上一篇AlexNet,本文讲述使用tensorflow实现VGG19网络 ...

  9. CNN经典之VGG网络+PyTorch复现

    一.前情说明: 写在前面的话 本系列博客在于汇总CSDN的精华帖,类似自用笔记,方便以后的复习回顾,博文中的引用都注明出处,并点赞收藏原博主. 博客大致分为两部分,第一部是转载于<雪饼>大 ...

最新文章

  1. ajax对服务器路径请求
  2. 量子信息之父辞世,开山论文写完14年才被发表,晚年去当建筑工人
  3. Angular和SAP C4C的事件处理队列
  4. 被5月GitHub Top20榜单惊呆了 原来区块链大佬都在做这个
  5. python supervisor_python之supervisor进程管理工具
  6. TCP/IP 三次握手和四次挥手
  7. Unity3d通用工具类之生成文件的MD5
  8. eclipse lombok 标红_无法使Lombok项目在Eclipse上运行
  9. freecplus框架-Oracle数据库操作
  10. 分布式文件系统-HDFS(   HDFS全称是Hadoop Distributed System)
  11. jsf标签,jsp标签与jstl标签
  12. mysql 主从的作用_MySQL主从复制作用和原理
  13. 让textbox紧贴IME
  14. [经典力学]牛顿自然哲学的数学原理论文解读
  15. Java实现图片水印
  16. TCP长连接和短连接的区别(好文章!)
  17. Springboot2.2对put,detele方法的更改
  18. 电脑设置显示文件扩展名(文件后缀名)
  19. git版本管理软件——git储藏
  20. IOS开发之UI进阶(安全区高度)

热门文章

  1. SQL的简单增、删、改
  2. 运算符重载的非成员函数形式
  3. matlab利用图像减法实现找茬
  4. AttributeError: module 'labelme.utils' has no attribute 'draw_label'
  5. zynq开发系列3:GPIO连接MIO通过按键控制LED灯亮起
  6. 阅读邮件回复邮件计算机操作题,《电子邮件》阅读练习题附答案
  7. DataGrid中間隔色的實現
  8. VS2012统计代码量
  9. BestCoder Round #81 (div.2) B Matrix
  10. 闲来无事写写-Huffman树的生成过程