介绍

在本教程中,您将了解Keras .fit.fit_generator函数的工作原理,包括它们之间的差异。为了帮助您获得实践经验,我已经提供了一个完整的示例,向您展示如何从头开始实现Keras数据生成器。

Keras深度学习库包括可用于训练您自己的模型:

  • .fit
  • .fit_generator

如果你是Keras和深度学习的新手,在试图确定你应该使用哪种函数时,你可能会觉得有点不知所措。如果你需要使用你自己的自定义数据,这种混乱只会更加复杂。

为了帮助掀开关于Keras fit和fit_generator函数的迷云,我将花费本教程讨论:

  • Keras的.fit.fit_generator函数之间的区别
  • 在训练自己的深度学习模型时,何时使用每个函数
  • 如何实现自己的Keras数据生成器,并在使用.fit_generator训练模型时使用它
  • 在训练完成后评估网络时,如何使用.predict_generator函数

如何使用Keras fit和fit_generator

在今天的教程的第一部分中,我们将讨论Keras的.fit.fit_generator.train_on_batch函数之间的差异。

我将向您展示一个“非标准”图像数据集的示例,它根本不包含任何实际的PNG,JPEG等图像!相反,整个图像数据集由两个CSV文件表示,一个用于训练,第二个用于评估。

我们的目标是实现能够在此CSV图像数据上训练网络的Keras生成器(不用担心,我将向您展示如何从头开始实现这样的生成器功能)。

最后,我们将训练和评估我们的网络。

何时使用Keras的fit,fit_generator和train_on_batch函数?

这三个功能基本上可以完成相同的任务,但他们如何去做这件事是非常不同的。

让我们逐个探索这些函数,查看函数调用的示例,然后讨论它们彼此之间的差异。

Keras .fit函数

model.fit(trainX, trainY, batch_size=32, epochs=50)

在这里您可以看到我们提供的训练数据(trainX)和训练标签(trainY)。

然后,我们指示Keras允许我们的模型训练50个epoch,同时batch size为32

.fit的调用在这里做出两个主要假设:

相反,我们的网络将在原始数据上训练。

原始数据本身将适合内存,我们无需将旧批量数据从RAM中移出并将新批量数据移入RAM。

此外,我们不会使用数据增强动态操纵训练数据。

Keras fit_generator函数

在深度学习中,我们数据通常会很大,即使在使用GPU的情况下,我们如果一次性将所有数据(如图像)读入CPU的内存中,内存很有可能会奔溃。这在实际的项目中很有可能会出现。

对于小型,简单化的数据集,使用Keras的.fit函数是完全可以接受的。

这些数据集通常不是很具有挑战性,不需要任何数据增强。

但是,真实世界的数据集很少这么简单:

在这些情况下,我们需要利用Keras的.fit_generator函数:

函数的参数是:

  • generator:生成器函数,生成器的输出应该为:

    • 一个形如(inputs,targets)的tuple

    • 一个形如(inputs, targets,sample_weight)的tuple。所有的返回值都应该包含相同数目的样本。生成器将无限在数据集上循环。每个epoch以经过模型的样本数达到samples_per_epoch时,记一个epoch结束

  • steps_per_epoch:整数,当生成器返回steps_per_epoch次数据时计一个epoch结束,执行下一个epoch

  • epochs:整数,数据迭代的轮数

  • verbose:日志显示,0为不在标准输出流输出日志信息,1为输出进度条记录,2为每个epoch输出一行记录

  • validation_data:具有以下三种形式之一

    • 生成验证集的生成器

    • 一个形如(inputs,targets)的tuple

    • 一个形如(inputs,targets,sample_weights)的tuple

  • validation_steps: 当validation_data为生成器时,本参数指定验证集的生成器返回次数

  • class_weight:规定类别权重的字典,将类别映射为权重,常用于处理样本不均衡问题。

  • sample_weight:权值的numpy array,用于在训练时调整损失函数(仅用于训练)。可以传递一个1D的与样本等长的向量用于对样本进行1对1的加权,或者在面对时序数据时,传递一个的形式为(samples,sequence_length)的矩阵来为每个时间步上的样本赋不同的权。这种情况下请确定在编译模型时添加了sample_weight_mode='temporal'

  • workers:最大进程数在使用基于进程的线程时,最多需要启动的进程数量

  • use_multiprocessing:布尔值。当为True时,使用基于基于过程的线程。
  • max_q_size:生成器队列的最大容量

  • initial_epoch: 从该参数指定的epoch开始训练,在继续之前的训练时有用。

