1、概述

卷积部分的知识点在博客:TensorFlow精进之路(三):两层卷积神经网络模型将MNIST未识别对的图片筛选出来已经写过,所以不再赘述。这一节简单聊聊tensorflow的编程基础。

2、会话Session

Tensorflow有“图”和“会话”的概念,“图”定义一个计算任务,计算则是在“会话”中完成。从一个“Hello world”理解这个概念吧。比如在python中,想打印“Hello world”,代码如下:

str = 'Hello world!'
print(str)

运行结果:

Hello world!

Tensorflow定义常量的函数是tf.constant()

我模仿上面的代码如下:

import tensorflow as tf
a = tf.constant('hello world!')
print(a)

运行结果:

Tensor("Const:0", shape=(), dtype=string)

意不意外,惊不惊喜?

将上面的代码改成下面的代码,

import tensorflow as tf
a = tf.constant('hello world!')
sess = tf.Session()
print(sess.run(a))
sess.close()

运行结果:

hello world!

这才是正确的操作姿势。图只是定义了应该怎么做,而会话Session才是真正的执行。但是一般不会像上面那种写法,而是使用with语法,如下:

a = tf.constant('hello world!')
with tf.Session() as sess:print(sess.run(a))

这样做的好处是如果忘记手动关闭会话,或者程序崩溃时,会自动关闭会话。

那我们上面打印的Tensor("Const:0", shape=(), dtype=string)是什么鬼?

为了弄清这个东西,我们再写个小demo,如下

import tensorflow as tf
b = tf.constant([[1, 1, 1], [2, 2, 2]], name='wei')
c = tf.constant([[3, 3, 3], [4, 4, 4]], name='fang')
d = b + cprint(b)
print(c)
print(d)

运行结果:

Tensor("wei:0", shape=(2, 3), dtype=int32)

Tensor("fang:0", shape=(2, 3), dtype=int32)

Tensor("add:0", shape=(2, 3), dtype=int32)

可以看到,第一个参数其实就是这个常量的名字,如果不设置名字的话,会默认给它一个名字,而第二个参数就是它的形状,第三个参数就是类型。

这小节只需要记住,图只是定义操作,计算在会话中进行就好了。

3、占位符placeholder

placeholder是用来保存数据的,在session运行阶段,通过feed_dict的字典结构给placeholder喂数据。占位符在tensorflow中非常常见的,可以看看前面的MNIST例子,我们在这里也写个demo,

import tensorflow as tfa = tf.placeholder(tf.float32, [2])
b = tf.placeholder(tf.float32, [2])
c = tf.add(a, b)with tf.Session() as sess:print(sess.run(c, feed_dict={a:[1.0, 2.0], b:[3.0, 4.0]}))

运行结果:

[4. 6.]

4、指定GPU运算

如果想指定用第一个GPU运算,只要将其放在with tf.device('/gpu:0'):里即可,还是用上面的例子,代码如下:

import tensorflow as tfa = tf.placeholder(tf.float32, [2])
b = tf.placeholder(tf.float32, [2])
c = tf.add(a, b)config = tf.ConfigProto(log_device_placement=True)
with tf.Session(config=config) as sess:with tf.device('/gpu:0'):print(sess.run(c, feed_dict={a:[1.0, 2.0], b:[3.0, 4.0]}))

运行结果:

Device mapping:

/job:localhost/replica:0/task:0/device:GPU:0 -> device: 0, name: GeForce GTX 950, pci bus id: 0000:01:00.0, compute capability: 5.2

Add: (Add): /job:localhost/replica:0/task:0/device:GPU:0

Placeholder_1: (Placeholder): /job:localhost/replica:0/task:0/device:GPU:0

Placeholder: (Placeholder): /job:localhost/replica:0/task:0/device:GPU:0

[4. 6.]

5、保存和加载模型

如果花了几天辛辛苦苦训练了一个模型,却没有保存训练结果,相当于杨白劳,所以要会保存模型的训练结果,以及使用训练结果。我们前面训练的MNIST和CIFAR10都有这方面的操作的,可以回去看看。

5.1、保存模型

5.1.1、tf.train.Saver

我们定义一个变量a,初始化为32,然后将它保存,代码如下:

import tensorflow as tf
a = tf.Variable(32, tf.float32, name='Nini')
saver = tf.train.Saver()
with tf.Session() as sess:sess.run(tf.global_variables_initializer())print(sess.run(a))saver.save(sess, 'wilf/wilf.cpkt')

运行结果:

32

看看wilf目录,生成了四个文件,

这里面怎么存数据呢?我们把它打印出来看看,代码如下:

from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
print_tensors_in_checkpoint_file('wilf/wilf.cpkt', None, True)

运行结果:

tensor_name:  nini

32

5.1.2、tf.train.MonitoredTrainingSession

我们在分析官方CIFAR10的代码时,还遇到了一个自动保存训练结果的方法,tf.train.MonitoredTrainingSession,这里就不说了,可以去看前面的例子。

5.2、加载模型

