利用 train_on_batch 精细管理训练过程

大部分使用 keras 的同学使用 fit() 或者 fit_generator() 进行模型训练, 这两个 api 对于刚接触深度学习的同学非常友好和方便,但是由于其是非常深度的封装,对于希望自定义训练过程的同学就显得不是那么方便(从 torch 转 keras 的同学可能更喜欢自定义训练过程),而且,对于 GAN 这种需要分步进行训练的模型,也无法直接使用 fit 或者 fit_generator 直接训练的。因此,keras 提供了 train_on_batch 这个 api,对一个 mini-batch 的数据进行梯度更新。
总结优点如下:

  • 更精细自定义训练过程,更精准的收集 loss 和 metrics
  • 分步训练模型-GAN的实现
  • 多GPU训练保存模型更加方便
  • 更多样的数据加载方式,结合 torch dataloader 的使用

下面介绍 train_on_batch 的使用

1. train_on_batch 的输入输出

1.1 输入

y_pred = Model.train_on_batch(x,y=None,sample_weight=None,class_weight=None,reset_metrics=True,return_dict=False,
)
  • x:模型输入,单输入就是一个 numpy 数组, 多输入就是 numpy 数组的列表
  • y:标签,单输出模型就是一个 numpy 数组, 多输出模型就是 numpy 数组列表
  • sample_weight:mini-batch 中每个样本对应的权重,形状为 (batch_size)
  • class_weight:类别权重,作用于损失函数,为各个类别的损失添加权重,主要用于类别不平衡的情况, 形状为 (num_classes)
  • reset_metrics:默认True,返回的metrics只针对这个mini-batch, 如果False,metrics 会跨批次累积
  • return_dict:默认 False, y_pred 为一个列表,如果 True 则 y_pred 是一个字典

1.2 输出

  • 单输出模型,1个loss,没有metrics,train_on_batch 返回一个标量,代表这个 mini-batch 的 loss, 例如
model = keras.models.Model(inputs=inputs, outputs=outputs)
model.compile(Adam, loss=['binary_crossentropy'])
y_pred = model.train_on_batch(x=image,y=label)
# y_pred 为标量
  • 单输出模型,有1个loss,n个metrics,train_on_batch 返回一个列表, 列表长度为 1+n, 例如
model = keras.models.Model(inputs=inputs, outputs=outputs)
model.compile(Adam, loss=['binary_crossentropy'], metrics=['accuracy'])
y_pred = model.train_on_batch(x=image,y=label)
# len(y_pred) == 2, y_pred[0]为loss, y_pred[1]为accuracy
  • 多输出模型,n个loss,m个metrics,train_on_batch返回一个列表,列表长度为 1+n+m*n, 例如
model = keras.models.Model(inputs=inputs, outputs=[output1, output2])
model.compile(Adam, loss=['binary_crossentropy', 'binary_crossentropy'], metrics=['accuracy', 'accuracy'])
y_pred = model.train_on_batch(x=image,y=label)
# 查看model.metrics_names来了解返回列表中每个值的含义

2. train_on_batch 多GPU训练模型

2.1 多GPU模型初始化,加载权重,模型编译,模型保存

注意!训练时对 para_model 操作,保存时对 model 做操作

import tensorflow as tf
import keras
import os# 初始化GPU的使用个数
gpu = "0,1"
os.environ["CUDA_VISIBLE_DEVICES"] = gpu
gpu_num = len(gpu.split(','))# model初始化
with tf.device('/cpu:0'):# 使用多GPU时,先在CPU上初始化模型model = YourModel(input_size, num_classes)model.load_weights('*.h5') # 如果有权重需要加载,在这里实现
para_model = keras.utils.multi_gpu_model(model, gpus=gpu_num) # 在GPU上初始化多GPU模型
para_model.compile(optimizer, loss=[...], metrics=[...]) # 编译多GPU模型# 训练和验证,对 para_model 使用 train_on_batch
def train():para_model.train_on_batch(...)def evaluate():para_model.test_on_batch(...)# 保存模型,注意!训练时对 para_model 操作,保存时对 model 做操作
# 不要使用 para_model.save() 或者 para_model.save_weights(),否则加载时会出问题
model.save('*.h5')
model.save_weights('*.h5')