官方demo代码:

def generate_arrays_from_file(path):while 1:f = open(path)for line in f:# create Numpy arrays of input data# and labels, from each line in the filex, y = process_line(line)yield (x, y)f.close()model.fit_generator(generate_arrays_from_file('/my_file.txt'), samples_per_epoch=10000, epochs=10)

官方的demo没有实现batch_size,该demo每次只能提取一个样本。

针对上述的数据集,实现的batch_size数据提取的迭代器,代码如下:

def process_line(line):tmp = [int(val) for val in line.strip().split(',')]x = np.array(tmp[:-1])y = np.array(tmp[-1:])return x, ydef generate_arrays_from_file(path, batch_size):while 1:f = open(path)cnt = 0X = []Y = []for line in f:# create Numpy arrays of input data# and labels, from each line in the filex, y = process_line(line)X.append(x)Y.append(y)cnt += 1if cnt == batch_size:cnt = 0yield (np.array(X), np.array(Y))X = []Y = []f.close()
model.fit_generator(generate_arrays_from_file('./train', batch_size=batch_size),samples_per_epoch=25024, nb_epoch=nb_epoch, validation_data=(X_test, y_test), max_q_size=1000,verbose=1, nb_worker=1)

keras 使用迭代器来实现大数据的训练, 其简单的思想就是,使用迭代器从文件中去顺序读取数据。所以自己的训练数据一定要先随机打散。因为我们的迭代器也是每次顺序读取一个batch_size的数据进行训练。

