tfrecords文件存放在文件 "../../DATA/imglists/PNet/train_PNet_landmark.tfrecord_shuffle中,接下来用它来训练PNet网络,即文件train_PNet.py;

def train_PNet(base_dir, prefix, end_epoch, display, lr):"""train PNet:param dataset_dir: tfrecord 文件路径:param prefix:  模型存放位置:param end_epoch: 训练循环次数:param display::param lr: 学习率:return:"""net_factory = P_Net  #加载神经网络结构train(net_factory,prefix, end_epoch, base_dir, display=display, base_lr=lr) #进行训练
if __name__ == '__main__':#data pathbase_dir = '../../DATA/imglists/PNet'model_name = 'MTCNN'model_path = '../data/%s_model/PNet_landmark/PNet' % 模型存放路径          prefix = model_pathend_epoch = 30display = 100lr = 0.001   #学习率train_PNet(base_dir, prefix, end_epoch, display, lr)

1、首先介绍net_factory=P_Net

def P_Net(inputs,label=None,bbox_target=None,landmark_target=None,training=True):#define common paramwith slim.arg_scope([slim.conv2d],activation_fn=prelu,weights_initializer=slim.xavier_initializer(),biases_initializer=tf.zeros_initializer(),weights_regularizer=slim.l2_regularizer(0.0005), padding='valid'):#初始化卷积层print(inputs.get_shape())    #输出input的shapnet = slim.conv2d(inputs, 10, 3, stride=1,scope='conv1') #num_outputs=10,kernel_size=[3,3],#卷积操作,卷积核的个数是10,卷积核的形式是[3,3],步长为1,其余的参数和上面的slim.arg_scope一样_activation_summary(net) #数据的记录print(net.get_shape())net = slim.max_pool2d(net, kernel_size=[2,2], stride=2, scope='pool1', padding='SAME') #池化_activation_summary(net)print(net.get_shape())net = slim.conv2d(net,num_outputs=16,kernel_size=[3,3],stride=1,scope='conv2')_activation_summary(net)print(net.get_shape())net = slim.conv2d(net,num_outputs=32,kernel_size=[3,3],stride=1,scope='conv3')_activation_summary(net)print(net.get_shape())#batch*H*W*2conv4_1 = slim.conv2d(net,num_outputs=2,kernel_size=[1,1],stride=1,scope='conv4_1',activation_fn=tf.nn.softmax)  #输出人脸类别:人脸,非人脸_activation_summary(conv4_1)print (conv4_1.get_shape())#batch*H*W*4bbox_pred = slim.conv2d(net,num_outputs=4,kernel_size=[1,1],stride=1,scope='conv4_2',activation_fn=None) #人脸边框回归_activation_summary(bbox_pred)print (bbox_pred.get_shape())#batch*H*W*10landmark_pred = slim.conv2d(net,num_outputs=10,kernel_size=[1,1],stride=1,scope='conv4_3',activation_fn=None) #人脸特征点坐标预测_activation_summary(landmark_pred)print (landmark_pred.get_shape())# add projectors for visualization#cls_prob_original = conv4_1 #bbox_pred_original = bbox_predif training:#batch*2# calculate classification losscls_prob = tf.squeeze(conv4_1,[1,2],name='cls_prob')  #tf.squeeze()删除conv4_1中所指定位置大小是1的维度,conv4_1=[batch,1,1,2],变成了[batch,2]cls_loss = cls_ohem(cls_prob,label)  #获得人脸分类训练的loss值#batch*4# cal bounding box error, squared sum errorbbox_pred = tf.squeeze(bbox_pred,[1,2],name='bbox_pred')   #=[batch,1,1,4],变成了[batch,4]bbox_loss = bbox_ohem(bbox_pred,bbox_target,label)           #获得人脸框训练的loss值#batch*10landmark_pred = tf.squeeze(landmark_pred,[1,2],name="landmark_pred")  #[batch,1,1,10],变成了[batch,10]landmark_loss = landmark_ohem(landmark_pred,landmark_target,label)    #获得人脸特征点训练的loss值accuracy = cal_accuracy(cls_prob,label)    #人脸分类精度L2_loss = tf.add_n(slim.losses.get_regularization_losses())return cls_loss,bbox_loss,landmark_loss,L2_loss,accuracy#testelse:#测试时,batch_size = 1cls_pro_test = tf.squeeze(conv4_1, axis=0)bbox_pred_test = tf.squeeze(bbox_pred,axis=0)landmark_pred_test = tf.squeeze(landmark_pred,axis=0)return cls_pro_test,bbox_pred_test,landmark_pred_test

