keras通过model.fit_generator训练模型(节省内存)

前言

前段时间在训练模型的时候,发现当训练集的数量过大,并且输入的图片维度过大时,很容易就超内存了,举个简单例子,如果我们有20000个样本,输入图片的维度是224x224x3,用float32存储,那么如果我们一次性将全部数据载入内存的话,总共就需要20000x224x224x3x32bit/8=11.2GB 这么大的内存,所以如果一次性要加载全部数据集的话是需要很大内存的。

如果我们直接用keras的fit函数来训练模型的话,是需要传入全部训练数据,但是好在提供了fit_generator,可以分批次的读取数据,节省了我们的内存,我们唯一要做的就是实现一个生成器(generator)。

1.fit_generator函数简介

fit_generator(generator,steps_per_epoch=None,epochs=1,verbose=1,callbacks=None,validation_data=None,validation_steps=None,class_weight=None,max_queue_size=10,workers=1,use_multiprocessing=False,shuffle=True,initial_epoch=0)

参数:

generator:一个生成器,或者一个 Sequence (keras.utils.Sequence) 对象的实例。这是我们实现的重点,后面会着介绍生成器和sequence的两种实现方式。

steps_per_epoch:这个是我们在每个epoch中需要执行多少次生成器来生产数据,fit_generator函数没有batch_size这个参数,是通过steps_per_epoch来实现的,每次生产的数据就是一个batch,因此steps_per_epoch的值我们通过会设为(样本数/batch_size)。如果我们的generator是sequence类型,那么这个参数是可选的,默认使用len(generator) 。

epochs:即我们训练的迭代次数。

verbose:0, 1 或 2。日志显示模式。 0 = 安静模式, 1 = 进度条, 2 = 每轮一行

callbacks:在训练时调用的一系列回调函数。

validation_data:和我们的generator类似,只是这个使用于验证的,不参与训练。

validation_steps:和前面的steps_per_epoch类似。

class_weight:可选的将类索引(整数)映射到权重(浮点)值的字典,用于加权损失函数(仅在训练期间)。 这可以用来告诉模型「更多地关注」来自代表性不足的类的样本。(感觉这个参数用的比较少)

max_queue_size:整数。生成器队列的最大尺寸。默认为10.

workers:整数。使用的最大进程数量,如果使用基于进程的多线程。 如未指定,workers 将默认为 1。如果为 0,将在主线程上执行生成器。

use_multiprocessing:布尔值。如果 True,则使用基于进程的多线程。默认为False。

shuffle:是否在每轮迭代之前打乱 batch 的顺序。 只能与Sequence(keras.utils.Sequence) 实例同用。

initial_epoch: 开始训练的轮次(有助于恢复之前的训练)

2.generator实现

2.1生成器的实现方式

样例代码:

