作者 | AI小昕

编辑 | 安可

【导读】:本文讲了Tensorflow从入门到精通。欢迎大家点击上方蓝字关注我们的公众号:深度学习与计算机视觉

Tensor介绍

Tensor(张量)是Tensorflow中最重要的数据结构,用来表示Tensorflow程序中的所有数据。Tensor本是广泛应用在物理、数学领域中的一个物理量。那么在Tensorflow中该如何理解Tensor的概念呢?

实际上,我们可以把Tensor理解成N维矩阵(N维数组)。其中零维张量表示的是一个标量,也就是一个数;一维张量表示的是一个向量,也可以看作是一个一维数组;二维张量表示的是一个矩阵;同理,N维张量也就是N维矩阵。

在计算图模型中,操作间所传递的数据都可以看做是Tensor。那Tensor的结构到底是怎样的呢?我们可以通过程序1-1更深入的了解一下Tensor。

程序1-1:

#导入tensorflow模块
import tensorflow as tf
a = tf.constant([[2.0,3.0]] ,name="a")
b = tf.constant([[1.0],[4.0]] ,name="b")
result = tf.matmul(a,b,name="mul")
print(result)>> Tensor("mul_3:0", shape=(1, 1), dtype=float32)

程序1-1的输出结果表明:构建图的运算过程输出的结果是一个Tensor,且其主要由三个属性构成:Name、Shape和Type。Name代表的是张量的名字,也是张量的唯一标识符,我们可以在每个op上添加name属性来对节点进行命名,Name的值表示的是该张量来自于第几个输出结果(编号从0开始),上例中的“mul_3:0”说明是第一个结果的输出。Shape代表的是张量的维度,上例中shape的输出结果(1,1)说明该张量result是一个二维数组,且每个维度数组的长度是1。最后一个属性表示的是张量的类型,每个张量都会有唯一的类型,常见的张量类型如图1-1所示。

图1-1 常用的张量类型

我们需要注意的是要保证参与运算的张量类型相一致,否则会出现类型不匹配的错误。如程序1-2所示,当参与运算的张量类型不同时,Tensorflow会报类型不匹配的错误:

程序1-2:

import tensorflow as tf
m1 = tf.constant([5,1])
m2 = tf.constant([2.0,4.0])
result = tf.add(m1,m2)TypeError: Input 'y' of 'Add' Op has type float32 that does not match type int32 of argument 'x'.

正如程序的报错所示:m1是int32的数据类型,而m2是float32的数据类型,两者的数据类型不匹配,所以发生了错误。所以我们在实际编程时,一定注意参与运算的张量数据类型要相同。

常量、变量及占位符

Tensorflow中对常量的初始化,不管是对数值、向量还是对矩阵的初始化,都是通过调用constant()函数实现的。因为constant()函数在Tensorflow中的使用非常频繁,经常被用于构建图模型中常量的定义,所以接下来,我们通过程序1-3了解一下constant()的相关属性:

程序2-1:

import tensorflow as tf
a = tf.constant([2.0,3.0],name="a",shape=(2,0),dtype="float64",verify_shape="true")
print(a)>> Tensor("a_11:0", shape=(2, 0), dtype=float64)

如程序2-1所示,函数constant有五个参数,分别为value,name,dtype,shape和verify_shape。其中value为必选参数,其它均为可选参数。Value为常量的具体值,可以是一个数字,一维向量或是多维矩阵。Name是常量的名字,用于区别其它常量。Dtype是常量的类型,具体类型可参见图2-2。Shape是指常量的维度,我们可以自行定义常量的维度。

verify_shape是验证shape是否正确,默认值为关闭状态(False)。也就是说当该参数true状态时,就会检测我们所写的参数shape是否与value的真实shape一致,若不一致就会报TypeError错误。如:上例中的实际shape为(2,0),若我们将参数中的shape属性改为(2,1),程序就会报如下错误:

TypeError: Expected Tensor's shape: (2, 1), got (2,).

Tensorflow还提供了一些常见常量的初始化,如:tf.zeros、tf.ones、tf.fill、tf.linspace、tf.range等,均可以快速初始化一些常量。例如:我们想要快速初始化N维全0的矩阵,我们可以利用tf.zeros进行初始化,如程序1-4所示:

程序2-2:

