这个问题应该算是很简单的,只不过我是新手,需要多记录下。在看Stanford的TensorFlow教程(地址为:https://www.youtube.com/watch?v=g-EvyKpZjmQ&list=PLQ0sVbIj3URf94DQtGPJV629ctn2c1zN-)Lecture 1的一段代码的时候,发现并不能运行:

import tensorflow as tfwith tf.device('/gpu:1'):a = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name='a')b = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name='b')c = tf.matmul(a, b)sess = tf.Session(config=tf.ConfigProto(log_device_placement=True))print(sess.run(c))

报错为:ValueError: Shape must be rank 2 but is rank 1 for 'MatMul' (op: 'MatMul') with input shapes: [6], [6].

TensorFlow才接触不久,基本都是运行下别人的代码,看看效果,所以对其中的方法也都是混个脸熟,并不十分清楚。这里的tf.matmul()方法和另一个tf.mul()要区分下,tf.mul实际上在新版的TensorFlow中已经修改为tf.multiply()了,我是参考https://blog.csdn.net/liuyuemaicha/article/details/70305678这篇博文学习的,测试下multiply:

import tensorflow as tfa = tf.get_variable('a', [2, 3], initializer=tf.random_normal_initializer(mean=0, stddev=1))
b = tf.get_variable('b', [2, 3], initializer=tf.constant_initializer(2))
c = tf.get_variable('c', [3, 2], initializer=tf.ones_initializer())init_op = tf.global_variables_initializer()with tf.Session() as sess:sess.run(init_op)print('a:\n', sess.run(a))print('b:\n', sess.run(b))print('c:\n', sess.run(c))print('multiply a, b')print(sess.run(tf.multiply(a, b)))print('matmul a, c')print(sess.run(tf.matmul(a, c)))

tf.get_variable()方法的使用第一个参数是name,第二个是shape,第三个是initializer。tf.random_normal_initializer()方法就是返回一个具有正态分布的张量初始化器,均值(期望值)mean默认为0,标准差默认为1,也就是默认为标准正态分布。得到的结果为:

a:
 [[-1.2580129   0.42341614  0.2203044 ]
 [-1.1805797  -1.8744725  -0.1812443 ]]
b:
 [[2. 2. 2.]
 [2. 2. 2.]]
c:
 [[1. 1.]
 [1. 1.]
 [1. 1.]]
multiply a, b
[[-2.5160258  0.8468323  0.4406088]
 [-2.3611593 -3.748945  -0.3624886]]
matmul a, c
[[-0.6142924 -0.6142924]
 [-3.2362967 -3.2362967]]

可以看到tf.multiply()方法是对应位置元素直接相乘的,因此要求二者的shape相等,该操作也成为哈达马积(Hadamard)。a和c两个变量一个是2行3列,一个3行2列,可以用tf.matmul()方法求矩阵乘积,得到了2行2列的一个矩阵。

回到刚刚的问题,比如参考https://blog.csdn.net/blythe0107/article/details/74171870,可以采用reshape的方式,使前者的列等于后者的行也就行了,如下:

import tensorflow as tf
import numpy as npwith tf.device('/gpu:0'):a = tf.constant(np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).reshape(2, 3), name='a')b = tf.constant(np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).reshape(3, 2), name='b')c = tf.matmul(a, b)with tf.device('/gpu:1'):d = tf.constant(np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).reshape(2, 3), name='d')e = tf.constant(np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).reshape(3, 2), name='e')f = tf.matmul(d, e)sess = tf.Session(config=tf.ConfigProto(log_device_placement=True))print(sess.run(c))
print(sess.run(f))

这样得到的输出如下:

2018-08-02 15:52:42.801535: I tensorflow/core/common_runtime/gpu/gpu_device.cc:971] 1:   Y N
2018-08-02 15:52:42.801871: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1084] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 10388 MB memory) -> physical GPU (device: 0, name: GeForce GTX 1080 Ti, pci bus id: 0000:21:00.0, compute capability: 6.1)
2018-08-02 15:52:42.905229: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1084] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:1 with 10407 MB memory) -> physical GPU (device: 1, name: GeForce GTX 1080 Ti, pci bus id: 0000:2d:00.0, compute capability: 6.1)
Device mapping:
/job:localhost/replica:0/task:0/device:GPU:0 -> device: 0, name: GeForce GTX 1080 Ti, pci bus id: 0000:21:00.0, compute capability: 6.1
/job:localhost/replica:0/task:0/device:GPU:1 -> device: 1, name: GeForce GTX 1080 Ti, pci bus id: 0000:2d:00.0, compute capability: 6.1
2018-08-02 15:52:43.010702: I tensorflow/core/common_runtime/direct_session.cc:288] Device mapping:
/job:localhost/replica:0/task:0/device:GPU:0 -> device: 0, name: GeForce GTX 1080 Ti, pci bus id: 0000:21:00.0, compute capability: 6.1
/job:localhost/replica:0/task:0/device:GPU:1 -> device: 1, name: GeForce GTX 1080 Ti, pci bus id: 0000:2d:00.0, compute capability: 6.1

