本次实验类似于猫狗大战,只不过将两种动物识别变为了四种动物识别。
本文的重点是卷积神经网络Xception的实践,在之前的学习中,我们已经实验过其他几种比较常用的网络模型,但是Xception网络并未实践过。在弄本科毕设的时候,一个好朋友的毕设就是基于Xception实现海洋垃圾的识别,最终的实验效果达到了99%左右,由此可见Xception的模型性能还是不错的。
本次实验基于Xception实现动物识别,最终的模型准确率在95%左右。

1.导入库

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import os,pathlib,PIL# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号

2.数据加载

data_dir = "E:/tmp/.keras/datasets/animal_photos"
data_dir = pathlib.Path(data_dir)
img_count = len(list(data_dir.glob('*/*')))

共4000张图片

all_images_paths = list(data_dir.glob('*'))
all_images_paths = [str(path) for path in all_images_paths]
all_label_names = [path.split("\\")[5].split(".")[0] for path in all_images_paths]
分为四类: ['cat', 'chook', 'dog', 'horse']

超参数的设置:

height = 224
width = 224
epochs =10
batch_size = 128

图像增强:
一共分为4类,每一类有1000张图片,数据并不是很多,因此对原数据进行数据加强。并按照8:2的比例划分训练集与测试集。

train_data_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255,rotation_range=45,shear_range=0.2,zoom_range=0.2,validation_split=0.2,horizontal_flip=True
)
train_ds = train_data_gen.flow_from_directory(directory=data_dir,target_size=(height,width),batch_size=batch_size,shuffle=True,class_mode='categorical',subset='training'
)
test_ds = train_data_gen.flow_from_directory(directory=data_dir,target_size=(height,width),batch_size=batch_size,shuffle=True,class_mode='categorical',subset='validation'
)

显示图像:

plt.figure(figsize=(15, 10))  # 图形的宽为15高为10for images, labels in train_ds:for i in range(8):ax = plt.subplot(5, 8, i + 1)plt.imshow(images[i])plt.title(all_label_names[np.argmax(labels[i])])plt.axis("off")break
plt.show()

3.Xception模型

Xception是Inception的改进版本,创新点便是 深度可分离卷积

深度可分离卷积 = 深度卷积+逐点卷积。具体步骤如下所示:

第一步:Depthwise 卷积,对输入的每个channel,分别进行 3 × 3 卷积操作,并将结果 concat
第二步:Pointwise 卷积,对 Depthwise 卷积中的 concat 结果,进行 1 × 1 卷积操作。

标准卷积与深度可分离卷积的对比如下所示:图片来源

既然最终的结果是一样的,那为什么深度可分离卷积方式更优呢?
原因就是利用深度可分离卷积,参数更少,从而在迭代更新的过程中,计算量就更小

本次实验利用迁移学习采用官方模型进行训练

base_model = tf.keras.applications.Xception(weights = 'imagenet',include_top = False,pooling = 'max',input_shape = (height,width,3))
base_model.trainable = False#前面的参数设置为不可训练
input = base_model.input
x = tf.keras.layers.Dense(256,activation='relu')(base_model.output)
x = tf.keras.layers.Dense(128,activation='relu')(x)
output = tf.keras.layers.Dense(4,activation='sigmoid')(x)
model = tf.keras.models.Model(inputs = input,outputs = output)

优化器的设置:

# 设置初始学习率
initial_learning_rate = 1e-4lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(initial_learning_rate,decay_steps=300,decay_rate=0.96,staircase=True)# 将指数衰减学习率送入优化器
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)

网络编译&&训练

model.compile(optimizer = optimizer,loss = "categorical_crossentropy",metrics = ['accuracy']
)history = model.fit(train_ds,validation_data = test_ds,epochs = epochs
)

Accuracy与Loss图如下所示:

模型准确率比较高,在95%左右。

4.预测&&混淆矩阵

模型保存:

model.save("E:/Users/yqx/PycharmProjects/animal_rec/model.h5")

模型加载:

model = tf.keras.models.load_model("E:/Users/yqx/PycharmProjects/animal_rec/model.h5")

预测:

plt.figure(figsize=(50,50))
num = 0
for images,labels in test_ds:for i in range(64):ax = plt.subplot(8,8,i+1)plt.imshow(images[i])img_array = tf.expand_dims(images[i],0)pre = model.predict(img_array)if np.argmax(pre) == np.argmax(labels[i]):plt.title(all_label_names[np.argmax(pre)])else:plt.title("False :"+str(all_label_names[np.argmax(pre)]))if np.argmax(pre) == np.argmax(labels[i]):num += 1plt.axis("off")break
plt.suptitle("The Acc rating is:{}".format(num / 64))
plt.show()


混淆矩阵的绘制:

from sklearn.metrics import confusion_matrix
import seaborn as sns
import pandas as pd#绘制混淆矩阵
def plot_cm(labels,pre):conf_numpy = confusion_matrix(labels,pre)#根据实际值和预测值绘制混淆矩阵conf_df = pd.DataFrame(conf_numpy,index=all_label_names,columns=all_label_names)#将data和all_label_names制成DataFrameplt.figure(figsize=(8,7))sns.heatmap(conf_df,annot=True,fmt="d",cmap="BuPu")#将data绘制为混淆矩阵plt.title('混淆矩阵',fontsize = 15)plt.ylabel('真实值',fontsize = 14)plt.xlabel('预测值',fontsize = 14)plt.show()
test_pre = []
test_label = []
num = 0
for images,labels in test_ds:num = num + 1for image,label in zip(images,labels):img_array = tf.expand_dims(image,0)#增加一个维度pre = model.predict(img_array)#预测结果test_pre.append(all_label_names[np.argmax(pre)])#将预测结果传入列表test_label.append(all_label_names[np.argmax(label)])#将真实结果传入列表if num == 3:#由于硬件问题,只测试了3个batch_sizebreak
plot_cm(test_label,test_pre)#绘制混淆矩阵


