[深度学习] Keras 如何使用fit和fit_generator
介绍
在本教程中,您将了解Keras .fit
和.fit_generator
函数的工作原理,包括它们之间的差异。为了帮助您获得实践经验,我已经提供了一个完整的示例,向您展示如何从头开始实现Keras数据生成器。
Keras深度学习库包括可用于训练您自己的模型:
.fit
.fit_generator
如果你是Keras和深度学习的新手,在试图确定你应该使用哪种函数时,你可能会觉得有点不知所措。如果你需要使用你自己的自定义数据,这种混乱只会更加复杂。
为了帮助掀开关于Keras fit和fit_generator函数的迷云,我将花费本教程讨论:
- Keras的
.fit
,.fit_generator
函数之间的区别 - 在训练自己的深度学习模型时,何时使用每个函数
- 如何实现自己的Keras数据生成器,并在使用
.fit_generator
训练模型时使用它 - 在训练完成后评估网络时,如何使用
.predict_generator
函数
如何使用Keras fit和fit_generator
在今天的教程的第一部分中,我们将讨论Keras的.fit
,.fit_generator
和.train_on_batch
函数之间的差异。
我将向您展示一个“非标准”图像数据集的示例,它根本不包含任何实际的PNG,JPEG等图像!相反,整个图像数据集由两个CSV文件表示,一个用于训练,第二个用于评估。
我们的目标是实现能够在此CSV图像数据上训练网络的Keras生成器(不用担心,我将向您展示如何从头开始实现这样的生成器功能)。
何时使用Keras的fit,fit_generator和train_on_batch函数?
这三个功能基本上可以完成相同的任务,但他们如何去做这件事是非常不同的。
让我们逐个探索这些函数,查看函数调用的示例,然后讨论它们彼此之间的差异。
Keras .fit函数
model.fit(trainX, trainY, batch_size=32, epochs=50)
在这里您可以看到我们提供的训练数据(trainX
)和训练标签(trainY
)。
然后,我们指示Keras允许我们的模型训练50
个epoch,同时batch size为32
。
原始数据本身将适合内存,我们无需将旧批量数据从RAM中移出并将新批量数据移入RAM。
Keras fit_generator函数
在深度学习中,我们数据通常会很大,即使在使用GPU的情况下,我们如果一次性将所有数据(如图像)读入CPU的内存中,内存很有可能会奔溃。这在实际的项目中很有可能会出现。
对于小型,简单化的数据集,使用Keras的.fit
函数是完全可以接受的。
在这些情况下,我们需要利用Keras的.fit_generator
函数:
generator:生成器函数,生成器的输出应该为:
一个形如(inputs,targets)的tuple
一个形如(inputs, targets,sample_weight)的tuple。所有的返回值都应该包含相同数目的样本。生成器将无限在数据集上循环。每个epoch以经过模型的样本数达到
samples_per_epoch
时,记一个epoch结束
steps_per_epoch:整数,当生成器返回
steps_per_epoch
次数据时计一个epoch结束,执行下一个epochepochs:整数,数据迭代的轮数
verbose:日志显示,0为不在标准输出流输出日志信息,1为输出进度条记录,2为每个epoch输出一行记录
validation_data:具有以下三种形式之一
生成验证集的生成器
一个形如(inputs,targets)的tuple
一个形如(inputs,targets,sample_weights)的tuple
validation_steps: 当validation_data为生成器时,本参数指定验证集的生成器返回次数
class_weight:规定类别权重的字典,将类别映射为权重,常用于处理样本不均衡问题。
sample_weight:权值的numpy array,用于在训练时调整损失函数(仅用于训练)。可以传递一个1D的与样本等长的向量用于对样本进行1对1的加权,或者在面对时序数据时,传递一个的形式为(samples,sequence_length)的矩阵来为每个时间步上的样本赋不同的权。这种情况下请确定在编译模型时添加了
sample_weight_mode='temporal'
。workers:最大进程数在使用基于进程的线程时,最多需要启动的进程数量
- use_multiprocessing:布尔值。当为True时,使用基于基于过程的线程。
max_q_size:生成器队列的最大容量
initial_epoch: 从该参数指定的epoch开始训练,在继续之前的训练时有用。
官方demo代码:
def generate_arrays_from_file(path):while 1:f = open(path)for line in f:# create Numpy arrays of input data# and labels, from each line in the filex, y = process_line(line)yield (x, y)f.close()model.fit_generator(generate_arrays_from_file('/my_file.txt'), samples_per_epoch=10000, epochs=10)
官方的demo没有实现batch_size,该demo每次只能提取一个样本。
针对上述的数据集,实现的batch_size数据提取的迭代器,代码如下:
def process_line(line):tmp = [int(val) for val in line.strip().split(',')]x = np.array(tmp[:-1])y = np.array(tmp[-1:])return x, ydef generate_arrays_from_file(path, batch_size):while 1:f = open(path)cnt = 0X = []Y = []for line in f:# create Numpy arrays of input data# and labels, from each line in the filex, y = process_line(line)X.append(x)Y.append(y)cnt += 1if cnt == batch_size:cnt = 0yield (np.array(X), np.array(Y))X = []Y = []f.close()
model.fit_generator(generate_arrays_from_file('./train', batch_size=batch_size),samples_per_epoch=25024, nb_epoch=nb_epoch, validation_data=(X_test, y_test), max_q_size=1000,verbose=1, nb_worker=1)
keras 使用迭代器来实现大数据的训练, 其简单的思想就是,使用迭代器从文件中去顺序读取数据。所以自己的训练数据一定要先随机打散。因为我们的迭代器也是每次顺序读取一个batch_size的数据进行训练。
# initialize the number of epochs and batch size
EPOCHS = 100
BS = 32# construct the training image generator for data augmentation
aug = ImageDataGenerator(rotation_range=20, zoom_range=0.15,width_shift_range=0.2, height_shift_range=0.2, shear_range=0.15,horizontal_flip=True, fill_mode="nearest")# train the network
H = model.fit_generator(aug.flow(trainX, trainY, batch_size=BS),validation_data=(testX, testY), steps_per_epoch=len(trainX) // BS,epochs=EPOCHS)
我们首先初始化将要训练的网络的epoch和batch size。
然后我们初始化aug
,这是一个Keras ImageDataGenerator
对象,用于图像的数据增强,随机平移,旋转,调整大小等。
执行数据增强是正则化的一种形式,使我们的模型能够更好的被泛化。
但是,应用数据增强意味着我们的训练数据不再是“静态的” ——数据不断变化。
根据提供给ImageDataGenerator
的参数随机调整每批新数据。
因此,我们现在需要利用Keras的.fit_generator
函数来训练我们的模型。
顾名思义,.fit_generator
函数假定存在一个为其生成数据的基础函数。
该函数本身是一个Python生成器。
Keras在使用.fit_generator
训练模型时的过程:
- Keras调用提供给
.fit_generator
的生成器函数(在本例中为aug.flow
) - 生成器函数为
.fit_generator
函数生成一批大小为BS
的数据 .fit_generator
函数接受批量数据,执行反向传播,并更新模型中的权重- 重复该过程直到达到期望的epoch数量
您会注意到我们现在需要在调用.fit_generator
时提供steps_per_epoch
参数(.fit
方法没有这样的参数)。
为什么我们需要steps_per_epoch
?
请记住,Keras数据生成器意味着无限循环,它永远不会返回或退出。
由于该函数旨在无限循环,因此Keras无法确定一个epoch何时开始的,并且新的epoch何时开始。
因此,我们将训练数据的总数除以批量大小的结果作为steps_per_epoch
的值。一旦Keras到达这一步,它就会知道这是一个新的epoch。
参考:
- 在Keras中导入测试数据的方法
- 如何使用Keras fit和fit_generator
[深度学习] Keras 如何使用fit和fit_generator相关推荐
- 利用深度学习(Keras)进行癫痫分类-Python案例
目录 癫痫介绍 数据集 Keras深度学习案例 本分享为脑机学习者Rose整理发表于公众号:脑机接口社区 QQ交流群:903290195 癫痫介绍 癫痫,即俗称"羊癫风",是由多种 ...
- 深度学习——keras教程系列基础知识
大家好,本期我们将开始一个新的专题的写作,因为有一些小伙伴想了解一下深度学习框架Keras的知识,恰好本人也会一点这个知识,因此就开始尝试着写一写吧.本着和大家一起学习的态度,有什么写的不是很好的地方 ...
- python深度学习--Keras函数式API(多输入,多输出,类图模型)
import numpy as np import pandas as pd import matplotlib.pyplot as plt import pylab from pandas impo ...
- 深度学习 Keras入门 一 之基础篇
1.关于Keras 1)简介 Keras是由纯python编写的基于theano/tensorflow的深度学习框架. Keras是一个高层神经网络API,支持快速实验,能够把你的idea迅速转换为结 ...
- 私人定制——使用深度学习Keras和TensorFlow打造一款音乐推荐系统
随着生活水平的极大提高,人们在很多情况下都会边听音乐边做一些事情,比如在健身房.出行路上等,越来越多的人也开始慢慢走在Hifi发烧友的这一条不归路上,频繁地换耳机.换功放等,小编在这里劝一下大家不要向 ...
- 深度学习--Keras总结
Keras主要包括14个模块,本文主要对Models.layers.Initializations.Activations.Objectives.Optimizers.Preprocessing.me ...
- 用深度学习keras的cnn做图像识别分类,准确率达97%
Keras是一个简约,高度模块化的神经网络库. 可以很容易和快速实现原型(通过总模块化,极简主义,和可扩展性) 同时支持卷积网络(vision)和复发性的网络(序列数据).以及两者的组合. 无缝地运行 ...
- 个人深度学习keras环境配置介绍
因为我主要使用的是keras,所以我所要介绍的也是keras的环境,如果以后我转pytorch后再补充我的pytorch环境配置情况 我的keras环境配置如下: cudatoolkit 10.0.1 ...
- 深度学习Keras框架实践笔记
在其他机器保存keras模型(.h5),load_model(~.h5)后报错[in from_config if 'class_name' not in config[0] or config[0] ...
最新文章
- cin输入字符串怎么结束_翻遍全网,只为让你记住这些输入输出函数
- c/c++中typedef详解(此文对typedef用于结构体的定义说明得很清楚到位)
- javaScript——廖雪峰老师学习笔记(一)
- phpstrtotime()对于31日求上个月有问题
- a标签连接空标签的方法
- 解决eclipse中jsp没有代码提示问题
- GZip、Deflate压缩算法对应的C#压缩解压函数
- 利用NABCD模型进行竞争性需求分析
- Padavan 老毛子路由器登录SSH教程
- 双稳态(bistable)与单稳态
- longitudinal models | 纵向研究 | mixed model
- 文本CSS多行溢出隐藏显示省略号
- 云大计算机应用技术考研2021,2021云南大学考研经验贴
- 一、万维网的发展(W3C组织的建立)
- 京东商城(HTML和CSS实现京东商城网站)
- java 生成txt文档 指定编码格式
- css做出圆角矩形边框
- 服务器虚拟化vmware价格,vmware服务器虚拟化实施方案(vmware服务器虚拟化收费)...
- 26位前谷歌AI专家出走创业
- 小学美术计算机教案模板,小学美术教案模板
热门文章
- s6 android 7.0 国行,三星S6电信版/S6 Edge国行版升级安卓7.0
- 内地高校招收澳门保送生公布录取结果 882名学生获录取
- 专访驭势科技吴甘沙:无人驾驶硝烟弥漫,“创造”才有未来|封面人物
- ps、grep和kill联合使用杀掉进程(转)
- JavaScript 总结几个提高性能知识点
- eclipse插件安装,万能方法
- HttpURLConnection和HttpClient的简单用法
- MonoRail MVC应用(2)-构建多层结构的应用程序
- Java 8开发的4大顶级技巧
- JDBC中Statement与PreparedStatement的区别