代码地址 https://github.com/AITTSMD/MTCNN-Tensorflow
这里我就不在进行MTCNN的介绍了。分析的再清楚都不如从源码的实现去分析。
Talk is cheap, just show me the code。
MTCNN主要分为三个网络 PNet RNet ONet
其中PNet是个全卷积网络 这是和RNet ONet最大的区别
由于篇幅有限 分成多篇进行分析
MTCNN源码详细解读(2)- PNet的训练和数据集的构建
MTCNN源码详细解读(3)- RNet的训练和数据集的构建

def P_Net(inputs,label=None,bbox_target=None,landmark_target=None,training=True):#define common param# 为相同的卷积操作 设置一样的初始化参数和激活函数preluwith slim.arg_scope([slim.conv2d],activation_fn=prelu,weights_initializer=slim.xavier_initializer(),biases_initializer=tf.zeros_initializer(),weights_regularizer=slim.l2_regularizer(0.0005), padding='valid'):# PNet 训练输入时(batch_size, 12, 12, 3)# (batch_size, 10, 10, 10)net = slim.conv2d(inputs, 10, 3, stride=1,scope='conv1')# (batch_size, 5, 5, 10)net = slim.max_pool2d(net, kernel_size=[2,2], stride=2, scope='pool1', padding='SAME')# (batch_size, 3, 3, 16)net = slim.conv2d(net,num_outputs=16,kernel_size=[3,3],stride=1,scope='conv2')# (batch_size, 1, 1, 32)net = slim.conv2d(net,num_outputs=32,kernel_size=[3,3],stride=1,scope='conv3')#batch*H*W*2# 用 1 * 1卷积核来做输出# 这里是类别输出 虽然是二分类 但是作者用2的维度来表示  第一位表示不是人脸置信度 第二位表示是人脸的置信度conv4_1 = slim.conv2d(net,num_outputs=2,kernel_size=[1,1],stride=1,scope='conv4_1',activation_fn=tf.nn.softmax)#batch*H*W*4# 这里就是输出坐标的偏移 4个值bbox_pred = slim.conv2d(net,num_outputs=4,kernel_size=[1,1],stride=1,scope='conv4_2',activation_fn=None)#batch*H*W*10# 这里是landmark五个点的坐标就是10个值 回归值landmark_pred = slim.conv2d(net,num_outputs=10,kernel_size=[1,1],stride=1,scope='conv4_3',activation_fn=None)if training:#batch*2# (batch, 1, 1, 2) 去掉dim=[1, 2]两个维度# 下面也是同理cls_prob = tf.squeeze(conv4_1,[1,2],name='cls_prob')# 计算分类损失cls_loss = cls_ohem(cls_prob,label)#batchbbox_pred = tf.squeeze(bbox_pred,[1,2],name='bbox_pred')# 计算坐标损失bbox_loss = bbox_ohem(bbox_pred,bbox_target,label)#batch*10landmark_pred = tf.squeeze(landmark_pred,[1,2],name="landmark_pred")# 计算landMark损失landmark_loss = landmark_ohem(landmark_pred,landmark_target,label)accuracy = cal_accuracy(cls_prob,label)L2_loss = tf.add_n(slim.losses.get_regularization_losses())return cls_loss,bbox_loss,landmark_loss,L2_loss,accuracy

网络结构看上去简单清晰 下面分析下三个损失函数
1 分类损失cls_ohem 常用的交叉熵损失

