mobilenet是轻量级神经网络模型,其独特之处在于将普通的卷积操作变化为深度分离卷积操作,使其在参数数量和计算量上都大大减少,从而适用于一些性能不是很好的机器(精度略微降低,但是运行速度明显提升)。mobilenet_v1和mobilenet_v2都有创新的地方,我也在学习中。感兴趣的朋友可以看看这篇博客https://blog.csdn.net/u011974639/article/details/79199306,https://blog.csdn.net/u011995719/article/details/79135818作者写的很详细。所以我这里只写一下深度分离卷积和普通卷积的具体区别,方便理解。

先看普通卷积:对于一张8x8的RGB三通道图,对其做卷积核3x3的卷积,步长为1,使用5个卷积核,输出结果是6x6x5。具体的操作是:R、G、B通道分别于卷积核1做卷积操作,然后将三个通道的值相加,输出一个特征图。5和卷积核就得到5个特征图。普通卷积会把输入的各个通道的信息整合在一起。

再来看深度分离卷积:第一步深度卷积(depthwise convolution),对于一张8x8的RGB三通道图,先用3个卷积核分别对R、G、B三个通道做卷积,需要注意的是,这里一个卷积核只处理一个通道,不像普通卷积那样,一个卷积核处理三个通道。深度卷积的输出是6x6x3。第二步逐点卷积(pointwise convolution),这里的输入就是深度卷积的输出:6x6x3。然后使用5个卷积核做1x1的卷积操作,这里的卷积操作和普通的卷积操作没区别,只是卷积核尺寸变为1x1。这么做的目的就是将之前三个通道的结果整合在一起,使得最后的输出为6x6x5。至于1x1卷积核,还有一个非常重要的作用:改变数据维度

说完这些,下面上训练代码:

# conding=utf-8import tensorflow as tf
from datetime import datetime
from mobilenet_v2 import mobilenetv2batch_size = 64
lr = 0.001
n_cls = 2
max_steps = 10000def read_and_decode(filename):#根据文件名生成一个队列filename_queue = tf.train.string_input_producer([filename])reader = tf.TFRecordReader()_, serialized_example = reader.read(filename_queue)  # 返回文件名和文件features = tf.parse_single_example(serialized_example,features={'label': tf.FixedLenFeature([], tf.int64),'img_raw': tf.FixedLenFeature([], tf.string),})  # 取出包含image和label的feature对象img = tf.decode_raw(features['img_raw'], tf.float32)img = tf.reshape(img, [128, 128, 1])label = tf.cast(features['label'], tf.int64)return img,labeldef train():x = tf.placeholder(dtype=tf.float32, shape=[None, 128, 128, 1], name='input')y = tf.placeholder(dtype=tf.float32, shape=[None, n_cls], name='label')keep_prob = tf.placeholder(tf.float32)output,_ = mobilenetv2(x, n_cls, is_train=True)loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=output, labels=y))l2_loss = tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))total_loss = loss + l2_loss# optimizerupdata_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)with tf.control_dependencies(updata_ops):train_op = tf.train.AdamOptimizer(learning_rate=lr, beta1=0.9).minimize(total_loss)accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.arg_max(output, 1),tf.arg_max(y, 1)), tf.float32))images, labels = read_and_decode('train_float_1chanel.tfrecords')img_batch, label_batch = tf.train.batch([images, labels], batch_size=batch_size, capacity=40)label_batch = tf.one_hot(label_batch, n_cls, 1, 0)init = tf.global_variables_initializer()saver = tf.train.Saver()with tf.Session() as sess:sess.run(init)coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(sess=sess, coord=coord)for i in range(max_steps):batch_x, batch_y = sess.run([img_batch, label_batch])_,loss_val = sess.run([train_op, total_loss], feed_dict={x:batch_x, y:batch_y, keep_prob:0.8})if i%10 == 0:train_arr = accuracy.eval(feed_dict={x:batch_x, y:batch_y, keep_prob:1.0})print('%s: Step [%d] Loss: %f, training accuracy: %g' %(datetime.now(), i, loss_val, train_arr))if (i%1000) == 0:saver.save(sess, './model_1chanel/model.ckpt', global_step=i)coord.request_stop()coord.join(threads)if __name__ == '__main__':train()