3. 自定义学习率调整策略

由于无法使用callback,我们使用 keras.backend.get_value() 和 keras.backend.set_value() 来获取和设置当前学习率。举个栗子, 实现一下最简单阶梯下降学习率,每10个epoch,学习率下降0.1倍

import keras.backend as Kfor epoch in range(100):train_one_epoch()evaluate()# 每10个epoch,lr缩小0.1倍if epoch%10==0 and epoch!=0:lr = K.get_value(model.optimizer.lr) # 获取当前学习率lr = lr * 0.1 # 学习率缩小0.1倍K.set_value(model.optimizer.lr, lr) # 设置学习率

4. keras和torch的结合

torch 的 dataloader 是目前为止我用过最好用的数据加载方式,使用 train_on_batch 一部分的原因是因为我能够用 torch dataloader 载入数据,然后用 train_on_batch 对模型进行训练,通过合理的控制 cpu worker 的使用个数和 batch_size 的大小,使模型的训练效率最大化

4.1 dataloader+train_on_batch 训练keras模型pipeline

# 定义 torch dataset
class Dataset(torch.utils.data.Dataset):def __init__(self, root_list, transforms=None):self.root_list = root_listself.transforms = transformsdef __getitem__(self, idx):# 假设是图像分类任务image = ... # 读取单张图像label = ... # 读取标签if self.transforms is not None:image = self.transforms(image)return image, label # shape: (H,W,3), salardef __len__(self):return len(self.root_list)# 自定义 collate_fn 使 dataloader 返回 numpy array
def collate_fn(batch):# 这里的 batch 是 tuple 列表,[(image, label),(image, label),...]image, label = zip(*batch)image = np.asarray(image) # (batch_size, H, W, 3)label = np.asarray(label) # (batch_size)return image, label # 如果 datast 返回的图像是 ndarray,这样loader返回的也是 ndarray# 定义dataset
train_dataset = Dataset(train_list)
valid_dataset = Dataset(valid_list)
test_dataset = Dataset(test_list)# 定义 dataloader, 如果不使用自定义 collate_fn,
# 从 loader 取出的默认是 torch Tensor,需要做一个 .numpy()的转换
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size, shuffle=True, num_workers=4, collate_fn=collate_fn)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size, shuffle=False, num_workers=4, collate_fn=collate_fn)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size, shuffle=False, num_workers=4, collate_fn=collate_fn)# 定义 train,evaluate,test
def train():for i,(inputs, label) in enumerate(train_loader):# 如果 inputs 和 label 是 torch Tensor# 请用 inputs = inputs.numpy() 和 label = label.numpy() 转成 ndarrayy_pred = model.train_on_batch(inputs, label)def evaluate():for i,(inputs, label) in enumerate(valid_loader):# 如果 inputs 和 label 是 Tensor,同上y_pred = model.test_on_batch(inputs, label)def test():for i,(inputs, label) in enumerate(test_loader):# 如果 inputs 和 label 是 Tensor,同上y_pred = model.test_on_batch(inputs, label)def run():for epoch in num_epoch:train()evaluate()test()if __name__ == "__main__":run()

总结

还有一些使用 train_on_batch 的地方比如 GAN 的训练,这里就不介绍了,具体可以上 github 上搜索,例如 keras-dcgan。

参考

keras 官方 api: train_on_batch

