最近总想写点东西,把以前的这些网络都翻出来自己实现一遍。计划上从经典的分类网络到目标检测再到目标分割的都过一下。这篇从最简单的LeNet写起。

先上一张经典的LeNet模型结果图:

该网络结构包含2个Conv,2个pooling,2个fully connection外加一个softmax分类。至于每层的卷积核尺寸,多少个卷积核这些可以直接从代码中了解,在这就不细说了。

网络特点:相比之前的MLP神经网络,LeNet采用卷积做特征提取,采用池化做降采样,采用激活函数做非线性变换。利用卷积神经网络的稀疏连接和权值共享的特性大幅度减少参数数量。忽然发现之前写了篇简单的CNN分类的文章就是用的LeNet做mnist识别。那我们这篇就用TF的高级库slim重写下,对man/woman进行二分类,图片尺寸与上图略有不同。话不多说先上代码,本次例程分为三个文件:

read_data.py:用于数据的读取和整理,包括随机打乱,拆分训练测试级等

import os
from skimage import io, transform
import glob
import numpy as nplabel_map = {'man': 0, 'woman': 1}def get_data_list(path):category = [path + x for x in os.listdir(path) if os.path.isdir(path + x)]images = []labels = []category.sort()for idx, folder in enumerate(category):for img in glob.glob(folder + '/*.jpg'):images.append(img)labels.append(idx)image_list = np.asarray(images, np.str)label_list = np.asarray(labels, np.int32)return image_list, label_listdef shuffle_data(images, labels):num_example = len(images)arr = np.arange(num_example)np.random.shuffle(arr)images = images[arr]labels = labels[arr]return images, labelsdef divide_dataset(images, labels, ratio):num_example = images.shape[0]split = np.int(num_example * ratio)train_images = images[:split]train_labels = labels[:split]val_images = images[split:]val_labels = labels[split:]return train_images, train_labels, val_images, val_labelsdef mini_batches(inputs=None, targets=None, batch_size=None, height=224, width=224):assert len(inputs) == len(targets)indices = np.arange(len(inputs))for start_idx in range(0, len(inputs) - batch_size + 1, batch_size):excerpt = indices[start_idx:start_idx + batch_size]images = []for i in excerpt:img = io.imread(inputs[i])img = transform.resize(img, (height, width))images.append(img)images = np.asarray(images, np.float32)yield images, targets[excerpt]def one_hot(labels, num_classes):n_sample = len(labels)n_class = num_classesonehot_labels = np.zeros((n_sample, n_class))onehot_labels[np.arange(n_sample), labels] = 1return onehot_labels

LeNet.py:简单的LeNet五层网络,直接用slim实现。

import tensorflow as tf
import tensorflow.contrib.slim as slimdef build_LeNet(inputs, num_classes):# block_1with tf.variable_scope('conv_layer_1'):conv1 = slim.conv2d(inputs, num_outputs=32, kernel_size=[5, 5], scope='conv')pool1 = slim.max_pool2d(conv1, kernel_size=[2, 2], stride=2, scope='pool')with tf.variable_scope('conv_layer_2'):conv2 = slim.conv2d(pool1, num_outputs=64, kernel_size=[5, 5], scope='conv')pool2 = slim.max_pool2d(conv2, kernel_size=[2, 2], stride=2, scope='pool')with tf.variable_scope('fatten'):feature_shape = pool2.get_shape()fatten_shape = feature_shape[1].value * feature_shape[2].value * feature_shape[3].valuepool2_flatten = tf.reshape(pool2, [-1, fatten_shape])with tf.variable_scope('fc_layer_1'):fc1 = slim.fully_connected(pool2_flatten, num_outputs=1024, scope='fc')with tf.variable_scope('fc_layer_2'):fc2 = slim.fully_connected(fc1, num_outputs=1024, scope='fc')with tf.variable_scope('output'):logits = slim.fully_connected(fc2, num_outputs=num_classes, activation_fn=tf.nn.softmax, scope='fc')return logits

最后是训练代码train.py,与之前的文章中基本相同。