将上面的代码稍加修改,为了看到效果,我们将a的初始化值改为12,如下:

import tensorflow as tfa = tf.Variable(12, tf.float32, name='Nini')
saver = tf.train.Saver()
with tf.Session() as sess:saver.restore(sess, 'wilf/wilf.cpkt')print(sess.run(a))

运行结果:

32

6、变量

6.1、tf.Variable上面的例子其实已经使用过tf.Variable创建变量,我想强调的是,使用变量之前,一定要先使用

sess.run(tf.global_variables_initializer())对变量进行初始化,否则会出错,代码如下:import tensorflow as tf
a = tf.Variable(32, tf.float32, name='Nini')
with tf.Session() as sess:# sess.run(tf.global_variables_initializer())print(sess.run(a))

运行结果:

tensorflow.python.framework.errors_impl.FailedPreconditionError: Attempting to use uninitialized value weifang

[[Node: _retval_weifang_0_0 = _Retval[T=DT_INT32, index=0, _device="/job:localhost/replica:0/task:0/device:CPU:0"](weifang)]]

6.2、修改变量

既然说道变量,那怎么修改变量里的值呢?不可能直接用a = 11吧?Tensorflow提供了tf.assign()函数来修改变量,代码如下:

import tensorflow as tfa = tf.Variable(32, tf.float32, name='Nini')
with tf.Session() as sess:sess.run(tf.global_variables_initializer())print("before:", sess.run(a))print("after:", sess.run(tf.assign(a, 11)))

运行结果:

('before:', 32)

('after:', 11)

6.3、tf.get_variable

tf.get_variable也是定义变量的,比如:

var1 = tf.get_variable('wei', [1], dtype=tf.float32)

那么它跟tf.Variable有什么不同呢?我们写个demo,通过tf.Variable和tf.get_variable分别定义两个名字一样的变量:

import tensorflow as tfvar1 = tf.Variable(1, name='NI')
print(var1.name)
var2 = tf.Variable(2, name='NI')
print(var2.name)var3 = tf.get_variable('ni', [1], dtype=tf.float32)
print(var3.name)
var4 = tf.get_variable('ni', [1], dtype=tf.float32)
print(var4.name)

运行结果:

NI:0

NI_1:0

ni:0

ValueError: Variable fang already exists, disallowed. Did you mean to set reuse=True or reuse=tf.AUTO_REUSE in VarScope? Originally defined at

可以看到,运行到var4的时候程序崩了,说明tf.Variable可以定义两个重名的变量,如果重名,会自动在名字后加上”_N”N是数字,而tf.get_variable不行。而且tf.get_variable定义的时候必须设置名字。那有什么方法用tf.get_variable设置两个名字一样的变量呢?这就要指定变量的作用域了。

6.4、tf.variable_scope

tf.variable_scope函数可以指定作用域,代码如下:

import tensorflow as tfwith tf.variable_scope("scope1"):var1 = tf.get_variable("ni", [1], dtype=tf.float32)with tf.variable_scope('scope2'):var2 = tf.get_variable("ni", [1], dtype=tf.float32)print(var1.name)
print(var2.name)

运行结果:

scope1/wei:0

scope2/wei:0

作用域还可以嵌套使用,将上面例子’scope2’作用域往后缩进试试,代码如下:

import tensorflow as tfwith tf.variable_scope("scope1"):var1 = tf.get_variable("Ni", [1], dtype=tf.float32)with tf.variable_scope('scope2'):var2 = tf.get_variable("Ni", [1], dtype=tf.float32)print(var1.name)
print(var2.name)

运行结果:

scope1/wei:0

scope1/scope2/wei:0

6.5、共享变量

说了这么多,这个get_variable到底有什么用呢?答案是共享变量。在某些情况下,一个模型可能需要其他模型创建的变量,这时候就用到共享变量了。

Tf.variable_scope有一个reuse=True的属性,表示,如果已经定义过该变量,就不再创建,而是在图中找到这个变量直接使用。实例代码如下:

import tensorflow as tfwith tf.variable_scope('scope1'):var1 = tf.get_variable('Ni', [1], initializer=tf.constant_initializer(123))with tf.variable_scope('scope2'):var2 = tf.get_variable('ni', [1], initializer=tf.constant_initializer(234))with tf.variable_scope('scope1', reuse=True):var3 = tf.get_variable('Ni', [1], initializer=tf.constant_initializer(888))with tf.variable_scope('scope2', reuse=True):var4 = tf.get_variable('ni', [1], initializer=tf.constant_initializer(999))with tf.Session() as sess:sess.run(tf.global_variables_initializer())print(sess.run(var1))print(sess.run(var2))print(sess.run(var3))print(sess.run(var4))

运行结果:

[123.]

[234.]

[123.]

[234.]

可以看到,虽然var3和var4初始化变量值为888和999,但是输出的结果依然是var1和var2的初始值,说明并没有重新创建,也可以打印它名字来看,其实是一样的。一定要加上reuse=True属性!

