前言

之前写代码的时候都要预先初始化权重,还得担心变量是否会出现被重复定义的错误,但是看网上有直接用tf.layers构建网络,很简洁的方法。

这里主要尝试了不预定义权重,是否能够实现正常训练、模型保存和调用,事实证明阔以。

验证

训练与模型保存

很简洁的代码直接五十行实现了手写数字的网络训练

import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_data
mnist=input_data.read_data_sets("./TensorFlow-Examples-master/examples/3_NeuralNetworks/tmp",one_hot=True)steps=5000
batch_size=100
def conv_network(x):x=tf.reshape(x,[-1,28,28,1])#第一层卷积conv1=tf.layers.conv2d(inputs=x,filters=32,kernel_size=[5,5],activation=tf.nn.relu)conv1=tf.layers.max_pooling2d(conv1,pool_size=[2,2],strides=[2,2])#第二层卷积conv2=tf.layers.conv2d(inputs=conv1,filters=64,kernel_size=[3,3],activation=tf.nn.relu)conv2=tf.layers.max_pooling2d(inputs=conv2,pool_size=[2,2],strides=[2,2])#第三层卷积conv3=tf.layers.conv2d(inputs=conv2,filters=32,kernel_size=[3,3],activation=tf.nn.relu)conv3=tf.layers.max_pooling2d(inputs=conv3,pool_size=[2,2],strides=[2,2])#全连接fc1=tf.layers.flatten(conv3)fc1=tf.layers.dense(fc1,500,activation=tf.nn.relu)#输出fc2=tf.layers.dense(fc1,10)fc2=tf.nn.softmax(fc2) #因为loss里面用了softmax_cross_enrtopy,所以此行去掉return fc2input_img=tf.placeholder(dtype=tf.float32,shape=[None,28*28],name='X')
input_lab=tf.placeholder(dtype=tf.int32,shape=[None,10])#损失函数
output_lab=conv_network(input_img)
logit_loss=tf.nn.softmax_cross_entropy_with_logits_v2(labels=input_lab,logits=output_lab)
loss=tf.reduce_mean(tf.cast(logit_loss,tf.float32)) #可以去掉,因为softmax_cross_entroy自带求均值
optim=tf.train.AdamOptimizer(0.001).minimize(loss)
#评估函数
pred_equal=tf.equal(tf.arg_max(output_lab,1),tf.arg_max(input_lab,1))
accuracy=tf.reduce_mean(tf.cast(pred_equal,tf.float32))init=tf.global_variables_initializer()
saver=tf.train.Saver()
tf.add_to_collection('pred',output_lab)
with tf.Session() as sess:sess.run(init)for step in range(steps):data_x,data_y=mnist.train.next_batch(batch_size)sess.run(optim,feed_dict={input_img:data_x,input_lab:data_y})if step%100==0 or step==1:accuracy_val=sess.run(accuracy,feed_dict={input_img:data_x,input_lab:data_y})print('step'+str(step)+' ,loss '+'{:.4f}'.format(accuracy_val))print('training finished!!')saver.save(sess,'./layermodel/CNN_layer')

【更新日志】 2019-9-2
学艺不精,上面由于损失函数用的softmax_cross_entropy_with_logits_v2,所以输出会被归一化,得分也是一个batch的损失均值,因而构建网络的时候,没必要用最后下面两句话:

loss=tf.reduce_mean(tf.cast(logit_loss,tf.float32))
fc2=tf.nn.softmax(fc2)

调用模型

实现单张手写数字的识别

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import cv2
saver=tf.train.import_meta_graph('./layermodel/CNN_layer.meta')
sess=tf.Session()
saver.restore(sess,'./layermodel/CNN_layer')
graph=tf.get_default_graph()
print(graph.get_all_collection_keys())
#['pred', 'train_op', 'trainable_variables', 'variables']
print(graph.get_collection('trainable_variables'))
prediction=graph.get_collection('pred')
X=graph.get_tensor_by_name('X:0')
#读取图片
image=cv2.imread('./mnist/test/2/2_2.png')
image=cv2.cvtColor(image,cv2.COLOR_BGR2GRAY)
plt.imshow(image)
plt.show()
#显示图片
input_img=np.reshape(image,[1,28*28])
result=sess.run(prediction,feed_dict={X:input_img})
print(result)
#[array([[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)]

后记

其实主要是为了后续使用tf.layers里面的其它结构比如BN做准备,因为代码越复杂,写起来越恶心,不如现在看看如何简化代码,第一步就是去除了权重的预定义,后续慢慢研究其它的。

训练代码:链接:https://pan.baidu.com/s/1gmX-YBkz4nNG3RpJ_rEBKQ 密码:o8u2

测试代码:链接:https://pan.baidu.com/s/1ME9pgyM9TNQadmzMeURlNg 密码:5z7k

