#-*- coding: utf-8 -*-

from skimage importio,transformimportglobimportosimporttensorflow as tfimportnumpy as npimporttime

path='D:/code/python/Anaconda3/envs/faces'

#将所有的图片resize成100*100

w=128h=128c=3

#读取图片

defread_img(path):

cate=[path+'/'+x for x in os.listdir(path) if os.path.isdir(path+'/'+x)]

imgs=[]

labels=[]for idx,folder inenumerate(cate):for im in glob.glob(folder+'/*.png'):print('reading the images:%s'%(im))

img=io.imread(im)

img=transform.resize(img,(w,h,c))

imgs.append(img)

labels.append(idx)returnnp.asarray(imgs,np.float32),np.asarray(labels,np.int32)

data,label=read_img(path)#打乱顺序

num_example=data.shape[0]

arr=np.arange(num_example)

np.random.shuffle(arr)

data=data[arr]

label=label[arr]#将所有数据分为训练集和验证集

ratio=0.8s=np.int(num_example*ratio)

x_train=data[:s]

y_train=label[:s]

x_val=data[s:]

y_val=label[s:]#-----------------构建网络----------------------#占位符

x=tf.placeholder(tf.float32,shape=[None,w,h,c],name='x')

y_=tf.placeholder(tf.int32,shape=[None,],name='y_')defCNNlayer():#第一个卷积层(128——>64)

conv1=tf.layers.conv2d(

inputs=x,

filters=32,

kernel_size=[5, 5],

padding="same",

activation=tf.nn.relu,

kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))

pool1=tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)#第二个卷积层(64->32)

conv2=tf.layers.conv2d(

inputs=pool1,

filters=64,

kernel_size=[5, 5],

padding="same",

activation=tf.nn.relu,

kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))

pool2=tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)#第三个卷积层(32->16)

conv3=tf.layers.conv2d(

inputs=pool2,

filters=128,

kernel_size=[3, 3],

padding="same",

activation=tf.nn.relu,

kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))

pool3=tf.layers.max_pooling2d(inputs=conv3, pool_size=[2, 2], strides=2)#第四个卷积层(16->8)

conv4=tf.layers.conv2d(

inputs=pool3,

filters=128,

kernel_size=[3, 3],

padding="same",

activation=tf.nn.relu,

kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))

pool4=tf.layers.max_pooling2d(inputs=conv4, pool_size=[2, 2], strides=2)

re1= tf.reshape(pool4, [-1, 8 * 8 * 128])#全连接层

dense1 = tf.layers.dense(inputs=re1,

units=1024,

activation=tf.nn.relu,

kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),

kernel_regularizer=tf.contrib.layers.l2_regularizer(0.003))

dense2= tf.layers.dense(inputs=dense1,

units=512,

activation=tf.nn.relu,

kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),

kernel_regularizer=tf.contrib.layers.l2_regularizer(0.003))

logits= tf.layers.dense(inputs=dense2,

units=60,

activation=None,

kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),

kernel_regularizer=tf.contrib.layers.l2_regularizer(0.003))returnlogits#---------------------------网络结束---------------------------

logits=CNNlayer()

loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits)

train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)

correct_prediction= tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_)

acc=tf.reduce_mean(tf.cast(correct_prediction, tf.float32))#定义一个函数,按批次取数据

def minibatches(inputs=None, targets=None, batch_size=None, shuffle=False):assert len(inputs) ==len(targets)ifshuffle:

indices=np.arange(len(inputs))

np.random.shuffle(indices)for start_idx in range(0, len(inputs) - batch_size + 1, batch_size):ifshuffle:

excerpt= indices[start_idx:start_idx +batch_size]else:

excerpt= slice(start_idx, start_idx +batch_size)yieldinputs[excerpt], targets[excerpt]#训练和测试数据,可将n_epoch设置更大一些

saver=tf.train.Saver(max_to_keep=3)

max_acc=0

f=open('ckpt1/acc.txt','w')

n_epoch=10batch_size=64sess=tf.InteractiveSession()

sess.run(tf.global_variables_initializer())for epoch inrange(n_epoch):

start_time=time.time()#training

train_loss, train_acc, n_batch =0, 0, 0for x_train_a, y_train_a in minibatches(x_train, y_train, batch_size, shuffle=True):

_,err,ac=sess.run([train_op,loss,acc], feed_dict={x: x_train_a, y_: y_train_a})

train_loss+= err; train_acc += ac; n_batch += 1

print("train loss: %f" % (train_loss/n_batch))print("train acc: %f" % (train_acc/n_batch))#validation

val_loss, val_acc, n_batch =0, 0, 0for x_val_a, y_val_a in minibatches(x_val, y_val, batch_size, shuffle=False):

err, ac= sess.run([loss,acc], feed_dict={x: x_val_a, y_: y_val_a})

val_loss+= err; val_acc += ac; n_batch += 1

print("validation loss: %f" % (val_loss/n_batch))print("validation acc: %f" % (val_acc/n_batch))

f.write(str(epoch+1)+', val_acc:'+str(val_acc)+'\n')if val_acc>max_acc:

max_acc=val_acc

saver.save(sess,'ckpt1/faces.ckpt',global_step=epoch+1)

f.close()

sess.close()

python构建cnn图片匹配_tensorflow搭建cnn人脸识别训练+识别代码(python)相关推荐

  1. python 二进制流转图片_Python零基础入门到精通-5.1节:Python程序的执行过程

    教程引言: 系统地讲解计算机基础知识,Python的基础知识, 高级知识,web开发框架,爬虫开发,数据结构与算法,nginx, 系统架构.一步步地帮助你从入门到就业. 5.1.1 在命令行中执行Py ...

  2. python label显示图片_高大上的YOLOV3对象检测算法,使用python也可轻松实现

    继续我们的目标检测算法的分享,前期我们介绍了SSD目标检测算法的python实现以及Faster-RCNN目标检测算法的python实现以及yolo目标检测算法的darknet的window环境安装, ...

  3. 用python构建多只股票日收益率直方图_Barra纯因子收益率的Python实现

    人生若只如初见,何事秋风悲画扇.等闲变却故人心,却道故人心易变. --<木兰花> 纳兰容若 多因子模型的介绍文章汗牛充栋,但系统性的归纳整理首推石川博士的多因子系列文章,看完绝对让人有醍醐 ...

  4. python做三维图片挑战眼力_这几天有django和python做了一个多用户博客系统(可选择模板) 没完成,先分享下...

    最新请看这里:http://my.oschina.net/djangochina/blog/140099 断断续续2周时间吧,用django做了一个多用户博客系统,现在还没有做完,做分享下 做的时候房 ...

  5. python 拼多多抢券_拼多多满减优惠 AC代码 python

    思路就是从价值最高的优惠券开始遍历尝试,价值相同的优惠券则先尝试需要满足的金额小的,然后用在比满减所需金额大的商品中最便宜的那个上,就ok from sys import stdin as f [n, ...

  6. TensorFlow搭建CNN实现时间序列预测(风速预测)

    目录 I. 数据集 II. 特征构造 III. 一维卷积 IV. 数据处理 1. 数据预处理 2. 数据集构造 V. CNN模型 1. 模型搭建 2. 模型训练及表现 VI. 源码及数据 时间序列预测 ...

  7. python搭建django框架,Python之Web框架Django项目搭建全过程

    Python之Web框架Django项目搭建全过程 IDE说明: Win7系统 Python:3.5 Django:1.10 Pymysql:0.7.10 Mysql:5.5 注:可通过pip fre ...

  8. thinkcmf5调用指定分类的二级_Tengine快速上手系列教程amp;视频:基于Python API的图片分类应用入门丨附彩蛋...

    前言:近期,Tengine团队加班加点,好消息接踵而来,OpenCV 4.3.0发布,OPEN AI LAB AIoT智能开发平台Tengine与OpenCV合作共同加速边缘智能,Tengine再获业 ...

  9. python怎么加图片_python中如何保存图片

    一提到数字图像处理,可能大多数人就会想到matlab,但matlab也有自身的缺点: 1.不开源,价格贵 2.软件容量大.一般3G以上,高版本甚至达5G以上. 3.只能做研究,不易转化成软件.pyth ...

最新文章

  1. ipython使用_IPython的介绍与使用
  2. Spring Boot笔记-@Qualifier与@Autowired与@Bean
  3. 《深入理解 Spring Cloud 与微服务构建》第三章 Spring Cloud
  4. JDK8新特性(六)之Stream流的forEach()方法
  5. 脉位调制解调 matlab,基于matlab的am调制解调
  6. 远程会议总卡顿?8 个“小白”办法一看就会!
  7. Javascript调用后台方法
  8. [CF600E]Dsu on tree
  9. java语言飞机大战代码_飞机大战JAVA代码
  10. 关闭安卓系统导航栏右下角自动旋转按钮
  11. Proxifier实现指定进程代理IP 雷电模拟器为例
  12. 博客网站怎么做,怎样建立一个自己的网站
  13. linux 无法生成图片大小,简单点。表演()在Linux上的ImageJ中生成错误
  14. .netcf 图片区域拷贝[图片切割]
  15. SAP BAPI创建交货单拆单原因调查
  16. Docker 启动和退出一个容器
  17. Openvino学习之openvino2022.1版安装配置
  18. java 实例化list_java中List的用法和实例详解
  19. JPA——Java.util.Date和Java.sql.Date
  20. 使用IText7 生成PDF文档

热门文章

  1. JAVA课上动手动脑问题2
  2. 04:sqlalchemy操作数据库 不错
  3. Jquery前端分页插件pagination同步加载和异步加载
  4. 实时排行榜的后台数据功能实现
  5. [skill] vim 操作多个window
  6. 汉诺塔(三)_栈的应用
  7. 计算机病毒实践汇总五:搭建虚拟网络环境
  8. [重磅] 让HTML5达到原生的体验 系列之中的一个 避免切页白屏
  9. studio2008 无法显示该网页
  10. MyEclipse提示键配置、提示快捷键、提示背景色、关键字颜色、代码显示