MatMul: (MatMul): /job:localhost/replica:0/task:0/device:GPU:0
2018-08-02 15:52:43.011677: I tensorflow/core/common_runtime/placer.cc:886] MatMul: (MatMul)/job:localhost/replica:0/task:0/device:GPU:0
MatMul_1: (MatMul): /job:localhost/replica:0/task:0/device:GPU:1
2018-08-02 15:52:43.011720: I tensorflow/core/common_runtime/placer.cc:886] MatMul_1: (MatMul)/job:localhost/replica:0/task:0/device:GPU:1
a: (Const): /job:localhost/replica:0/task:0/device:GPU:0
2018-08-02 15:52:43.011741: I tensorflow/core/common_runtime/placer.cc:886] a: (Const)/job:localhost/replica:0/task:0/device:GPU:0
b: (Const): /job:localhost/replica:0/task:0/device:GPU:0
2018-08-02 15:52:43.011760: I tensorflow/core/common_runtime/placer.cc:886] b: (Const)/job:localhost/replica:0/task:0/device:GPU:0
d: (Const): /job:localhost/replica:0/task:0/device:GPU:1
2018-08-02 15:52:43.011778: I tensorflow/core/common_runtime/placer.cc:886] d: (Const)/job:localhost/replica:0/task:0/device:GPU:1
e: (Const): /job:localhost/replica:0/task:0/device:GPU:1
2018-08-02 15:52:43.011795: I tensorflow/core/common_runtime/placer.cc:886] e: (Const)/job:localhost/replica:0/task:0/device:GPU:1
[[22. 28.]
 [49. 64.]]
[[22. 28.]
 [49. 64.]]

可以看到,变量和op可以指定GPU,本例中a和b用了GPU0,另外也处理了matmul()的操作。而d和e即计算f的任务则放在了GPU1上,这个可能算是最简单了单主机多GPU使用了。

关于前面的变量使用,记录如下。

TensorFlow有两个关于variable的op,即tf.Variable()和tf.get_variable(),这里参考

https://blog.csdn.net/u012436149/article/details/53696970/学习下。比如下面的代码:

import tensorflow as tfw_1 = tf.Variable(3, name='w_1')
w_2 = tf.Variable(1, name='w_1')print(w_1.name)
print(w_2.name)init_op = tf.global_variables_initializer()with tf.Session() as sess:sess.run(init_op)sess.run(tf.Print(w_1, [w_1, w_1.name, str(w_1.value)]))sess.run(tf.Print(w_2, [w_2, w_2.name, str(w_2.value)]))

这里使用了tf.Print()方法来输出一些调试信息,其value部分用str()方法处理下不然报错。输出结果:

w_1:0
w_1_1:0