1、_activation_summary函数,输入一个张量:

def _activation_summary(x):tensor_name = x.op.nameprint('load summary for : ',tensor_name) #打印tensor的名字tf.summary.histogram(tensor_name + '/activations',x)#以直方图的形式显示tensor在训练过程的值的分布情况

histogram(name, values, collections=None, family=None)输出一个直方图 .

  • name:生成的节点名称.作为TensorBoard中的一个系列名称.
  • values:一个实数张量.用于构建直方图的值.
  • collections:图形集合键的可选列表.添加新的summary操作到这些集合中.默认为GraphKeys.SUMMARIES.
  • family: summary标签名称的前缀,用于在Tensorboard上显示的标签名称.(可选项)

tf.summary.histogram()将输入的一个任意大小和形状的张量压缩成一个由宽度和数量组成的直方图数据结构.假设输入 [0.5, 1.1, 1.3, 2.2, 2.9, 2.99],则可以创建三个bin,分别包含0-1之间/1-2之间/2-3之间的所有元素,即三个bin中的元素分别为[0.5]/[1.1,1.3]/[2.2,2.9,2.99]. 这样,通过可视化张量在不同时间点的直方图来显示某些分布随时间变化的情况
from:https://blog.csdn.net/akadiao/article/details/79551180

2、squeeze函数返回一个张量,这个张量是将原始input中所有维度为1的那些维都删掉的结果,axis可以用来指定要删掉的为1的维度,此处要注意指定的维度必须确保其是1

# 't' 是一个维度[1, 2, 1, 3, 1, 1]的张量
tf.shape(tf.squeeze(t, [2, 4]))  # [1, 2, 3, 1],标号从零开始,只删掉了2和4维的1

3、cls_loss = cls_ohem(cls_prob,label) #获得人脸分类训练的loss值

def cls_ohem(cls_prob, label):#cls_prob的shape是[384,2]#label的shape是[384]zeros = tf.zeros_like(label)   #建立一个和label相同shape的全0数组#label=-1 --> label=0net_factory##将正样本的label保持为1,负样本label为0,其他的的两个样本label值为0。pos -> 1, neg -> 0, others -> 0label_filter_invalid = tf.where(tf.less(label,0), zeros, label)num_cls_prob = tf.size(cls_prob)  #返回cls_prob中元素的个数,384*2cls_prob_reshape = tf.reshape(cls_prob,[num_cls_prob,-1]) #reshape成768行label_int = tf.cast(label_filter_invalid,tf.int32)#获取class_prob的行数,384num_row = tf.to_int32(cls_prob.get_shape()[0])#row = [0,2,4.....]row = tf.range(num_row)*2# 因为conv4_1输出的的shape是[batch,2],每一张图片经过网络后,再经过softmax函数,输出的的是两个概率值# 第一个值表示非人脸的概率,第二个值表示是人脸的概率,加起来的和为1,此时的label_int仍是[batch]# 只是里面只有pos样本对应的label值是1,其余均为0,所以row + label_int表示的是cls_prob_reshape里面pos样本的索引indices_ = row + label_int# 使用tf.gather函数将pos样本提取出来,再使用tf.squeeze函数将shape变成(384,)# 此时label_prob是里面有384个概率,是pos样本对应的概率和非pos样本对应的概率label_prob = tf.squeeze(tf.gather(cls_prob_reshape, indices_))loss = -tf.log(label_prob+1e-10)  #是一个二分类问题,使用的交叉熵损失函数。加上1e-10,是为了防止里面的label_prob值太小,输出为负无穷zeros = tf.zeros_like(label_prob, dtype=tf.float32)ones = tf.ones_like(label_prob,dtype=tf.float32)# set pos and neg to be 1, rest to be 0valid_inds = tf.where(label < zeros,zeros,ones)num_valid = tf.reduce_sum(valid_inds)#获得正样本和负样本的数量keep_num = tf.cast(num_valid*num_keep_radio,dtype=tf.int32)#num_keep_radio:只计算前70%的损失#FILTER OUT PART AND LANDMARK DATAloss = loss * valid_inds     #去除掉part样本和landmark样本(因为他们对应的valid_inds=0)loss,_ = tf.nn.top_k(loss, k=keep_num)  #得到loss值中大小排前百分之七十的样本return tf.reduce_mean(loss)
  • cast(x, dtype, name=None)