import tensorflow as tf
a=tf.zeros([2,2],tf.float32)
b=tf.zeros_like(a,optimize=True)with tf.Session() as sess:print(sess.run(a))
print(sess.run(b))>> [[ 0.  0.][ 0.  0.]][[ 0.  0.][ 0.  0.]]

程序2-2向我们展示了tf.zeros和tf.zeros_like的用法。其它常见常量的具体初始化用法可以参考Tensorflow官方手册:

https://www.tensorflow.org/api_guides/python/constant_op

此外,Tensorflow还可以生成一些随机的张量,方便快速初始化一些随机值。如:tf.random_normal()、tf.truncated_normal()、tf.random_uniform()、tf.random_shuffle()等。如程序1-5所示,我们以tf.random_normal()为例,来看一下随机张量的具体用法:

程序2-3:

import tensorflow as tf
random_num=tf.random_normal([2, 3], mean=-1, stddev=4,
dtype=tf.float32,seed=None,name='rnum')with tf.Session() as sess:
print(sess.run(random_num))>> [[-2.71897316  1.04246855 -3.12996817][-1.34851456 -0.13599336  4.60532522]]

随机张量random_normal()有shape、mean、stddev、dtype、seed、name六个属性。shape是指张量的形状,如上述程序是生成一个2行3列的tensor;mean是指正态分布的均值;stddev是指正太分布的标准差;dtype是指生成tensor的数据类型;seed是分发创建的一个随机种子;而name是给生成的随机张量命名。

Tensorflow中的其它随机张量的具体使用方法和属性介绍,可以参见Tensorflow官方手册:https://www.tensorflow.org/api_guides/python/constant_op。这里将不在一一赘述。

除了常量constant(),变量variable()也是在Tensorflow中经常会被用到的函数。变量的作用是保存和更新参数。执行图模型时,一定要对变量进行初始化,经过初始化后的变量才能拿来使用。变量的使用包括创建、初始化、保存、加载等操作。首先,我们通过程序2-4了解一下变量是如何被创建的:

程序2-4:

import tensorflow as tfA = tf.Variable(3, name="number")
B = tf.Variable([1,3], name="vector")
C = tf.Variable([[0,1],[2,3]], name="matrix")
D = tf.Variable(tf.zeros([100]), name="zero")
E = tf.Variable(tf.random_normal([2,3], mean=1, stddev=2, dtype=tf.float32))

程序2-4展示了创建变量的多种方式。我们可以把函数variable()理解为构造函数,构造函数的使用需要初始值,而这个初始值是一个任何形状、类型的Tensor。也就是说,我们既可以通过创建数字变量、一维向量、二维矩阵初始化Tensor,也可以使用常量或是随机常量初始化Tensor,来完成变量的创建。

当我们完成了变量的创建,接下来,我们要对变量进行初始化。变量在使用前一定要进行初始化,且变量的初始化必须在模型的其它操作运行之前完成。通常,变量的初始化有三种方式,如程序2-5所示:

程序2-5:

#初始化全部变量:
init = tf.global_variables_initializer()
with tf.Session() as sess:sess.run(init)#初始化变量的子集:
init_subset=tf.variables_initializer([b,c], name="init_subset")
with tf.Session() as sess:sess.run(init_subset)#初始化单个变量:
init_var = tf.Variable(tf.zeros([2,5]))
with tf.Session() as sess:sess.run(init_var.initializer)

程序2-5说明了初始化变量的三种方式:初始化全部变量、初始化变量的子集以及初始化单个变量。首先,global_variables_initializer()方法是不管全局有多少个变量,全部进行初始化,是最简单也是最常用的一种方式;variables_initializer()是初始化变量的子集,相比于全部初始化化的方式更加节约内存;Variable()是初始化单个变量,函数的参数便是要初始化的变量内容。通过上述的三种方式,我们便可以实现变量的初始化,放心的使用变量了。

我们经常在训练模型后,希望保存训练的结果,以便下次再使用或是方便日后查看,这时就用到了Tensorflow变量的保存。变量的保存是通过tf.train.Saver()方法创建一个Saver管理器,来保存计算图模型中的所有变量。具体代码如程序2-6所示:

程序2-6:

import tensorflow as tfvar1 = tf.Variable([1,3], name="v1")
var2 = tf.Variable([2,4], name="v2")#对全部变量进行初始化
init = tf.initialize_all_variables()#调用Saver()存储器方法
saver = tf.train.Saver()#执行图模型
with tf.Session() as sess:sess.run(init)#设置存储路径save_path = saver.save(sess,"test/save.ckpt")