# -*- coding: utf-8 -*-
import tensorflow as tf
from model.LeNet import *
from utils.read_data import *tf.app.flags.DEFINE_integer('num_classes', 2, 'classification number.')
tf.app.flags.DEFINE_integer('crop_width', 256, 'width of input image.')
tf.app.flags.DEFINE_integer('crop_height', 256, 'height of input image.')
tf.app.flags.DEFINE_integer('channels', 3, 'channel number of image.')
tf.app.flags.DEFINE_integer('batch_size', 32, 'num of each batch')
tf.app.flags.DEFINE_integer('num_epochs', 400, 'number of epoch')
tf.app.flags.DEFINE_bool('continue_training', False, 'whether is continue training')
tf.app.flags.DEFINE_float('learning_rate', 0.01, 'learning rate')
tf.app.flags.DEFINE_string('dataset_path', './datasets/', 'path of dataset')
tf.app.flags.DEFINE_string('checkpoints', './checkpoints/model.ckpt', 'path of checkpoints')
FLAGS = tf.app.flags.FLAGSdef main(_):# dataset processimages, labels = get_data_list(FLAGS.dataset_path)images, labels = shuffle_data(images, labels)train_images, train_labels, val_images, val_labels = divide_dataset(images, labels, ratio=0.8)# define network inputinput = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, FLAGS.crop_width, FLAGS.crop_height, FLAGS.channels], name='input')output = tf.placeholder(tf.int32, shape=[FLAGS.batch_size, FLAGS.num_classes], name='output')# Control GPU resource utilizationconfig = tf.ConfigProto(allow_soft_placement=True)config.gpu_options.allow_growth = Truesess = tf.Session(config=config)# build networklogits = build_LeNet(input, FLAGS.num_classes)# losscross_entropy_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=output))regularization_loss = tf.reduce_sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))loss = cross_entropy_loss + regularization_loss# optimizertrain_op = tf.train.AdamOptimizer(FLAGS.learning_rate).minimize(loss)# calculate correctcorrect_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(output, 1))accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))with sess.as_default():# init all paramterssaver = tf.train.Saver(max_to_keep=1000)sess.run(tf.global_variables_initializer())# restore weights fileif FLAGS.continue_training:saver.restore(sess, FLAGS.checkpoints)# begin trainingfor epoch in range(FLAGS.num_epochs):print("===epoch %d===" % epoch)# trainingtrain_loss, train_acc, n_batch = 0, 0, 0for train_images_batch, train_labels_batch in mini_batches(train_images, train_labels, FLAGS.batch_size, FLAGS.crop_height, FLAGS.crop_width):train_labels_batch = one_hot(train_labels_batch, FLAGS.num_classes)_, err, acc = sess.run([train_op, loss, accuracy], feed_dict={input: train_images_batch, output: train_labels_batch})train_loss += errtrain_acc += accn_batch += 1print("   train loss: %f" % (np.sum(train_loss) / n_batch))print("   train acc: %f" % (np.sum(train_acc) / n_batch))# validationval_loss, val_acc, n_batch = 0, 0, 0for val_images_batch, val_labels_batch in mini_batches(val_images, val_labels, FLAGS.batch_size, FLAGS.crop_height, FLAGS.crop_width):val_labels_batch = one_hot(val_labels_batch, FLAGS.num_classes)err, acc = sess.run([loss, accuracy], feed_dict={input: val_images_batch, output: val_labels_batch})val_loss += errval_acc += accn_batch += 1print("   validation loss: %f" % (np.sum(val_loss) / n_batch))print("   validation acc: %f" % (np.sum(val_acc) / n_batch))# Create directories if neededif not os.path.isdir("checkpoints"):os.makedirs("checkpoints")saver.save(sess, "%s/model.ckpt" % ("checkpoints"))sess.close()if __name__ == '__main__':tf.app.run()

运行结果:

   validation loss: 1.332082validation acc: 0.572750

估计是网络太小,而且这几篇文章旨在过下分类网络,对于数据增扩,调参技巧放在以后写。

