简介

官方提供的.flow_from_directory(directory)函数可以读取并训练大规模训练数据,基本可以满足大部分需求,可以参考我的笔记。但是在有些场合下,需要自己读取大规模数据以及对应标签,下面提供一种方法。这个方法是读取图片的地址保存到内存中,这样就不会出现内存不足情况了。

下面我使用的数据集是猫狗的数据集,下载地址

猫狗数据集:https://pan.baidu.com/s/13hw4LK8ihR6-6-8mpjLKDA   密码:dmp4

代码

from keras.preprocessing.image import load_img, img_to_array
from keras.layers import Conv2D, MaxPooling2D, Dropout, Dense
from keras.layers.core import Activation, Flatten
from keras.models import Sequential
from keras.optimizers import RMSpropimport numpy as np
import os
import randomtrain_file_path = 'kaggle/train/'val_x = []
train_x = []
count = 0
for dir, file, images in os.walk(train_file_path):for image in images:# print(image)count += 1fullname = os.path.join(dir, image)if count%5 == 0:val_x.append(fullname)train_x.append(fullname)# print(len(train_x))
#
# for i, val in enumerate(train_x):
#     print(val)
#     if i == 10:
#         breakdef get_image_label(image_paths):# print(image_paths)image_labels = []for image_path in image_paths:# print(image_path)image_name = image_path.split('/')[2]# print(image_name)if 'cat' in image_name:image_labels.append(0)else:image_labels.append(1)# breakreturn image_labels# image_labels = get_image_label(train_x)
# for i, image_label in enumerate(image_labels):
#     print(image_label)
#     if i == 10:
#         break# 读取图片
def load_batch_image(img_path, train_set = True, target_size=(150, 150)):im = load_img(img_path, target_size=target_size)if train_set:return img_to_array(im)  # converts image to numpy arrayelse:return img_to_array(im) / 255.0# 建立一个数据迭代器
def GET_DATASET_SHUFFLE(X_samples, batch_size, train_set = True):random.shuffle(X_samples)# for i, image_label in enumerate(X_samples):#     print(image_label)#     if i == 10:#         breakbatch_num = int(len(X_samples) / batch_size)max_len = batch_num * batch_sizeX_samples = np.array(X_samples[:max_len])y_samples = np.array(get_image_label(X_samples))print('X_samples.shape:', X_samples.shape)X_batches = np.split(X_samples, batch_num)# print(X_batches)# for x_batch in X_batches:#     print(x_batch)#     breaky_batches = np.split(y_samples, batch_num)# for i, y_batch in y_batches:#     print('y_batch:', y_batch)#     if i == 10:#         break# print('y_batches:', y_batches)for i in range(len(X_batches)):if train_set:x = np.array(list(map(load_batch_image, X_batches[i], [True for _ in range(batch_size)])))else:x = np.array(list(map(load_batch_image, X_batches[i], [False for _ in range(batch_size)])))# print(x.shape)y = np.array(y_batches[i])yield x, y# 搭建模型
model = Sequential()
model.add(Conv2D(32, (3, 3), activation='relu',input_shape=(150, 150, 3)))
model.add(MaxPooling2D((2, 2)))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D((2, 2)))
model.add(Conv2D(128, (3, 3), activation='relu'))
model.add(MaxPooling2D((2, 2)))
model.add(Conv2D(128, (3, 3), activation='relu'))
model.add(MaxPooling2D((2, 2)))
model.add(Flatten())
model.add(Dense(512, activation='relu'))
model.add(Dense(1, activation='sigmoid'))print(model.summary())model.compile(loss='binary_crossentropy',optimizer=RMSprop(lr=1e-4),metrics=['acc'])batch_size = 32model.fit_generator(GET_DATASET_SHUFFLE(train_x, batch_size, True),epochs=10,steps_per_epoch=int(len(train_x)/batch_size))

里面有很多被注释的测试程序, 你可以做一下参考吧

参考博客地址:http://www.cnblogs.com/hejunlin1992/p/9371078.html