我们要注意,我们的存储文件save.ckpt是一个二进制文件,Saver存储器提供了向该二进制文件保存变量和恢复变量的方法。保存变量的方法就是程序中的save()方法,保存的内容是从变量名到tensor值的映射关系。完成该存储操作后,会在对应目录下生成如图2-1所示的文件:

图2-1 保存变量生成的相应文件

Saver提供了一个内置的计数器自动为checkpoint文件编号。这就支持训练模型在任意步骤多次保存。此外,还可以通过global_step参数自行对保存文件进行编号,例如:global_step=2,则保存变量的文件夹为model.ckpt-2。

那如何才能恢复变量呢?首先,我们要知道一定要用和保存变量相同的Saver对象来恢复变量。其次,不需要事先对变量进行初始化。具体代码如程序2-7所示:

程序2-7:

import tensorflow as tfvar1 = tf.Variable([0,0], name="v1")
var2 = tf.Variable([0,0], name="v2")
saver = tf.train.Saver()
module_file = tf.train.latest_checkpoint('test/')with tf.Session() as sess:saver.restore(sess,module_file)print("Model restored.")

本程序示例中,我们要注意:变量的获取是通过restore()方法,该方法有两个参数,分别是session和获取变量文件的位置。我们还可以通过latest_checkpoint()方法,获取到该目录下最近一次保存的模型。

以上就是对变量创建、初始化、保存、加载等操作的介绍。此外,还有一些与变量相关的重要函数,如:eval()等。

认识了常量和变量,Tensorflow中还有一个非常重要的常用函数——placeholder。placeholder是一个数据初始化的容器,它与变量最大的不同在于placeholder定义的是一个模板,这样我们就可以session运行阶段,利用feed_dict的字典结构给placeholder填充具体的内容,而无需每次都提前定义好变量的值,大大提高了代码的利用率。Placeholder的具体用法如程序2-8所示:

程序序2-8:

import tensorflow as tfa = tf.placeholder(tf.float32,shape=[2],name=None)
b = tf.constant([6,4],tf.float32)
c = tf.add(a,b)with tf.Session() as sess:print(sess.run(c,feed_dict={a:[10,10]}))

程序2-8演示了placeholder占位符的使用过程。Placeholder()方法有dtype,shape和name三个参数构成。dtype是必填参数,代表传入value的数据类型;shape是选填参数,代表传入value的维度;name也是选填参数,代表传入value的名字。我们可以把这三个参数看作为形参,在使用时传入具体的常量值。这也是placeholder不同于常量的地方,它不可以直接拿来使用,而是需要用户传递常数值。

最后,Tensorflow中还有一个重要的概念——fetch。Fetch的含义是指可以在一个会话中同时运行多个op。这就方便我们在实际的建模过程中,输出一些中间的op,取回多个tensor。Fetch的具体用法如程序2-9所示:

程序2-9:

import tensorflow as tfa = tf.constant(5)
b = tf.constant(6)
c = tf.constant(4)
add = tf.add(b, c)
mul = tf.multiply(a, add)with tf.Session() as sess:result = sess.run([mul, add])print(result)>> [50,10]

程序2-10展示了fetch的用法,即我们利用session的run()方法同时取回多个tensor值,方便我们查看运行过程中每一步op的输出结果。

程序2-10:

import tensorflow as tfvar1 = tf.Variable([0,0], name="v1")
var2 = tf.Variable([0,0], name="v2")
Saver = tf.train.Saver()
module_file = tf.train.latest_checkpoint('test/')with tf.Session() as sess:saver.restore(sess,module_file)print("Model restored.")

小结:本节旨在让大家学会Tensorflow的基础知识,为后边实战的章节打下基础。主要介绍了Tensor的概念,以及Tensorflow中的常量、变量、占位符、feed等知识点。


欢迎扫码关注:

 点击下方 |  | 了解更多

