python怎么导入数据集keras_keras使用Sequence类调用大规模数据集进行训练的实现
使用Keras如果要使用大规模数据集对网络进行训练,就没办法先加载进内存再从内存直接传到显存了,除了使用Sequence类以外,还可以使用迭代器去生成数据,但迭代器无法在fit_generation里开启多进程,会影响数据的读取和预处理效率,在本文中就不在叙述了,有需要的可以另外去百度。
下面是我所使用的代码
class SequenceData(Sequence):
def __init__(self, path, batch_size=32):
self.path = path
self.batch_size = batch_size
f = open(path)
self.datas = f.readlines()
self.L = len(self.datas)
self.index = random.sample(range(self.L), self.L)
#返回长度,通过len()调用
def __len__(self):
return self.L - self.batch_size
#即通过索引获取a[0],a[1]这种
def __getitem__(self, idx):
batch_indexs = self.index[idx:(idx+self.batch_size)]
batch_datas = [self.datas[k] for k in batch_indexs]
img1s,img2s,audios,labels = self.data_generation(batch_datas)
return ({'face1_input_1': img1s, 'face2_input_2': img2s, 'input_3':audios},{'activation_7':labels})
def data_generation(self, batch_datas):
#预处理操作
return img1s,img2s,audios,labels
然后在代码里通过fit_generation函数调用并训练
这里要注意,use_multiprocessing参数是是否开启多进程,由于python的多线程不是真的多线程,所以多进程还是会获得比较客观的加速,但不支持windows,windows下python无法使用多进程。
D = SequenceData('train.csv')
model_train.fit_generator(generator=D,steps_per_epoch=int(len(D)),
epochs=2, workers=20, #callbacks=[checkpoint],
use_multiprocessing=True, validation_data=SequenceData('vali.csv'),validation_steps=int(20000/32))
同样的,也可以在测试的时候使用
model.evaluate_generator(generator=SequenceData('face_test.csv'),steps=int(125100/32),workers=32)
补充知识:keras数据自动生成器,继承keras.utils.Sequence,结合fit_generator实现节约内存训练
我就废话不多说了,大家还是直接看代码吧~
#coding=utf-8
'''
Created on 2018-7-10
'''
import keras
import math
import os
import cv2
import numpy as np
from keras.models import Sequential
from keras.layers import Dense
class DataGenerator(keras.utils.Sequence):
def __init__(self, datas, batch_size=1, shuffle=True):
self.batch_size = batch_size
self.datas = datas
self.indexes = np.arange(len(self.datas))
self.shuffle = shuffle
def __len__(self):
#计算每一个epoch的迭代次数
return math.ceil(len(self.datas) / float(self.batch_size))
def __getitem__(self, index):
#生成每个batch数据,这里就根据自己对数据的读取方式进行发挥了
# 生成batch_size个索引
batch_indexs = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
# 根据索引获取datas集合中的数据
batch_datas = [self.datas[k] for k in batch_indexs]
# 生成数据
X, y = self.data_generation(batch_datas)
return X, y
def on_epoch_end(self):
#在每一次epoch结束是否需要进行一次随机,重新随机一下index
if self.shuffle == True:
np.random.shuffle(self.indexes)
def data_generation(self, batch_datas):
images = []
labels = []
# 生成数据
for i, data in enumerate(batch_datas):
#x_train数据
image = cv2.imread(data)
image = list(image)
images.append(image)
#y_train数据
right = data.rfind("\\",0)
left = data.rfind("\\",0,right)+1
class_name = data[left:right]
if class_name=="dog":
labels.append([0,1])
else:
labels.append([1,0])
#如果为多输出模型,Y的格式要变一下,外层list格式包裹numpy格式是list[numpy_out1,numpy_out2,numpy_out3]
return np.array(images), np.array(labels)
# 读取样本名称,然后根据样本名称去读取数据
class_num = 0
train_datas = []
for file in os.listdir("D:/xxx"):
file_path = os.path.join("D:/xxx", file)
if os.path.isdir(file_path):
class_num = class_num + 1
for sub_file in os.listdir(file_path):
train_datas.append(os.path.join(file_path, sub_file))
# 数据生成器
training_generator = DataGenerator(train_datas)
#构建网络
model = Sequential()
model.add(Dense(units=64, activation='relu', input_dim=784))
model.add(Dense(units=2, activation='softmax'))
model.compile(loss='categorical_crossentropy',
optimizer='sgd',
metrics=['accuracy'])
model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit_generator(training_generator, epochs=50,max_queue_size=10,workers=1)
以上这篇keras使用Sequence类调用大规模数据集进行训练的实现就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持python博客。
python怎么导入数据集keras_keras使用Sequence类调用大规模数据集进行训练的实现相关推荐
- 导入要素类到要素数据集当中(C++)(ArcObject)史上最快
//导入要素类到要素数据集IDatasetContainerPtr ipDatasetContainer = ipFeatureDataset;IDatasetPtr ipInDataset = ip ...
- python如何导入类里_Python导入模块中的所有类(98)
要导入模块中的每个类,可使用下面的语法: from mod import * 不推荐使用这种导入方式,其原因有二.首先,如果只要看一下文件开头的import语句,就能清楚地知道程序使用了哪些类,将大有 ...
- python导入类有红线_解决Python中导入自己写的类,被划红线,但不影响执行的问题...
1. 错误描述 之前在学习Python的过程中,导入自己写的包文件时,与之相关的方法等都会被划红线,但并不影响代码执行,如图: 看着红线确实有点强迫症,并且在这个过程当时,当使用该文件里的方法时不会自 ...
- python之导入类
作者:从未止步- 博客主页:从未止步的博客 专栏:和我一起学Python 语录:Every day is a second chance 行动是理想最高贵的表达 ,给大家介绍一款超牛的斩获大厂offe ...
- python如何导入模块中的类_python导入模块中类的方法
自定义标题 python导入模块中类的方法 1.导入模块中的单类 2.一个模块中存储多个类时导入的方法 3.导入模块中所有类的方法 4.导入模块中的类时把类使用as取一个别名 python导入模块中类 ...
- python将数据集分成训练样本和类标签
这里假设 类标签为largeDoses, smallDoses, didntLike三类,假设训练样本有三个特征属性,类标签放在数据集的最后一列 import numpy as npdef file2 ...
- NLP之词向量:利用word2vec对20类新闻文本数据集进行词向量训练、测试(某个单词的相关词汇)
NLP之词向量:利用word2vec对20类新闻文本数据集进行词向量训练.测试(某个单词的相关词汇) 目录 输出结果 设计思路 核心代码 输出结果 寻找训练文本中与morning最相关的10个词汇: ...
- python机器学习——决策树(分类)及“泰坦尼克号沉船事故”数据集案例操作
决策树(分类)及具体案例操作 一.决策树(分类)算法 (1)算法原理(类似于"分段函数") (2)决策树的变量类型 (3)量化纯度 (4)基本步骤 (5)决策树的优缺点 二.决策树 ...
- 基于视频理解TSM和数据集20bn-jester-v1的27类手势识别
基于视频理解TSM-mobilenetv2和数据集20bn-jester-v1的27类手势识别 基于视频理解TSM-resnet50和数据集20bn-jester-v1的27类手势识别 基于视频理解T ...
最新文章
- Datawhale组队学习周报(第038周)
- vscode css智能补全_强大的 VS Code入门
- redmine 插件开发非官方指南
- 数据库-优化-子查询优化
- python元编程之使用动态属性实现定制类--特殊方法__setattr__,__getattribute__篇
- 【深度学习】——物体检测细节处理(NMS、样本不均衡、遮挡物体)
- QCustomplot(一) 能做什么事
- js 中的console.log有什么作用
- 微服务精华问答 | 微服务如何测试?
- ASIO 腾空出世 (那些年我们追过的网络库.PartII)
- jmeter ---实战(详解)
- mysql amd.dll 后门_DLL后门清除完全篇
- 朴宥拉短片突破了几百万的观看量
- php语言输出九九乘法表_PHP 输出九九乘法表
- 大学计算机ppt制作步骤,PPT制作教程步骤方法_PPT制作技巧教程快捷键_PPT制作基础教程...
- Windows NT 内核版本号对应的操作系统版本号
- DDD基础_领域设计10大基础概念
- c语言列宽作用,c语言|格式化输入输出详解
- transition天坑
- 功能类微信小程序的推广