函数的作用是执行 tensorflow 中张量数据类型转换,比如读入的图片如果是int8类型的,一般在要在训练前把图像的数据格式转换为float32。

第一个参数 x:   待转换的数据(张量)
第二个参数 dtype: 目标数据类型
第三个参数 name: 可选参数,定义操作的名称

  • tf.less(label,0):

less( x, y, name=None )函数:以元素方式返回(x <y)的真值,返回值:该函数返回 bool 类型的张量.

  • tf.where(condition, x=None, y=None, name=None):

condition是bool型值,True/False。返回值:condition中元素为True的元素替换为x中的元素,condition中为False的元素替换为y中对应元素,x只负责对应替换True的元素,y只负责对应替换False的元素,x,y各有分工。由于是替换,返回值的维度,和condition,x , y都是相等的。

  • tf.nn.top_k(input, k, name=None)

解释:这个函数的作用是返回 input 中每行最大的 k 个数,并且返回它们所在位置的索引。

  • tf.gather:用一个一维的索引数组,将张量中对应索引的向量提取出来

4、bbox_loss = bbox_ohem(bbox_pred,bbox_target,label) #获得人脸框训练的loss值

def bbox_ohem(bbox_pred,bbox_target,label):''':param bbox_pred::param bbox_target::param label: class label:return: mean euclidean loss for all the pos and part examples'''zeros_index = tf.zeros_like(label, dtype=tf.float32)ones_index = tf.ones_like(label,dtype=tf.float32)valid_inds = tf.where(tf.equal(tf.abs(label), 1),ones_index,zeros_index) #获取pos样本和part样本,label是1或-1的置1,其他的置0#(batch,)#计算平方误差square_error = tf.square(bbox_pred-bbox_target)square_error = tf.reduce_sum(square_error,axis=1)#keep_num scalarnum_valid = tf.reduce_sum(valid_inds)    #计算pos样本和part样本的数量keep_num = tf.cast(num_valid, dtype=tf.int32)square_error = square_error*valid_inds       #去掉neg样本和landmark样本的平方和,即将其error置0# keep top k examples, k equals to the number of positive examples_, k_index = tf.nn.top_k(square_error, k=keep_num)square_error = tf.gather(square_error, k_index)   #将所有pos样本和part样本的平方和提取出来return tf.reduce_mean(square_error)

5、landmark_loss = landmark_ohem(landmark_pred,landmark_target,label) #获得人脸特征点训练的loss值

def landmark_ohem(landmark_pred,landmark_target,label):''':param landmark_pred::param landmark_target::param label::return: mean euclidean loss'''#keep label =-2  then do landmark detectionones = tf.ones_like(label,dtype=tf.float32)zeros = tf.zeros_like(label,dtype=tf.float32)valid_inds = tf.where(tf.equal(label,-2),ones,zeros)       #只用label=-2的样本用于训练square_error = tf.square(landmark_pred-landmark_target)square_error = tf.reduce_sum(square_error,axis=1)       #平方误差num_valid = tf.reduce_sum(valid_inds)                   #样本数keep_num = tf.cast(num_valid, dtype=tf.int32)square_error = square_error*valid_inds        #只保留label=-2样本的error_, k_index = tf.nn.top_k(square_error, k=keep_num)  #返回索引square_error = tf.gather(square_error, k_index)     #提取landmark的errorreturn tf.reduce_mean(square_error)

6、accuracy = cal_accuracy(cls_prob,label) #人脸分类精度