训练完了,上测试代码:

#conding=utf-8import tensorflow as tf
from mobilenet_v2 import mobilenetv2
import os
import re
import numpy as np
import xml.dom.minidomdef test(log_path, test_path):x = tf.placeholder(dtype=tf.float32, shape=[None, 128, 128, 1], name='input')keep_prob = tf.placeholder(tf.float32)output,_ = mobilenetv2(x, num_classes=2, is_train=False)score = tf.nn.softmax(output)f_cls = tf.argmax(score, 1)sess = tf.InteractiveSession()sess.run(tf.global_variables_initializer())saver = tf.train.Saver()print('\n载入检查点...')ckpt = tf.train.get_checkpoint_state(log_path)if ckpt and ckpt.model_checkpoint_path:global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]saver.restore(sess, ckpt.model_checkpoint_path)print('载入成功,global_step = %s\n' % global_step)else:print('没有找到检查点')numb = 0files_number = 0error_list =[]error = open('error_float_1chanel.txt','w')classes = sorted(os.walk(test_path).__next__()[1])for c in classes:c_dir = os.path.join(test_path, c)walk = os.walk(c_dir).__next__()[2]for sample in walk:imgpath = os.path.join(c_dir, sample)if sample.endswith('.xml'):dom = xml.dom.minidom.parse(imgpath)cc = dom.getElementsByTagName('data')c1 = cc[0]f = c1.firstChild.datafiledata1 = re.findall('\d+', f)filedata = np.array(filedata1)data = filedata.reshape(128, 128, 1)data = data.astype(np.float32)im = np.expand_dims(data, axis=0)pred, _score = sess.run([f_cls, score], feed_dict={x: im, keep_prob: 1.0})prob = round(np.max(_score), 4)print('{}flowers class is : {}, score: {}'.format(imgpath, int(pred), prob))error_log = ('{}flowers class is : {}, score: {}'.format(imgpath, int(pred), prob))if c == 'ne':  #这里存放测试的负样本数据if pred == 1:numb += 1else:error_list.append(error_log)if c == 'po':  #这里存放测试的正样本数据if pred == 0:numb += 1else:error_list.append(error_log)files_number = files_number + 1acc = numb / files_number * 100print('测试样本数:',files_number)print('测试误识数:',len(error_list))print('测试准确率:',acc)for i in range(len(error_list)):line = str(error_list[i] + '\n')error.write(line)if __name__ == '__main__':test_path = './test'log_path = 'model_1chanel'test(log_path,test_path)

我只用了400个测试数据,准确在98%左右,还是有效果的。代码还有很多要改进的地方,欢迎大家一起交流!