def cls_ohem(cls_prob, label):# 构建一个和label shape一致的0数组# (batch, )zeros = tf.zeros_like(label)#label=-1 --> label=0net_factory# 对于label小于0的过滤掉 label {0, 1}的保留# 这里先简单说下 PNet总共有三种label 0-negative 1-positive -1-part  -2-landmark 后面在数据集构建的时候会详细说明# 对于分类损失只需要计算 label为 0, 1的图片label_filter_invalid = tf.where(tf.less(label,0), zeros, label)# (batch_size, 2) --> size: batch_size * 2num_cls_prob = tf.size(cls_prob)# reshape 后 (batch_size * 2, 1)  为什么这么做呢因为这里对二分类用了2个输出表示 所有每个位置的值度需要计算损失# 如果用1个值来表示就没必要这么麻烦cls_prob_reshape = tf.reshape(cls_prob,[num_cls_prob,-1])# 将上面的label转成intlabel_int = tf.cast(label_filter_invalid,tf.int32)# cls_prob shape 还是 (batch_size, 2) 所有 num_row就是batchnum_row = tf.to_int32(cls_prob.get_shape()[0])# 这里对num_row * 2 因为有两个值表示置信度  第一个位置不是人脸的 第二个位置是人脸的# 这里详细分析下为什么乘2# 假设batch_size=5  row = [0, 2, 4, 6, 8]  假设我们的label经过过滤后[1, 0, 0, 0, 1] # 相加变成 [1, 2, 4, 6, 9] 也就是说如果第i张图片label为1 就把第i张图片输出第二个位置的置信度值取出来 对于0的不变就是第一个位置置信度# 有可能会有人说那过滤掉的label也不是0嘛 后label为0的没区分开来 这里不用担心 坐着下面会做mask 这是个常用手段 不需要参与计算的位置都mask掉row = tf.range(num_row)*2indices_ = row + label_int# 从 (batch_size *2, 1)中取出对应位置的label进行损失计算label_prob = tf.squeeze(tf.gather(cls_prob_reshape, indices_))# 计算负的log损失loss = -tf.log(label_prob+1e-10)zeros = tf.zeros_like(label_prob, dtype=tf.float32)ones = tf.ones_like(label_prob,dtype=tf.float32)# 这里就是添加mask  对于label小于0的mask掉# 下面就是简单的求和valid_inds = tf.where(label < zeros,zeros,ones)num_valid = tf.reduce_sum(valid_inds)keep_num = tf.cast(num_valid*num_keep_radio,dtype=tf.int32)#set 0 to invalid sampleloss = loss * valid_indsloss,_ = tf.nn.top_k(loss, k=keep_num)return tf.reduce_mean(loss)

2 边框回归损失bbox_ohem 这里用的均方误差或者smoothL1 和RCNN提出的smoothL1一致

#label=1 or label=-1 then do regression
def bbox_ohem(bbox_pred,bbox_target,label):zeros_index = tf.zeros_like(label, dtype=tf.float32)ones_index = tf.ones_like(label,dtype=tf.float32)# 对label为-1, 1的做边框回归valid_inds = tf.where(tf.equal(tf.abs(label), 1),ones_index,zeros_index)#(batch,)# 下面就是简单的均方误差square_error = tf.square(bbox_pred-bbox_target)square_error = tf.reduce_sum(square_error,axis=1)#keep_num scalarnum_valid = tf.reduce_sum(valid_inds)#keep_num = tf.cast(num_valid*num_keep_radio,dtype=tf.int32)keep_num = tf.cast(num_valid, dtype=tf.int32)#keep valid index square_errorsquare_error = square_error*valid_inds# 这里有个小技巧 支取topK个用来做反向传播# 思想就是训练误差最大的topk个_, k_index = tf.nn.top_k(square_error, k=keep_num)square_error = tf.gather(square_error, k_index)return tf.reduce_mean(square_error)

3 就是landmark损失 landmark_ohem