from keras.models import Sequential
from keras.layers import Dense
import numpy as np
from PIL import Imagedef process_x(path):img = Image.open(path)img = img.resize((96, 96))img = img.convert('RGB')img = np.array(img)img = np.asarray(img, np.float32) / 255.0# 也可以进行进行一些数据数据增强的处理return imgdef generate_arrays_from_file(x_y):# x_y 是我们的训练集包括标签,每一行的第一个是我们的图片路径,后面的是图片标签global countbatch_size = 8while 1:batch_x = x_y[(count - 1) * batch_size:count * batch_size, 0]batch_y = x_y[(count - 1) * batch_size:count * batch_size, 1:]batch_x = np.array([process_x(img_path) for img_path in batch_x])batch_y = np.array(batch_y).astype(np.float32)print("count:" + str(count))count = count + 1yield batch_x, batch_ymodel = Sequential()
model.add(Dense(units=1000, activation='relu', input_dim=2))
model.add(Dense(units=2, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
count = 1
x_y = []
model.fit_generator(generate_arrays_from_file(x_y), steps_per_epoch=10, epochs=2, max_queue_size=1, workers=1)

在理解上面代码之前我们需要首先了解yield的用法。

yield关键字:

我们先通过一个例子看一下yield的用法:

def foo():print("starting...")while True:res = yield 4print("res:", res)g = foo()
print(next(g))
print("----------")
print(next(g))

运行结果:

starting...
4
----------
res: None
4

带yield的函数是一个生成器,而不是一个函数。因为foo函数中有yield关键字,所以foo函数并不会真的执行,而是先得到一个生成器的实例,当我们第一次调用next函数的时候,foo函数才开始行,首先先执行foo函数中的print方法,然后进入while循环,循环执行到yield时,yield其实相当于return,函数返回4,程序停止。所以我们第一次调用next(g)的输出结果是前面两行。

然后当我们再次调用next(g)时,这个时候是从上一次停止的地方继续执行,也就是要执行res的赋值操作,因为4已经在上一次执行被return了,随意赋值res为None,然后执行print(“res:”,res)打印res: None,再次循环到yield返回4,程序停止。

所以yield关键字的作用就是我们能够从上一次程序停止的地方继续执行,这样我们用作生成器的时候,就避免一次性读入数据造成内存不足的情况。

现在看到上面的示例代码:

generate_arrays_from_file函数就是我们的生成器,每次循环读取一个batch大小的数据,然后处理数据,并返回。x_y是我们的把路径和标签合并后的训练集,类似于如下形式:

['data/img_4092.jpg' '0' '1' '0' '0' '0' ]

至于格式不一定要这样,可以是自己的格式,至于怎么处理,根于自己的格式,在process_x进行处理,这里因为是存放的图片路径,所以在process_x函数的主要作用就是读取图片并进行归一化等操作,也可以在这里定义自己需要进行的操作,例如对图像进行实时数据增强。

2.2使用Sequence实现generator

示例代码:

class BaseSequence(Sequence):"""基础的数据流生成器,每次迭代返回一个batchBaseSequence可直接用于fit_generator的generator参数fit_generator会将BaseSequence再次封装为一个多进程的数据流生成器而且能保证在多进程下的一个epoch中不会重复取相同的样本"""def __init__(self, img_paths, labels, batch_size, img_size):# np.hstack在水平方向上平铺self.x_y = np.hstack((np.array(img_paths).reshape(len(img_paths), 1), np.array(labels)))self.batch_size = batch_sizeself.img_size = img_sizedef __len__(self):# math.ceil表示向上取整# 调用len(BaseSequence)时返回,返回的是每个epoch我们需要读取数据的次数return math.ceil(len(self.x_y) / self.batch_size)def preprocess_img(self, img_path):img = Image.open(img_path)resize_scale = self.img_size[0] / max(img.size[:2])img = img.resize((self.img_size[0], self.img_size[0]))img = img.convert('RGB')img = np.array(img)# 数据归一化img = np.asarray(img, np.float32) / 255.0return imgdef __getitem__(self, idx):batch_x = self.x_y[idx * self.batch_size: (idx + 1) * self.batch_size, 0]batch_y = self.x_y[idx * self.batch_size: (idx + 1) * self.batch_size, 1:]batch_x = np.array([self.preprocess_img(img_path) for img_path in batch_x])batch_y = np.array(batch_y).astype(np.float32)print(batch_x.shape)return batch_x, batch_y# 重写的父类Sequence中的on_epoch_end方法,在每次迭代完后调用。def on_epoch_end(self):# 每次迭代后重新打乱训练集数据np.random.shuffle(self.x_y)

在上面代码中,__len __和__getitem __,是我们重写的魔法方法,__len __是当我们调用len(BaseSequence)函数时调用,这里我们返回(样本总量/batch_size),供我们传入fit_generator中的steps_per_epoch参数;__getitem __可以让对象实现迭代功能,这样在将BaseSequence的对象传入fit_generator中后,不断执行generator就可循环的读取数据了。

举个例子说明一下getitem的作用:

class Animal:def __init__(self, animal_list):self.animals_name = animal_listdef __getitem__(self, index):return self.animals_name[index]animals = Animal(["dog", "cat", "fish"])
for animal in animals:print(animal)

输出结果:

dog
cat
fish

并且使用Sequence类可以保证在多进程的情况下,每个epoch中的样本只会被训练一次。

参考yield方法:

2020-12-11 keras通过model.fit_generator训练模型(节省内存)相关推荐

  1. Keras之model.fit_generator()的使用

    Keras之model.fit_generator()的使用 model.fit_generator()是利用生成器,分批次向模型送入数据的方式,可以有效节省单次内存的消耗 一.使用方式 1.引入库 ...

  2. 一篇文章足够你学习蓝牙技术,提供史上最全的蓝牙技术(传统蓝牙/低功耗蓝牙)文章总结,文档下载总结(2020/12/11更新)

    本文章目的: 1)给广大蓝牙爱好者提供蓝牙资料下载渠道 2)给广大蓝牙爱好者增加一个蓝牙学习文章导读 我们的蓝牙书以及CSDN蓝牙系列的书籍以及视频有以下计划,大家可以根据兴趣爱好或者工作需要挑选特定 ...

  3. C语言入门 -- 打印工资总额、税金及净工资(2020/12/11)

    打印工资 已知基本工资率.加班费计算及税率,打印工资总额.税金及净工资 假设如下: (1) 基本工资率=每小时应支付美元,例如10美元/小时 (2) 加班(每周超过40小时)=(基本工资率)*2.5 ...

  4. c#12星座速配代码_白羊座今日运势|2020/12/11

    整体运势:★★★☆☆ 爱情运势:★★☆☆☆ 事业运势:★★☆☆☆ 财富运势:★★★☆☆ 幸运数字:7 速配星座:金牛座 幸运颜色:橙色 幸运时刻:12:00-14:00 整体运势: 接收的消息会比较多 ...

  5. 网址收藏 2020.12.11

    Java虚拟机规范: malldump(内存转储):https://github.com/yodaos-project/malldump: ethtool:源代码 :https://mirrors.e ...

  6. 《惢客创业日记》2020.12.11(周五)每个客户都有隐形需求

    今天早晨做了一个梦,而这个梦正好为昨天凉粉儿反馈的一个用户调研提供了一个解决思路.先说说昨天凉粉儿的用户调研吧. 昨天下午下班后,我正在公司干活,凉粉儿发来一个微信,她针对惢客慈善中的捐物版块,调研了 ...

  7. STEMA 考试每日一练 2020.12.7 - 2021.11.30 试题及答案 - 刷题

    2020.12.7 在以下几个选项中,正确的从小到大的排序是( ) A 地球<太阳系<可观测宇宙<银河系B 地球<太阳系<银河系<可观测宇宙C 太阳系<地球& ...

  8. Keras实现mode.fit和model.fit_generator比较

    模型部分 模型部分都一样,比如我这里使用AlexNet网络来做.我做的是一个二分类任务,所以结尾部分网络有改动.输入图片尺寸是256*256的,所以输出图片尺寸有一点改动. from keras.mo ...

  9. 10年老电脑如何提速_电信宽带免费提速至200M,面向全国用户活动日期2020年11月9日至12月31日...

    近日中国电信免费提速活动,也是为了大家方便剁手吧! 2020年11月9日至12月31日,针对接入速率200Mbps以下的电信光纤宽带家庭用户,中国电信推出免费在线提速到200Mbps的优惠活动:针对接 ...