def cal_accuracy(cls_prob,label):''':param cls_prob::param label::return:calculate classification accuracy for pos and neg examples only'''#按行返回cls_prob的最大值的索引,索引值为0或者1,因为输出cls_prob有两个值#第一个值表示非人脸的概率,第二个值表示人脸的概率,所以索引等于0时,表示这个图片网络预测为非人脸;为1时网络预测这张图片为人脸pred = tf.argmax(cls_prob,axis=1)label_int = tf.cast(label,tf.int64)#tf.greater_equal()函数判断label_int是否大于等于0,返回True或者False#tf.where函数返回True值对应的索引,即cond是pos样本和neg样本对应的索引cond = tf.where(tf.greater_equal(label_int,0))picked = tf.squeeze(cond)# 获得pos样本和neg样本的labellabel_picked = tf.gather(label_int,picked)pred_picked = tf.gather(pred,picked)  #预测值#通过tf.equal()函数返回的True或者False值得到网络的预测值是否准确#将True和Flase转化为1和0求得平均值即得到准确率# ACC = (TP+FP)/total populationaccuracy_op = tf.reduce_mean(tf.cast(tf.equal(label_picked,pred_picked),tf.float32))return accuracy_op

第二部分:

train(net_factory,prefix, end_epoch, base_dir, display=display, base_lr=lr) #进行训练

def train(net_factory, prefix, end_epoch, base_dir,display=200, base_lr=0.01):"""train PNet/RNet/ONet:param net_factory:  P_Net网络:param prefix:   模型存放位置:param end_epoch:  训练循环次数:param dataset::param display::param base_lr::return:"""net = prefix.split('/')[-1]       #PNet#label filelabel_file = os.path.join(base_dir,'train_%s_landmark.txt' % net)  #读取样本信息print(label_file)f = open(label_file, 'r')num = len(f.readlines())#训练样本总数print("Total size of the dataset is: ", num)print(prefix)#PNet use this method to get dataif net == 'PNet':dataset_dir = os.path.join(base_dir,'train_%s_landmark.tfrecord_shuffle' % net) #读取样本的tfrecordprint('dataset dir is:',dataset_dir)image_batch, label_batch, bbox_batch,landmark_batch = read_single_tfrecord(dataset_dir, config.BATCH_SIZE, net)  #设置config.BATCH_SIZE = 384##RNet使用了四个tfrecord文件获取数据else:pos_dir = os.path.join(base_dir,'pos_landmark.tfrecord_shuffle')part_dir = os.path.join(base_dir,'part_landmark.tfrecord_shuffle')neg_dir = os.path.join(base_dir,'neg_landmark.tfrecord_shuffle')landmark_dir = os.path.join('../../DATA/imglists/RNet','landmark_landmark.tfrecord_shuffle')dataset_dirs = [pos_dir,part_dir,neg_dir,landmark_dir]pos_radio = 1.0/6;part_radio = 1.0/6;landmark_radio=1.0/6;neg_radio=3.0/6   #比例:1:1:1:3pos_batch_size = int(np.ceil(config.BATCH_SIZE*pos_radio))assert pos_batch_size != 0,"Batch Size Error "part_batch_size = int(np.ceil(config.BATCH_SIZE*part_radio))assert part_batch_size != 0,"Batch Size Error "neg_batch_size = int(np.ceil(config.BATCH_SIZE*neg_radio))assert neg_batch_size != 0,"Batch Size Error "landmark_batch_size = int(np.ceil(config.BATCH_SIZE*landmark_radio))assert landmark_batch_size != 0,"Batch Size Error "batch_sizes = [pos_batch_size,part_batch_size,neg_batch_size,landmark_batch_size]#各类样本的个数image_batch, label_batch, bbox_batch,landmark_batch = read_multi_tfrecords(dataset_dirs,batch_sizes, net)        #landmark_dir    if net == 'PNet':image_size = 12radio_cls_loss = 1.0;radio_bbox_loss = 0.5;radio_landmark_loss = 0.5;#3个损失函数的权重elif net == 'RNet':image_size = 24radio_cls_loss = 1.0;radio_bbox_loss = 0.5;radio_landmark_loss = 0.5;else:radio_cls_loss = 1.0;radio_bbox_loss = 0.5;radio_landmark_loss = 1;image_size = 48#define placeholderinput_image = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE, image_size, image_size, 3], name='input_image')  #对于PNet:[384,12,12,3]label = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE], name='label')bbox_target = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE, 4], name='bbox_target')landmark_target = tf.placeholder(tf.float32,shape=[config.BATCH_SIZE,10],name='landmark_target')#对图像加入颜色干扰input_image = image_color_distort(input_image)# 通过net_factory(RNet)获得loss和accuracy值cls_loss_op,bbox_loss_op,landmark_loss_op,L2_loss_op,accuracy_op = net_factory(input_image, label, bbox_target,landmark_target,training=True)#train,update learning rate(3 loss)total_loss_op  = radio_cls_loss*cls_loss_op + radio_bbox_loss*bbox_loss_op + radio_landmark_loss*landmark_loss_op + L2_loss_optrain_op, lr_op = train_model(base_lr,total_loss_op,num)
。。。。。。

