之前也写过关于使用tensorflow在猫狗数据集上的训练,想要学习的可以看一下

数据集下载

猫狗数据集:https://pan.baidu.com/s/13hw4LK8ihR6-6-8mpjLKDA   密码:dmp4

代码

对于猫狗数据上的训练,存在一点问题在与不能够将所有的训练集读入到内存,这时候需要用到keras中的

model.fit_generator()

接下来还是看代码吧

from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
from keras.optimizers import RMSpropfrom keras.preprocessing.image import ImageDataGenerator#下面的一部分是进行数据预处理
import os
import shutildataset_dir = 'kaggle/train/'
train_cats_dir = 'kaggle/train/cats/'
train_dogs_dir = 'kaggle/train/dogs/'
validation_cats_dir = 'kaggle/validation/cats/'
validation_dogs_dir = 'kaggle/validation/dogs/'if not os.path.exists(train_cats_dir):os.mkdir(train_cats_dir)
if not os.path.exists(train_dogs_dir):os.mkdir(train_dogs_dir)
if not os.path.exists(validation_cats_dir):os.mkdir('kaggle/validation/')os.mkdir(validation_cats_dir)
if not os.path.exists(validation_dogs_dir):os.mkdir(validation_dogs_dir)cat_count = 0
dog_count = 0image_list = os.listdir(dataset_dir)
for image in image_list:print(image)animal = image.split('.')[0]image_path = os.path.join(dataset_dir, image)if animal == 'cat':cat_count += 1if cat_count % 5 == 0:shutil.move(image_path, validation_cats_dir)else:shutil.move(image_path, train_cats_dir)if animal == 'dog':dog_count += 1if dog_count % 5 == 0:shutil.move(image_path, validation_dogs_dir)else:shutil.move(image_path, train_dogs_dir)train_dir = 'kaggle/train/'
validation_dir = 'kaggle/validation/'train_datagen = ImageDataGenerator(rescale=1. / 255)
validation_datagen = ImageDataGenerator(rescale=1. / 255)train_generator = train_datagen.flow_from_directory(train_dir,  # target directorytarget_size=(150, 150),  # resize图片batch_size=20,class_mode='binary'
)validation_generator = validation_datagen.flow_from_directory(validation_dir,target_size=(150, 150),batch_size=20,class_mode='binary'
)for data_batch, labels_batch in train_generator:print('data batch shape:', data_batch.shape)print('labels batch shape:', labels_batch.shape)breakmodel = Sequential()
model.add(Conv2D(32, (3, 3), activation='relu',input_shape=(150, 150, 3)))
model.add(MaxPooling2D((2, 2)))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D((2, 2)))
model.add(Conv2D(128, (3, 3), activation='relu'))
model.add(MaxPooling2D((2, 2)))
model.add(Conv2D(128, (3, 3), activation='relu'))
model.add(MaxPooling2D((2, 2)))
model.add(Flatten())
model.add(Dense(512, activation='relu'))
model.add(Dense(1, activation='sigmoid'))print(model.summary())model.compile(loss='binary_crossentropy',optimizer=RMSprop(lr=1e-4),metrics=['acc'])hist = model.fit_generator(train_generator,steps_per_epoch=100,epochs=30,validation_data=validation_generator,validation_steps=50)model.save('cat_and_dog.h5')

上面的代码首先将kaggle/train/中的数据集分成两个部分,一部分是train的,一部分是validation的。因为model.fit_generator()函数的需求,还需要将train集中的猫狗数据分别放在kaggle/train/cats/和kaggle/train/dogs/;同理也需要将validation中的数据这样处理。然后就可以直接训练了。

注释:如果不成功的话,最好删除生成的kaggle文件,然后放上原始的kaggle文件。

解释

需要解释的估计就是model.fit_generator()

还有关于ImageDataGenerator的使用方法

fit_generator

fit_generator(self, generator, steps_per_epoch=None, epochs=1, verbose=1, callbacks=None, validation_data=None, validation_steps=None, class_weight=None, max_queue_size=10, workers=1, use_multiprocessing=False, shuffle=True, initial_epoch=0)

使用 Python 生成器或 Sequence 实例逐批生成的数据,按批次训练模型。

生成器与模型并行运行,以提高效率。 例如,这可以让你在 CPU 上对图像进行实时数据增强,以在 GPU 上训练模型。

keras.utils.Sequence 的使用可以保证数据的顺序, 以及当 use_multiprocessing=True 时 ,保证每个输入在每个 epoch 只使用一次。