# initialize the number of epochs and batch size
EPOCHS = 100
BS = 32# construct the training image generator for data augmentation
aug = ImageDataGenerator(rotation_range=20, zoom_range=0.15,width_shift_range=0.2, height_shift_range=0.2, shear_range=0.15,horizontal_flip=True, fill_mode="nearest")# train the network
H = model.fit_generator(aug.flow(trainX, trainY, batch_size=BS),validation_data=(testX, testY), steps_per_epoch=len(trainX) // BS,epochs=EPOCHS)

我们首先初始化将要训练的网络的epoch和batch size。

然后我们初始化aug,这是一个Keras ImageDataGenerator对象,用于图像的数据增强,随机平移,旋转,调整大小等。

执行数据增强是正则化的一种形式,使我们的模型能够更好的被泛化。

但是,应用数据增强意味着我们的训练数据不再是“静态的” ——数据不断变化。

根据提供给ImageDataGenerator的参数随机调整每批新数据。

因此,我们现在需要利用Keras的.fit_generator函数来训练我们的模型。

顾名思义,.fit_generator函数假定存在一个为其生成数据的基础函数。

该函数本身是一个Python生成器。

Keras在使用.fit_generator训练模型时的过程:

  • Keras调用提供给.fit_generator的生成器函数(在本例中为aug.flow
  • 生成器函数为.fit_generator函数生成一批大小为BS的数据
  • .fit_generator函数接受批量数据,执行反向传播,并更新模型中的权重
  • 重复该过程直到达到期望的epoch数量

您会注意到我们现在需要在调用.fit_generator时提供steps_per_epoch参数(.fit方法没有这样的参数)。

为什么我们需要steps_per_epoch

请记住,Keras数据生成器意味着无限循环,它永远不会返回或退出。

由于该函数旨在无限循环,因此Keras无法确定一个epoch何时开始的,并且新的epoch何时开始。

因此,我们将训练数据的总数除以批量大小的结果作为steps_per_epoch的值。一旦Keras到达这一步,它就会知道这是一个新的epoch。

参考:

  • 在Keras中导入测试数据的方法
  • 如何使用Keras fit和fit_generator

[深度学习] Keras 如何使用fit和fit_generator相关推荐

  1. 利用深度学习(Keras)进行癫痫分类-Python案例

    目录 癫痫介绍 数据集 Keras深度学习案例 本分享为脑机学习者Rose整理发表于公众号:脑机接口社区 QQ交流群:903290195 癫痫介绍 癫痫,即俗称"羊癫风",是由多种 ...

  2. 深度学习——keras教程系列基础知识

    大家好,本期我们将开始一个新的专题的写作,因为有一些小伙伴想了解一下深度学习框架Keras的知识,恰好本人也会一点这个知识,因此就开始尝试着写一写吧.本着和大家一起学习的态度,有什么写的不是很好的地方 ...

  3. python深度学习--Keras函数式API(多输入,多输出,类图模型)

    import numpy as np import pandas as pd import matplotlib.pyplot as plt import pylab from pandas impo ...

  4. 深度学习 Keras入门 一 之基础篇

    1.关于Keras 1)简介 Keras是由纯python编写的基于theano/tensorflow的深度学习框架. Keras是一个高层神经网络API,支持快速实验,能够把你的idea迅速转换为结 ...

  5. 私人定制——使用深度学习Keras和TensorFlow打造一款音乐推荐系统

    随着生活水平的极大提高,人们在很多情况下都会边听音乐边做一些事情,比如在健身房.出行路上等,越来越多的人也开始慢慢走在Hifi发烧友的这一条不归路上,频繁地换耳机.换功放等,小编在这里劝一下大家不要向 ...

  6. 深度学习--Keras总结

    Keras主要包括14个模块,本文主要对Models.layers.Initializations.Activations.Objectives.Optimizers.Preprocessing.me ...

  7. 用深度学习keras的cnn做图像识别分类,准确率达97%

    Keras是一个简约,高度模块化的神经网络库. 可以很容易和快速实现原型(通过总模块化,极简主义,和可扩展性) 同时支持卷积网络(vision)和复发性的网络(序列数据).以及两者的组合. 无缝地运行 ...

  8. 个人深度学习keras环境配置介绍

    因为我主要使用的是keras,所以我所要介绍的也是keras的环境,如果以后我转pytorch后再补充我的pytorch环境配置情况 我的keras环境配置如下: cudatoolkit 10.0.1 ...

  9. 深度学习Keras框架实践笔记

    在其他机器保存keras模型(.h5),load_model(~.h5)后报错[in from_config if 'class_name' not in config[0] or config[0] ...

最新文章

  1. cin输入字符串怎么结束_翻遍全网,只为让你记住这些输入输出函数
  2. c/c++中typedef详解(此文对typedef用于结构体的定义说明得很清楚到位)
  3. javaScript——廖雪峰老师学习笔记(一)
  4. phpstrtotime()对于31日求上个月有问题
  5. a标签连接空标签的方法
  6. 解决eclipse中jsp没有代码提示问题
  7. GZip、Deflate压缩算法对应的C#压缩解压函数
  8. 利用NABCD模型进行竞争性需求分析
  9. Padavan 老毛子路由器登录SSH教程
  10. 双稳态(bistable)与单稳态
  11. longitudinal models | 纵向研究 | mixed model
  12. 文本CSS多行溢出隐藏显示省略号
  13. 云大计算机应用技术考研2021,2021云南大学考研经验贴
  14. 一、万维网的发展(W3C组织的建立)
  15. 京东商城(HTML和CSS实现京东商城网站)
  16. java 生成txt文档 指定编码格式
  17. css做出圆角矩形边框
  18. 服务器虚拟化vmware价格,vmware服务器虚拟化实施方案(vmware服务器虚拟化收费)...
  19. 26位前谷歌AI专家出走创业
  20. 小学美术计算机教案模板,小学美术教案模板

热门文章

  1. s6 android 7.0 国行,三星S6电信版/S6 Edge国行版升级安卓7.0
  2. 内地高校招收澳门保送生公布录取结果 882名学生获录取
  3. 专访驭势科技吴甘沙:无人驾驶硝烟弥漫,“创造”才有未来|封面人物
  4. ps、grep和kill联合使用杀掉进程(转)
  5. JavaScript 总结几个提高性能知识点
  6. eclipse插件安装,万能方法
  7. HttpURLConnection和HttpClient的简单用法
  8. MonoRail MVC应用(2)-构建多层结构的应用程序
  9. Java 8开发的4大顶级技巧
  10. JDBC中Statement与PreparedStatement的区别