无意间在kaggle上发现的一个数据集,旨在提高网络模型判别男女的准确率,博主利用迁移学习试验了多个卷积神经网络,最终的模型准确率在95%左右。并划分了训练集、测试集、验证集三类,最终在验证集上的准确率为93.63%.

1.导入库

import tensorflow as tf
import matplotlib.pyplot as plt
import os,PIL,pathlib
import pandas as pd
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers,models
from tensorflow.keras import layers, models, Input
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, Dense, Flatten, Dropout,BatchNormalization,Activation
from tensorflow.keras.layers import MaxPooling2D, AveragePooling2D, Concatenate, Lambda,GlobalAveragePooling2D
from tensorflow.keras import backend as K# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号

2.导入数据

kaggle上下载的原数据一共有20000+张,由于硬件原因,博主选取了4286张作为训练集和测试集,剩下的作为验证集。

data_dir = "E:/tmp/.keras/datasets/Man_Women/faces_test"
data_dir = pathlib.Path(data_dir)
img_count = len(list(data_dir.glob('*/*')))
print(img_count)all_images_paths = list(data_dir.glob('*'))
all_images_paths = [str(path) for path in all_images_paths]
all_label_names = [path.split("\\")[6].split(".")[0] for path in all_images_paths]
print(all_label_names)
4286
['man', 'woman']
Found 4286 images belonging to 2 classes.

参数设置:

height = 224
width = 224
epochs = 15
batch_size = 32

3.训练集与测试集

按照8:2的比例划分训练集和测试集

train_data_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255,#归一化validation_split=0.2,horizontal_flip=True#进行水平翻转,作为数据增强
)train_ds = train_data_gen.flow_from_directory(directory=data_dir,target_size=(height,width),batch_size=batch_size,shuffle=True,class_mode='categorical',subset='training'
)
test_ds = train_data_gen.flow_from_directory(directory=data_dir,target_size=(height,width),batch_size=batch_size,shuffle=True,class_mode='categorical',subset='validation'
)
Found 3430 images belonging to 2 classes.
Found 856 images belonging to 2 classes.

图片展示:

plt.figure(figsize=(15, 10))  # 图形的宽为15高为10for images, labels in train_ds:for i in range(30):ax = plt.subplot(5, 6, i + 1)plt.imshow(images[i])plt.title(all_label_names[np.argmax(labels[i])])plt.axis("off")break
plt.show()

3.迁移学习VGG16网络

base_model = tf.keras.applications.VGG16(include_top=False, weights="imagenet",input_shape=(height,width,3),pooling = 'max')
x = base_model.output
x = tf.keras.layers.BatchNormalization(axis=-1, momentum=0.99, epsilon=0.001 )(x)
x = tf.keras.layers.Dense(256, activation='relu')(x)
x = tf.keras.layers.Dropout(rate=.45, seed=123)(x)
output = tf.keras.layers.Dense(2, activation='softmax')(x)
model=Model(inputs=base_model.input, outputs=output)

设置优化器

# #设置优化器
# #起始学习率
init_learning_rate = 1e-4
lr_sch = tf.keras.optimizers.schedules.ExponentialDecay(initial_learning_rate=init_learning_rate,decay_steps=50,decay_rate=0.96,staircase=True
)
gen_optimizer = tf.keras.optimizers.Adam(learning_rate=lr_sch)

网络编译&&训练

model.compile(optimizer=gen_optimizer,loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),metrics=['accuracy']
)history = model.fit(train_ds,epochs=epochs,validation_data=test_ds
)

训练结果如下所示:

最终的模型准确率为95%左右,在博主试验的这些网络模型中,VGG16的模型准确率是最高的。

4.混淆矩阵的绘制

网络保存:

model.save("E:/Users/yqx/PycharmProjects/Man_Women_Rec/model_.h5")

网络加载:

model = tf.keras.models.load_model("E:/Users/yqx/PycharmProjects/Man_Women_Rec/model.h5")

利用模型对验证集的数据进行测试:

plt.figure(figsize=(50,50))for images,labels in validation_ds:num = 0total = 0for i in range(64):total += 1ax = plt.subplot(8,8,i+1)plt.imshow(images[i])img_array = tf.expand_dims(images[i],0)pre = model.predict(img_array)if np.argmax(pre) == np.argmax(labels[i]):num += 1plt.title(all_label_names[np.argmax(pre)])plt.axis("off")print(total)print(num)breakplt.suptitle("The acc rating of validation is:{}".format((num / total)))plt.show()


绘制混淆矩阵:

from sklearn.metrics import confusion_matrix
import seaborn as sns
import pandas as pd# 绘制混淆矩阵
def plot_cm(labels, pre):conf_numpy = confusion_matrix(labels, pre)  # 根据实际值和预测值绘制混淆矩阵conf_df = pd.DataFrame(conf_numpy, index=all_label_names,columns=all_label_names)  # 将data和all_label_names制成DataFrameplt.figure(figsize=(8, 7))sns.heatmap(conf_df, annot=True, fmt="d", cmap="BuPu")  # 将data绘制为混淆矩阵plt.title('混淆矩阵', fontsize=15)plt.ylabel('真实值', fontsize=14)plt.xlabel('预测值', fontsize=14)plt.show()test_pre = []
test_label = []
for images, labels in validation_ds:for image, label in zip(images, labels):img_array = tf.expand_dims(image, 0)  # 增加一个维度pre = model.predict(img_array)  # 预测结果test_pre.append(all_label_names[np.argmax(pre)])  # 将预测结果传入列表test_label.append(all_label_names[np.argmax(label)])  # 将真实结果传入列表break  # 由于硬件问题。这里我只用了一个batch,一共128张图片。
plot_cm(test_label, test_pre)  # 绘制混淆矩阵

