• 官方Document: https://tensorflow.google.cn/api_guides/python/sparse_ops
  • 开发测试环境:
    • Win10
    • Python 3.6.4
    • tensorflow-gpu 1.6.0

SparseTensor与SparseTensorValue的理解

SparseTensor(indices, values, dense_shape)

稀疏矢量的表示

  • indices shape为[N, ndims]的2-D int64矢量,用以指定非零元素的位置,比如indices=[[1,3], [2,4]]表示[1,3]和[2,4]位置的元素为非零元素。
  • values shape为[N]的1-D矢量,对应indices所指位置的元素值
  • dense_shape shape为[ndims]的1-D矢量,代表稀疏矩阵的shape
SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])
>>
[[1, 0, 0, 0][0, 0, 2, 0][0, 0, 0, 0]]SparseTensor(indices=[[0], [3]], values=[4, 6], dense_shape=[7])
>>[4, 0, 0, 6, 0, 0, 0]

稀疏矢量的封装并不直观,可以通过稀疏矢量的方式构建矢量(sparse_to_dense)或者将稀疏矢量转换成矢sparse_tensor_to_dense量的方式来感受一下:

def sparse_to_dense(sparse_indices,output_shape,sparse_values,default_value=0,validate_indices=True,name=None)
  • sparse_indices sparse_indices:稀疏矩阵中那些个别元素对应的索引值。

    • sparse_indices是个数,那么它只能指定一维矩阵的某一个元素
    • sparse_indices是个向量,那么它可以指定一维矩阵的多个元素
    • sparse_indices是个矩阵,那么它可以指定二维矩阵的多个元素
  • output_shape 输出的稀疏矩阵的shape
  • sparse_value 个别元素的值
    • sparse_values是个数:所有索引指定的位置都用这个数
    • sparse_values是个向量:输出矩阵的某一行向量里某一行对应的数(所以这里向量的长度应该和输出矩阵的行数对应,不然报错)
  • default_value:未指定元素的默认值,一般如果是稀疏矩阵的话就是0了

实例展示

import tensorflow as tf
import numpy  BATCHSIZE=6label=tf.expand_dims(tf.constant([0,2,3,6,7,9]),1)
index=tf.expand_dims(tf.range(0, BATCHSIZE),1)
# use a matrix
concated = tf.concat([index, label], 1)   # [[0, 0], [0, 2], [0, 3], [0, 6], [0, 7], [0, 9]] (6,2)
onehot_labels = tf.sparse_to_dense(concated, [BATCHSIZE,10], 1.0, 0.0)# use a vector
sparse_indices2=tf.constant([1,3,4])
onehot_labels2 = tf.sparse_to_dense(sparse_indices2, [10], 1.0, 0.0)#can use# use a scalar
sparse_indices3=tf.constant(5)
onehot_labels3 = tf.sparse_to_dense(sparse_indices3, [10], 1.0, 0.0)sparse_tensor_00 = tf.SparseTensor(indices=[[0,0,0], [1,1,2]], values=[4, 6], dense_shape=[2,2,3])
dense_tensor_00 = tf.sparse_tensor_to_dense(sparse_tensor_00)with tf.Session(config=config) as sess:result1=sess.run(onehot_labels)result2 = sess.run(onehot_labels2)result3 = sess.run(onehot_labels3)result4 = sess.run(dense_tensor_00)print ("This is result1:")print (result1)print ("This is result2:")print (result2)print ("This is result3:")print (result3)print ("This is result4:")print (result4)

输出结果如下

[[1. 0. 0. 0. 0. 0. 0. 0. 0. 0.][0. 0. 1. 0. 0. 0. 0. 0. 0. 0.][0. 0. 0. 1. 0. 0. 0. 0. 0. 0.][0. 0. 0. 0. 0. 0. 1. 0. 0. 0.][0. 0. 0. 0. 0. 0. 0. 1. 0. 0.][0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]]
This is result2:
[0. 1. 0. 1. 1. 0. 0. 0. 0. 0.]
This is result3:
[0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
This is result4:
[[[4 0 0][0 0 0]][[0 0 0][0 0 6]]]

区别

两者的区别可以通过应用来说起

If you would like to define the tensor outside the graph, e.g. define the sparse tensor for later data feed, use SparseTensorValue. In contrast, if the sparse tensor is defined in graph, use SparseTensor

在graph定义sparse_placeholder,在feed中需要使用SparseTensorValue

x_sp = tf.sparse_placeholder(dtype=tf.float32)
W = tf.Variable(tf.random_normal([6, 6]))
y = tf.sparse_tensor_dense_matmul(sp_a=x_sp, b=W)init = tf.global_variables_initializer()
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.2)
sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
sess.run(init)stv = tf.SparseTensorValue(indices=[[0, 0], [1, 2]], values=[1.1, 1.2],
dense_shape=[2,6])
result = sess.run(y,feed_dict={x_sp:stv})print(result)

在graph中做定义需要使用SparseTensor

indices_i = tf.placeholder(dtype=tf.int64, shape=[2, 2])
values_i = tf.placeholder(dtype=tf.float32, shape=[2])
dense_shape_i = tf.placeholder(dtype=tf.int64, shape=[2])
st = tf.SparseTensor(indices=indices_i, values=values_i, dense_shape=dense_shape_i)W = tf.Variable(tf.random_normal([6, 6]))
y = tf.sparse_tensor_dense_matmul(sp_a=st, b=W)init = tf.global_variables_initializer()
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.2)
sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
sess.run(init)result = sess.run(y,feed_dict={indices_i:[[0, 0], [1, 2]], values_i:[1.1, 1.2], dense_shape_i:[2,6]})print(result)