keras笔记(4)-使用Keras训练大规模数据集相关推荐

  1. keras笔记(3)-猫狗数据集上的训练以及单张图片多张图片的测试

    之前也写过关于使用tensorflow在猫狗数据集上的训练,想要学习的可以看一下 数据集下载 猫狗数据集:https://pan.baidu.com/s/13hw4LK8ihR6-6-8mpjLKDA ...

  2. keras笔记-mnist数据集上的简单训练

    学习了keras已经好几天了,之前一直拒绝使用keras,但是现在感觉keras是真的好用啊,可以去尝试一下啊. 首先展示的第一个代码,还是mnist数据集的训练和测试,以下是代码: from ker ...

  3. Keras TensorFlow教程:使用自己的数据集进行训练

    大多数Keras教程都尝试使用图像分类数据集(如MNIST(手写识别)或基本对象CIFAR-10(基本对象识别))来开启Keras库的基础知识学习. 这篇文章将对Keras入门教程进行不同的尝试.使用 ...

  4. python3 23.keras使用交叉熵代价函数进行MNIST数据集简单分类 学习笔记

    文章目录 前言 一.交叉熵代价函数简介 二.交叉熵代价函数使用 前言 计算机视觉系列之学习笔记主要是本人进行学习人工智能(计算机视觉方向)的代码整理.本系列所有代码是用python3编写,在平台Ana ...

  5. 如何在Keras中训练大型数据集

    https://www.toutiao.com/a6670173759829180936/ 在本文中,我们将讨论如何使用Keras在不适合内存的大数据集上训练我们的深度学习网络. 介绍 深度学习算法优 ...

  6. python怎么导入数据集keras_keras使用Sequence类调用大规模数据集进行训练的实现

    使用Keras如果要使用大规模数据集对网络进行训练,就没办法先加载进内存再从内存直接传到显存了,除了使用Sequence类以外,还可以使用迭代器去生成数据,但迭代器无法在fit_generation里 ...

  7. Python机器学习笔记:使用Keras进行回归预测

    Keras是一个深度学习库,包含高效的数字库Theano和TensorFlow.是一个高度模块化的神经网络库,支持CPU和GPU. 本文学习的目的是学习如何加载CSV文件并使其可供Keras使用,如何 ...

  8. keras系列︱图像多分类训练与利用bottleneck features进行微调(三)

    引自:http://blog.csdn.net/sinat_26917383/article/details/72861152 中文文档:http://keras-cn.readthedocs.io/ ...

  9. keras 多层lstm_《Keras 实现 LSTM》笔记

    本文在原文的基础上添加了一些注释.运行结果和修改了少量的代码. 1. 介绍 LSTM(Long Short Term Memory)是一种特殊的循环神经网络,在许多任务中,LSTM表现得比标准的RNN ...

最新文章

  1. 程序的编译和链接过程
  2. TeamCity 和 Nexus 的使用
  3. 【Python】Flask 框架安装虚拟环境报错—处理中......
  4. 为什么需要自己实现前端框架
  5. video 标签存在的一些坑
  6. java学习(137):java异常初识
  7. python方差分析模型的预测结果怎么看_statsmodels中方差分析表结果解析
  8. win下 git gui 使用教程
  9. freeCodeCamp:Title Case a Sentence
  10. WAVE-U-NET: A MULTI-SCALE NEURAL NETWORK FOR END-TO-END AUDIO SOURCE SEPARATION
  11. iOS手势的传递问题
  12. flash打造佛光效果实例教程
  13. 用python给游戏加上音效_添加声音到你的Python游戏
  14. 2022-2028全球电动汽车电池冷却器行业调研及趋势分析报告
  15. 高等数学---不定积分的计算---基本积分法
  16. c# wpf 利用截屏键实现截屏功能
  17. Linux 基础总结,这一篇就够了!
  18. 2008 mysql 本地安全_apache在windows2003或win2008环境中的安全设置
  19. keil写代码时遇到的问题——warning:implicit declaration of function XXXX is invalid in C99
  20. iPhone刷门禁卡的设置方法

热门文章

  1. marquee命令的基本用法
  2. mybatis返回Date类型数据 格式化
  3. MySQL批量更新数据
  4. IOS8-人机界面指南
  5. 递推+矩阵快速幂 HDU 2065
  6. swfobject.js 2.2简单使用方法
  7. 获取进程列表和结束进程
  8. 【原】如何实现IE6下块级元素的内容自动收缩
  9. DBGridEh导出Excel等格式文件
  10. map、set和unordered_map、unordered_set对比