努力加油a啊

深度学习之基于Xception实现四种动物识别相关推荐

  1. 深度学习之基于VGG16与ResNet50实现鸟类识别

    鸟类识别在之前做过,但是效果特别差.而且ResNet50的效果直接差到爆炸,这次利用VGG16与ResNet50的官方模型进行鸟类识别. 1.导入库 import tensorflow as tf i ...

  2. 深度学习之基于opencv和CNN实现人脸识别

    这个项目在之前人工智能课设上做过,但是当时是划水用的别人的.最近自己实现了一下,基本功能可以实现,但是效果并不是很好.容易出现错误识别,或者更改了背景之后识别效果变差的现象.个人以为是数据选取的问题, ...

  3. 斯坦福大学深度学习与自然语言处理第四讲:词窗口分类和神经网络

    斯坦福大学在三月份开设了一门"深度学习与自然语言处理"的课程:CS224d: Deep Learning for Natural Language Processing,授课老师是 ...

  4. [深度学习] 分布式Horovod介绍(四)

    [深度学习] 分布式模式介绍(一) [深度学习] 分布式Tensorflow介绍(二) [深度学习] 分布式Pytorch 1.0介绍(三) [深度学习] 分布式Horovod介绍(四) 实际应用中, ...

  5. 深度学习入门笔记(十四):Softmax

    欢迎关注WX公众号:[程序员管小亮] 专栏--深度学习入门笔记 声明 1)该文章整理自网上的大牛和机器学习专家无私奉献的资料,具体引用的资料请看参考文献. 2)本文仅供学术交流,非商用.所以每一部分具 ...

  6. 计算机视觉面试宝典--深度学习机器学习基础篇(四)

    计算机视觉面试宝典–深度学习机器学习基础篇(四) 本篇主要包含SVM支持向量机.K-Means均值以及机器学习相关常考内容等相关面试经验. SVM-支持向量机 支持向量机(support vector ...

  7. 深度学习之图像分类(十四)--ShuffleNetV2 网络结构

    深度学习之图像分类(十四)ShuffleNetV2 网络结构 目录 深度学习之图像分类(十四)ShuffleNetV2 网络结构 1. 前言 2. Several Practical Guidelin ...

  8. 深度学习入门(六十四)循环神经网络——编码器-解码器架构

    深度学习入门(六十四)循环神经网络--编码器-解码器架构 前言 循环神经网络--编码器-解码器架构 课件 重新考察CNN 重新考察RNN 编码器-解码器架构 总结 教材 1 编码器 2 解码器 3 合 ...

  9. 【深度学习】基于深度神经网络进行权重剪枝的算法(二)

    [深度学习]基于深度神经网络进行权重剪枝的算法(二) 文章目录 1 摘要 2 介绍 3 OBD 4 一个例子 1 摘要 通过从网络中删除不重要的权重,可以有更好的泛化能力.需求更少的训练样本.更少的学 ...

最新文章

  1. C语言程序设计第十章字符串,C语言程序设计(字符串)
  2. awk算术运算一例:统计hdfs上某段时间内的文件大小
  3. CICS FILE OPEN
  4. javascript DOM 编程艺术 札记2 平稳退化
  5. UA OPTI570 量子力学30 Degenerate Stationary Perturbation Theory简介
  6. 软考-信息系统项目管理师-流程管理
  7. 什么是Handler(二)
  8. 十二月份找工作好找吗_小儿推拿师工作好找吗?工资高吗?
  9. Magento教程 23:如何获取销售报表?
  10. 【C】揭秘rand()函数;
  11. 批处理 批量s扫1433_批处理批量字符替换
  12. SQL的一个排序的问题
  13. AttributeError: module 'torch._C' has no attribute '_cuda_setDevice'(在python命令后面加上 --gpu_ids -1)
  14. Plugin “GsonFormat“ is incompatible
  15. matlab调和均值滤波_求matlab均值滤波、中值滤波和领域平均滤波算法
  16. 微信公众号群发模板消息占用每月4次群发次数吗
  17. python PIL彩色图片转黑白图片
  18. innodb_io_capacity、innodb_io_capacity_max 的影响
  19. 手把手教你开发一款简单的AR软件
  20. 干货 | Between 运算符

热门文章

  1. C语言的条件编译#if, #elif, #else, #endif、#ifdef, #ifndef
  2. Everything的下载
  3. java 根据经纬度计算多边形的面积_强基初中数学amp;学Python——第二十九课 根据海伦秦九韶公式编程计算三角形面积...
  4. osg节点访问和遍历
  5. HTML与CSS基础之子元素的伪类(七)
  6. linux shell 网盘,linux在shell中获取时间
  7. python iter next_python类中的__iter__, __next__与built-in的iter()函数举例
  8. 虚拟主机搭建微信公众号服务器,建web服务器同时如何搭建虚拟主机?方法有几种?...
  9. clone的fork与pthread_create创建线程有何不同pthread多线程编程的学习小结(转)
  10. Android Studio经常使用配置及使用技巧(二)