1、config定义一些参数

config.BATCH_SIZE = 384
config.CLS_OHEM = True
config.CLS_OHEM_RATIO = 0.7
config.BBOX_OHEM = False
config.BBOX_OHEM_RATIO = 0.7
config.EPS = 1e-14
config.LR_EPOCH = [6,14,20]

2、增加颜色干扰

def image_color_distort(inputs):inputs = tf.image.random_contrast(inputs, lower=0.5, upper=1.5)inputs = tf.image.random_brightness(inputs, max_delta=0.2)inputs = tf.image.random_hue(inputs,max_delta= 0.2)inputs = tf.image.random_saturation(inputs,lower = 0.5, upper= 1.5)return inputs

3、配置学习率,优化器。

def train_model(base_lr, loss, data_num):"""train model:param base_lr: 学习率:param loss: loss 损失函数:param data_num: 训练样本总数:return: train_op, lr_op"""lr_factor = 0.1global_step = tf.Variable(0, trainable=False)#LR_EPOCH [8,14]#boundaried [num_batch,num_batch]boundaries = [int(epoch * data_num / config.BATCH_SIZE) for epoch in config.LR_EPOCH]#lr_values[0.01,0.001,0.0001,0.00001]lr_values = [base_lr * (lr_factor ** x) for x in range(0, len(config.LR_EPOCH) + 1)]#control learning ratelr_op = tf.train.piecewise_constant(global_step, boundaries, lr_values)optimizer = tf.train.MomentumOptimizer(lr_op, 0.9)train_op = optimizer.minimize(loss, global_step)return train_op, lr_op

在tensorflow中,在训练过程中更改学习率主要有两种方式,第一个是学习率指数衰减,第二个就是迭代次数在某一范围指定一个学习率。学习率采用第二种:,x={0,1,2,3}.根据config.LR_EPOCH=[6,14,20],知道在前6次训练时用基础学习率,之后学习率降低1/10,训练14次后,再降低1/10,训练20次,再降低1/10.(总训练次数这里设为30次)

  • tf.train. piecewise_constant(x, boundaries, values, name=None)就是为第二种学习率变化方式而设计的