[3][w_1:0][<bound method Variable.value of <tf.Variable \'w_1:0\' shape=() dtype=int32_ref>>]
[1][w_1_1:0][<bound method Variable.value of <tf.Variable \'w_1_1:0\' shape=() dtype=int32_ref>>]

使用tf.Variable()系统会自动处理命名冲突,这里如果用tf.get_variable()则会报错w_1变量已存在。所以当我们需要共享变量的时候,用tf.get_variable()。关于其实质区别,看下这段代码:

import tensorflow as tfwith tf.variable_scope('scope1'):w1 = tf.get_variable('w1', shape=[])w2 = tf.Variable(0.0, name='w_1')with tf.variable_scope('scope1', reuse=True):w1_p = tf.get_variable('w1', shape=[])w2_p = tf.Variable(1.0, name='w2')print(w1 is w1_p, w2 is w2_p)

输出为True False。由于tf.Variable()每次都在创建新对象,所有reuse=True 和它并没有什么关系。对于get_variable(),如果已经创建的变量对象,就把那个对象返回,如果没有创建变量对象的话,就创建一个新的。

TensorFlow MatMul操作rank错误问题记录相关推荐

  1. 深度学习(8)TensorFlow基础操作四: 维度变换

    深度学习(8)TensorFlow基础操作四: 维度变换 1. View 2. 示例 3. Reshape操作可能会导致潜在的bug 4. tf.transpose 5. Squeeze VS Exp ...

  2. 深度学习(11)TensorFlow基础操作七: 向前传播(张量)实战

    深度学习(11)TensorFlow基础操作七: 向前传播(张量)实战 1. 导包 2. 加载数据集 3. 转换数据类型 4. 查看x.shape, y.shape, x.dtype, y.dtype ...

  3. 深度学习(10)TensorFlow基础操作六: 数学运算

    深度学习(10)TensorFlow基础操作六: 数学运算 1. Operation type 2. + - * / % // 3. tf.math.log & tf.exp 4. log2, ...

  4. 深度学习(5)TensorFlow基础操作一: TensorFlow数据类型

    深度学习(5)TensorFlow基础操作一: TensorFlow数据类型 Data Container(数据载体) What's Tensor TF is a computing lib(科学计算 ...

  5. tensorFlow基础操作及常用函数

    tensorFlow基础操作及常用函数 1. 安装Tensorflow 2. TensorFlow基本操作 3. TensorFlow常用函数 3.1 常用矩阵创建方式 3.2 高斯初始化及洗牌操作 ...

  6. mysql 错误信息大全,MySQL错误信息记录

    MySQL错误信息记录 MySQL错误信息记录 考虑到MySQL是一门玄学,难免有些看不见的坑要踩,于是开这篇博文,以此来记录学习及以后使用过程中踩过的Error,如果你也想加入欢迎留言参加 ERRO ...

  7. access里的多步oledb错误_多步 OLE DB 操作产生错误,这问题怎么解决啊

    相信大家在调试程序时曾碰到过下面错误 数据库:ACCESS ------------------------------------------------------------------- Mi ...

  8. App错误日志记录到本地

    1.使用背景:客户使用App过程中程序出错,如无法复现Bug,会很难受.所以,将错误记录至本地并发送后台,会方便日后优化及维护. 2.bug捕捉的工具类 import java.io.File; im ...

  9. TensorFlow常用操作:代码示例

    1,定义矩阵代码示例: import tensorflow as tftf.zeros([3,4]) #定义3行4列元素均为0的矩阵tensor=tf.constant([1,2,3,4])#定义一维 ...

最新文章

  1. Linux下如何对tomcat Java线程进行分析?
  2. SAP WM初阶LQ02报错 - Movement Type 901 for manual transfer orders does not exist -
  3. 「模型解读」从2D卷积到3D卷积,都有什么不一样
  4. 如何检测Safari,Chrome,IE,Firefox和Opera浏览器?
  5. Linux下压缩包生成与解压命令以及进度
  6. layui导航栏页面滚动固定_网站建设页面导航如何降低用户寻找的时间
  7. 数据结构与算法-黑盒与白盒测试法
  8. pythonutf-8是不是二进制_python集合、字符编码、bytes与二进制
  9. HashMap(JDK1.8)
  10. 阶段1 语言基础+高级_1-2 -面向对象和封装_1面向对象思想的概述
  11. Linux的diff和git diff生成patch/打patch
  12. SQL Server 动态行转列(参数化表名、分组列、行转列字段、字段值)
  13. python 流程控制基础知识总结 和九九乘法表、质数、水仙花数、猜拳游戏练习
  14. pycharm Cannot connect to the Docker daemon at unix:///var/run/docker.sock. Is the docker daemon run
  15. github下载慢时可采用码云快速下载资源
  16. 求出数组最大值的方法
  17. Leetcode--Java--212. 单词搜索 II
  18. python数据按照分组进行频率分布_3.2.1 分布分析
  19. Android P Settings默认显示开发者选项
  20. Excel学习日记:L21-表格数值格式

热门文章

  1. 添加过的PDF注释可以修改吗?怎么修改PDF注释?
  2. 【学习笔记】打印1-100以内的质数优化
  3. DRV8837 12V单通道全桥电机驱动芯片替代料GC8837
  4. 【技巧】数据生成器对拍
  5. ASP.net2.0的machineKey
  6. word python_word_python
  7. ARIMA模型,ARIMAX模型预测冰淇淋消费时间序列数据
  8. python实现异步的几种方式_Python 异步编程
  9. C语言输入输出及选择结构程序设计的综合应用——简单超级战士游戏and 取牙签游戏
  10. MTK MT6739P 项目克隆脚本