深度学习之基于Xception实现四种动物识别
本次实验类似于猫狗大战,只不过将两种动物识别变为了四种动物识别。
本文的重点是卷积神经网络Xception的实践,在之前的学习中,我们已经实验过其他几种比较常用的网络模型,但是Xception网络并未实践过。在弄本科毕设的时候,一个好朋友的毕设就是基于Xception实现海洋垃圾的识别,最终的实验效果达到了99%左右,由此可见Xception的模型性能还是不错的。
本次实验基于Xception实现动物识别,最终的模型准确率在95%左右。
1.导入库
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import os,pathlib,PIL# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
2.数据加载
data_dir = "E:/tmp/.keras/datasets/animal_photos"
data_dir = pathlib.Path(data_dir)
img_count = len(list(data_dir.glob('*/*')))
共4000张图片
all_images_paths = list(data_dir.glob('*'))
all_images_paths = [str(path) for path in all_images_paths]
all_label_names = [path.split("\\")[5].split(".")[0] for path in all_images_paths]
分为四类: ['cat', 'chook', 'dog', 'horse']
超参数的设置:
height = 224
width = 224
epochs =10
batch_size = 128
图像增强:
一共分为4类,每一类有1000张图片,数据并不是很多,因此对原数据进行数据加强。并按照8:2的比例划分训练集与测试集。
train_data_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255,rotation_range=45,shear_range=0.2,zoom_range=0.2,validation_split=0.2,horizontal_flip=True
)
train_ds = train_data_gen.flow_from_directory(directory=data_dir,target_size=(height,width),batch_size=batch_size,shuffle=True,class_mode='categorical',subset='training'
)
test_ds = train_data_gen.flow_from_directory(directory=data_dir,target_size=(height,width),batch_size=batch_size,shuffle=True,class_mode='categorical',subset='validation'
)
显示图像:
plt.figure(figsize=(15, 10)) # 图形的宽为15高为10for images, labels in train_ds:for i in range(8):ax = plt.subplot(5, 8, i + 1)plt.imshow(images[i])plt.title(all_label_names[np.argmax(labels[i])])plt.axis("off")break
plt.show()
3.Xception模型
Xception是Inception的改进版本,创新点便是 深度可分离卷积。
深度可分离卷积 = 深度卷积+逐点卷积。具体步骤如下所示:
第一步:Depthwise 卷积,对输入的每个channel,分别进行 3 × 3 卷积操作,并将结果 concat。
第二步:Pointwise 卷积,对 Depthwise 卷积中的 concat 结果,进行 1 × 1 卷积操作。
标准卷积与深度可分离卷积的对比如下所示:图片来源
既然最终的结果是一样的,那为什么深度可分离卷积方式更优呢?
原因就是利用深度可分离卷积,参数更少,从而在迭代更新的过程中,计算量就更小。
本次实验利用迁移学习采用官方模型进行训练
base_model = tf.keras.applications.Xception(weights = 'imagenet',include_top = False,pooling = 'max',input_shape = (height,width,3))
base_model.trainable = False#前面的参数设置为不可训练
input = base_model.input
x = tf.keras.layers.Dense(256,activation='relu')(base_model.output)
x = tf.keras.layers.Dense(128,activation='relu')(x)
output = tf.keras.layers.Dense(4,activation='sigmoid')(x)
model = tf.keras.models.Model(inputs = input,outputs = output)
优化器的设置:
# 设置初始学习率
initial_learning_rate = 1e-4lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(initial_learning_rate,decay_steps=300,decay_rate=0.96,staircase=True)# 将指数衰减学习率送入优化器
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
网络编译&&训练
model.compile(optimizer = optimizer,loss = "categorical_crossentropy",metrics = ['accuracy']
)history = model.fit(train_ds,validation_data = test_ds,epochs = epochs
)
Accuracy与Loss图如下所示:
模型准确率比较高,在95%左右。
4.预测&&混淆矩阵
模型保存:
model.save("E:/Users/yqx/PycharmProjects/animal_rec/model.h5")
模型加载:
model = tf.keras.models.load_model("E:/Users/yqx/PycharmProjects/animal_rec/model.h5")
预测:
plt.figure(figsize=(50,50))
num = 0
for images,labels in test_ds:for i in range(64):ax = plt.subplot(8,8,i+1)plt.imshow(images[i])img_array = tf.expand_dims(images[i],0)pre = model.predict(img_array)if np.argmax(pre) == np.argmax(labels[i]):plt.title(all_label_names[np.argmax(pre)])else:plt.title("False :"+str(all_label_names[np.argmax(pre)]))if np.argmax(pre) == np.argmax(labels[i]):num += 1plt.axis("off")break
plt.suptitle("The Acc rating is:{}".format(num / 64))
plt.show()
混淆矩阵的绘制:
from sklearn.metrics import confusion_matrix
import seaborn as sns
import pandas as pd#绘制混淆矩阵
def plot_cm(labels,pre):conf_numpy = confusion_matrix(labels,pre)#根据实际值和预测值绘制混淆矩阵conf_df = pd.DataFrame(conf_numpy,index=all_label_names,columns=all_label_names)#将data和all_label_names制成DataFrameplt.figure(figsize=(8,7))sns.heatmap(conf_df,annot=True,fmt="d",cmap="BuPu")#将data绘制为混淆矩阵plt.title('混淆矩阵',fontsize = 15)plt.ylabel('真实值',fontsize = 14)plt.xlabel('预测值',fontsize = 14)plt.show()
test_pre = []
test_label = []
num = 0
for images,labels in test_ds:num = num + 1for image,label in zip(images,labels):img_array = tf.expand_dims(image,0)#增加一个维度pre = model.predict(img_array)#预测结果test_pre.append(all_label_names[np.argmax(pre)])#将预测结果传入列表test_label.append(all_label_names[np.argmax(label)])#将真实结果传入列表if num == 3:#由于硬件问题,只测试了3个batch_sizebreak
plot_cm(test_label,test_pre)#绘制混淆矩阵
努力加油a啊
深度学习之基于Xception实现四种动物识别相关推荐
- 深度学习之基于VGG16与ResNet50实现鸟类识别
鸟类识别在之前做过,但是效果特别差.而且ResNet50的效果直接差到爆炸,这次利用VGG16与ResNet50的官方模型进行鸟类识别. 1.导入库 import tensorflow as tf i ...
- 深度学习之基于opencv和CNN实现人脸识别
这个项目在之前人工智能课设上做过,但是当时是划水用的别人的.最近自己实现了一下,基本功能可以实现,但是效果并不是很好.容易出现错误识别,或者更改了背景之后识别效果变差的现象.个人以为是数据选取的问题, ...
- 斯坦福大学深度学习与自然语言处理第四讲:词窗口分类和神经网络
斯坦福大学在三月份开设了一门"深度学习与自然语言处理"的课程:CS224d: Deep Learning for Natural Language Processing,授课老师是 ...
- [深度学习] 分布式Horovod介绍(四)
[深度学习] 分布式模式介绍(一) [深度学习] 分布式Tensorflow介绍(二) [深度学习] 分布式Pytorch 1.0介绍(三) [深度学习] 分布式Horovod介绍(四) 实际应用中, ...
- 深度学习入门笔记(十四):Softmax
欢迎关注WX公众号:[程序员管小亮] 专栏--深度学习入门笔记 声明 1)该文章整理自网上的大牛和机器学习专家无私奉献的资料,具体引用的资料请看参考文献. 2)本文仅供学术交流,非商用.所以每一部分具 ...
- 计算机视觉面试宝典--深度学习机器学习基础篇(四)
计算机视觉面试宝典–深度学习机器学习基础篇(四) 本篇主要包含SVM支持向量机.K-Means均值以及机器学习相关常考内容等相关面试经验. SVM-支持向量机 支持向量机(support vector ...
- 深度学习之图像分类(十四)--ShuffleNetV2 网络结构
深度学习之图像分类(十四)ShuffleNetV2 网络结构 目录 深度学习之图像分类(十四)ShuffleNetV2 网络结构 1. 前言 2. Several Practical Guidelin ...
- 深度学习入门(六十四)循环神经网络——编码器-解码器架构
深度学习入门(六十四)循环神经网络--编码器-解码器架构 前言 循环神经网络--编码器-解码器架构 课件 重新考察CNN 重新考察RNN 编码器-解码器架构 总结 教材 1 编码器 2 解码器 3 合 ...
- 【深度学习】基于深度神经网络进行权重剪枝的算法(二)
[深度学习]基于深度神经网络进行权重剪枝的算法(二) 文章目录 1 摘要 2 介绍 3 OBD 4 一个例子 1 摘要 通过从网络中删除不重要的权重,可以有更好的泛化能力.需求更少的训练样本.更少的学 ...
最新文章
- C语言程序设计第十章字符串,C语言程序设计(字符串)
- awk算术运算一例:统计hdfs上某段时间内的文件大小
- CICS FILE OPEN
- javascript DOM 编程艺术 札记2 平稳退化
- UA OPTI570 量子力学30 Degenerate Stationary Perturbation Theory简介
- 软考-信息系统项目管理师-流程管理
- 什么是Handler(二)
- 十二月份找工作好找吗_小儿推拿师工作好找吗?工资高吗?
- Magento教程 23:如何获取销售报表?
- 【C】揭秘rand()函数;
- 批处理 批量s扫1433_批处理批量字符替换
- SQL的一个排序的问题
- AttributeError: module 'torch._C' has no attribute '_cuda_setDevice'(在python命令后面加上 --gpu_ids -1)
- Plugin “GsonFormat“ is incompatible
- matlab调和均值滤波_求matlab均值滤波、中值滤波和领域平均滤波算法
- 微信公众号群发模板消息占用每月4次群发次数吗
- python PIL彩色图片转黑白图片
- innodb_io_capacity、innodb_io_capacity_max 的影响
- 手把手教你开发一款简单的AR软件
- 干货 | Between 运算符
热门文章
- C语言的条件编译#if, #elif, #else, #endif、#ifdef, #ifndef
- Everything的下载
- java 根据经纬度计算多边形的面积_强基初中数学amp;学Python——第二十九课 根据海伦秦九韶公式编程计算三角形面积...
- osg节点访问和遍历
- HTML与CSS基础之子元素的伪类(七)
- linux shell 网盘,linux在shell中获取时间
- python iter next_python类中的__iter__, __next__与built-in的iter()函数举例
- 虚拟主机搭建微信公众号服务器,建web服务器同时如何搭建虚拟主机?方法有几种?...
- clone的fork与pthread_create创建线程有何不同pthread多线程编程的学习小结(转)
- Android Studio经常使用配置及使用技巧(二)