Xception迁移学习:玉米叶片病害识别分类

  • 数据集:来自网上公开的PlantVillage数据集中的玉米叶片部分。
  • 运行环境:Tensorflow深度学习开源框架,选用Python 3.6.12作为编程语言。

本代码是自己查阅了很多博客代码最后根绝自己要用的数据集综合而成的,由于过于久远,不记得参考了哪些博客,这里就不放链接了。记录下来,便于自己以后查阅。也是刚入门的小白,欢迎大佬指教!

代码如下

1. 导入

import tensorflow as tf
import tensorflow.keras as keras
import matplotlib.pyplot as plt
import tensorflow.keras.preprocessing.image as image
import os as os
from tensorflow.keras.applications import Xception
from tensorflow.keras.layers import Dense,Flatten,GlobalAveragePooling2D,Dropout
from tensorflow.keras.models import Model,load_model
from tensorflow.keras.optimizers import SGD

2. 设置参数和路径

IMG_SIZE:输入图片的尺寸;
batch_size:每次读取图片的数量;
EPOCHS:训练轮次;
train_path:训练集路径;val_path:验证集路径。

IMG_SIZE = 150
batch_size = 16
EPOCHS=100
IMG_SHAPE = (IMG_SIZE, IMG_SIZE, 3)
train_path = 'D:/tmp/New Maize Data set/Train_maize'
val_path='D:/tmp/New Maize Data set/Vali_maize'

3. 数据增强

由于电脑的配置低,带不动很多图片,所以只选取了每种病害图片的几百张作为训练集,故需要数据增强操作,提高分类准确率。
使用keras提供的图像生成器ImageDataGenerator类来实现数据增强。主要做法是每次取一个批次即batch_size大小的样本数据提供给模型,同时对每批样本进行归一化、随机旋转40°、随机水平和上下位置平移、随机错切变换角度、随机缩放比例、随机将一半图像水平翻转等操作。这样每一轮训练时输入的样本批次就不会完全相同,可以增强模型的泛化能力。
数据增强后的结果如图:

from tensorflow.keras.preprocessing.image import ImageDataGenerator
train_gen = ImageDataGenerator(rescale=1 / 255,rotation_range=40,  # 角度值,0-180.表示图像随机旋转的角度范围width_shift_range=0.2,  # 平移比例,下同height_shift_range=0.2,shear_range=0.2,  # 随机错切变换角度zoom_range=0.2,  # 随即缩放比例horizontal_flip=True,  # 随机将一半图像水平翻转validation_split=0.2,fill_mode='nearest'  # 填充新创建像素的方法
)train_generator = train_gen.flow_from_directory(directory=train_path,shuffle = True,batch_size = batch_size,class_mode = 'categorical',target_size = IMG_SHAPE[:-1],color_mode='rgb',#classes =classes,#subset='training'
)
validation_generator = train_gen.flow_from_directory(directory=val_path,shuffle = True,batch_size = batch_size,class_mode = 'categorical',target_size =IMG_SHAPE[:-1],color_mode='rgb',#classes =classes,#subset='validation')

4. 构建模型

这里所用模型直接调用keras中的Xception模型

#构建模型
model = tf.keras.Sequential([tf.keras.applications.Xception(input_shape=(150,150,3),weights='imagenet',include_top=False),
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(4,activation='softmax')])

设置迁移学习冻结模型的层数:冻结部分网络层,即只训练其中的一部分网络层。

for i, layer in enumerate(model.layers[0].layers):
if i > 85:
layer.trainable = True
else:
layer.trainable = False

5. 编译模型

#编译模型
model.compile(optimizer='adam',loss = 'categorical_crossentropy',metrics=['accuracy'])

6. 打印模型

model.summary()

模型打印结果可以看到可训练的参数数量

7. 训练模型

history=model.fit_generator(train_generator,
steps_per_epoch=max(1, train_generator.n//batch_size),
validation_data=validation_generator,
validation_steps=max(1, validation_generator.n//batch_size),
epochs =100,
#initial_epoch=0,
#callbacks=[checkpoint]
)

8. 保存模型

将模型保存为.h5文件

model.save('model/Xception_2_85_model.h5')

9. 绘制损失值曲线和准确率曲线

# 记录准确率和损失值
history_dict = history.history
train_loss = history_dict["loss"]
train_accuracy = history_dict["acc"]
val_loss = history_dict["val_loss"]
val_accuracy = history_dict["val_acc"]# 绘制损失值曲线
plt.figure()
plt.title('InceptionV3-1')
plt.plot(range(EPOCHS),train_loss,c='k' ,ls='--',label='train_loss')
plt.plot(range(EPOCHS),val_loss,'k' ,label='val_loss' )
plt.legend()
plt.xlabel('epochs')
plt.ylabel('loss')import matplotlib as mpl
#中文字体设置
mpl.rcParams["font.family"] = "SimHei"
mpl.rcParams["axes.unicode_minus"] = False
mpl.rcParams["font.style"] = "normal"
mpl.rcParams["font.size"] = 10
# 绘制准确率曲线
plt.figure()
#plt.title('InceptionV3-1')
plt.plot(range(EPOCHS), train_accuracy,ls='--', c="k",label="训练集准确率")
plt.plot(range(EPOCHS), val_accuracy,c="k",label="验证集准确率")
plt.ylim(0.5,1)
plt.legend(loc='lower right')
plt.xlabel("训练轮次")
plt.ylabel("准确率")
plt.show()

10. 测试模型

测试结果可以输出一个混淆矩阵,查看每种病害类别的准确率。

import numpy as np
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import load_model
import datetime
from tensorflow.keras.callbacks import TensorBoard
from keras.backend.tensorflow_backend import set_session
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import itertoolsconfig = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
set_session(sess)
keras.backend.clear_session() #清理session#test image directory
dst_path = 'D:/tmp/New Maize Data set/Test_maize'
#model path
model_file ='C:/Users/name/model/Xception_2_85_model.h5'
batch_size = 8def plot_confusion_matrix(cm,target_names,title='Confusion Matrix',cmap=plt.cm.Greens,  # 设置混淆矩阵的颜色主题normalize=True):
accuracy = np.trace(cm) / float(np.sum(cm))
misclass = 1 - accuracyif cmap is None:
cmap = plt.get_cmap('Blues')plt.figure()
plt.imshow(cm, interpolation='nearest', cmap=cmap)
# plt.title(title)
plt.title(title+'\naccuracy={:0.4f}; misclass={:0.4f}'.format(accuracy, misclass))
plt.colorbar()if target_names is not None:
tick_marks = np.arange(len(target_names))
plt.xticks(tick_marks, target_names, rotation=45)
plt.yticks(tick_marks, target_names)if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]thresh = cm.max() / 1.5 if normalize else cm.max() / 2
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
if normalize:
plt.text(j, i, "{:0.4f}".format(cm[i, j]),horizontalalignment="center",color="white" if cm[i, j] > thresh else "black")
else:
plt.text(j, i, "{:,}".format(cm[i, j]),horizontalalignment="center",color="white" if cm[i, j] > thresh else "black")plt.ylabel('True label')
plt.xlabel('Predicted label')# load model
model = load_model(model_file)
# generator image
test_datagen = ImageDataGenerator(rescale=1. / 255)test_generator = test_datagen.flow_from_directory(
dst_path,
target_size=(150, 150),
batch_size=batch_size,
shuffle=False)labels = test_generator.class_indices #查看类别的label
#labels = ['blight', 'cercos', 'healthy','rust']
#然后直接用predice_geneorator 可以进行预测
test_generator.reset()
pred = model.predict_generator(test_generator, verbose=1)
# 输出每个图像的预测类别
predicted_class_indices = np.argmax(pred, axis=1)
#测试集的真实类别
true_label= test_generator.classes#简单画出混淆矩阵
import pandas as pd
table=pd.crosstab(true_label,predicted_class_indices,colnames=['predict'],rownames=['label'])
print(table)
#图片化显示混淆矩阵
conf_mat = confusion_matrix(y_true=true_label,y_pred=predicted_class_indices)
plt.figure()
plot_confusion_matrix(conf_mat, normalize=False, target_names=labels, title='Confusion Matrix')

测试结果如下:可以看出每种类别的识别率都很高

Xception迁移学习:玉米叶片病害识别分类相关推荐

  1. 基于 CNN 和迁移学习的农作物病害识别方法研究

    基于 CNN 和迁移学习的农作物病害识别方法研究 1.研究思路 采用互联网公开的 ImageNet 图像大数据集和PlantVillage 植物病害公共数据集, 以实验室的黄瓜和水稻病害数据集 AES ...

  2. 玉米叶片病害识别与分类的优化密集卷积神经网络模型

    An optimized dense convolutional neural network model for disease recognition and classification in ...

  3. 苹果叶片病害识别中的深度学习研究

    苹果叶片病害识别中的深度学习研究 1.研究内容 基于DenseNet-121深度卷积网络,提出了回归.多标签分类和聚焦损失函数3种苹果叶片病害识别方法. 2.数据集介绍 用于识别的图像数据集来源于Ai ...

  4. 玉米叶片病害分类的深度转移模型(改进AlexNet)

    目前,深度学习在图像分析和目标分类中发挥着重要作用.玉米病害导致产量下降,进而成为全球农业经济损失的突出因素.此前,研究人员已经使用手工制作的特征对玉米植株的叶片疾病进行图像分类和检测.如今,深度学习 ...

  5. 基于卷积神经网络与迁移学习的油茶病害图像识别

    基于卷积神经网络与迁移学习的油茶病害图像识别 1.研究思路 利用深度卷积神经网络强大的特征学习和特征表达能力来自动学习油茶病害特征,并借助迁移学习方法将AlexNet模型在ImageNet图像数据集上 ...

  6. 基于深度残差网络的番茄叶片病害识别方法

    基于深度残差网络的番茄叶片病害识别方法 1.研究思路 该方法首先利用贝叶斯优化算法自主学习网络中难以确定的超参数,降低了深度学习网络的训练难度.在此基础上,通过在传统深度神经网络中添加残差单元,解决了 ...

  7. 基于自动图像分割算法和扩展数据集深度学习的经济作物病害识别

    基于自动图像分割算法和扩展数据集深度学习的经济作物病害识别 1.作物病害识别出现的问题 实际应用中作物图像的复杂背景信息和训练数据不足会导致深度学习的错误识别. 2.研究内容 提出了一种基于自动图像分 ...

  8. 基于深度学习的大豆叶片病害识别(自然环境下1470张图像)

    Abstract 本文提出了一种利用卷积神经网络(CNN)识别自然环境下大豆叶片病害的新方法.使用AlexNet.GoogLeNet和ResNet进行迁移学习.首先,通过设置不同的批量大小和迭代次数, ...

  9. 作物叶片病害识别系统

    摘 要 农作物病害的种类繁多,这直接影响了农作物的产量和品质,造成不可估量的损失.此数据集是使用原始数据集的脱机扩充重新创建的.该数据集由大约87K健康和病害作物叶的rgb图像组成,分为38个不同的类 ...

最新文章

  1. JavaScript setTimeout() 介绍
  2. C - Multiplication Table CodeForces - 448D
  3. WANTS好物CEO李毅秋:初创公司如何避免这些坑
  4. 网易实战分享|Docker文件系统实战
  5. PyTorch基础-交叉熵函数mnist数据集识别-04
  6. Flume sink=avro rpc connection error
  7. 漏洞工具:nmap和nessus
  8. android 滑动过程 触发,android 代码实现模拟用户点击、滑动等操作
  9. ansys怎么使用anand模型_【干货】经典ANSYS 与 Workbench如何实现联合仿真,相互切换操作。...
  10. 12012.memtester内存测试
  11. poj 1056 IMMEDIATE DECODABILITY
  12. 懂得智能配色的ImageView,还能给自己设置多彩的阴影(PaletteImageView)
  13. 在指定命令下打开命令提示符的几种方式
  14. Python爬虫请求头、请求代理以及cookie操作
  15. c语言之图形编程 pdf,《C语言图形编程》.pdf
  16. html5 前端js框架,前端h5框架总结
  17. 豪斯多夫(Hausdorff)距离
  18. 【LeetCode】一年中的第几天
  19. FastDFS集群tracker实现负载均衡
  20. 干接点信号_百度百科

热门文章

  1. 【Eclipse 开发工具常用快捷键】
  2. Java每天/每周定时执行任务
  3. 如何实现一个React全家桶项目(附完整教程及代码)
  4. 【华为OD机试真题2023 JAVAJS】单核CPU任务调度
  5. 聊一聊Mysql中的字符串拼接函数
  6. C++调用python遇到的各种问题
  7. elastic search 官网
  8. 微信红包(拼手气红包)
  9. 湖南大学CS-2018期末考试解析
  10. word考试计算机试题及答案,2017职称计算机考试Word2003冲刺试题及答案(1)