keras train_on_batch详解(train_on_batch的输出输入详解,train_on_batch多GPU训练详解,自定义学习率调整策略)相关推荐

  1. 【详解】模型优化技巧之优化器和学习率调整

    目录 PyTorch十大优化器 1 torch.optim.SGD 2 torch.optim.ASGD 3 torch.optim.Rprop 4 torch.optim.Adagrad 5 tor ...

  2. Keras深度学习实战(4)——深度学习中常用激活函数和损失函数详解

    Keras深度学习实战(4)--深度学习中常用激活函数和损失函数详解 常用激活函数 Sigmoid 激活函数 Tanh 激活函数 ReLU 激活函数 线性激活函数 Softmax 激活函数 损失函数 ...

  3. 详解linux netstat输出的网络连接状态信息

    本博文为老男孩linu培训机构早期的培训教案,特分享以供大家学习参考. 全部系列分为五篇文章,本博文为第一篇: 目录:一.生产服务器netstat tcp连接状态................... ...

  4. 【C语言网】C语言基础题集训练详解(一)

    [C语言网]基础题集训练详解(一) 题目目录 [C语言网]基础题集训练详解(一) 前言 一.题目1000 [竞赛入门]简单的a+b 二. 题目1001 [编程入门]第一个HelloWorld程序 三. ...

  5. 【转】ASP.NET验证控件详解(非空验证,比较验证,范围验证,正则表达式,自定义验证)...

    [转]ASP.NET验证控件详解(非空验证,比较验证,范围验证,正则表达式,自定义验证) ASP.NET验证控件详解 现在ASP.NET,你不但可以轻松的实现对用户输入的验证,而且,还可以选择验证在服 ...

  6. linux中用zip压缩文件,详解Linux中zip压缩和unzip解压缩命令及使用详解

    下面给大家介绍下Linux中zip压缩和unzip解压缩命令详解 1.把/home目录下面的mydata目录压缩为mydata.zip zip -r mydata.zip mydata #压缩myda ...

  7. P2P技术详解(三):P2P技术之STUN、TURN、ICE详解

    本文是<P2P理论详解>系列文章中的第2篇,总目录如下: <P2P技术详解(一):NAT详解--详细原理.P2P简介> <P2P技术详解(二):P2P中的NAT穿越(打洞 ...

  8. 基石为勤能补拙的迷宫之旅——第三天(Python基本数据类型,与用户交互(输出输入),运算符)

    一. 基本数据类型 为何数据要区分类型?     数据类型值的是变量值的类型,变量值之所区分类型,是因为变量值是用来记录事物状态的,而事物的状态有不同的种类,对应着,也必须使用不同类型的值去记录它们. ...

  9. 输入一组整数,0结束输入,之后输出输入的最大的和最小的整数.【思路】

    package com.ykmimi.new1; /*** 输入一组整数,0结束输入,之后输出输入的最大的和最小的整数.*/ import java.util.Scanner;public class ...

最新文章

  1. html6个圆圈放一排,html中两个选择框如何并排放置(一)
  2. 分析RAC下一个SPFILE整合的三篇文章的文件更改
  3. matlab 图像函数以及运用(第十章)
  4. 《编码的奥秘》---学习编程一年半的体会
  5. Leetcode 1. 两数之和 (Python版)
  6. rh php56 php,在全球范围内提供RHSCL PHP的最佳方法
  7. Raid技术精简总结
  8. GTK+图形化应用程序开发学习笔记(三)—窗体
  9. HDOJ--1106排序
  10. 汇编语言 王爽 第四版 课后检测点 课后实验 包括解释 持续更新~~
  11. Https网络安全传输详解
  12. ABB机器人与OMRON PLC Socket通信
  13. es7 创建模板时,报错 Validation Failed: 1: index patterns are missing
  14. 霍普金斯计算机专业研究生如何,约翰·霍普金斯大学电气和计算机工程硕士研究生...
  15. LHS查询 RHS查询
  16. 简单计算机java程序_JAVA程序员需要知道的计算机底层基础10-操作系统引导程序的简单...
  17. win7计算机双击变管理,如何修复Win7系统鼠标单击以双击
  18. 去除页眉横线快准狠的3个方法,就喜欢这么简单粗暴的操作!
  19. 一份机器学习的自白书
  20. 风力发电会影响气候?

热门文章

  1. oracle sql 拆分字符串,Oracle数据库字符串分割的处理实现
  2. 【应用】 MODIS NDVI数据处理相关问题
  3. mac电脑上localhost找不到
  4. poi电话号码导入问题
  5. 抗疫三年,“医”路有你
  6. Open Recent
  7. 修改host文件来访问GitHub
  8. 计算机毕业设计Java校园摄影爱好者交流网站(源码+系统+mysql数据库+Lw文档)
  9. 谷歌浏览器无法携带cookie问题
  10. Excel VBA 高级编程-来自直男的Excel表白