5.测试验证集

model = tf.keras.models.load_model("E:/Users/yqx/PycharmProjects/Man_Women_Rec/model.h5")
model.evaluate(validation_ds)

最终结果如下所示:

716/716 [==============================] - 418s 584ms/step - loss: 0.5345 - accuracy: 0.9363
[0.5345107175451417, 0.936279]#loss值与acc率

模型准确率比较高。在kaggle上,看到有模型准确率在99%左右,路过的大佬可以试验一下。

努力加油a啊

深度学习之基于卷积神经网络(VGG16)实现性别判别相关推荐

  1. Python深度学习实例--基于卷积神经网络的小型数据处理(猫狗分类)

    Python深度学习实例--基于卷积神经网络的小型数据处理(猫狗分类) 1.卷积神经网络 1.1卷积神经网络简介 1.2卷积运算 1.3 深度学习与小数据问题的相关性 2.下载数据 2.1下载原始数据 ...

  2. 深度学习之基于卷积神经网络(VGG16CNN)实现海贼王人物识别

    硬件问题真的是搞机器学习的一个痛处,更何况这只是入门级别的. 基于CNN和VGG16,实现对海贼王人物的分类识别.本次自己动手搭建了VGG16 网络,并且和迁移学习的VGG16的网络的实验效果做了一个 ...

  3. 深度学习之基于卷积神经网络实现花朵识别

    类比于猫狗大战,利用自己搭建的CNN网络和已经搭建好的VGG16实现花朵识别. 1.导入库 注:导入的库可能有的用不到,之前打acm时留下的毛病,别管用不用得到,先写上再说 import tensor ...

  4. 深度学习之基于卷积神经网络实现服装图像识别

    本博客与手写数字识别大同小异. 1.导入所需库 import tensorflow as tf from tensorflow.keras import datasets, layers, model ...

  5. 深度学习之基于卷积神经网络实现超大Mnist数据集识别

    在以往的手写数字识别中,数据集一共是70000张图片,模型准确率可以达到99%以上的准确率.而本次实验的手写数字数据集中有120000张图片,而且数据集的预处理方式也是之前没有遇到过的.最终在验证集上 ...

  6. 吴恩达.深度学习系列-C4卷积神经网络-W2深度卷积模型案例

    吴恩达.深度学习系列-C4卷积神经网络-W2深度卷积模型案例 (本笔记部分内容直接引用redstone的笔记http://redstonewill.com/1240/.原文整理的非常好,引入并添加我自 ...

  7. 深度学习(DL)与卷积神经网络(CNN)学习笔记随笔-03-基于Python的LeNet之LR

    原地址可以查看更多信息 本文主要参考于:Classifying MNIST digits using Logistic Regression  python源代码(GitHub下载 CSDN免费下载) ...

  8. 【深度学习系列】卷积神经网络CNN原理详解(一)——基本原理(1)

    上篇文章我们给出了用paddlepaddle来做手写数字识别的示例,并对网络结构进行到了调整,提高了识别的精度.有的同学表示不是很理解原理,为什么传统的机器学习算法,简单的神经网络(如多层感知机)都可 ...

  9. 深度学习笔记:卷积神经网络的可视化--卷积核本征模式

    目录 1. 前言 2. 代码实验 2.1 加载模型 2.2 构造返回中间层激活输出的模型 2.3 目标函数 2.4 通过随机梯度上升最大化损失 2.5 生成滤波器模式可视化图像 2.6 将多维数组变换 ...

最新文章

  1. 大牛推荐的5本 Linux 经典必读书
  2. 网络安全应急演练方案内容_开展应急演练,筑牢网络安全
  3. 拳王公社:没钱没资源没人脉!网络创业凭副业年赚20W+!
  4. 计算机管理员账户不能创建新的用户名,win10为什么无法更改账户名称解决方法 win10系统管理员用户名更改...
  5. Linux如何查看信号宏定义,转  LINUX 调试宏定义
  6. 运维自动化之使用PHP+MYSQL+SHELL打造私有监控系统
  7. 关于map的几种非常规排序
  8. 青蛙学Linux—Zabbix Web使用之模板④基于触发器的动作和告警媒介
  9. 表带可作为显示操作装置
  10. Introduction to Computer Networking学习笔记(二十三):拥塞控制-TCP Tahoe
  11. Qt/QML离线地图瓦片下载工具(瓦片地图)
  12. Bug heroes虫虫英雄 超详细翻译+基本攻略
  13. 大芒果mysql下载_大芒果魔兽世界单机版
  14. EasyGBS国标视频云服务平台可以获取录像却无法播放是什么原因?
  15. FileZilla的下载与安装
  16. DMA RDMA 技术详解
  17. 淘宝直通车中的类目推广
  18. 现阶段云计算的市场运用
  19. android平板值得买吗,最值得买大推荐 全新安卓平板你选谁?
  20. 我奋斗了18年才和你坐在一起喝咖啡与我奋斗了18年不是为了和你一起喝咖啡

热门文章

  1. uni-app使用input框 v-model双向绑定不起作用解决方案
  2. 排除网络故障课后习题参考答案
  3. oracle 11g 精简,Oracle 11g 精简客户端
  4. 列车停站方案_4月10日零时起阜阳高铁、铁路大调图!最新列车时刻表来了!看看有没有你经常乘坐的列车?...
  5. python爬虫贴吧_Python爬虫简单实现,贴吧图片一键下
  6. 在Linux系统下, 可以用一个命令很容易批量删除.svn的文件夹
  7. 置顶带滚动效果_前端面试:如何实现轮播图效果?
  8. Spring_AOP架构介绍与源码分析(含事务深度分析)
  9. 解决svn log显示no author,no date的方法之一
  10. 关于ES6的10个最佳特性