keras train_on_batch详解(train_on_batch的输出输入详解,train_on_batch多GPU训练详解,自定义学习率调整策略)
利用 train_on_batch 精细管理训练过程
大部分使用 keras 的同学使用 fit() 或者 fit_generator() 进行模型训练, 这两个 api 对于刚接触深度学习的同学非常友好和方便,但是由于其是非常深度的封装,对于希望自定义训练过程的同学就显得不是那么方便(从 torch 转 keras 的同学可能更喜欢自定义训练过程),而且,对于 GAN 这种需要分步进行训练的模型,也无法直接使用 fit 或者 fit_generator 直接训练的。因此,keras 提供了 train_on_batch 这个 api,对一个 mini-batch 的数据进行梯度更新。
总结优点如下:
- 更精细自定义训练过程,更精准的收集 loss 和 metrics
- 分步训练模型-GAN的实现
- 多GPU训练保存模型更加方便
- 更多样的数据加载方式,结合 torch dataloader 的使用
下面介绍 train_on_batch 的使用
1. train_on_batch 的输入输出
1.1 输入
y_pred = Model.train_on_batch(x,y=None,sample_weight=None,class_weight=None,reset_metrics=True,return_dict=False,
)
- x:模型输入,单输入就是一个 numpy 数组, 多输入就是 numpy 数组的列表
- y:标签,单输出模型就是一个 numpy 数组, 多输出模型就是 numpy 数组列表
- sample_weight:mini-batch 中每个样本对应的权重,形状为 (batch_size)
- class_weight:类别权重,作用于损失函数,为各个类别的损失添加权重,主要用于类别不平衡的情况, 形状为 (num_classes)
- reset_metrics:默认True,返回的metrics只针对这个mini-batch, 如果False,metrics 会跨批次累积
- return_dict:默认 False, y_pred 为一个列表,如果 True 则 y_pred 是一个字典
1.2 输出
- 单输出模型,1个loss,没有metrics,train_on_batch 返回一个标量,代表这个 mini-batch 的 loss, 例如
model = keras.models.Model(inputs=inputs, outputs=outputs)
model.compile(Adam, loss=['binary_crossentropy'])
y_pred = model.train_on_batch(x=image,y=label)
# y_pred 为标量
- 单输出模型,有1个loss,n个metrics,train_on_batch 返回一个列表, 列表长度为 1+n, 例如
model = keras.models.Model(inputs=inputs, outputs=outputs)
model.compile(Adam, loss=['binary_crossentropy'], metrics=['accuracy'])
y_pred = model.train_on_batch(x=image,y=label)
# len(y_pred) == 2, y_pred[0]为loss, y_pred[1]为accuracy
- 多输出模型,n个loss,m个metrics,train_on_batch返回一个列表,列表长度为 1+n+m*n, 例如
model = keras.models.Model(inputs=inputs, outputs=[output1, output2])
model.compile(Adam, loss=['binary_crossentropy', 'binary_crossentropy'], metrics=['accuracy', 'accuracy'])
y_pred = model.train_on_batch(x=image,y=label)
# 查看model.metrics_names来了解返回列表中每个值的含义
2. train_on_batch 多GPU训练模型
2.1 多GPU模型初始化,加载权重,模型编译,模型保存
注意!训练时对 para_model 操作,保存时对 model 做操作
import tensorflow as tf
import keras
import os# 初始化GPU的使用个数
gpu = "0,1"
os.environ["CUDA_VISIBLE_DEVICES"] = gpu
gpu_num = len(gpu.split(','))# model初始化
with tf.device('/cpu:0'):# 使用多GPU时,先在CPU上初始化模型model = YourModel(input_size, num_classes)model.load_weights('*.h5') # 如果有权重需要加载,在这里实现
para_model = keras.utils.multi_gpu_model(model, gpus=gpu_num) # 在GPU上初始化多GPU模型
para_model.compile(optimizer, loss=[...], metrics=[...]) # 编译多GPU模型# 训练和验证,对 para_model 使用 train_on_batch
def train():para_model.train_on_batch(...)def evaluate():para_model.test_on_batch(...)# 保存模型,注意!训练时对 para_model 操作,保存时对 model 做操作
# 不要使用 para_model.save() 或者 para_model.save_weights(),否则加载时会出问题
model.save('*.h5')
model.save_weights('*.h5')
3. 自定义学习率调整策略
由于无法使用callback,我们使用 keras.backend.get_value() 和 keras.backend.set_value() 来获取和设置当前学习率。举个栗子, 实现一下最简单阶梯下降学习率,每10个epoch,学习率下降0.1倍
import keras.backend as Kfor epoch in range(100):train_one_epoch()evaluate()# 每10个epoch,lr缩小0.1倍if epoch%10==0 and epoch!=0:lr = K.get_value(model.optimizer.lr) # 获取当前学习率lr = lr * 0.1 # 学习率缩小0.1倍K.set_value(model.optimizer.lr, lr) # 设置学习率
4. keras和torch的结合
torch 的 dataloader 是目前为止我用过最好用的数据加载方式,使用 train_on_batch 一部分的原因是因为我能够用 torch dataloader 载入数据,然后用 train_on_batch 对模型进行训练,通过合理的控制 cpu worker 的使用个数和 batch_size 的大小,使模型的训练效率最大化
4.1 dataloader+train_on_batch 训练keras模型pipeline
# 定义 torch dataset
class Dataset(torch.utils.data.Dataset):def __init__(self, root_list, transforms=None):self.root_list = root_listself.transforms = transformsdef __getitem__(self, idx):# 假设是图像分类任务image = ... # 读取单张图像label = ... # 读取标签if self.transforms is not None:image = self.transforms(image)return image, label # shape: (H,W,3), salardef __len__(self):return len(self.root_list)# 自定义 collate_fn 使 dataloader 返回 numpy array
def collate_fn(batch):# 这里的 batch 是 tuple 列表,[(image, label),(image, label),...]image, label = zip(*batch)image = np.asarray(image) # (batch_size, H, W, 3)label = np.asarray(label) # (batch_size)return image, label # 如果 datast 返回的图像是 ndarray,这样loader返回的也是 ndarray# 定义dataset
train_dataset = Dataset(train_list)
valid_dataset = Dataset(valid_list)
test_dataset = Dataset(test_list)# 定义 dataloader, 如果不使用自定义 collate_fn,
# 从 loader 取出的默认是 torch Tensor,需要做一个 .numpy()的转换
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size, shuffle=True, num_workers=4, collate_fn=collate_fn)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size, shuffle=False, num_workers=4, collate_fn=collate_fn)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size, shuffle=False, num_workers=4, collate_fn=collate_fn)# 定义 train,evaluate,test
def train():for i,(inputs, label) in enumerate(train_loader):# 如果 inputs 和 label 是 torch Tensor# 请用 inputs = inputs.numpy() 和 label = label.numpy() 转成 ndarrayy_pred = model.train_on_batch(inputs, label)def evaluate():for i,(inputs, label) in enumerate(valid_loader):# 如果 inputs 和 label 是 Tensor,同上y_pred = model.test_on_batch(inputs, label)def test():for i,(inputs, label) in enumerate(test_loader):# 如果 inputs 和 label 是 Tensor,同上y_pred = model.test_on_batch(inputs, label)def run():for epoch in num_epoch:train()evaluate()test()if __name__ == "__main__":run()
总结
还有一些使用 train_on_batch 的地方比如 GAN 的训练,这里就不介绍了,具体可以上 github 上搜索,例如 keras-dcgan。
参考
keras 官方 api: train_on_batch
keras train_on_batch详解(train_on_batch的输出输入详解,train_on_batch多GPU训练详解,自定义学习率调整策略)相关推荐
- 【详解】模型优化技巧之优化器和学习率调整
目录 PyTorch十大优化器 1 torch.optim.SGD 2 torch.optim.ASGD 3 torch.optim.Rprop 4 torch.optim.Adagrad 5 tor ...
- Keras深度学习实战(4)——深度学习中常用激活函数和损失函数详解
Keras深度学习实战(4)--深度学习中常用激活函数和损失函数详解 常用激活函数 Sigmoid 激活函数 Tanh 激活函数 ReLU 激活函数 线性激活函数 Softmax 激活函数 损失函数 ...
- 详解linux netstat输出的网络连接状态信息
本博文为老男孩linu培训机构早期的培训教案,特分享以供大家学习参考. 全部系列分为五篇文章,本博文为第一篇: 目录:一.生产服务器netstat tcp连接状态................... ...
- 【C语言网】C语言基础题集训练详解(一)
[C语言网]基础题集训练详解(一) 题目目录 [C语言网]基础题集训练详解(一) 前言 一.题目1000 [竞赛入门]简单的a+b 二. 题目1001 [编程入门]第一个HelloWorld程序 三. ...
- 【转】ASP.NET验证控件详解(非空验证,比较验证,范围验证,正则表达式,自定义验证)...
[转]ASP.NET验证控件详解(非空验证,比较验证,范围验证,正则表达式,自定义验证) ASP.NET验证控件详解 现在ASP.NET,你不但可以轻松的实现对用户输入的验证,而且,还可以选择验证在服 ...
- linux中用zip压缩文件,详解Linux中zip压缩和unzip解压缩命令及使用详解
下面给大家介绍下Linux中zip压缩和unzip解压缩命令详解 1.把/home目录下面的mydata目录压缩为mydata.zip zip -r mydata.zip mydata #压缩myda ...
- P2P技术详解(三):P2P技术之STUN、TURN、ICE详解
本文是<P2P理论详解>系列文章中的第2篇,总目录如下: <P2P技术详解(一):NAT详解--详细原理.P2P简介> <P2P技术详解(二):P2P中的NAT穿越(打洞 ...
- 基石为勤能补拙的迷宫之旅——第三天(Python基本数据类型,与用户交互(输出输入),运算符)
一. 基本数据类型 为何数据要区分类型? 数据类型值的是变量值的类型,变量值之所区分类型,是因为变量值是用来记录事物状态的,而事物的状态有不同的种类,对应着,也必须使用不同类型的值去记录它们. ...
- 输入一组整数,0结束输入,之后输出输入的最大的和最小的整数.【思路】
package com.ykmimi.new1; /*** 输入一组整数,0结束输入,之后输出输入的最大的和最小的整数.*/ import java.util.Scanner;public class ...
最新文章
- html6个圆圈放一排,html中两个选择框如何并排放置(一)
- 分析RAC下一个SPFILE整合的三篇文章的文件更改
- matlab 图像函数以及运用(第十章)
- 《编码的奥秘》---学习编程一年半的体会
- Leetcode 1. 两数之和 (Python版)
- rh php56 php,在全球范围内提供RHSCL PHP的最佳方法
- Raid技术精简总结
- GTK+图形化应用程序开发学习笔记(三)—窗体
- HDOJ--1106排序
- 汇编语言 王爽 第四版 课后检测点 课后实验 包括解释 持续更新~~
- Https网络安全传输详解
- ABB机器人与OMRON PLC Socket通信
- es7 创建模板时,报错 Validation Failed: 1: index patterns are missing
- 霍普金斯计算机专业研究生如何,约翰·霍普金斯大学电气和计算机工程硕士研究生...
- LHS查询 RHS查询
- 简单计算机java程序_JAVA程序员需要知道的计算机底层基础10-操作系统引导程序的简单...
- win7计算机双击变管理,如何修复Win7系统鼠标单击以双击
- 去除页眉横线快准狠的3个方法,就喜欢这么简单粗暴的操作!
- 一份机器学习的自白书
- 风力发电会影响气候?