#coding=utf-8

import cifar10,cifar10_input

import tensorflow as tf

import numpy as np

import time

# define max_iter_step batch_size

max_iter_step = 1000

batch_size = 128

# define variable_with_weight_loss

# 和之前定义的weight有所不同,

# 这里定义附带loss的weight,通过权重惩罚避免部分权重系数过大,导致overfitting

def variable_with_weight_loss(shape,stddev,w1):

var = tf.Variable(tf.truncated_normal(shape,stddev=stddev))

if w1 is not None:

weight_loss = tf.multiply(tf.nn.l2_loss(var),w1,name='weight_loss')

tf.add_to_collection('losses',weight_loss)

return var

# 下载数据集 - 调用cifar10函数下载并解压

cifar10.maybe_download_and_extract()

cifar_dir = '/tmp/cifar10_data/cifar-10-batches-bin'

# 采用 data augmentation进行数据处理

# 生成训练数据,训练数据通过cifar10_input的distort变化

images_train, labels_train = cifar10_input.distorted_inputs(data_dir=cifar_dir,batch_size=batch_size)

# 测试数据(eval_data 测试数据)

images_test,labels_test = cifar10_input.inputs(eval_data=True,data_dir=cifar_dir,batch_size=batch_size)

# 创建输入数据,采用 placeholder

x_input = tf.placeholder(tf.float32,[batch_size,24,24,3])

y_input = tf.placeholder(tf.int32,[batch_size])

# 创建第一个卷积层 input:3(channel) kernel:64 size:5*5

weight1 = variable_with_weight_loss(shape=[5,5,3,64],stddev=5e-2,w1=0.0)

bias1 = tf.Variable(tf.constant(0.0,shape=[64]))

conv1 = tf.nn.conv2d(x_input,weight1,[1,1,1,1],padding='SAME')

relu1 = tf.nn.relu(tf.nn.bias_add(conv1,bias1))

pool1 = tf.nn.max_pool(conv1,ksize=[1,3,3,1],strides=[1,2,2,1],padding='SAME')

norm1 = tf.nn.lrn(pool1,4,bias=1.0,alpha=0.001/9.0,beta=0.75)

# 创建第二个卷积层 input:64 kernel:64 size:5*5

weight2 = variable_with_weight_loss(shape=[5,5,64,64],stddev=5e-2,w1=0.0)

bias2 = tf.Variable(tf.constant(0,1,shape=[64]))

conv2 = tf.nn.conv2d(norm1,weight2,[1,1,1,1],padding='SAME')

relu2 = tf.nn.relu(tf.nn.bias_add(conv2,bias2))

norm2 = tf.nn.lrn(relu2,4,bias=1.0,alpha=0.001/9.0,beta=0.75)

pool2 = tf.nn.max_pool(norm2,ksize=[1,3,3,1],strides=[1,2,2,1],padding='SAME')

# 创建第三个层-全连接层 output:384

reshape = tf.reshape(pool2,[batch_size,-1])

dim = reshape.get_shape()[1].value

weight3 = variable_with_weight_loss(shape=[dim,384],stddev=0.04,w1=0.004)

bias3 = tf.Variable(tf.constant(0.1,shape=[384]))

local3 = tf.nn.relu(tf.matmul(reshape,weight3)+bias3)

# 创建第四个层-全连接层 output:192

weight4 = variable_with_weight_loss(shape=[384,192],stddev=0.04,w1=0.004)

bias4 = tf.Variable(tf.constant(0.1,shape=[192]))

# 最后一层 output:10

weight5 = variable_with_weight_loss(shape=[192,10],stddev=1/192.0,w1=0.0)

bias5 = tf.Variable(tf.constant(0.0,shape=[10]))

results = tf.add(tf.matmul(local4,weight5),bias5)

# 定义loss

def loss(results,labels):

labels = tf.cast(labels,tf.int64)

cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=results,labels=labels,name='cross_entropy_per_example')

cross_entropy_mean = tf.reduce_mean(cross_entropy,name='cross_entropy')

tf.add_to_collection('losses',cross_entropy_mean)

return tf.add_n(tf.get_collection('losses'),name='total_loss')

# 计算loss

loss = loss(results,y_input)

train_op = tf.train.AdamOptimizer(1e-3).minimize(loss) # Adam

top_k_op = tf.nn.in_top_k(results,y_input,1) # top1 准确率

sess = tf.InteractiveSession() # 创建session

tf.global_variable_initializer().run() # 初始化全部模型

tf.train.start_queue_runners() # 启动多线程加速

# 开始训练

for step in range(max_steps):

start_time = time.time()

image_batch,label_batch = sess.run([images_train,labels_train])

_,loss_value = sess.run([train_op,loss],

feed_dict={x_input:image_batch, y_input:label_batch})

duration = time.time() - start_time

if step % 10 == 0:

examples_per_sec = batch_size/duration

sec_per_batch = float(duration)

format_str = ('step%d,loss=%.2f(%.1fexamples/sec;%.3fsec/batch')

print(format_str % (step,loss_value,examples_per_sec,sec_per_batch))

# 评测模型在测试集上的准确度

num_examples = 10000

import math

num_iter = int(math.ceil(num_examples/batch_size))

true_count = 0

total_sample_count = num_iter * batch_size

step = 0

while step < num_iter:

image_batch,label_batch = sess.run([images_test,labels_test])