最新文章

  1. only one element tensors can be converted to Python scalars
  2. Java 8系列(一): 日期/时间- JSR310( Date and Time API)
  3. 工作25:工具里面代码提交
  4. docker卸载 windows版本_DevOps系列 006 - Docker安装
  5. 高性能可扩展mysql-数据库设计规范
  6. 2007年4月 [Update to 4.27]
  7. Java重入函数_重入函数
  8. redis读数据超时问题查询
  9. 深度学习项目实施流程
  10. 揭秘当下最主流的的7个app推广渠道及其不为人知的秘密
  11. html5:初学h标签的使用 p标签 br标签 hr标签
  12. Spring 集成与分片详解
  13. python 报错in module,Centos 7 python 编译报错 ImportError: No module named six 解决办法
  14. 盘点2017 CES展会所有亮眼黑科技 (上)
  15. 案例分析:session丢失及appdomain回收
  16. 通过Gearman实现MySQL到Redis的数据复制
  17. 高性能MySQL之 Chapter13
  18. 向日葵远程操控的实现
  19. Lion的无线网络诊断工具
  20. 计算机休眠模式对cpu,笔记本计算机处于待机模式时,正常的CPU温度是多少?

热门文章

  1. 控制台上的内容不输入到nohup.out
  2. hadoop学习-Netflix电影推荐系统
  3. gis中dbf转为csv_Python中.dbf到.csv的批量转换
  4. 神经元模型及网络结构
  5. LeetCode题组:第21题-合并两个有序链表
  6. win11+AMD的cpu+3060GPU电脑安装 tensorflow-GPU+cuda11+cudnn
  7. 解决问题:EnvironmentLocationNotFound: Not a conda environment: /anaconda3/envs/anaconda3
  8. python pandas加速包
  9. Python中的if __name__ == ‘__main__‘
  10. JAVA服务治理实践之无侵入的应用服务监控--转