tensorflow随笔——LeNet网络相关推荐

  1. tensorflow随笔——VGG网络

    这次用slim搭个稍微大一点的网络VGG16,VGG16和VGG19实际上差不多,所以本例程的代码以VGG16来做5类花的分类任务. VGG网络相比之前的LeNet,AlexNet引入如下几个特点: ...

  2. tensorflow 随笔-----------VGG网络的模型的复现

    推荐个课程北京大学的tensorflow 笔记 VGG网络是谷歌千分类用的,实现的是对图像的识别 由于模型过大,需要的可以私聊我(13072509383微信) vgg16.py 网络对参数的读取 im ...

  3. tensorflow之lenet训练手写字及应用

    我的目标是用tensorflow实现视频质量诊断,但是馒头还是需要一个一个吃,先从工程应用的角度实现用python训练手写字,并在C#中调用识别自己写的手写字. 思路如下: 使用lenet网络训练完数 ...

  4. 使用Keras训练Lenet网络来进行手写数字识别

    使用Keras训练Lenet网络来进行手写数字识别 这篇博客将介绍如何使用Keras训练Lenet网络来进行手写数字识别. LeNet架构是深度学习中的一项开创性工作,演示了如何训练神经网络以端到端的 ...

  5. TensorFlow创建DeepDream网络

    TensorFlow创建DeepDream网络 Google 于 2014 年在 ImageNet 大型视觉识别竞赛(ILSVRC)训练了一个神经网络,并于 2015 年 7 月开放源代码. 该网络学 ...

  6. Ubuntu 14.04 64位机上用Caffe+MNIST训练Lenet网络操作步骤

    1.      将终端定位到Caffe根目录: 2.      下载MNIST数据库并解压缩:$ ./data/mnist/get_mnist.sh 3.      将其转换成Lmdb数据库格式:$ ...

  7. 使用Tensorflow实现残差网络ResNet-50

    这篇文章讲解的是使用Tensorflow实现残差网络resnet-50. 侧重点不在于理论部分,而是在于代码实现部分.在github上面已经有其他的开源实现,如果希望直接使用代码运行自己的数据,不建议 ...

  8. LeNet网络配置文件 lenet_train_test.prototxt

    .prototxt文件 定义了网络的结构,我们可以通过它了解网络是如何设计的,也可以建立属于自己的网络.这种格式来源于Google的Protocol Buffers,后来被开源,主要用于海量数据存储. ...

  9. 【深度学习】使用tensorflow实现VGG19网络

    转载注明出处:http://blog.csdn.net/accepthjp/article/details/70170217 接上一篇AlexNet,本文讲述使用tensorflow实现VGG19网络 ...

最新文章

  1. C#中选择文件的例子
  2. oracle reverse 反转函数
  3. python画车辆轨迹图,在python中绘制轨道轨迹
  4. python调用random失败_python怎么调用random
  5. oracle数据库日期格式的运算,Oracle时间类型date,timestamp时间差计算
  6. 8个树莓派超级计算机_6 个可以尝试的树莓派教程
  7. R语言第八讲 评估模型之交叉验证法分析案例
  8. 工程数学(数值分析)第三讲:求解线性代数方程组
  9. python基础教程-Python入门教程完整版(懂中文就能学会)
  10. Windows里Anaconda-Navigator无法打开(解决)
  11. 阶段3 2.Spring_10.Spring中事务控制_6 spring基于XML的声明式事务控制-配置步骤
  12. android中的MotionEvent 及其它事件处理
  13. xposed 修改手机定位
  14. ewb交通灯报告和文件_简易交通灯控制逻辑电路设计报告
  15. 特征工程(Feature Engineering)
  16. uni-app本地打包配置过程中出现的问题
  17. Vbox中Ubuntu和win7主机文件共享
  18. 头条号发视频为什么没有收益,在头条号发搞笑视频收益大吗
  19. Python学习笔记——eofs.standard的使用
  20. python语言下同一个类下有多个函数,其中一个函数想调用另外一个函数里面的变量怎么调用

热门文章

  1. MySQL Workbench 使用 (3):数据库备份与恢复
  2. 贺利坚老师汇编课程42笔记:DIV除法指令
  3. ScheduledThreadPoolExecutor Usage
  4. 2019.2.2牛客寒假算法基础集训营
  5. dev 域名与 Chrome
  6. linux下搭建python机器学习环境
  7. STL之ForwordList
  8. 探寻 JavaScript 逻辑运算符(与、或)的真谛
  9. 主题切换时如何主动去刷新一些资源?
  10. 对Spring的一些理解