predictions = sess.run([top_k_op],feed_dict={x_input:image_batch,y_input:label_batch})

true_count += np.sum(predictions)

step += 1

# 打印结果

precision = true_count / total_sample_count

print('precision @ 1 =%.3f' % precision)

tensorflow 读取cifar_浅入浅出TensorFlow 4 - 训练CIFAR数据相关推荐

  1. 浅入浅出深度学习理论实践

    全文共9284个字,40张图,预计阅读时间30分钟. 前言 之前在知乎上看到这么一个问题:在实际业务里,在工作中有什么用得到深度学习的例子么?用到 GPU 了么?,回头看了一下自己写了这么多东西一直围 ...

  2. [科普]浅入浅出Liunx Shellcode

    创建时间:2008-05-13 文章属性:原创 文章提交: pr0cess  (pr0cess_at_cnbct.org) 浅入浅出Liunx Shellcode /*---------------- ...

  3. Java 注解 (Annotation)浅入深出

    Java 注解 (Annotation)浅入深出 本文主要参考与借鉴frank909 文章,但更为简单,详细. Annotation 中文译过来就是注解.标释的意思.Annotation是一种应用于类 ...

  4. 「游戏引擎 浅入浅出」项目介绍

    「游戏引擎 浅入浅出」是一本开源电子书,PDF/随书代码/资源下载: https://github.com/ThisisGame/cpp-game-engine-book 项目介绍 README 本书 ...

  5. 浅入浅出Oracle Spatial GeoRaster 10g影像数据管理(2)

    浅入浅出Oracle Spatial GeoRaster  10g 影像数据管理(2)--物理存储 1.物理存储方式概要      在上个部分<浅入浅出Oracle Spatial GeoRas ...

  6. 浅入浅出Javac编译原理——爪哇岛探险(1)

    浅入浅出Javac编译原理 Java语言是当今程序员中使用最广的语言,不光是从语言本身来说,还包括了与Java相关的一些概念.例如JDK,J2EE,JVM等等.还不断有新的语言出现,如groove,s ...

  7. 「游戏引擎 浅入浅出」前言

    「游戏引擎 浅入浅出」是一本开源电子书,Github地址: https://github.com/ThisisGame/cpp-game-engine-book 为什么写这本书? 在与同事沟通时,会提 ...

  8. 编译原理代码生成器java_浅入浅出Javac编译原理

    浅入浅出Javac编译原理 Java语言是当今程序员中使用最广的语言,不光是从语言本身来说,还包括了与Java相关的一些概念.例如JDK,J2EE,JVM等等.还不断有新的语言出现,如groove,s ...

  9. 浅入深出之Java集合框架(上)

    Java中的集合框架(上) 由于Java中的集合框架的内容比较多,在这里分为三个部分介绍Java的集合框架,内容是从浅到深,如果已经有java基础的小伙伴可以直接跳到浅入深出之Java集合框架(下). ...

  10. 浅入深出之Java集合框架(中)

    Java中的集合框架(中) 由于Java中的集合框架的内容比较多,在这里分为三个部分介绍Java的集合框架,内容是从浅到深,如果已经有java基础的小伙伴可以直接跳到浅入深出之Java集合框架(下). ...

最新文章

  1. MySQL面试题 | 附答案解析(十八)
  2. ubuntu12.04编译android4.0源代码Deug2
  3. linux下 zip解压 tar解压 gz解压 bz2等各种解压文件命令
  4. 2017-2018-1 20155332实验三 实时系统报告
  5. java io之图片存取
  6. 美团延长旅行订单免费取消保障政策至2月29日
  7. 卸载idea2020删除以前的配置_推荐一款只有5M大小的绿色良心的卸载工具!
  8. C# 以MDF文件连接数据库
  9. 译码器(24译码器,38译码器)笔记
  10. 在vue中使用html表格
  11. 【LeetCode】75. Sort Colors(颜色排序)-C++实现的两种方法及超详细图解
  12. Smart Link概述
  13. Unity UGUI Inputfield 回车submit 按下Enter回车完成
  14. 将数组分成两部分,使得 |sum1 - sum2| 最小. LeetCode - 1049
  15. 所谓键位冲突和无冲突的各种原理
  16. 【本人秃顶程序员】程序员不要去这样的公司
  17. Python语言程序设计 - 北京理工大学 网课所有资料(源码,pdf,ppt课件,视频等)
  18. 2022DASCTF Apr X FATE 防疫挑战赛 部分web复现
  19. RISC-V CSR 相关指令集
  20. 汽车之家字体加密破解(CSS样式反爬)

热门文章

  1. 基于AVR和MT8870的远程家电控制系统设计
  2. suse mysql ERROR1045_Suse发生了错误Access denied for user #39;#39;@#39;localhost#39; toamp;...
  3. 【控制】《多智能体系统的协同群集运动控制》陈杰老师-第6章-参数不确定的高阶非线性多智能体系统一致性控制
  4. 3.6 权值初始化-机器学习笔记-斯坦福吴恩达教授
  5. 内核启动流程分析(一)编译体验
  6. u-boot分析之makefile分析(二)
  7. Xilinx IP核之FIFO
  8. supervisor
  9. extern C 在c 与 cxx间的使用
  10. IT兄弟连 JavaWeb教程 EL表达式获取对象的属性以及数组的元素