参数

  • generator: 一个生成器或 Sequence (keras.utils.Sequence)。 生成器的输出应该为以下之一:
  • 一个 (inputs, targets) 元组
  • 一个 (inputs, targets, sample_weights) 元组。 所有的数组都必须包含同样数量的样本。生成器将无限地在数据集上循环。当运行到第 steps_per_epoch 时,记一个 epoch 结束。
  • steps_per_epoch: 在声明一个 epoch 完成并开始下一个 epoch 之前从 generator 产生的总步数(批次样本)。它通常应该等于你的数据集的样本数量除以批量大小。可选参数 Sequence:如果未指定,将使用len(generator) 作为步数。
  • epochs: 整数,数据的迭代总轮数。请注意,与 initial_epoch 一起,参数 epochs 应被理解为 「最终轮数」。模型并不是训练了 epochs 轮,而是到第 epochs 轮停止训练。
  • verbose: 日志显示模式。0,1 或 2。
  • callbacks: 在训练时调用的一系列回调函数。
  • validation_data: 它可以是以下之一:
  • 验证数据的生成器或 Sequence 实例
  • 一个 (inputs, targets) 元组
  • 一个 (inputs, targets, sample_weights) 元组。
  • validation_steps: 仅当 validation_data 是一个生成器时才可用。 每个 epoch 结束时验证集生成器产生的步数。它通常应该等于你的数据集的样本数量除以批量大小。可选参数 Sequence:如果未指定,将使用len(generator) 作为步数。
  • class_weight: 将类别映射为权重的字典。
  • max_queue_size: 生成器队列的最大尺寸。
  • workers: 使用的最大进程数量。
  • use_multiprocessing: 如果 True,则使用基于进程的多线程。 请注意,因为此实现依赖于多进程,所以不应将不可传递的参数传递给生成器,因为它们不能被轻易地传递给子进程。
  • shuffle: 是否在每轮迭代之前打乱 batch 的顺序。只能与 Sequence (keras.utils.Sequence) 实例同用。
  • initial_epoch: 开始训练的轮次(有助于恢复之前的训练)。

返回

一个 History 对象。

问题

使用这种方法进行训练,感觉标签就是文件夹的个数,对于大规模的数据训练还是不怎么好用,可以使用接下来的博客,可以将所有图片的地址保存到内存中,然后进行训练。这样就不会出现内存不足的情况了

单张图片多张图片的测试

直接代码吧

import os
from keras.models import load_model
from keras.preprocessing import image
import  matplotlib.pyplot as plt
import numpy as np#单张图片的识别
model = load_model('cats_and_dogs_small_1.h5')img = image.load_img('kaggle/test1/2.jpg', target_size=(150, 150))
# plt.imshow(img)
# plt.show()#【3】将图片转化为4d tensor形式
x = image.img_to_array(img)
print(x.shape) #(224, 224, 3)
x = np.expand_dims(x, axis=0)
print(x.shape) #(1, 224, 224, 3)pres = model.predict(x)
print(int(pres[0][0]))
if int(pres[0][0]) > 0.5:print('识别的结果是狗')#多张图片的识别,这里买可以从kaggle/test1/中复制几张图片放到一个新建的文件夹test中,然后进行测试
#如果加载kaggle/test1/中的全部分会超出内存,当然也可以使用训练时候的策略进行测试
file_list = os.listdir('test_image/')
images = []for file in file_list:# print(file)img = image.load_img(os.path.join('test_image/', file), target_size=(150, 150))img = image.img_to_array(img)img = np.expand_dims(img, axis=0)images.append(img)x_train = np.array(images, dtype="float") / 255.0
x = np.concatenate([x for x in x_train])#预测
y = model.predict(x)#根据结果可以看出来,0代表的是猫,1代表的是狗。
# 同时也可以从训练cats_and_dogs_small/train/里面文件的顺序知道类别代表的信息for i in range(len(file_list)):print(y[i][0])# print('image class:', int(y[i]))# print('image class:', round(y[i]))if y[i][0] > 0.5:print('image {} class:'.format(file_list[i]), 1)else:print('image {} class:'.format(file_list[i]), 0)