x指的是global_step,其实就是迭代次数,boundaries一个列表,内容指的是迭代次数所在的区间,values是个列表,存放在不同区间该使用的学习率的值。(这里一共有4个区间:[0,6],[6,14],[14,20],[20,30]

注意,values中的数目应该比boundaries中的数目大1,原因很简单,无非是两个数可以制定出三个区间嘛,有三个区间自然要用3个学习率。

  • boundaries = [int(epoch * data_num / config.BATCH_SIZE) for epoch in config.LR_EPOCH]

data_num / config.BATCH_SIZE表示训练样本需要划分的批次数,假设为1批次,则boundaries =[6,14,20],即按正常训练次数进行调节学习率;假设划分为2个批次,则boundaries =[12,28,40],即先训练12次,前6次是第一批次,后6次是第二批次,然后再调节学习率。。。

MTCNN-tensorflow源码解析之训练PNet网络-train_PNet.py相关推荐

  1. Tensorflow源码解析1 -- 内核架构和源码结构

    1 主流深度学习框架对比 当今的软件开发基本都是分层化和模块化的,应用层开发会基于框架层.比如开发Linux Driver会基于Linux kernel,开发Android app会基于Android ...

  2. Tensorflow源码解析5 -- 图的边 - Tensor

    1 概述 前文两篇文章分别讲解了TensorFlow核心对象Graph,和Graph的节点Operation.Graph另外一大成员,即为其边Tensor.边用来表示计算的数据,它经过上游节点计算后得 ...

  3. Tensorflow源码解析2 -- 前后端连接的桥梁 - Session

    1 Session概述 Session是TensorFlow前后端连接的桥梁.用户利用session使得client能够与master的执行引擎建立连接,并通过session.run()来触发一次计算 ...

  4. Tensorflow源码解析3 -- TensorFlow核心对象 - Graph

    1 Graph概述 计算图Graph是TensorFlow的核心对象,TensorFlow的运行流程基本都是围绕它进行的.包括图的构建.传递.剪枝.按worker分裂.按设备二次分裂.执行.注销等.因 ...

  5. Tensorflow源码解析3 -- TensorFlow核心对象 - Graph 1

    1 Graph概述 计算图Graph是TensorFlow的核心对象,TensorFlow的运行流程基本都是围绕它进行的.包括图的构建.传递.剪枝.按worker分裂.按设备二次分裂.执行.注销等.因 ...

  6. Tensorflow源码解析2 -- 前后端连接的桥梁 - Session 1

    1 Session概述 Session是TensorFlow前后端连接的桥梁.用户利用session使得client能够与master的执行引擎建立连接,并通过session.run()来触发一次计算 ...

  7. Tensorflow源码解析6 -- TensorFlow本地运行时

    1 概述 TensorFlow后端分为四层,运行时层.计算层.通信层.设备层.运行时作为第一层,实现了session管理.graph管理等很多重要的逻辑,是十分关键的一层.根据任务分布的不同,运行时又 ...

  8. 语义分割丨PSPNet源码解析「训练阶段」

    引言 之前一段时间在参与语义分割的项目,最近有时间了,正好把这段时间的所学总结一下. 在代码上,语义分割的框架会比目标检测简单很多,但其中也涉及了很多细节.在这篇文章中,我以PSPNet为例,解读一下 ...

  9. tensorflow 启动多个session_Tensorflow源码解析7 -- TensorFlow分布式运行时

    1 概述 TensorFlow架构设计精巧,在后端运行时这一层,除了提供本地运行时外,还提供了分布式运行时.通过分布式训练,在多台机器上并行执行,大大提高了训练速度.前端用户通过session.run ...

最新文章

  1. PLSQL常用方法汇总(转载)
  2. python画函数图-Python 绘制你想要的数学函数图形
  3. 机器人聊天软件c#_C#制作简易QQ聊天机器人
  4. linux 查看java最大内存配置,Linux和Windows下的内存设置
  5. [object detection] TypeError: can't pickle dict_values objects
  6. 7 种让 if / else 变得更加优雅的方式,你 pick 了吗?
  7. K8S使用dashboard管理集群
  8. 欢迎转载中国网站排名
  9. python盖帽法_干货:用Python进行数据清洗,这7种方法你一定要掌握
  10. python cox回归_TCGA+biomarker——多因素Cox回归
  11. 项目文档----项目描述
  12. MySQL基本增删改查以及搭配node在项目中的操作
  13. 记一次腾讯面试:进程之间究竟有哪些通信方式?如何通信? ---- 告别死记硬背
  14. mysql表不支持optimize_OPTIMIZE TABLE MYSQL
  15. 全面讨论泛化 (generalization) 和正则化 (regularization) — Part 1
  16. 以豌豆荚为例,用 Scrapy 爬取分类多级页面
  17. python获取datetime的周和星期
  18. 移动端九宫格转盘抽奖vue组件
  19. IOS 常用UI控件
  20. sql server 中获取前一天日期_图解面试题:如何比较日期数据?

热门文章

  1. linux搜索含多个字符串,关于linux:使用grep搜索多个字符串
  2. 华为鸿蒙系统学习笔记11-鸿蒙(HarmonyOS)2.0方舟编译器官方网址开源地址
  3. 软考信息安全工程师学习笔记目录
  4. php memcached 设置过期,memcached过期时间无效
  5. list ilist java_C#中IList与List区别
  6. pandas将字符串转换成时间_pandas入门: 时间字符串转换为年月日
  7. lucene 全文检索引擎的架构
  8. bzoj1593 [Usaco2008 Feb]Hotel 旅馆(线段树)
  9. apply、call、callee、caller初步了解
  10. 用mac的safari浏览器调试ios手机的网页