TensorFlow精进之路(九):TensorFlow编程基础相关推荐

  1. TensorFlow精进之路(十二):随时间反向传播BPTT

    1.概述 上一节介绍了TensorFlow精进之路(十一):反向传播BP,这一节就简单介绍一下BPTT. 2.网络结构 RNN正向传播可以用上图表示,这里忽略偏置. 上图中, x(1:T)表示输入序列 ...

  2. TensorFlow精进之路(三):两层卷积神经网络模型将MNIST未识别对的图片筛选出来

    1.概述 自从开了专栏<TensorFlow精进之路>关于对TensorFlow的整理思路更加清晰.上两篇讲到Softmax回归模型和两层卷积神经网络模型训练MNIST,虽然使用神经网络能 ...

  3. TensorFlow精进之路(八):神经元

    1.概述 喝完奶茶继续干,通过前面的学习,对深度学习似乎有那么点感觉了,本来想继续往下学学一些应用的例子的,但是现在我想还是系统的先把一些深度学习的基本概念总结一下,以及先系统的学习一下tensorf ...

  4. TensorFlow精进之路(十六):使用slim模型库对图片分类

    1.概述 TF-slim是tensorflow的一个轻量级库,它将很多常见tensorflow函数进行封装,使的模型的构建.训练.测试都更加简洁,特别适用于构建结构复杂的深度神经网络.github地址 ...

  5. tensorflow精进之路(二十七)——人脸识别(中)(MTCNN人脸检查和人脸对齐+FaceNet模型)

    1.概述 上一讲,我们讲了人脸识别的基本原理,这一讲,我们用tensorflow来实现它. 2.下载LFW人脸数据集 2.1.LFW数据集简介 LFW人脸数据集主要用来研究非受限情况下的人脸识别问题, ...

  6. TensorFlow精进之路(一):Softmax回归模型训练MNIST

    1.MNIST数据集简介: MNIST数据集主要由一些手写数字的图片和相应标签组成,图片总共分为10类,分别对应0-9十个数字. 如上图所示,每张图片的大小为28×28像素.而标签则由one-hot向 ...

  7. Tensorflow精进之路(二):两层卷积神经网络模型训练MNIST

    这段时间,打算把TensorFlow再补补,提升一下技术水平~ 希望我能坚持下来,抽空把这本书刷下来吧~ 导入数据 下面的代码会直接下载数据,如果没有那个文件夹的话,但是,如果有这个文件夹而且里面有那 ...

  8. tensorflow精进之路(十九)——python3网络爬虫(下)

    1.概述 这一节,我们将在百度图片中爬取需要训练的图片数据:猪.蛇.狗.大象.老虎. 2.打开待爬取网页 打开百度图片首页: http://image.baidu.com/ 在搜索框中输入" ...

  9. TensorFlow精进之路(十五):深度神经网络简介

    1.概述 本来想用卷积神经网络来预测点东西,但是效果嘛......,还是继续学习图像类的应用吧-前面学习的神经网络都是一些基础的结构,这些网络在各自的领域中都有一定效果,但是解决复杂问题肯定不够的,这 ...

最新文章

  1. Linux虚拟机基本操作
  2. cgroup代码浅析(2)
  3. boost::core模块实现分配const void指针
  4. word表格图片自动适应表格大小_几招教你快速解决word文字、图片、表格排版问题,你肯定遇到过...
  5. MYSQL出错代码列表大全(中文)
  6. python根据相关系数绘制热力图
  7. Android系统(23)---Android 应用分屏
  8. ALEIS,啊,累死
  9. 15.10. Session/Cookie
  10. 拓扑排序Topological Sorting
  11. 转--计算几何常用算法概览
  12. HTML实现在线取色器
  13. mysql中转换日期格式,MySQL日期格式转换
  14. 基于HTML/CSS/JS的动态元素周期表
  15. 终于搞清楚了:java的long的小l和大L区别
  16. 目标跟踪常用算法——EKF篇
  17. 5D论文PMF及改进
  18. Kubernetes系列(一)基于CentOS8部署Kubernetes1.19集群
  19. 蓝牙 - 如何在Windows下抓取蓝牙数据
  20. 未来城市空中交通——NASA Embraces Urban Air Mobility, Calls for Market Study

热门文章

  1. postgresql grant权限解释
  2. lol云顶之奕助手_云顶之奕小小英雄介绍 除了棋子以外它也很重要!
  3. OpenStack-Icehouse(nova-network)多节点基础环境部署
  4. js中做数字运算时出现的异常,期望值比实际值小太多太多
  5. 职教高中计算机专业知识,新课改背景下计算机专业教学(职教)三维目标设计初探...
  6. XenApp 6安装过程中的两个常见错误
  7. WPF 与 摄像头资料
  8. 腾讯QQ2010安装时提示“C:\windows\Installer\QQ2010.msi时发生网络错误”的解决方 ......
  9. 让nginx支持php
  10. outermost shell_outermost是什么意思_outermost怎么读_outermost翻译_用法_发音_词组_同反义词_最外面的_离中心最远的-新东方在线英语词典...