在feed中应用SparseTensor,需要使用运算

x = tf.sparse_placeholder(tf.float32)
y = tf.sparse_reduce_sum(x)config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.allow_soft_placement = True
with tf.Session(config=config) as sess:indices = np.array([[3, 2, 0], [4, 5, 1]], dtype=np.int64)values = np.array([1.0, 2.0], dtype=np.float32)shape = np.array([7, 9, 2], dtype=np.int64)print(sess.run(y, feed_dict={x: tf.SparseTensorValue(indices, values, shape)}))  # Will succeed.print(sess.run(y, feed_dict={x: (indices, values, shape)}))  # Will succeed.sp = tf.SparseTensor(indices=indices, values=values, dense_shape=shape)sp_value = sp.eval(session=sess)print(sp_value)print(sess.run(y, feed_dict={x: sp_value}))  # Will succeed.

【TensorFlow】稀疏矢量相关推荐

  1. 基于Ganos百行代码实现亿级矢量空间数据在线可视化

    简介: 本文介绍如何使用RDS PG或PolarDB(兼容PG版或Oracle版)的Ganos时空引擎提供的数据库快显技术,仅用百行代码实现亿级海量几何空间数据的在线快速显示和流畅地图交互,且无需关注 ...

  2. 正则化--L1正则化(稀疏性正则化)

    稀疏矢量通常包含许多维度.创建特征组合会导致包含更多维度.由于使用此类高维度特征矢量,因此模型可能会非常庞大,并且需要大量的 RAM. 在高维度稀疏矢量中,最好尽可能使权重正好降至 0.正好为 0 的 ...

  3. 机器学习-过拟合、正则化、稀疏性、交叉验证概述

    在机器学习中,我们将模型在训练集上的误差称之为训练误差,又称之为经验误差,在新的数据集(比如测试集)上的误差称之为泛化误差,泛化误差也可以说是模型在总体样本上的误差.对于一个好的模型应该是经验误差约等 ...

  4. 机器学习入门13 - 正则化:稀疏性 (Regularization for Sparsity)

    原文链接:https://developers.google.com/machine-learning/crash-course/regularization-for-sparsity/ 1- L₁正 ...

  5. 高效大规模图像搜索开源实现

    传统Bag of Features方法的OpenCV C++代码. (关注"我爱计算机视觉"公众号,一个有价值有深度的公众号~) 在深度学习逐渐统治计算机视觉领域的时候,传统算法依 ...

  6. google机器学习速成教程学习笔记

    Machine Learning notes 监督式机器学习 线性回归.训练和损失 迭代方式降低损失 降低损失 (Reducing Loss):梯度下降法 使用TensorFlow 泛化 训练集和测试 ...

  7. XGBoost之类别特征的处理

    目录 Label encoding与 One-Hot encoding Label encoding one-hot encoding 利用神经网络的Embedding层处理类别特征 Embeddin ...

  8. 机器学习——Google 快速入门课程(综合版)

    前言 本文参考 Google 谷歌官网机器学习的快速入门课程,整体课程比较好理解,供大家学习参考:文章也会结合自己的理解进行优化.看到官网的消息2021/7之后就不提供中文版的机器学习快速入门课程了, ...

  9. Google机器学习速成课程 - 视频笔记整理汇总 - 基础篇核心部分

    Google机器学习速成课程 - 视频笔记整理 - 基础篇核心部分 课程网址: https://developers.google.com/machine-learning/crash-course/ ...

最新文章

  1. soup.a.parents都有哪些
  2. block为什么用copy以及如何解决循环引用
  3. Linq TO SQL 虽好,但不要滥用
  4. php三位不够前面加0,php 格式化数字 位数不足前面加0补足的实现方法
  5. GDCM:gdcm::ImageReader的测试程序
  6. tcp_nodelay memcached java_TCP_NODELAY 和 TCP_NOPUSH
  7. Sequelize 中文文档 v4 - Querying - 查询
  8. 在matplotlib中关闭绘图轴的方法
  9. macbook交叉编译linux,mac交叉编译到Linux报错
  10. 【cs231】损失函数与优化
  11. 如何在 Windows 下像 Mac 一样优雅开发
  12. 测试人员入门级的数据库知识(SQL语句)
  13. 【学习笔记】python实现图像的手绘效果
  14. contest13 CF197div2 oooxx ooooo ooooo
  15. 免费在线生成二维码网站,支持二维码自定义
  16. 吴裕雄--天生自然 诗经:鹊踏枝·谁道闲情抛弃久
  17. 十一大开源机器人平台
  18. The Picture of Dorian Gray——17
  19. 高频小信号谐振放大器【Multisim】【高频电子线路】
  20. 王道论坛《计算机网络》网课学习笔记

热门文章

  1. python time智能等待_Python Selenium智能等待
  2. mysql5.6 pt-query-digest_pt-query-digest安装及分析
  3. JAVA入门级教学之(连接运算符)
  4. python修复不了_如何修复Python代码?
  5. c语言printf函数很长时间,C语言学习之printf()函数特别注意事项
  6. python prettytable表格列数太多_excel列数太多了怎么办
  7. java new一个对象的过程中发生了什么
  8. 【LeetCode笔记】299. 猜数字游戏 (Java、偏数学)
  9. 【LeetCode笔记】34. 在排序数组中查找元素的第一个和最后一个位置(Java、二分)
  10. iqc工作职责和工作内容_监理工程师工作职责