def landmark_ohem(landmark_pred,landmark_target,label):''':param landmark_pred::param landmark_target::param label::return: mean euclidean loss'''#keep label =-2  then do landmark detection# 对于landmark的样本label = -2 所以这里需要找到label为-2的样本ones = tf.ones_like(label,dtype=tf.float32)zeros = tf.zeros_like(label,dtype=tf.float32)valid_inds = tf.where(tf.equal(label,-2),ones,zeros)# 这个和边框回归损失是一致的 都市MSE损失 然后选取loss最大的来进行反向传播square_error = tf.square(landmark_pred-landmark_target)square_error = tf.reduce_sum(square_error,axis=1)num_valid = tf.reduce_sum(valid_inds)#keep_num = tf.cast(num_valid*num_keep_radio,dtype=tf.int32)keep_num = tf.cast(num_valid, dtype=tf.int32)square_error = square_error*valid_inds_, k_index = tf.nn.top_k(square_error, k=keep_num)square_error = tf.gather(square_error, k_index)return tf.reduce_mean(square_error)

最后对于RNet和ONet的网络结构和PNet基本差不多,最大差别就是 RNet和ONet不是全卷机网络最后接了fc

def R_Net(inputs,label=None,bbox_target=None,landmark_target=None,training=True):with slim.arg_scope([slim.conv2d],activation_fn = prelu,weights_initializer=slim.xavier_initializer(),biases_initializer=tf.zeros_initializer(),weights_regularizer=slim.l2_regularizer(0.0005),                        padding='valid'):net = slim.conv2d(inputs, num_outputs=28, kernel_size=[3,3], stride=1, scope="conv1")net = slim.max_pool2d(net, kernel_size=[3, 3], stride=2, scope="pool1", padding='SAME')net = slim.conv2d(net,num_outputs=48,kernel_size=[3,3],stride=1,scope="conv2")net = slim.max_pool2d(net,kernel_size=[3,3],stride=2,scope="pool2")net = slim.conv2d(net,num_outputs=64,kernel_size=[2,2],stride=1,scope="conv3")fc_flatten = slim.flatten(net)fc1 = slim.fully_connected(fc_flatten, num_outputs=128,scope="fc1")#batch*2cls_prob = slim.fully_connected(fc1,num_outputs=2,scope="cls_fc",activation_fn=tf.nn.softmax)#batch*4bbox_pred = slim.fully_connected(fc1,num_outputs=4,scope="bbox_fc",activation_fn=None)#batch*10landmark_pred = slim.fully_connected(fc1,num_outputs=10,scope="landmark_fc",activation_fn=None)#trainif training:cls_loss = cls_ohem(cls_prob,label)bbox_loss = bbox_ohem(bbox_pred,bbox_target,label)accuracy = cal_accuracy(cls_prob,label)landmark_loss = landmark_ohem(landmark_pred,landmark_target,label)L2_loss = tf.add_n(slim.losses.get_regularization_losses())return cls_loss,bbox_loss,landmark_loss,L2_loss,accuracyelse:return cls_prob,bbox_pred,landmark_pred