keras笔记(3)-猫狗数据集上的训练以及单张图片多张图片的测试相关推荐

  1. Pytorch采用AlexNet实现猫狗数据集分类(训练与预测)

    Pytorch采用AlexNet实现猫狗数据集分类(训练与预测) 介绍 AlexNet网络模型 猫狗数据集 AlexNet网络训练 训练全代码 预测 预测图片 介绍 AlexNet模型是CNN网络中经 ...

  2. 【猫狗数据集】pytorch训练猫狗数据集之创建数据集

    数据集下载地址: 链接:https://pan.baidu.com/s/1tJQIY0ob2EyQn3cDipPkow?pwd=7gch  提取码:7gch 猫狗数据集的分为训练集25000张,在训练 ...

  3. 深度学习笔记:在小数据集上从头训练卷积神经网络

    目录 0. 前言 1. 数据下载和预处理¶ 2. 搭建一个小的卷积网络 3. 数据预处理 4. 模型训练¶ 5. 在测试集进行模型性能评估 6. 小结¶ 0. 前言 本文(以及接下来的几篇)介绍如何搭 ...

  4. Tensorflow2.0实战练习之猫狗数据集(包含自定义训练和迁移学习)

    最近在学习使用Tenforflow2.0,写下这篇文章,用来帮助和我一样的初学者,文章中如果存在某些问题,还希望各位指出. 目录 数据集介绍 数据处理及增强 VGG模型介绍 模型搭建 训练及结果展示 ...

  5. 基于VGG深度学习神经网络的猫狗数据集分类

    摘要:VGG网络是由牛津大学视觉几何组完成的基于深度卷积神经网络的大规模图像识别架构,该网络参考了AlexNet.ZFNet.OverFeat等经典的网络架构,从而得出的.这个架构参加了ILSVRC- ...

  6. 基于Keras2《面向小数据集构建图像分类模型》——Kaggle猫狗数据集

    概述 在本文中,将使用VGG-16模型提供一种面向小数据集(几百张到几千张图片)构造高效.实用的图像分类器的方法并给出试验结果. 本文将探讨如下几种方法: 从图片中直接训练一个小网络(作为基准方法) ...

  7. 神经网络学习小记录19——微调VGG分类模型训练自己的数据(猫狗数据集)

    神经网络学习小记录19--微调VGG分类模型训练自己的数据(猫狗数据集) 注意事项 学习前言 什么是VGG16模型 VGG模型的复杂程度 训练前准备 1.数据集处理 2.创建Keras的VGG模型 3 ...

  8. TensorFlowKeras入门猫狗数据集识别

    一.CNN卷积网络神经介绍 1.卷积神经网络结构介绍 如果用全连接神经网络处理大尺寸图像具有三个明显的缺点: (1)首先将图像展开为向量会丢失空间信息: (2)其次参数过多效率低下,训练困难: (3) ...

  9. 神经网络学习小记录17——使用AlexNet分类模型训练自己的数据(猫狗数据集)

    神经网络学习小记录17--使用AlexNet分类模型训练自己的数据(猫狗数据集) 学习前言 什么是AlexNet模型 训练前准备 1.数据集处理 2.创建Keras的AlexNet模型 开始训练 1. ...

最新文章

  1. Qt的4个图像类QImage/QPixmap/QBitmap/QPicture 转
  2. Oracle迁移至PostgreSQL工具之Ora2Pg
  3. 共享卫士2.0版设置说明
  4. cocos2dx 开启控制台
  5. vc/vs开发的应用程序添加dump崩溃日志转
  6. python groupby用法_Python 标准库实践之合并字典组成的列表
  7. H5桌面通知: Notification API 的应用
  8. 阿里云mysql可视化_MySql可视化工具MySQL Workbench使用教程
  9. 镜像分发工具压测解决方案——hijack压测
  10. selenium不定位元素直接操作键盘之Keys.CONTROL
  11. Laravel Scout 包在 Elasticsearch 中的使用记录
  12. CSDN博客放阿里妈妈广告代码的方法
  13. 安卓眼球追踪_iPhone 11 Pro 可配合 Eyeware Beam 眼球追踪玩 PC 大屏游戏
  14. oracle、mysql、sqlserver的对比数据库引擎的对比与选型InnoDB解决幻读
  15. extjs 数字校园-云资源平台 2014.2.4-班级座位表
  16. centos中redis设置密码
  17. element更改导航菜单被选中项的背景颜色
  18. 【目标跟踪】------deepsort
  19. html 预览 base64 PDF
  20. bzoj2754【SCOI2012】喵星球上的点名

热门文章

  1. 关于ListView中Detail模式下的一些基本操作
  2. Django—自定义分页
  3. django-pure-pagination 分页插件
  4. BZOJ_2179_FFT快速傅立叶_(FFT)
  5. NSArray 所有基础点示例
  6. url即统一资源定位符
  7. 【数据结构】栈的存储实现
  8. windows8 开发教程 教你制作 多点触控Helper可将任意容器内任意对象进行多点缩放...
  9. fck2.6.3配置
  10. html下拉表覆盖透明,css透明元素如何遮挡住fixed元素