下例是用tensorflow写的2层cnn+2层fc的一个卷积神经网络做mnist的分类例子,旨在简单明了,过一遍TF代码的分类流程。

实例只有两个文件:

train.py:数据读取,模型训练。

# coding=utf-8
import tensorflow as tf
import model
import os
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('dataset/', one_hot=True)tf.app.flags.DEFINE_integer('image_width', 28, 'width of image')
tf.app.flags.DEFINE_integer('image_height', 28, 'height of image')
tf.app.flags.DEFINE_integer('channel', 1, 'channel of image')
tf.app.flags.DEFINE_float('keep_drop', 1.0, 'keep drop out')
tf.app.flags.DEFINE_float('lr', 0.001, 'learning rate')
tf.app.flags.DEFINE_integer('batch_size', 32, 'batch size')
tf.app.flags.DEFINE_integer('epochs', 100, 'num of epoch')
tf.app.flags.DEFINE_integer('num_classes', 10, 'num of class')
tf.app.flags.DEFINE_string('checkpoints', './checkpoints/model.ckpt', 'path of checkpoints')
tf.app.flags.DEFINE_boolean('continue_training', False, 'continue')
FLAGS = tf.app.flags.FLAGSdef main(_):input = tf.placeholder(dtype=tf.float32, shape=[None, FLAGS.image_width*FLAGS.image_height])output = tf.placeholder(dtype=tf.int32, shape=[None, FLAGS.num_classes])# Control GPU resource utilizationconfig = tf.ConfigProto(allow_soft_placement=True)config.gpu_options.allow_growth = Truesess = tf.Session(config=config)# build networklogits = model.build(input, FLAGS.image_height, FLAGS.image_width, FLAGS.channel, FLAGS.keep_drop, True)# losscross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=output))# optimitertrain_op = tf.train.AdamOptimizer(FLAGS.lr).minimize(cross_entropy)# evalutioncorrect_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(output, 1))accuracy = tf.reduce_mean(tf.cast(correct_prediction, 'float'))with sess.as_default():# initialsaver = tf.train.Saver(max_to_keep=1000)sess.run(tf.global_variables_initializer())# Restore weights fileif FLAGS.continue_training:saver.restore(sess, FLAGS.checkpoints)# begin trainfor epoch in range(FLAGS.epochs):for k in range(int(mnist.train.num_examples / FLAGS.batch_size)):batch = mnist.train.next_batch(FLAGS.batch_size)_, network, loss, acc = sess.run([train_op, logits, cross_entropy, accuracy], feed_dict={input: batch[0], output: batch[1]})print('loss : %f accuracy : %f'% (loss, acc))print('精确率:', accuracy.eval({input: mnist.test.images, output: mnist.test.labels}))# Create directories if neededif not os.path.isdir("checkpoints"):os.makedirs("checkpoints")saver.save(sess, "%s/model.ckpt" % ("checkpoints"))if __name__ == '__main__':tf.app.run()

model.py:网络搭建。

import tensorflow as  tf
import numpy as npdef weight_variable(shape):initial = tf.truncated_normal(shape, stddev=0.1)return tf.Variable(initial)def bias_variable(shape):initial = tf.constant(0.1, shape=shape)return tf.Variable(initial)def conv2d(x, W):return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')def max_pool(x):return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')def build(inputs, height, width, channel, keep_drop, train):x_image = tf.reshape(inputs, [-1, height, width, channel])# block_1weight_1 = weight_variable(([5, 5, 1, 32]))bias_1 = bias_variable([32])conv_1 = tf.nn.relu(conv2d(x_image, weight_1) + bias_1)pool_1 = max_pool(conv_1)# block_2weight_2 = weight_variable([5, 5, 32, 64])bias_2 = bias_variable([64])conv_2 = tf.nn.relu(conv2d(pool_1, weight_2) + bias_2)pool_2 = max_pool(conv_2)# fc_1fc_weight_1 = weight_variable([7 * 7 * 64, 1024])fc_bias_1 = bias_variable([1024])flat = tf.reshape(pool_2, [-1, 7 * 7 * 64])fc_1 = tf.nn.relu(tf.matmul(flat, fc_weight_1) + fc_bias_1)# Dropoutif train == True:fc_1 = tf.nn.dropout(fc_1, keep_prob=keep_drop)fc_weight_2 = weight_variable([1024, 10])fc_bias_2 = bias_variable([10])logits = tf.nn.softmax(tf.matmul(fc_1, fc_weight_2) + fc_bias_2)return logits

运行结果:

loss : 1.461150 accuracy : 1.000000
loss : 1.461150 accuracy : 1.000000
loss : 1.461150 accuracy : 1.000000
loss : 1.461150 accuracy : 1.000000
loss : 1.461150 accuracy : 1.000000
loss : 1.461150 accuracy : 1.000000
精确率: 0.9924

本例子结构非常简单,如有细节上或其他问题欢迎留言讨论。