3D活体识别使用mobilenet_v2训练模型相关推荐

  1. SVM训练3D活体识别模型

    训练数据采集分为正样本和负样本 正样本:正脸.侧脸.抬头.低头.表情,4000张 负样本:平面照片.弯曲照片.褶皱照片.照片抠眼鼻嘴做简单面具.电子屏,4000张 分辨率为128*128,因为1米以内 ...

  2. 手机被“秒解锁”?活体检测+3D人脸识别让刷脸更安全

    如今,人们使用智能手机进行刷脸解锁.刷脸支付就像吃饭喝水一样自然.人脸识别技术的进步为人们的日常生活带来了诸多便利,但同时也引发了隐私安全问题. 近日,来自清华的 Real AI(瑞莱智慧)展示了一项 ...

  3. 活体检测+3D人脸识别:为“刷脸”上道安全锁

    人脸识别技术现已广泛应用于安全管理.移动支付.司法刑侦等多个领域.所谓人脸识别,就是利用计算机技术的对比分析功能来实现身份认证的过程,这是一种基于生物特征的识别技术. 运用2D摄像头或3D摄像头进行检 ...

  4. 一文为你详解2D与3D人脸识别有什么区别?

    最近业界内刮起了一股"人脸识别安全"的大讨论,小到个人大到超市以及银行,都在使用这个刷脸认证或支付,说它好吧,确实解决了无接触,快速高效等问题,你说它不好吧,也是有原因的,比如最明 ...

  5. 轻量级3d模型查看器_耐能取得两项软件著作权,自研轻量级3D人脸识别算法领先业界...

    近日,耐能收到国家版权局颁发的两份<计算机软件著作权登记证书>,两款软件分别是人脸活体检测和人脸识别开发包软件V1.1.0.卷积神经网络简化和加速开发工具软件V2.2.17.这次取得两项软 ...

  6. 活体识别6:小视科技开源的静默活体检测

    说明 该项目为小视科技的静默活体检测项目.开源地址在 https://github.com/minivision-ai/Silent-Face-Anti-Spoofing. 由于不是论文衍生项目,所以 ...

  7. LBP+SVM 活体识别

    针对上一篇"深度摄像头-活体识别"的改进版 大致思路: 1.RGB人脸检测 2.同步人脸位置到深度图矩形框 3.裁剪矩形框,提取LBP特征 4.训练SVM模型. 5.集成模型到de ...

  8. 2D与3D人脸识别有什么本质上的区别?

    https://www.zhihu.com/question/324123433/answer/681365180 https://www.zhihu.com/question/324123433/a ...

  9. 坐地铁飞机数秒进站,关于3D人脸识别闸机你知道的有多少?

    11月1日起,5G正式进入商用,这意味着5G终于要落地普及啦!在这场5G新风口下,各类人工智能技术的结合也将加速大规模应用,重塑各个传统行业的发展.其中"刷脸"应用遍布零售店.银行 ...

最新文章

  1. Field types
  2. Django介绍和虚拟环境(django特点、MVC、MVT、Django学习资料)
  3. js脚本 处理js注入
  4. git撤消所有未提交或未保存的更改
  5. python螺旋打印二维数组_Python使用迭代器打印螺旋矩阵的思路及代码示例
  6. 关于MOSS SDK的Web Content Management
  7. Fabric--简单的资产Chaincode
  8. boost::mp11::mp_list相关用法的测试程序
  9. 纪念张首晟教授:英魂长存于行行字迹 何惧漫漫征途
  10. 10.java 关键字与保留字
  11. 我想批量删除专题内最古老的100篇文章
  12. 天玑800处理器支持鸿蒙系统吗,为何Redmi Note 9选择天玑800U处理器?和骁龙750G差距多大...
  13. php laravel 面试,当面试关问你Laravel Facade,说出这几个关键词就可以
  14. Objective-C 函数
  15. 风控建模九:一些特征工程方法及自动化工具小结
  16. 睡眠多少分钟一个循环_睡眠分多少阶段
  17. 2021暑假Leetcode刷题——Two Pointers(1)
  18. 十年一剑智能眼镜的中场战事
  19. 【浏览器】HTTP 缓存机制
  20. image-conversion 图片压缩,vue

热门文章

  1. 单相桥式有源逆变电路,单相半波可控整流电路,单相桥式半控整流电路,单相桥式全控整流电路
  2. AI视觉识别让无人机巡航拥有智慧之眼
  3. jxls导出excel
  4. 多台电脑/多系统共享键鼠神器(synergy)安装与使用
  5. 手绘计算机比赛海报,手绘海报大赛专题计划.doc
  6. 齿轮相关计算机,古希腊人的齿轮计算机
  7. 技术往事:改变世界的TCP/IP协议(珍贵多图、手机慎点)
  8. 销售ERP软件系统主要包括哪些功能?
  9. 【渝粤教育】电大中专电子商务网站建设与维护 (23)作业 题库
  10. k型热电偶材料_K型热电偶规格参数及使用性质.doc