最近看到tensorflow训练cifar10数据集,说实话相比于mnist数据集,cifar10有了一个质的飞跃,从单通道灰度图像转变到三通道彩色图像。

cifar10

下面来简单介绍下cifar10数据集,该数据集共有60000张彩色图像,这些图像是32*32*3,分为10个类,每类6000张图。这里面有50000张用于训练,构成了5个训练批,每一批10000张图;另外10000用于测试,单独构成一批。测试批的数据里,取自10类中的每一类,每一类随机取1000张。抽剩下的就随机排列组成了训练批。注意一个训练批中的各类图像并不一定数量相同,总的来看训练批,每一类都有5000张图。Tensorflow自带有cifar的例子,可以在线下载cifar数据集,也可以离线下载,然后读取数据,在这里主要讲解如何搭建训练工程。下面请看代码:

import cifar10,cifar10_input
import tensorflow as tf
import numpy as np
import timemax_steps = 3000
batch_size = 128
data_dir = 'C:\\Users\\new\\Desktop\\cifar-10-batches-bin'def variable_with_weight_loss(shape, stddev, wl):var = tf.Variable(tf.truncated_normal(shape, stddev=stddev))if wl is not None:weight_loss = tf.multiply(tf.nn.l2_loss(var), wl, name='weight_loss')tf.add_to_collection('losses', weight_loss)return vardef loss(logits, labels):
#      """Add L2Loss to all the trainable variables.
#      Add summary for "Loss" and "Loss/avg".
#      Args:
#        logits: Logits from inference().
#        labels: Labels from distorted_inputs or inputs(). 1-D tensor
#                of shape [batch_size]
#      Returns:
#        Loss tensor of type float.
#      """
#      # Calculate the average cross entropy loss across the batch.
#将labels数据格式转换为int64labels = tf.cast(labels, tf.int64)cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels, name='cross_entropy_per_example')cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')tf.add_to_collection('losses', cross_entropy_mean)# The total loss is defined as the cross entropy loss plus all of the weight# decay terms (L2 loss).return tf.add_n(tf.get_collection('losses'), name='total_loss')###images_train, labels_train = cifar10_input.distorted_inputs(data_dir=data_dir,batch_size=batch_size)images_test, labels_test = cifar10_input.inputs(eval_data=True,data_dir=data_dir,batch_size=batch_size)
#images_train, labels_train = cifar10.distorted_inputs()
#images_test, labels_test = cifar10.inputs(eval_data=True)image_holder = tf.placeholder(tf.float32, [batch_size, 24, 24, 3])
label_holder = tf.placeholder(tf.int32, [batch_size])#logits = inference(image_holder)weight1 = variable_with_weight_loss(shape=[5, 5, 3, 64], stddev=5e-2, wl=0.0)
kernel1 = tf.nn.conv2d(image_holder, weight1, [1, 1, 1, 1], padding='SAME')
bias1 = tf.Variable(tf.constant(0.0, shape=[64]))
#将w*x和b加起来
conv1 = tf.nn.relu(tf.nn.bias_add(kernel1, bias1))
pool1 = tf.nn.max_pool(conv1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1],padding='SAME')
#LRN为局部响应归一化,一般在激活或者池化后使用,让强信号更强,弱信号更弱,通常很少使用,被dropout等方法替代
norm1 = tf.nn.lrn(pool1, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75)weight2 = variable_with_weight_loss(shape=[5, 5, 64, 64], stddev=5e-2, wl=0.0)
kernel2 = tf.nn.conv2d(norm1, weight2, [1, 1, 1, 1], padding='SAME')
bias2 = tf.Variable(tf.constant(0.1, shape=[64]))
conv2 = tf.nn.relu(tf.nn.bias_add(kernel2, bias2))
norm2 = tf.nn.lrn(conv2, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75)
pool2 = tf.nn.max_pool(norm2, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1],padding='SAME')#在这里通过reshape函数把结构化数据转变成向量数据格式,这一步就是把卷积层转换为全连接层
reshape = tf.reshape(pool2, [batch_size, -1])#这里是因为数据是以batch_size个存储的,不是单个,其实就是batch_size*单个数据
dim = reshape.get_shape()[1].value
weight3 = variable_with_weight_loss(shape=[dim, 384], stddev=0.04, wl=0.004)
bias3 = tf.Variable(tf.constant(0.1, shape=[384]))
local3 = tf.nn.relu(tf.matmul(reshape, weight3) + bias3)weight4 = variable_with_weight_loss(shape=[384, 192], stddev=0.04, wl=0.004)
bias4 = tf.Variable(tf.constant(0.1, shape=[192]))
local4 = tf.nn.relu(tf.matmul(local3, weight4) + bias4)weight5 = variable_with_weight_loss(shape=[192, 10], stddev=1/192.0, wl=0.0)
bias5 = tf.Variable(tf.constant(0.0, shape=[10]))
logits = tf.add(tf.matmul(local4, weight5), bias5)
#总的损失函数
loss = loss(logits, label_holder)train_op = tf.train.AdamOptimizer(1e-3).minimize(loss) #0.72top_k_op = tf.nn.in_top_k(logits, label_holder, 1)sess = tf.InteractiveSession()
tf.global_variables_initializer().run()tf.train.start_queue_runners()
###
for step in range(max_steps):start_time = time.time()image_batch,label_batch = sess.run([images_train,labels_train])_, loss_value = sess.run([train_op, loss],feed_dict={image_holder: image_batch, label_holder:label_batch})duration = time.time() - start_timeif step % 10 == 0:examples_per_sec = batch_size / durationsec_per_batch = float(duration)format_str = ('step %d, loss = %.2f (%.1f examples/sec; %.3f sec/batch)')print(format_str % (step, loss_value, examples_per_sec, sec_per_batch))###
num_examples = 10000
import math
num_iter = int(math.ceil(num_examples / batch_size))
true_count = 0
total_sample_count = num_iter * batch_size
step = 0
while step < num_iter:image_batch,label_batch = sess.run([images_test,labels_test])predictions = sess.run([top_k_op],feed_dict={image_holder: image_batch,label_holder:label_batch})true_count += np.sum(predictions)step += 1precision = true_count / total_sample_count
print('precision @ 1 = %.3f' % precision)