tensorflow随笔——简单的卷积神经网络分类实例相关推荐

  1. tensorflow随笔——简单的循环神经网络分类实例

    继上一篇用简单的卷积神经网络做mnist分类之后,本篇文章采用RNN替换CNN写了一个mnist分类实例.实例中包含两个文件: train.py:数据加载和训练代码. # coding=utf-8 i ...

  2. 深度学习之利用TensorFlow实现简单的卷积神经网络(MNIST数据集)

    卷积神经网络(Convolutional Neural Networks, CNN)是一类包含卷积计算且具有深度结构的前馈神经网络(Feedforward Neural Networks),是深度学习 ...

  3. tensorflow 搭建简单的卷积神经网络,输入二维数组完成分类

    目录 一.数据处理 二.搭建cnn模型 三.训练并测试模型 一.数据处理 导入数据 #导入数据 X=pd.read_csv('data.csv',header=None) X=np.array(tes ...

  4. Pytorch:手把手教你搭建简单的卷积神经网络(CNN),实现MNIST数据集分类任务

    关于一些代码里的解释,可以看我上一篇发布的文章,里面有很详细的介绍!!! 可以依次把下面的代码段合在一起运行,也可以通过jupyter notebook分次运行 第一步:基本库的导入 import n ...

  5. CV之IC之AlexNet:基于tensorflow框架采用CNN卷积神经网络算法(改进的AlexNet,训练/评估/推理)实现猫狗分类识别案例应用

    CV之IC之AlexNet:基于tensorflow框架采用CNN卷积神经网络算法(改进的AlexNet,训练/评估/推理)实现猫狗分类识别案例应用 目录 基于tensorflow框架采用CNN(改进 ...

  6. 简单的卷积神经网络,实现手写英文字母识别

    简单的卷积神经网络,实现手写英文字母识别 1 搭建Python运行环境(建议用Anaconda),自学Python程序设计 安装Tensorflow.再安装Pycharm等环境.(也可用Pytorch ...

  7. 吴恩达深度学习课程笔记(四):卷积神经网络2 实例探究

    吴恩达深度学习课程笔记(四):卷积神经网络2 实例探究 吴恩达深度学习课程笔记(四):卷积神经网络2 实例探究 2.1 为什么要进行实例探究 2.2 经典网络 LeNet-5 AlexNet VGG- ...

  8. ubuntu16.04 简单的卷积神经网络 cpu和gpu训练时间对比

    我的电脑配置: cpu:i5-4200H gpu:gtx 950M 昨天测试了训练一般的神经网络使用cpu和gpu各自的速度,使用gpu比使用cpu大概能节省42%的时间,当时我以为这么个程度已经很不 ...

  9. 【机器学习】百度飞桨AI Studio平台项目:基于卷积神经网络分类方法的人脸颜值打分

    基于卷积神经网络分类方法的人脸颜值打分 说在前面 实验数据 解决过程 1.Precondition 2.Dataset Preparation 3.Network Configuration 4.Mo ...

最新文章

  1. 【090】Excel VBA 基础
  2. 笔记本卡顿不流畅是什么原因_电脑卡顿不流畅是什么原因
  3. hdu3697(贪心+暴力)
  4. Thread Join 讲解
  5. 天玥运维安全网关_智慧灯杆网关
  6. mysql 5.5 5.6 备份库_mysql5.5备份数据库里面除系统库外的所有数据库
  7. u盘循环冗余能修复吗_激素脸怎么办?激素脸还能改善修复好吗?
  8. html pre标签增加行号,vue使用highlight.js 添加行号
  9. Mac 系统SourceTree 配置VSCode代码对比工具
  10. ISO14001环境管理体系认证好处
  11. 33个地区发iPhone5,老外纳闷中国没人排队_-Chaz-_新浪博客
  12. linux的XDG(X Desktop Group)基本目录规范
  13. Excel中使用条件格式(比较两列将内容不同用颜色标识)
  14. ODBC连接数据库使用动态密码
  15. oracle占用io,解决 oracle IO占用率很高的问题
  16. Quality-Estimation1 (翻译质量评价-复现 WMT2018 阿里论文结果)
  17. 咨询答疑:从产品设计到康威定律
  18. VLC Media Player
  19. 20162316刘诚昊 第三周学习总结
  20. 新手学编程,如何入门?

热门文章

  1. opencv 实现图像高斯金字塔
  2. C++ STL 容器元素排列之next_permutation和prev_permutation的使用示范
  3. js系列教程13-原型、原型链、作用链、闭包全解
  4. python网络爬虫系列教程——python中lxml库应用全解(xpath表达式)
  5. jquery系列教程5-动画操作全解
  6. 用自己数据集训练Mask_RCNN代码
  7. pmod ad2 digilent 提供的pmodad2.c和pmodad2.h
  8. AD画电路的踩坑点——继电器线圈的正负性以G6K-2F-Y为例
  9. windows安装解压版mysql
  10. redis--服务器与客户端