代码在 github

import tensorflow as tf
from absl.flags import FLAGS@tf.function
def transform_targets_for_output(y_true, grid_size, anchor_idxs):#这个函数分别对比某一类anchors (一共是三类,每一类对应不同的尺寸的box)#每一类box 对应的尺寸翻倍# y_true: (N, boxes, (x1, y1, x2, y2, class, best_anchor))#这里的N是样本的数量N = tf.shape(y_true)[0]# y_true_out: (N, grid, grid, anchors, [x1, y1, x2, y2, obj, class])#输出的张量尺寸#tf.shape(anchor_idxs)=3=len(anchor_idxs)y_true_out = tf.zeros((N, grid_size, grid_size, tf.shape(anchor_idxs)[0], 6))anchor_idxs = tf.cast(anchor_idxs, tf.int32)#这是动态数组indexes = tf.TensorArray(tf.int32, 1, dynamic_size=True)updates = tf.TensorArray(tf.float32, 1, dynamic_size=True)idx = 0#N 对应的是样本数量#二维遍历,i对应的是每一个样本for i in tf.range(N):#tf.shape(y_true)=[  N, 100,   6]#一 张图片最多识别100个目标,因为一幅图最多对应100个for j in tf.range(tf.shape(y_true)[1]):"""++++++++++++++++ x2,y2+                ++                ++                +x1,y1 ++++++++++++"""# x2=y_true[i][j][2]对应的是标记的矩形终点坐标#如果x2==0那么就没有这个类别 passif tf.equal(y_true[i][j][2], 0):continue#这里指的是y_true[i][j][5] 这个种类的anchor 是否在这个 anchor_idxs中anchor_eq = tf.equal(anchor_idxs, tf.cast(y_true[i][j][5], tf.int32))print(anchor_eq)# print(anchor_idxs.numpy(), '##############',y_true[i][j][5].numpy())print('-'*30+'>')# print(i,j)#如果y_true[i][j][5] 这个种类的anchor 是在这个 anchor_idxs中#即 anchor_idxs 存在一个 值为True if tf.reduce_any(anchor_eq):#这是box的坐标box = y_true[i][j][0:4]#box 的中点坐标box_xy = (y_true[i][j][0:2] + y_true[i][j][2:4]) / 2#找到标注的那个box 对应的anchor 对应的位置,这里重新编码了anchor_idx = tf.cast(tf.where(anchor_eq), tf.int32)#grid_xy是grid_size*grid_size 这个真实 box下anchor中心的坐标grid_xy = tf.cast(box_xy // (1/grid_size), tf.int32)# grid[y][x][anchor] = (tx, ty, bw, bh, obj, class)indexes = indexes.write(#id   i=样本编号(0-6),anchor中心坐标x,y    anchor 种类取值在[0,1,2]idx, [i, grid_xy[1], grid_xy[0], anchor_idx[0][0]])updates = updates.write(#对应的标注坐标 和        #1只是占位  类别idx, [box[0], box[1], box[2], box[3], 1, y_true[i][j][4]])idx += 1# tf.print(indexes.stack())# tf.print(updates.stack())#y_true_out.shape=[3, 104, 104, 3, 6]#3是样本数量#104是指的是box 的大小,每一个pixel都有可能是anchor 的中心点#所以就粗暴的给每一个pixel分配了一个内存空间#3 是同一个尺度的anhor 点有3个box #6对应 [x1, y1, x2, y2, class , anchor_class]##return tf.tensor_scatter_nd_update(y_true_out, indexes.stack(), updates.stack())def transform_targets(y_train, anchors, anchor_masks, size):y_outs = []#将图像分成32*32格#grid_size=13grid_size = size // 32# calculate anchor index for true boxesanchors = tf.cast(anchors, tf.float32)#anchors 是聚类出来的点,x,y分别是聚类框框的宽度和高度#这里是每个anchor 框框的面积anchor_area = anchors[..., 0] * anchors[..., 1]#box_wh的宽度-高度, box_wh.shape=[k, 100, 2],k是样本的数量box_wh = y_train[..., 2:4] - y_train[..., 0:2]#这里将box_wh从三维扩张到四维box_wh_expand=tf.expand_dims(box_wh, -2)# tf.tile是将量在某个或某几个维度上复制,这里是在第三个维度上复制,复制9个,因为一共9个锚点#box_wh.shape=[3, 100, 9, 2],从原来的一行两列变成9行两列'''box_wh[0][0]=<tf.Tensor: shape=(9, 2), dtype=float32, numpy=
array([[0.55466664, 0.32999998],[0.55466664, 0.32999998],[0.55466664, 0.32999998],[0.55466664, 0.32999998],[0.55466664, 0.32999998],[0.55466664, 0.32999998],[0.55466664, 0.32999998],[0.55466664, 0.32999998],[0.55466664, 0.32999998]], dtype=float32)>'''box_wh = tf.tile(box_wh_expand,(1, 1, tf.shape(anchors)[0], 1))#box_area.shape=[k, 100, 9]'''box_area[0][0]Out[362]: <tf.Tensor: shape=(9,), dtype=float32, numpy=array([0.18303998, 0.18303998, 0.18303998, 0.18303998, 0.18303998,0.18303998, 0.18303998, 0.18303998, 0.18303998], dtype=float32)>'''box_area = box_wh[..., 0] * box_wh[..., 1]#tf.minimum(A,B), A的维度为mn,B的维度为kn,且m=n,或者 n=1,就可以比较大小#intersection是交集#这里用到了矩阵的广播机制,分别与9个anchor box 进行比较#                           delta  xintersection = tf.minimum(box_wh[..., 0], anchors[..., 0]) * \tf.minimum(box_wh[..., 1], anchors[..., 1])  #delta y#交并比iou = intersection / (box_area + anchor_area - intersection)#找到和标记的框框最接近那个anchor ,输出anchor_id anchor_idx = tf.cast(tf.argmax(iou, axis=-1), tf.float32)anchor_idx = tf.expand_dims(anchor_idx, axis=-1)#这里的y_train.shape=3, 100, 6],最后一个维度是6#[x1,y1,x2,y2,class_id,anchor_idx]y_train = tf.concat([y_train, anchor_idx], axis=-1)for anchor_idxs in anchor_masks:y_outs.append(transform_targets_for_output(y_train, grid_size, anchor_idxs))grid_size *= 2return tuple(y_outs)import pickle
import numpy as np# data_output = open('data.pkl','wb')
# pickle.dump(kk1,data_output)
# data_output.close()# rb 以二进制读取
data_input = open('data.pkl','rb')
y_train = pickle.load(data_input)
data_input.close()size=416anchors = np.array([(10, 13), (16, 30), (33, 23), (30, 61), (62, 45),(59, 119), (116, 90), (156, 198), (373, 326)],np.float32) / 416
anchor_masks = np.array([[6, 7, 8], [3, 4, 5], [0, 1, 2]])#y_train.shape=(6, 100, 5)#6是样本数量#100是标签数量# 5[x1,y1,x2,y2,class]cc= transform_targets(y_train, anchors, anchor_masks, size)
对应的标注坐标 和   #1只是占位  类别,所以最后一个维度是6
倒数第二个维度3=len(anchor_masks[k]),k=0,1,2

[box[0], box[1], box[2], box[3], 1, y_true[i][j][4]])

cc[0].shape
Out[248]: TensorShape([6, 13, 13, 3, 6])cc[1].shape
Out[249]: TensorShape([6, 26, 26, 3, 6])cc[2].shape
Out[250]: TensorShape([6, 52, 52, 3, 6])

yolov3 数据预处理相关推荐

  1. YOLOV3林业病虫害数据集和数据预处理-paddle教程

    林业病虫害数据集和数据预处理方法介绍 在本课程中,将使用百度与林业大学合作开发的林业病虫害防治项目中用到昆虫数据集. 读取AI识虫数据集标注信息 AI识虫数据集结构如下: 提供了2183张图片,其中训 ...

  2. 【2022】全网最详细纯手工写YOLOv3之数据预处理,你必须看得懂(二)

    二. YOLO数据预处理流程 2.1 流程简述 原始数据一般是图片数据和标注数据,其中标注数据目前有两种,一种是VOC格式的.xml文件存储标注信息,另外一种标注格式CoCo用json来存储标注信息, ...

  3. 卷积在计算机中实现+pool作用+数据预处理目的+特征归一化+理解BN+感受野理解与计算+梯度回传+NMS/soft NMS

    一.卷积在计算机中实现 1.卷积 将其存入内存当中再操作(按照"行先序"): 这样就造成混乱. 故需要im2col操作,将特征图转换成庞大的矩阵来进行卷积计算,利用矩阵加速来实现, ...

  4. 机器学习PAL数据预处理

    机器学习PAL数据预处理 本文介绍如何对原始数据进行数据预处理,得到模型训练集和模型预测集. 前提条件 完成数据准备,详情请参见准备数据. 操作步骤 登录PAI控制台. 在左侧导航栏,选择模型开发和训 ...

  5. 深度学习——数据预处理篇

    深度学习--数据预处理篇 文章目录 深度学习--数据预处理篇 一.前言 二.常用的数据预处理方法 零均值化(中心化) 数据归一化(normalization) 主成分分析(PCA.Principal ...

  6. 目标检测之Faster-RCNN的pytorch代码详解(数据预处理篇)

    首先贴上代码原作者的github:https://github.com/chenyuntc/simple-faster-rcnn-pytorch(非代码作者,博文只解释代码) 今天看完了simple- ...

  7. 第七篇:数据预处理(四) - 数据归约(PCA/EFA为例)

    前言 这部分也许是数据预处理最为关键的一个阶段. 如何对数据降维是一个很有挑战,很有深度的话题,很多理论书本均有详细深入的讲解分析. 本文仅介绍主成分分析法(PCA)和探索性因子分析法(EFA),并给 ...

  8. 数据预处理--噪声_为什么数据对您的业务很重要-以及如何处理数据

    数据预处理--噪声 YES! Data is extremely important for your business. 是! 数据对您的业务极为重要. A human body has five ...

  9. 数据预处理(完整步骤)

    原文:http://dataunion.org/5009.html 一:为什么要预处理数据? (1)现实世界的数据是肮脏的(不完整,含噪声,不一致) (2)没有高质量的数据,就没有高质量的挖掘结果(高 ...

最新文章

  1. java1.5连接oracle12c_java1.5连接oracle12c
  2. php 编译ext目录下的,PHP编译安装后的目录和文件解释?
  3. noip2016 换教室
  4. 三元组顺序表表示的稀疏矩阵加法_Matlab入门教程 第 2 章 Matlab矩阵处理之稀疏矩阵...
  5. c语言调用oracle函数返回值吗,C语言通过值和引用函数
  6. 一文重新认识联邦学习
  7. shell 文本后几行_Shell和Vi编辑器
  8. 计算机系统 就业前景,计算机系统结构就业前景
  9. 线性表之带头双向循环链表
  10. R查看和更改工作路径的命令
  11. 关于通过邮箱找回密码的实现
  12. 熔断机制什么意思_熔断机制是什么意思?熔断机制的作用
  13. Vue(1706E)
  14. 赔97.6万元!腾讯一程序员违反竞业协议,三年白干了!
  15. 请详细解释下小波去噪的原理
  16. excel切片器_给我1分钟,让你的Excel表格好看些,再好看些!
  17. Hadoop 安装
  18. Mysql基础篇-23-触发器Tigger
  19. 若你喜欢怪人 其实我很美
  20. 怎样解决eclipse在线安装插件奇慢无比问题

热门文章

  1. 6.MATLAB变量——矩阵操作一
  2. 什么是OOP(面向对象编程)?
  3. WPF快速入门系列(8)——MVVM快速入门
  4. 最简单的composer 包 使用
  5. 关于sendmail报错“did not issue MAIL/EXPN/VRFY/ETRN during connection to
  6. 《算法设计手册》面试题解答 第三章:数据结构
  7. 通俗易懂地解决中文乱码问题(2) --- 分析解决Mysql插入移动端表情符报错 ‘incorrect string value: '\xF0......
  8. 数据库中字段类型对应C#中的数据类型
  9. java中list排序
  10. GCC全过程详解+剖析生成的.o文件(2)