Tensorflow系列 | Tensorflow从入门到精通(二):附代码实战相关推荐

  1. Kali Linux 从入门到精通(二)-安装

    Kali Linux 从入门到精通(二)-安装 Kail Linux 安装 持久加密USB安装-1 LUSK:Linux Unified Key Setup 磁盘分区加密规范 不依赖与操作系统的磁盘级 ...

  2. Mybatis从入门到精通二(入门详解)

    Mybatis从入门到精通二(想学Mybatis,看了这一篇你就不需要其他的了) 本课程分为两天第一天的请参考: https://blog.csdn.net/weixin_43564627/artic ...

  3. java从入门到精通二十四(三层架构完成增删改查)

    java从入门到精通二十四(三层架构完成增删改查) 前言 环境准备 创建web项目结构 导入依赖和配置文件 创建层次模型 实现查询 实现添加 实现修改 完成删除 做一个用户登录验证 会话技术 cook ...

  4. java从入门到精通二十三(Servlet)

    java从入门到精通二十三(Servlet) Servlet 说明 Servlet初步入门尝试 Servlet生命周期 Servlet方法说明和体系结构 方法说明 体系结构说明 一些优化封装 urlP ...

  5. 北风网web开发资深讲师李炎恢出品--ASP系列课程从入门到精通

    北风网web开发资深讲师李炎恢出品--ASP系列课程从入门到精通 http://www.verycd.com/topics/2755115/ 中文名: 北风网web开发资深讲师李炎恢出品--ASP系列 ...

  6. 雷达通信 技术《相控阵入门到精通》 视频教程 代码 下载

    雷达通信 技术<相控阵入门到精通> 视频 代码 下载 01 电扫阵列_MATLAB建模与仿真(含全书源码) 02 相控阵天线的基础理论(含源码) 03 宽带阵列信号处理 04 无源和有源相 ...

  7. MyBatis从入门到精通(二):MyBatis XML方式的基本用法之Select

    最近在读刘增辉老师所著的<MyBatis从入门到精通>一书,很有收获,于是将自己学习的过程以博客形式输出,如有错误,欢迎指正,如帮助到你,不胜荣幸! 1. 明确需求 书中提到的需求是一个基 ...

  8. 它来了!ROS2从入门到精通:理论与实战

    ROS是什么? 随着人工智能技术的飞速发展与进步,机器人的智能化已经成为现代机器人发展的终极目标.机器人发展的速度在不断提升,应用范围也在不断拓展,例如自动驾驶.移动机器人.操作机器人.信息机器人等. ...

  9. Java学习路线导航,带你入门到精通(附Java全套学习资源)

    最近也有很多小伙伴来向我请教,他们大多是一些Java刚入门的新手,还不了解Java这个行业,也不知道Java零基础该从何学起,开始的时候非常迷茫,所以今天写了这篇文章,具体来说说Java的学习路线. ...

最新文章

  1. android怎么判断程序进入了后台,Android检测应用程序是否进入后台
  2. 《JAVA与模式》之访问者模式
  3. redhat7.4安装神通数据库
  4. 05-WIFI通讯客户端搭建
  5. 数据埋点太难!知乎的做法有何可借鉴之处?
  6. 过气旗舰不如?刘作虎确认一加新机:比一加7 Pro更超值
  7. 免校准的电量计量芯片_单相电能表如何校准(单相电能计量芯片+MCU)
  8. SQLi LABS Less 16 布尔盲注
  9. Java学习笔记 06 数字格式化及数学运算
  10. Git服务器-Gogs搭建
  11. 易语言自定义数据类型转c,转换JSON结构为易语言代码自定义数据类型
  12. windows安装syslog日志转发客户端nxlog
  13. 虚幻引擎(UE4) UMG实例
  14. 【PB】数据窗口的修改属性
  15. 谷歌浏览器windows以及mac系统下设置跨域
  16. 【工大SCIR】AAAI20 基于Goal(话题)的开放域多轮对话规划
  17. 简易计算机课程设计总结,简单计算器课程设计报告.doc
  18. 顶级的 18 款开源的低代码开发平台,经典收藏
  19. word中的表格怎么按照章节自动插入题注(即表头的编号)
  20. 万卷书- 创新型学校 [Creative Schools]

热门文章

  1. 2018年3月国家统计局对《三次产业划分规定(2012)》的新调整
  2. app端UI的制图规范
  3. Navicat Premium 免费
  4. 66-甲说乙说谎,乙说丙说谎,丙说甲乙说谎
  5. TinyXML的TiXmlElement::GetText()返回NULL
  6. python连接mt4服务器_如何从MetaTrader 4/5终端向外部服务器发送数据?
  7. Apache Linkis 中间件架构及快速安装
  8. 线性代数之——相似矩阵
  9. OpenCV入门系列 —— cv::dilate 图像膨胀
  10. 大顶堆和小顶堆-java