MTCNN源码详细解读(1)- PNet/RNet/ONet的网络结构和损失函数相关推荐

  1. AFL(American Fuzzy Lop)源码详细解读(1)

    AFL(American Fuzzy Lop)源码详细解读(1) 多亏大佬们的文章,对读源码帮助很大: https://eternalsakura13.com/2020/08/23/afl/ http ...

  2. AFL(American Fuzzy Lop)源码详细解读(3)

    AFL(American Fuzzy Lop)源码详细解读(3) 本篇是关于主循环阶段的内容,整个AFL最核心的部分,篇幅较长.最后简述一下afl_fuzz整体流程. 多亏大佬们的文章,对读源码帮助很 ...

  3. AFL(American Fuzzy Lop)源码详细解读(2)

    AFL(American Fuzzy Lop)源码详细解读(2) 本篇是关于 dry run (空跑.演练) 阶段的内容,一直到主循环之前. 多亏大佬们的文章,对读源码帮助很大: https://et ...

  4. 【Vue源码解读】万行源码详细解读

    前言 Vue2 的源码2年前粗略的看过一遍,重点在对响应式属性.对象监听.watch.computed.生命周期等内容的理解,但好记忆不如烂笔头,当初没有做笔记,现在重读一遍,针对重点内容详细解读并记 ...

  5. 【原理+源码详细解读】从Transformer到ViT

    文章目录 参考文献 简介 Transformer架构 Position Encoding Self-attention Multi-head Self-attention Masked Multi-H ...

  6. WannaCry 勒索病毒复现及分析,蠕虫传播机制全网源码详细解读 | 原力计划

    作者 | 杨秀璋,责编 | 夕颜 来源 | CSDN博客 头图 | CSDN 下载自东方 IC 出品 | CSDN(ID:CSDNnews) 这篇文章将详细讲解WannaCry蠕虫的传播机制,带领大家 ...

  7. WannaCry勒索病毒复现及分析,蠕虫传播机制全网源码详细解读 | 原力计划

    作者 | 杨秀璋 编辑 | 夕颜 题图 | 东方 IC 出品 | CSDN(ID:CSDNnews) 这篇文章将详细讲解WannaCry蠕虫的传播机制,带领大家详细阅读源代码,分享WannaCry勒索 ...

  8. [网络安全自学篇] 七十三.WannaCry勒索病毒复现及分析(四)蠕虫传播机制全网源码详细解读

    这是作者网络安全自学教程系列,主要是关于安全工具和实践操作的在线笔记,特分享出来与博友们学习,希望您喜欢,一起进步.前文分享了逆向分析OllyDbg动态调试工具的基本用法,包括界面介绍.常用快捷键和T ...

  9. VueRouter源码详细解读

    路由模式 1. hash 使用 URL hash 值来作路由.支持所有浏览器,包括不支持 HTML5 History Api 的浏览器. Hash URL,当 # 后面的哈希值发生变化时,不会向服务器 ...

最新文章

  1. NSUserDefaults的用法
  2. 从VirtualBox虚拟主机访问NAT客户机的方法
  3. (转)Web Framework 的速度与激情 16 正式上映
  4. java二叉查找算法_Java手写二叉搜索树算法
  5. c语言程序设计中三子棋游戏,C语言实现简易版三子棋游戏
  6. c# datagridview 绑定mysql_c#简单的数据库查询与绑定DataGridView。
  7. html5 自动连线,基于html5二个div 连线
  8. lua redisson执行lua脚本
  9. zip和rar压缩文件的区别
  10. 使用curl清理Elasticsearch数据方法
  11. 基于JSP动漫论坛的设计与实现
  12. python蒙特卡洛方法圆周率_使用Python语言的蒙特卡洛方法计算圆周率π的一种实现...
  13. 这种性生活伤女人尿道
  14. 计算机网络连接图标 红叉,win7电脑的网络连接图标出现红叉以及一直转圈的原因和解决方法...
  15. 软件开发人员的职业发展规划
  16. 让大数据告诉你,网红“小龙虾”究竟有多火
  17. 计算机操作系统唤醒原语,计算机操作系统原语分析(范文).doc
  18. C#基础知识---飞行棋小游戏
  19. 记录下今天的搜索成果
  20. go操作MongoDB

热门文章

  1. 分布式数据库服务器时钟同步(NTP网络时钟同步)北斗卫星同步时钟起到关键性作用
  2. 懒惰使人沉睡;懈怠的人必受饥饿。
  3. RZ,NRZ,NRZI
  4. 相较国外代码托管平台 gitlab,咱们中国自己的代码托管平台有哪些优势?
  5. 发送文件的过程计算机,用电脑给别人传文件的方法步骤图
  6. IDEA小技巧之痛苦面具 主菜单不见了怎么办?
  7. UOJ#748-[UNR #6]机器人表演【dp】
  8. 海康威视网络摄像头通过浏览器网页的配置流程
  9. Galera/mysql 集群 备忘
  10. ubuntu机械盘写入cannot be copied because you do not have permissions to create it in the destination.