AI:CNN神经网络猫狗分类经典案例

猫狗的训练数据可以在kaggle下载:

https://www.kaggle.com/tongpython/cat-and-dog/data

本例使用ImageDataGenerator在迭代生成训练数据时候,需要把训练数据和验证,测试数据分类放置到data下面三个不同目录文件夹下。如图:

因为有猫和狗两类,所有在data/train目录下,再建两个目录data/train/dog和data/train/cat:

同理,其他的data/validation和data/test目录下,再建两个目录:cat和data/,在cat和dog目录下,放置对应的图片。

分类建立完成后,Keras会在ImageDataGenerator迭代过程中,自动的为data/train,data/test,data/validation内部生产训练标签,标签依据就是在data/train,data/test,data/validation下面的分类目录,本例是/dog和/cat目录文件夹作为两分类。

import numpy as np
import matplotlib.pyplot as plt
from keras.preprocessing.image import ImageDataGenerator, image
from keras import layers
from keras import models
from keras.layers import Dropout
from keras import optimizers
from keras.models import load_modeltrain_dir = './data/train/'
validation_dir = './data/validation/'
model_file_name = 'cat_dog_model.h5'def init_model():model = models.Sequential()KERNEL_SIZE = (3, 3)model.add(layers.Conv2D(filters=32, kernel_size=KERNEL_SIZE, activation='relu', input_shape=(150, 150, 3)))model.add(layers.MaxPooling2D((2, 2)))model.add(layers.Conv2D(filters=64, kernel_size=KERNEL_SIZE, activation='relu'))model.add(layers.MaxPooling2D((2, 2)))model.add(layers.Conv2D(filters=128, kernel_size=KERNEL_SIZE, activation='relu'))model.add(layers.MaxPooling2D((2, 2)))model.add(layers.Conv2D(filters=128, kernel_size=KERNEL_SIZE, activation='relu'))model.add(layers.MaxPooling2D((2, 2)))model.add(layers.Flatten())model.add(layers.Dense(512, activation='relu'))model.add(Dropout(0.5))model.add(layers.Dense(1, activation='sigmoid'))model.compile(loss='binary_crossentropy',optimizer=optimizers.RMSprop(lr=1e-3),metrics=['accuracy'])return modeldef fig_loss(history):history_dict = history.historyloss_values = history_dict['loss']val_loss_values = history_dict['val_loss']epochs = range(1, len(loss_values) + 1)plt.plot(epochs, loss_values, 'b', label='Training loss')plt.plot(epochs, val_loss_values, 'r', label='Validation loss')plt.title('Training and validation loss')plt.xlabel('Epochs')plt.ylabel('Loss')plt.legend()plt.grid()plt.show()def fig_acc(history):history_dict = history.historyacc = history_dict['accuracy']val_acc = history_dict['val_accuracy']epochs = range(1, len(acc) + 1)plt.plot(epochs, acc, 'g', label='Training acc')plt.plot(epochs, val_acc, 'r', label='Validation acc')plt.title('Training and validation accuracy')plt.xlabel('Epochs')plt.ylabel('Accuracy')plt.legend()plt.grid()plt.show()def fit(model):train_datagen = ImageDataGenerator(rescale=1. / 255)validation_datagen = ImageDataGenerator(rescale=1. / 255)train_generator = train_datagen.flow_from_directory(train_dir,target_size=(150, 150),batch_size=256,class_mode='binary')validation_generator = validation_datagen.flow_from_directory(validation_dir,target_size=(150, 150),batch_size=64,class_mode='binary')history = model.fit_generator(train_generator,# steps_per_epoch=,epochs=10,validation_data=validation_generator,# validation_steps=,)model.save(model_file_name)fig_loss(history)fig_acc(history)def predict():model = load_model(model_file_name)print(model.summary())img_path = './data/test/cat/cat.4021.jpg'img = image.load_img(img_path, target_size=(150, 150))img_tensor = image.img_to_array(img)img_tensor = img_tensor / 255img_tensor = np.expand_dims(img_tensor, axis=0)# 其形状为 (1, 150, 150, 3)plt.imshow(img_tensor[0])plt.show()result = model.predict(img_tensor)print(result)# 画出count个预测结果和图像
def fig_predict_result(model, count):test_datagen = ImageDataGenerator(rescale=1. / 255)test_generator = test_datagen.flow_from_directory('./data/test/',target_size=(150, 150),batch_size=256,class_mode='binary')text_labels = []plt.figure(figsize=(30, 20))# 迭代器可以迭代很多条数据,但我这里只取第一个结果看看for batch, label in test_generator:pred = model.predict(batch)for i in range(count):true_reuslt = label[i]print(true_reuslt)if pred[i] > 0.5:text_labels.append('dog')else:text_labels.append('cat')# 4列,若干行的图plt.subplot(count / 4 + 1, 4, i + 1)plt.title('This is a ' + text_labels[i])imgplot = plt.imshow(batch[i])plt.show()# 可以接着画很多,但是只是随机看看几条结果。所以这里停下来。breakif __name__ == '__main__':model = init_model()fit(model)# 利用训练好的模型预测结果。predict()model = load_model(model_file_name)#随机查看10个预测结果并画出它们fig_predict_result(model, 10)

由于太耗时,本例只训练了10轮。

训练损失和验证损失:

训练精度和验证精度:

随机测试在data/test/目录下的猫狗图片,10张,看看 模型预测的结果:

其中1张预测失败,剩余9张预测正确。

CNN神经网络猫狗分类经典案例相关推荐

  1. CNN神经网络猫狗分类经典案例,深度学习过程中间层激活特征图可视化

    AI:CNN神经网络猫狗分类经典案例,深度学习过程中间层激活特征图可视化 基于前文 https://zhangphil.blog.csdn.net/article/details/103581736 ...

  2. CV之IC之AlexNet:基于tensorflow框架采用CNN卷积神经网络算法(改进的AlexNet,训练/评估/推理)实现猫狗分类识别案例应用

    CV之IC之AlexNet:基于tensorflow框架采用CNN卷积神经网络算法(改进的AlexNet,训练/评估/推理)实现猫狗分类识别案例应用 目录 基于tensorflow框架采用CNN(改进 ...

  3. 【飞桨】卷积神经网络(CNN)实现猫狗分类

    目录 什么是卷积神经网络? 一.数据准备 二.网络配置 1. 定义网络 2. 定义输入数据的格式 3. 定义损失函数和准确率 4. 定义优化方法 三.模型训练&评估 四.模型预测 五.完整代码 ...

  4. 初学者友好项目 - 使用 CNN 的猫狗分类 ​

    使用CNN进行猫狗分类 卷积神经网络 (CNN) 是一种算法,将图像作为输入,然后为图像的所有方面分配权重和偏差,从而区分彼此.神经网络可以通过使用成批的图像进行训练,每个图像都有一个标签来识别图像的 ...

  5. 【TensorFlowKeras】基于卷积神经网络CNN的猫狗分类

    文章目录 一.猫狗数据集 二.构建网络 三.基准模型调整 四.使用VGG19实现猫狗分类 五.参考

  6. 卷积神经网络——猫狗分类

    目录 一.搭建环境,完成猫狗分类 一)安装TensorFlow和Keras 1.Anaconda中安装 2.cmd中安装 二)猫狗分类实验 1.先制作数据集 2.卷积神经网络CNN 三)附加问题 二. ...

  7. 基于TensorFlow的CNN模型——猫狗分类识别器(五)之训练和评估CNN模型

    注意:这是一个完整的项目,建议您按照完整的博客顺序阅读. 目录 三.训练和优化CNN模型 1.搭建训练主循环 2.训练时间的记录 3.早期终止机制 4.训练数据的可视化 5.训练数据的保存与加载 四. ...

  8. Keras--基于VGG16卷积神经网络---猫狗分类

    Cats vs. Dogs(猫狗大战)来源于 Kaggle 上的一个竞赛,内容非常简单, Kaggle 提供了一个猫和狗的数据集,我们需要建立一个算法进行训练,最后这个算法要能准确识别出猫和狗.Kag ...

  9. 猫狗分类-简单CNN

    文章 1.导入第三方库 2.定义模型 3.训练数据和测试数据生成 4.训练模型 猫狗分类的数据集可以查看图像数据预处理. 代码运行平台为jupyter-notebook,文章中的代码块,也是按照jup ...