【TensorFlow-windows】学习笔记八——简化网络书写相关推荐

  1. Docker学习笔记(八)Docker0网络解析

    Docker0 清空环境 清空所有的images docker rmi -f $(docker images -qa) 这里使用的我的阿里云服务器 ip addr查看网卡 Linux ip 命令与 i ...

  2. 黑马程序员_java自学学习笔记(八)----网络编程

    黑马程序员_java自学学习笔记(八)----网络编程 android培训. java培训.期待与您交流! 网络编程对于很多的初学者来说,都是很向往的一种编程技能,但是很多的初学者却因为很长一段时间无 ...

  3. TensorFlow 深度学习笔记 TensorFlow实现与优化深度神经网络

    TensorFlow 深度学习笔记 TensorFlow实现与优化深度神经网络 转载请注明作者:梦里风林 Github工程地址:https://github.com/ahangchen/GDLnote ...

  4. python3.4学习笔记(八) Python第三方库安装与使用,包管理工具解惑

    python3.4学习笔记(八) Python第三方库安装与使用,包管理工具解惑 许多人在安装Python第三方库的时候, 经常会为一个问题困扰:到底应该下载什么格式的文件? 当我们点开下载页时, 一 ...

  5. Linux+javaEE学习笔记之Linux网络环境配置

    Linux+javaEE学习笔记之Linux网络环境配置 网络知识简单介绍: Ip地址是:IP地址是IP协议提供的一种统一的地址格式,它为互联网上的每一个网络和每一台主机分配一个逻辑地址,以此来屏蔽物 ...

  6. python3第三方库手册_python3.4学习笔记(八) Python第三方库安装与使用,包管理工具解惑...

    python3.4学习笔记(八) Python第三方库安装与使用,包管理工具解惑 许多人在安装Python第三方库的时候, 经常会为一个问题困扰:到底应该下载什么格式的文件? 当我们点开下载页时, 一 ...

  7. OpenGL学习笔记(八):进一步理解VAO、VBO和SHADER,并使用VAO、VBO和SHADER绘制一个三角形

    原博主博客地址:http://blog.csdn.net/qq21497936 本文章博客地址:http://blog.csdn.net/qq21497936/article/details/7888 ...

  8. ReactJS学习笔记八:动画

    ReactJS学习笔记八:动画 分类: react学习笔记 javascript2015-07-06 20:27 321人阅读 评论(0) 收藏 举报 react动画 目录(?)[+] 这里只讨论Re ...

  9. TensorFlow Lite学习笔记

    TensorFlow Lite学习笔记 目录 TensorFlow Lite学习笔记 Tensorflow LIte Demo 模型固化freeze_graph和模型优化optimize_for_in ...

最新文章

  1. 网络编程3之TCP/IP协议
  2. wxWidgets:wxArray<T>类用法
  3. vue require动态路径图片报错_Vue 动态生成路由结构
  4. 如何处理错误信息 Pricing procedure could not be determined
  5. FireDAC 中文字段过滤问题
  6. FreeSql (二十五)延时加载
  7. Centos 下PHP编译安装fileinfo扩展
  8. brew报错:in `initialize‘: Version value must be a string; got a NilClass () (TypeError)
  9. python 颜色空间转换_python opencv入门 颜色空间转换(9)
  10. 红帽RHCE培训-课程3笔记内容2
  11. 【经验】聊自己非计算机专业做程序员的经验
  12. Linux驱动的ioctl函数简要说明
  13. 设置布局默认为LinearLayout,却成了RelativeLayout
  14. mac HBux连接夜神模拟器
  15. 关于胶囊检测的思考-代码实现
  16. 天翼云服务器迁移阿里云_Cloudops:云迁移的被忽略的部分
  17. realme支持鸿蒙系统,骁龙888+首批搭载安卓12,realme真我GT真香售价2499元起
  18. 带缓冲的输入/输入流
  19. android 常用加密,分享一下Android各种类型的加密
  20. bzoj 4755: [Jsoi2016]扭动的回文串 manachar+hash+二分

热门文章

  1. C语言编程序输出SCHAR_MAX的,运用堆栈把十进制变换成二进制
  2. 深度学习之基于CNN实现天气识别
  3. oracle管理认证方式,关于Oracle数据库管理员认证方法简述
  4. 计算机信息导论论文,电子信息导论论文2000字
  5. python题库选择填空_python练习题4.18猴子选大王
  6. 办公室自动化系统_RPA:办公自动化的下一站
  7. don't run elasticsearch as root
  8. shell函数日期之间的操作(日期转秒,日期间隔秒,日期间隔天)
  9. 【labelme】改造labelme
  10. VS2013出现“无法找到“xxx.exe”的调试信息,或者调试信息不匹配”错误解决方案