深度学习(14)TensorFlow高阶操作三: 张量排序
深度学习(14)TensorFlow高阶操作三: 张量排序
- 一. Sort, argsort
- 1. 一维Tensor
- 2. 多维Tensor
- 二. Top_k
- 三. Top-k accuracy(Top-k应用)
- 1. 问题描述
- 2. 问题解决
- 3. 代码
Outline
- Sort/argsort(排序/在序列中的位置)
- Topk(前k个排序结果)
- Top-5 Acc.(Topk应用)
一. Sort, argsort
1. 一维Tensor
(1) a = tf.random.shuffle(tf.range(5))
: 创建一个[0~5]的Tensor,再将其随机打乱,打乱后a=[2, 0, 3, 4, 1];
(2) tf.sort(a, direction=‘DESCENDING’)
: 将a按降序排列,即按照由大到小进行排列,排列完为[4, 3, 2, 1, 0];
注: 升序就是ASCENDING
(3) tf.argsort(a, direction=‘DESCENDING’)
: 将a按照降序排列,然后得到每个元素在原来a中的位置。例如,4在原来a中的位置是3,那么argsort中的第1个元素就是3; 3在原来a中的位置是2,那么argsort中的第2个元素就是2; 2在原来a中的位置是0,那么argsort中的第3个元素就是0; 1在原来a中的位置是4,那么argsort中的第4个元素就是4; 0在原来a中的位置是1,那么argsort中的第5个元素就是1; 所以得到argsort=[3, 2, 0, 4, 1];
(4) tf.gather(a, idx)
: 将a按照idx中的索引来排序,得到[4, 3, 2, 1, 0];
2. 多维Tensor
(1) tf.sort(a)
: 将a升序排列。不指定direction就是默认升序排列;
(2) tf.dort(a, direction=‘DESCENDING’)
: 将a降序排列;
(3) idx = tf.argsort(a)
: 将a的升序排列进行索引排列;
二. Top_k
- Only return top-k values and indices
这个方法只能返回tok-k的值和索引。
(1) res = tf.math.top_k(a, 2)
: 将a中每行按照top2排列;
(2) res.indices
: 返回res的索引顺序;
(3) rea.values
: 返回res的值的顺序;
三. Top-k accuracy(Top-k应用)
1. 问题描述
- Prob: [0.1, 0.2, 0.3, 0.4]
表示数字识别中一个数字为“0”的概率为0.1,为“1”的概率为0.2,为“2”的概率为0.3,为“3”的概率为0.4; - Label:[2]
表示这个数字的标签值(label)为2; - Only consider top-1 prediction: [3]
如果我们只考虑1个最有可能的预测值,就是这个数字为“3”,那么显然与标签值“2”对不上号,那么就会导致正确率为0%; - Only consider top-2 prediction: [3, 2]
如果我们考虑2个最有可能的预测值,使用top-2排列为[3, 2],这时准确率就为100%(只要top-2中出现预测值与标签值相同的情况,准确率就记为1); - Only consider top-3 prediction: [3, 2, 1]
如果我们考虑3个最有可能的预测值,使用top-3排列为[3, 2, 1],这时准确率就为100%;
2. 问题解决
(1) target = tf.constant([2, 0])
: 表示标签(label)值,即第1个样本的真实值为“2”, 第2个样本的真实值为“0”;
(2) k_b = tf.math.top_k(prob, 3).indices
: 表示将prob中的数据按照top-3排列并给出索引,这个结果也相当于这两个样本每个样本是什么数字的可能性,即第1个样本最有可能是“2”,次有可能是“1”,再次有可能是“0”; 第2个样本最有可能是“1”,次有可能是“0”,再次有可能是“2”;
(3) k_b = tf.transpose(k_b, [1, 0])
: 将k_b做转置操作,即将维度由[0, 1]变为[1, 0],这样的话,第1行就对应这两个样本最有可能(即top1)的预测值,第2行就对应这两个样本次有可能(即top2)的预测值;
(4) target = tf.broadcast_to(target, [3, 2])
: 使用Broadcasting方法将target的维度由[2]扩展为[3, 2]([2] →\to→ [1, 2] →\to→ [3, 2]);
接下来我们需要将预测值与标签值进行比较操作,例如: 使用top1的预测值比较,即使用[2, 1]和[2, 0]进行比较,结果为[True, False],那么其准确率就为50%,因为只有1个数字识别正确; 使用top2的预测值比较,即使用[[2, 1], [1, 0]]和[[2, 0], [2, 0]]进行比较,那么我们只需要保证top2的每个数字的预测值有1个与标签值相等即可,结果为[True, True],那么其准确率就为100%; 使用top3的预测值比较,即使用[[2, 1], [1, 0], [0, 2]]和[[2, 0], [2, 0], [2, 0]]进行比较,那么我们只需要保证top3的每个数字的预测值有1个与标签值相等即可,结果为[True, True, True],那么其准确率还是100%;
(5) def accuracy(output, target, topk=(1,))
: 定义accuracy方法,其中output.shape=[2, 3],即共有2个样本,每个样本有3种分类; target.shape=[2],即target为这2个样本的标签值(label); topk=(1,)表示取值top1;
(6) maxk = max(topk)
: 表示取最大的top值,例如topk=(1,),那么maxk就取1,也就意味着在后边的操作中我们使用top1计算accuracy; topk=(1, 2),那么maxk就取2,也就意味着在后边的操作中我们使用top2计算accuracy; topk=(1, 2, 3),那么maxk就取3,也就意味着在后边的操作中我们使用top3计算accuracy;
(7) batch_size = target.shape[0]
: target.shape=[2],所以batch_size=2;
(8) pred = tf.math.topk(output, maxk).indicies
: 将每个样本预测值按照topk的索引进行排列,例如,如果按照top1排列,那么tf.math.topk(output, 1).indicies
就为:
样本 | 预测值 |
---|---|
b0b_0b0 | 2 |
b1b_1b1 | 1 |
也就是:
[21]\begin{bmatrix}2\\1\end{bmatrix}[21]
如果按照top2排列,那么tf.math.topk(output, 2).indicies
就为:
样本 | 预测值top1 | 预测值top2 |
---|---|---|
b0b_0b0 | 2 | 1 |
b1b_1b1 | 1 | 0 |
也就是:
[2110]\begin{bmatrix}2&1\\1&0\end{bmatrix}[2110]
如果按照top1排列,那么tf.math.topk(output, 3).indicies
就为:
样本 | 预测值top1 | 预测值top2 | 预测值top3 |
---|---|---|---|
b0b_0b0 | 2 | 1 | 0 |
b1b_1b1 | 1 | 0 | 2 |
也就是:
[210102]\begin{bmatrix}2&1&0\\1&0&2\end{bmatrix}[211002]
(9) pred = tf.transpose(pred, perm=[1, 0])
: 将pred进行转置操作,也就是将pred的维度顺序索引由[0, 1]变为[1, 0],即:
如果按照top1排列,那么pred就为:
[21]\begin{bmatrix}2&1\end{bmatrix}[21]
如果按照top2排列,那么pred就为:
[2110]\begin{bmatrix}2&1\\1&0\end{bmatrix}[2110]
如果按照top3排列,那么pred就为:
[211002]\begin{bmatrix}2&1\\1&0\\0&2\end{bmatrix}⎣⎡210102⎦⎤
(10) target_ = tf.broadcast_to(target, pred.shape)
将target进行Broadcasting操作,使target.shape=pred.shape,即:
如果按照top1排列,那么target_就为:
[20]→[20]\begin{bmatrix}2&0\end{bmatrix}\to\begin{bmatrix}2&0\end{bmatrix}[20]→[20]
如果按照top2排列,那么target_就为:
[20]→[2020]\begin{bmatrix}2&0\end{bmatrix}\to\begin{bmatrix}2&0\\2&0\end{bmatrix}[20]→[2200]
如果按照top3排列,那么target_就为:
[20]→[202020]\begin{bmatrix}2&0\end{bmatrix}\to\begin{bmatrix}2&0\\2&0\\2&0\end{bmatrix}[20]→⎣⎡222000⎦⎤
(11) correct = tf.equal(pred, target_)
: 将pred与target_做比较操作,即:
如果按照top1排列,那么correct就为:
[TrueFalse]\begin{bmatrix}True&False\end{bmatrix}[TrueFalse]
如果按照top2排列,那么pred就为:
[TrueFalseFalseTrue]\begin{bmatrix}True&False\\False&True\end{bmatrix}[TrueFalseFalseTrue]
如果按照top3排列,那么pred就为:
[TrueFalseFalseTrueFalseFalse]\begin{bmatrix}True&False\\False&True\\False&False\end{bmatrix}⎣⎡TrueFalseFalseFalseTrueFalse⎦⎤
(12) res[]
: 新建res列表,用于存放准确率结果;
(13) correct_k = tf.cast(tf.reshape(correct[:k], [-1]), dtype=tf.float32)
:
其中correct[:k]
表示correct的前k-1行;
tf.reshape(correct[:k], [-1])
表示将correct的前k-1行进行reshape操作,即: 如果按照top1排列,那么correct[:k].shape就由[1, 2]变为[1]; 如果按照top2排列,那么correct[:k].shape就由[2, 2]变为[2]; 如果按照top3排列,那么correct[:k].shape就由[3, 2]变为[3];
tf.cast(tf.reshape(correct[:k], [-1]), dtype=tf.float32)
表示将reshape后的correct里的数据类型变为tf.float32,即将True变为1,False变为0;
(14) correct_k = tf.reduce_sum(correct_k)
: 将correct_k中的每行数据求和,即:
如果按照top1排列,那么correct_k=1+0=1correct\_k=1+0=1correct_k=1+0=1;
如果按照top2排列,那么correct_k=1+0+0+1=2correct\_k=1+0+0+1=2correct_k=1+0+0+1=2;
如果按照top3排列,那么correct_k=1+0+0+1+0+0=2correct\_k=1+0+0+1+0+0=2correct_k=1+0+0+1+0+0=2;
(15) acc = float(correct_k / batch_size)
: 计算top_k的准确率;
(16) res.append(acc)
: 将准确率acc放入到结果res中;
3. 代码
import tensorflow as tf
import osos.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
tf.random.set_seed(2467)def accuracy(output, target, topk=(1,)):maxk = max(topk)batch_size = target.shape[0]pred = tf.math.top_k(output, maxk).indicespred = tf.transpose(pred, perm=[1, 0])target_ = tf.broadcast_to(target, pred.shape)# [10, b]correct = tf.equal(pred, target_)res = []for k in topk:correct_k = tf.cast(tf.reshape(correct[:k], [-1]), dtype=tf.float32)correct_k = tf.reduce_sum(correct_k)acc = float(correct_k* (100.0 / batch_size) )res.append(acc)return resoutput = tf.random.normal([10, 6])
output = tf.math.softmax(output, axis=1)
target = tf.random.uniform([10], maxval=6, dtype=tf.int32)
print('prob:', output.numpy())
pred = tf.argmax(output, axis=1)
print('pred:', pred.numpy())
print('label:', target.numpy())acc = accuracy(output, target, topk=(1,2,3,4,5,6))
print('top-1-6 acc:', acc)
运行结果如下:
可以看到,随着top-k的增加,准确率不断提升,在数字识别种类有6种(即0~5)的情况下,如果我们计算top-6的话,准确率是一定会达到100%的。
参考文献:
[1] 龙良曲:《深度学习与TensorFlow2入门实战》
深度学习(14)TensorFlow高阶操作三: 张量排序相关推荐
- 深度学习(16)TensorFlow高阶操作五: 张量限幅
深度学习(16)TensorFlow高阶操作五: 张量限幅 1. clip_by_value 2. relu 3. clip_by_norm 4. Gradient clipping 5. 梯度爆炸实 ...
- 深度学习(17)TensorFlow高阶操作六: 高阶OP
深度学习(17)TensorFlow高阶操作六: 高阶OP 1. Where(tensor) 2. where(cond, A, B) 3. 1-D scatter_nd 4. 2-D scatter ...
- 深度学习(15)TensorFlow高阶操作四: 填充与复制
深度学习(15)TensorFlow高阶操作四: 填充与复制 1. Pad 2. 常用于Image Padding 3. tile 4. tile VS broadcast_to Outline pa ...
- 深度学习(12)TensorFlow高阶操作一: 合并与分割
深度学习(12)TensorFlow高阶操作一: 合并与分割 1. concat 2. stack: create new dim 3. Dim mismatch 4. unstuck 5. spli ...
- 深度学习中的高阶特征
由于自己研究方向为基于高阶的图像分类,故在这里对相关论文做一个简单的划分和总结. 按照计算高阶的层,位于卷积神经网络的位置划分,可以分为: 网络末端 网络中部 2022-05-24 update (C ...
- 深度学习入门及高阶经典课程、教程等资源合集(长期整理)
深度学习资料 经典课程 MIT 图分析 yale 图统计推断 standford 机器学习 stanford 机器学习系统设计 stanford 实用机器学习 纽约大学深度学习2020 吴恩达深度学习 ...
- 深度学习——在TensorFlow中查看和设定张量的形态
参考书籍:<深度学习--基于Python语言和TensorFlow平台> import tensorflow as tfx = tf.placeholder(dtype=tf.float3 ...
- TensorFlow高阶 API: keras教程-使用tf.keras搭建mnist手写数字识别网络
TensorFlow高阶 API:keras教程-使用tf.keras搭建mnist手写数字识别网络 目录 TensorFlow高阶 API:keras教程-使用tf.keras搭建mnist手写数字 ...
- Tensorflow学习四---高阶操作
Tensorflow学习四-高阶操作 Merge and split 1.tf.concat 拼接 a = tf.ones([4,32,8]) b = tf.ones([2,32,8]) print( ...
最新文章
- ZBar与ZXing使用后感觉
- ubuntu下python2完全卸载
- 如何从零开始学python_从零开始学Python【4】--numpy
- apache 目录访问加密 简单
- JPA使用指南 javax.persistence的注解配置
- 谈家政O2O平台的出路
- 提升win双屏体验_海信双屏A6L评测,在自由阅读中植入护眼水墨屏
- Windows系统帮助中心程序的0day漏洞
- 整流、开关、肖特基区别
- 查找在Git中删除文件的时间
- ASPNET MVC Error 403.14
- 汉诺塔c 语言程序代码,汉诺塔 (C语言代码)
- STM32接入机智云--实现数据上传和命令下发
- Android SVG矢量图/矢量动画、Vector和VectorDrawable矢量图及动画,减少App Size
- 常用密码技术-对称加密
- 木瓜奇迹洗服务器维护,木瓜奇迹各种职业+点法
- 赔 1100 万美元!谷歌招聘年龄歧视
- MSVC编译器-C2001 常量中有换行符错误解决方法
- linux查看是什么系统
- 龙岭迷窟真的这么好看?今天我们就用 Java 爬取豆瓣数据好好分析一下!
热门文章
- .net mysql 备份_windows mysql 自动备份的几种方法
- html 自动滚动标签,HTML滚动标签(marquee标签)
- wrapper包装java_java Object 类 与 Wrapper包装类
- java怎么获取ajax_Java学习路线
- Gradle系列(三):项目实践
- Swif语法基础 要点归纳(一)
- GCJ 2008 Round 1A Minimum Scalar Product( 水 )
- 图解WildFly8之Servlet容器Undertow剖析
- mysql 高性能压力测试(总结了好久)
- apache rewrite 二级域名