最新文章

  1. LabVIEW保存、读取配置文件
  2. poj3114Countries in War(缩点+DIJK)
  3. 进程线程005 SwapContext函数分析
  4. Redis-cluster架构
  5. VMware 7.1.4安装Mac.OS.X.Lion.操作系统 key:安装 系统
  6. MySQL 高可用架构在业务层面的应用分析
  7. 【MATLAB统计分析与应用100】案例001:matlab使用Importdata函数导入文本txt数据
  8. Visual Studio 选择相同变量高亮
  9. c加加中print是什么意思_砖家财经:基金名字后面的A、B和C,分别代表什么意思?...
  10. asp.net生成高质量缩略图通用函数
  11. copy 自定义对象
  12. c语言入门基础知识总结
  13. 【菜鸡的LeetCode答案】【C#】7.反转整数
  14. 增强学习之一——Q-Learning公式
  15. 【Prometheus】Prometheus联邦的一次优化记录[续]
  16. Linux服务器下安装ANSYS
  17. python微博相册爬虫
  18. 无人值守地磅称重系统方案的设计原理
  19. 【小沐学C++】C++17实现文件操作<filesystem>
  20. WordPress 不修改代码通过sql语句修改数据库批量增加文章阅读量

热门文章

  1. 花与剑尚未获取服务器信息,花与剑澄心无忆攻略,触发条件及完成方式介绍
  2. 推理时 cnn bn 折叠;基于KWS项目
  3. 用Word转换向导批量转换Word文档(转)
  4. YTU OJ 1329: 手机尾号评分
  5. 出现ERROR 1698 (28000): Access denied for user ‘root‘@‘localhost‘ 的解决方法
  6. Linux中的UID与GID
  7. 阿里云的NoSQL存储服务OTS的应用分析
  8. idea查看类层级hierarchy快捷键
  9. 生物识别技术是什么,生物识别技术的比较介绍
  10. 十种常见的图像标注方法 | 数据标注