3、tensorflow2.0 实现MTCNN、训练O_net网络,并进行测试图片
训练O_net网络,并测试图片
- 上一篇,我们已经知道如何生成O_net训练集,这次我们开始训练Onet网络。
- 训练完成后,保存权重,我们随机抽取一张照片,测试一下效果。
代码:
train_Onet.py
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras import metrics
from red_tf import *
from MTCNN_ import Onet,cls_ohem,cal_accuracy,bbox_ohem
from tqdm import tqdm
import cv2data_path = "48/train_ONet_landmark.tfrecord_shuffle"# 加载pokemon数据集的工具!
def load_pokemon(mode='train'):""" 加载pokemon数据集的工具!:param root: 数据集存储的目录:param mode: mode:当前加载的数据是train,val,还是test:return:"""# # 创建数字编码表,范围0-4;# name2label = {} # "sq...":0 类别名:类标签; 字典 可以看一下目录,一共有5个文件夹,5个类别:0-4范围;# for name in sorted(os.listdir(os.path.join(root))): # 列出所有目录;# if not os.path.isdir(os.path.join(root, name)):# continue# # 给每个类别编码一个数字# name2label[name] = len(name2label.keys())# 读取Label信息;保存索引文件images.csv# [file1,file2,], 对应的标签[3,1] 2个一一对应的list对象。# 根据目录,把每个照片的路径提取出来,以及每个照片路径所对应的类别都存储起来,存储到CSV文件中。size = 48images,labels,boxes = red_tf(data_path,size)# 图片切割成,训练70%,验证15%,测试15%。if mode == 'train': # 100% 训练集images = images[:int(len(images))]labels = labels[:int(len(labels))]boxes = boxes[:int(len(boxes))]elif mode == 'val': # 15% = 70%->85% 验证集images = images[int(0.7 * len(images)):int(0.85 * len(images))]labels = labels[int(0.7 * len(labels)):int(0.85 * len(labels))]boxes = boxes[int(0.7 * len(boxes)):int(0.85 * len(boxes))]else: # 15% = 70%->85% 测试集images = images[int(0.85 * len(images)):]labels = labels[int(0.85 * len(labels)):]boxes = boxes[int(0.85 * len(boxes)):]ima = tf.data.Dataset.from_tensor_slices(images)lab = tf.data.Dataset.from_tensor_slices(labels)roi = tf.data.Dataset.from_tensor_slices(boxes)train_data = tf.data.Dataset.zip((ima, lab, roi)).shuffle(1000).batch(6)train_data = list(train_data.as_numpy_iterator())return train_data# 图像色相变换
def image_color_distort(inputs):inputs = tf.image.random_contrast(inputs, lower=0.5, upper=1.5)inputs = tf.image.random_brightness(inputs, max_delta=0.2)inputs = tf.image.random_hue(inputs,max_delta= 0.2)inputs = tf.image.random_saturation(inputs,lower = 0.5, upper= 1.5)return inputsdef train(eopch):model = Onet()model.load_weights("onet.h5")optimizer = keras.optimizers.Adam(learning_rate=1e-3)off = 1000acc_meter = metrics.Accuracy()for epoch in tqdm(range(eopch)):for i,(img,lab,boxes) in enumerate(load_pokemon("train")):img = image_color_distort(img)# 开一个gradient tape, 计算梯度with tf.GradientTape() as tape:cls_prob, bbox_pred,laim = model(img)cls_loss = cls_ohem(cls_prob, lab)bbox_loss = bbox_ohem(bbox_pred, boxes,lab)# landmark_loss = landmark_loss_fn(landmark_pred, landmark_batch, label_batch)# accuracy = cal_accuracy(cls_prob, label_batch)total_loss_value = cls_loss + 0.5 * bbox_lossgrads = tape.gradient(total_loss_value, model.trainable_variables)optimizer.apply_gradients(zip(grads, model.trainable_variables))if i % 200 == 0:print('Training loss (for one batch) at step %s: %s' % (i, float(total_loss_value)))print('Seen so far: %s samples' % ((i + 1) * 6))for i, (v_img, v_lab1, boxes) in enumerate(load_pokemon("val")):v_img = image_color_distort(v_img)with tf.GradientTape() as tape:cls_prob, bbox_pred,laim= model(v_img)cls_loss = cls_ohem(cls_prob, v_lab1)bbox_loss = bbox_ohem(bbox_pred, boxes,v_lab1)# landmark_loss = landmark_loss_fn(landmark_pred, landmark_batch, label_batch)# accuracy = cal_accuracy(cls_prob, label_batch)total_loss_value = cls_loss + 0.5 * bbox_lossgrads = tape.gradient(total_loss_value, model.trainable_variables)optimizer.apply_gradients(zip(grads, model.trainable_variables))if i % 200 == 0:print('val___ loss (for one batch) at step %s: %s' % (i, float(total_loss_value)))print('Seen so far: %s samples' % ((i + 1) * 6))model.save_weights('./Weights/Onet_wight/onet_30.ckpt')
train(30)
到这里我们已经训练完成了,并把网络权重参数也保存下来,接下来我们开始进行预测。
预测代码:
from Detection.Detect import detect_pnet,detect_Rnet,detect_Onetimport cv2
import numpy as npdef prediction(image_path):image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)P_boxes, P_boxes_c = detect_pent(image)R_boxes,R_boxes_c = detect_Rnet(image,P_boxes_c)O_boxes,O_boxes_c = detect_Onet(image,R_boxes_c)# if ret == False:# # 未检测到人脸# print("该图片未检测到人脸")for i in range(O_boxes_c.shape[0]):bbox = O_boxes_c[i, :4]score = O_boxes_c[i, 4]corpbbox = [int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])]# 画人脸框cv2.rectangle(image, (corpbbox[0], corpbbox[1]),(corpbbox[2], corpbbox[3]), (255, 0, 0), 1)# 判别为人脸的置信度cv2.putText(image, '{:.2f}'.format(score),(corpbbox[0], corpbbox[1] - 2),cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1)cv2.imshow('im', image)cv2.waitKey(0)cv2.destroyAllWindows()image = cv2.imread("./24.jpg")
prediction(image)
这是该MTCNN检测出来的目标。看来效果不错。
到这里就完。代码中比较重要的回归框我没有表述,这个我觉得我描述的比较菜,我这里推荐一位大神些的边框回归原理,有兴趣的可以去看一下。
边框回归(Bounding Box Regression)详解
3、tensorflow2.0 实现MTCNN、训练O_net网络,并进行测试图片相关推荐
- 训练O_net网络,并测试图片
训练O_net网络,并测试图片 上一篇,我们已经知道如何生成O_net训练集,这次我们开始训练Onet网络. 训练完成后,保存权重,我们随机抽取一张照片,测试一下效果. 代码: train_Onet. ...
- tensorflow2.0 实现MTCNN、P_net数据生成,及训练-1
1.MTCNN 的优点及必须要了解基础点. MTCNN 的 "MT"是指多任务学习(Multi-Task),在同一个任务中同时学习"分类识别"."边框 ...
- Tensorflow2.0:使用Keras自定义网络实战
tensorflow2.0建议使用tf.keras作为构建神经网络的高级API 接下来我就使用tensorflow实现VGG16去训练数据 背景介绍: 2012年 AlexNet 在 ImageNet ...
- Tensorflow2.0 之深度残差收缩网络 (DRSN)
文章目录 DRSN 原理 残差网络 自注意力网络 软阈值化 代码实现 DRSN 原理 DRSN 由三部分组成:残差网络.自注意力网络和软阈值化. 残差网络 残差网络(或称深度残差网络.深度残差学习,英 ...
- 基于tensorflow2.0实现猫狗大战(搭建网络迁移学习)
猫狗大战是kaggle平台上的一个比赛,用于实现猫和狗的二分类问题.最近在学卷积神经网络,所以自己动手搭建了几层网络进行训练,然后利用迁移学习把别人训练好的模型直接应用于猫狗分类这个数据集,比较一下实 ...
- 再战FGM!Tensorflow2.0 自定义模型训练实现NLP中的FGM对抗训练 代码实现
TF版本2.2及以上 def creat_FGM(epsilon=1.0):@tf.function def train_step(self, data):'''计算在embedding上的gradi ...
- Tensorflow2.0入门教程22:RNN网络实现文本分类
RNN实现文本分类 import tensorflow as tf 下载数据集 imdb=tf.keras.datasets.imdb (train_x, train_y), (test_x, tes ...
- TensorFlow2.0:自定义层与自定义网络
自定义层函数需要继承layers.Layer,自定义网络需要继承keras.Model. 其内部需要定义两个函数: 1.__init__初始化函数,内部需要定义构造形式: 2.call函数,内部需要定 ...
- pip更新失败_最全Tensorflow2.0 入门教程持续更新
最全Tensorflow 2.0 入门教程持续更新: Doit:最全Tensorflow 2.0 入门教程持续更新zhuanlan.zhihu.com 完整tensorflow2.0教程代码请看ht ...
最新文章
- centos lamp 连接mysql_centOS下lamp安装
- oracle-sqlloader的简单使用
- nginx动态配置及服务发现那些事
- RemoveError: ‘setuptools‘ is a dependency of conda
- Zookeeper实现Master选举(哨兵机制)
- 【微学堂】线上Linux服务器运维安全策略经验分享
- 【云计算平台】Hadoop单机模式环境搭建
- linux c 编译器安装,安装 GNU 的 C/C++ 编译器
- 解决苹果手机返回不刷新问题
- 《大数据分析原理与实践》一一导读
- oracle10修改时区,ORACLE10g时区配置错误问题
- 程序员新入手MacStudio的装机环境
- Untiy学习 简单的脚本方法
- 三八节礼物推荐,不能错过的四款数码好物推荐
- 最简单可靠的机房温度电话报警
- element ui框架(准备)
- php中左移和右移,c语言左移和右移的示例详解
- 你还停留在使用Dagger2吗? 带你一步一步走进Dagger2的世界
- 小程序的老祖宗PWA为什么没有火起来?
- Go 单元测试综合案例
热门文章
- (附源码)springboot社区快递代取服务系统 毕业设计051434
- Unity3d 自发光(荧光)Bloom效果的实现
- Linux笔记-ftp主动和被动模式下iptables的规则配置
- 正点原子linux驱动教程,正点原子 手把手教你学Linux之驱动开发篇
- 《提问的智慧》读后感
- 元宇宙基础设施五层级模型的关系作用与实力。
- dhcp服务器设置(路由器dhcp服务器怎么设置)
- 《C++ Primer》第15章 15.4节习题答案
- 5700教程☆问题汇总
- 魔兽服务器排队微信,服务器排队严重:《魔兽世界》经典怀旧服执行47服免费角色转移计划...