Tensorflow训练CIFAR10源代码相关推荐

  1. [深度学习-实践]Tensorflow 2.x应用ResNet SeNet网络训练cifar10数据集的模型在测试集上准确率 86%-87%,含完整代码

    环境 tensorflow 2.1 最好用GPU Cifar10数据集 CIFAR-10 数据集的分类是机器学习中一个公开的基准测试问题.任务的目标对一组32x32 RGB的图像进行分类,这个数据集涵 ...

  2. 使用猫狗大战数据集进行一次完整的TensorFlow训练

    1.简介 一直想将图片制作成tfrecords文件,然后在模型中运行一下.最初想用的数据集是mnist,但是跑的过程中一直出现问题.找到这一篇知乎上的博客,写的非常不错. 原博客地址:https:// ...

  3. TensorFlow基于cifar10数据集实现进阶的卷积网络

    TensorFlow基于cifar10数据集实现进阶的卷积网络 学习链接 CIFAR10模型及数据集介绍 综述 CIFAR10数据集介绍 CIFAR10数据集可视化 CIFAR10模型 CIFAR10 ...

  4. 将TensorFlow训练的模型移植到Android手机

    2019独角兽企业重金招聘Python工程师标准>>> 前言 本文中出现的TF皆为TensorFlow的简称. 先说两句题外话吧,TensorFlow 前两天热热闹闹的发布了正式版r ...

  5. 将TensorFlow训练好的模型迁移到Android APP上(TensorFlowLite)

    将TensorFlow训练好的模型迁移到Android APP上(TensorFlowLite) 1. 写在前面   最近在做一个数字手势识别的APP(关于这个项目,我会再写一篇博客仔细介绍,博客地址 ...

  6. 【深度学习】训练CIFAR-10数据集实现分类加测试

    网上有很多博主写的训练CIFAR-10的代码,本次只是单纯记录一下自己调试的一个程序,对于初学深度学习的小白可以参考,如有不对,请多多见谅!!! 一.CIFAR-10数据集由10个类的60000个32 ...

  7. Tensorflow学习笔记6:解决tensorflow训练过程中GPU未调用问题

    Tensorflow学习笔记6:解决tensorflow训练过程中GPU未调用问题 参考文章: (1)Tensorflow学习笔记6:解决tensorflow训练过程中GPU未调用问题 (2)http ...

  8. 使用TensorFlow训练WDL模型性能问题定位与调优

    简介 TensorFlow是Google研发的第二代人工智能学习系统,能够处理多种深度学习算法模型,以功能强大和高可扩展性而著称.TensorFlow完全开源,所以很多公司都在使用,但是美团点评在使用 ...

  9. 使用PaddleFluid和TensorFlow训练序列标注模型

    专栏介绍:Paddle Fluid 是用来让用户像 PyTorch 和 Tensorflow Eager Execution 一样执行程序.在这些系统中,不再有模型这个概念,应用也不再包含一个用于描述 ...

  10. 使用PaddleFluid和TensorFlow训练RNN语言模型

    专栏介绍:Paddle Fluid 是用来让用户像 PyTorch 和 Tensorflow Eager Execution 一样执行程序.在这些系统中,不再有模型这个概念,应用也不再包含一个用于描述 ...

最新文章

  1. 性能媲美BERT,参数量仅为1/300,谷歌最新的NLP模型
  2. Linux移植之auto.conf、autoconf.h、Mach-types.h的生成过程简析
  3. 在ListView中使用BaseAdapter进行适配
  4. Tomcat发布Web项目的两种方式
  5. Windows内核原理-同步IO与异步IO
  6. 在ASP.NET MVC3项目中,自定义404错误页面
  7. Quadratic equation(二次剩余)2019牛客多校第九场
  8. 阿里腾讯面试梳理个人成长经历分享
  9. java的恐怖推理游戏_胆小勿入!盘点一下2019年所有的恐怖游戏
  10. mysql安装配置jdbc_JDBC环境配置
  11. ie调取摄像头抓拍解决方案
  12. 批发表情包,掏出了python 3分钟爬取表情包素材,分享给你
  13. 启用IIS服务(运行中输入inetmgr打不开IIS管理器的解决办法)
  14. 无涂层无胶纸(UWF)的全球与中国市场2022-2028年:技术、参与者、趋势、市场规模及占有率研究报告
  15. 回归老博客(no zuo no dead)
  16. 数字电路为什么是低电平有效的多
  17. 下载keep运动软件_Keep下载_Keep安卓版下载_Keep app下载-太平洋下载中心
  18. 广州搬家公司 居民搬家 公司搬迁 事业单位搬迁全天服务
  19. 根治偏头痛及各种头痛病症
  20. awesome php

热门文章

  1. TensorFlow 安装
  2. 数据结构之排序算法Java实现(8)—— 线性排序之计数排序算法
  3. Shell 批量复制文件名相近的文件到指定文件名中
  4. 通过debug过程分析Struts2什么时候将Action对象放入了值栈ValueStack中
  5. 【原创】软件测试基础流程
  6. 如何将Eclipse中Web项目打成war包
  7. 干货干货:px和毫米之间的转换
  8. Python selenium报错:selenium.common.exceptions.ElementClickInterceptedException
  9. 03. 绝对不要以多态(polymorphically)方式处理数组
  10. struts2之自定义拦截器及拦截器生命周期分析