当遇到大数据时,无法将数据全部加载进内存,需要用到分批次加载,网上的方法很多都是关于分类数据集,记录一下分割数据集使用迭代器进行数据加载的方式。
主要从keras.utils.Sequence 继承后定义一个数据加载器 DataGenerator。
!!!注:本文的代码只展现了关键部分,不是完整代码

定义数据生成器

import glob
import tensorflow as tf
from model import unet
from tensorflow import keras
import math
import os
import cv2
import numpy as np
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Denseclass DataGenerator(keras.utils.Sequence):def __init__(self, data_img, data_mask, batch_size=1, shuffle=True):self.batch_size = batch_sizeself.data_img = data_imgself.data_mask = data_maskself.indexes = np.arange(len(self.data_img))self.shuffle = shuffledef __len__(self):# 计算每一个epoch的迭代次数return math.ceil(len(self.data_img) / float(self.batch_size))def __getitem__(self, index):# 生成每个batch数据,这里就根据自己对数据的读取方式进行发挥了# 生成batch_size个索引batch_indexs = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]# 根据索引获取datas集合中的数据batch_data_img = [self.data_img[k] for k in batch_indexs]batch_data_mask = [self.data_mask[k] for k in batch_indexs]# 生成数据X, y = self.data_generation(batch_data_img, batch_data_mask)return X, ydef on_epoch_end(self):# 在每一次epoch结束是否需要进行一次随机,重新随机一下indexif self.shuffle == True:np.random.shuffle(self.indexes)def data_generation(self, batch_data_img, batch_data_mask):images = []masks = []# 生成数据for data_img, data_mask in zip(batch_data_img, batch_data_mask):# x_train数据image = cv2.imread(data_img,cv2.IMREAD_COLOR)image = cv2.resize(image,(256,256))image = list(image)images.append(image)# y_train数据mask = cv2.imread(data_mask,cv2.IMREAD_GRAYSCALE)mask = cv2.resize(mask, (256,256))mask = mask / 255.0mask = list(mask)masks.append(mask)return np.array(images), np.array(masks)# 读取样本名称,然后根据样本名称去读取数据train_img = sorted(glob.glob('./trainnsmc/image/*.png'))
train_mask = sorted(glob.glob('./trainnsmc/label/*.png'))
# 数据生成器
training_generator = DataGenerator(train_img, train_mask,batch_size=8)

定义模型,进行训练

model = unet()
#编译模型
from keras_unet_collection import losses
model.compile(optimizer=tf.keras.optimizers.Adam(lr), loss=losses.dice,metrics=[ 'acc',losses.dice_coef])model.fit(training_generator, epochs=50,  max_queue_size=10)

!!!注:本文的代码只展现了关键部分,不是完整代码

Tensorflow2.0 使用Kera 迭代器 加载图像分割训练集相关推荐

  1. TensorFlow2.0 —— 模型保存与加载

    目录 1.Keras版本模型保存与加载 2.自定义版本模型保存与加载 3.总结 1.Keras版本模型保存与加载 保存模型权重(model.save_weights) 保存HDF5文件(model.s ...

  2. 【Tensorflow 2.0 正式版教程】ImageNet(二)模型加载与训练

    前面的教程都只在小模型.小数据库上进行了演示,这次来真正实战一个大型数据库ImageNet.教程会分为三部分:数据增强.模型加载与训练.模型测试,最终在ResNet50上可以达到77.72%的top- ...

  3. 【翻译】基于 Create React App路由4.0的异步组件加载(Code Splitting)

    基于 Create React App路由4.0的异步组件加载 本文章是一个额外的篇章,它可以在你的React app中,帮助加快初始的加载组件时间.当然这个操作不是完全必要的,但如果你好奇的话,请随 ...

  4. NVIDIA GEFORCE 2080 / 2080 SUPER / 2080 Ti + CUDA Toolkit 8.0 深度学习模型加载速度慢

    NVIDIA GEFORCE 2080 / 2080 SUPER / 2080 Ti + CUDA Toolkit 8.0 深度学习模型加载速度慢 (卡顿) GEFORCE RTX 2080 / GE ...

  5. [Asp.net mvc]实体更新异常:存储区更新、插入或删除语句影响到了意外的行数(0)。实体在加载后可能被修改或删除。

    学习asp.net mvc 时在更新实体进行SaveChanges()的时候出现了异常,异常如下: 存储区更新.插入或删除语句影响到了意外的行数(0).实体在加载后可能被修改或删除.刷新 Object ...

  6. store update、insert或delete语句影响了意外的行数(0)。自加载实体后,实体可能已被修改或删除

    报错详情是: store update.insert或delete语句影响了意外的行数(0).自加载实体后,实体可能已被修改或删除.请参见http://go.microsoft.com/fwlink/ ...

  7. 3.0、Hibernate-延迟加载 1

    3.0.Hibernate-延迟加载 1 Hibernate 延迟加载 也叫 惰性加载.懒加载: 使用延迟加载可以提高程序运行效率,Java 程序 与 数据库交互的频次越低,程序运行的效率就越高,所以 ...

  8. 第七章:Tensorflow2.0 RNN循环神经网络实现IMDB数据集训练(理论+实践)

    版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明. 本文链接:https://blog.csdn.net/LQ_qing/article/deta ...

  9. java加载tensorflow训练的PB模型记录

    java加载tensorflow训练的PB模型记录 python训练 1. 模型的输入输出定义 2. 训练时保存模型的方法 java加载模型 1.maven依赖 2. Java代码实例 tensor注 ...

最新文章

  1. HTML5 利用canvas API 展示阴影效果
  2. 奖牌分配/Median Pyramid Hard
  3. 剑桥大学在机器人,半导体,5G,区块链方面的研究实验室
  4. 我们做了一个医疗版MNIST数据集,发现常见AutoML算法没那么好用
  5. 赶集网MySQL开发36军规
  6. (引)ajax 经验-保留自己使用
  7. camel.js_Camel 2.11 –没有Spring的Camel Web应用程序
  8. 移动端前端笔记 — 遇到的常见JS与CSS问题及解决方法
  9. python 解析pb文件_利用Python解析json文件
  10. php func_get_args(),PHP中func_get_args(),func_get_arg(),func_num_args()有什么不同
  11. java set region_Java Tile.setRegion方法代碼示例
  12. iPhone X 不充电维修案例
  13. 网络其他计算机无法访问,win7局域网别人无法访问我的电脑是为什么 win7其他电脑无法访问我的电脑如何修复...
  14. Python+tkinter应用程序设置背景图片
  15. Avalonia的Snoop
  16. セルジュ / Serge
  17. jmeter录制手机脚本
  18. 微信公众号扫一扫功能vue配置
  19. 互联网电商平台运营模式
  20. Python实战1-9例:变量、运算、字符串等综合训练

热门文章

  1. python生成接口文档_使用apiDoc实现python接口文档编写
  2. 8a8k单片机c语言写闹钟,我的12864超级时钟制作资料 带红外遥控 闹铃 按键 完整源码...
  3. 如何使说唱节奏在七个简单的步骤
  4. 安装ubuntu 16.04系统
  5. 双线性插值(Bilinear Interpol)原理及应用
  6. 2019 CSP-S Day2 T1 Emiya 家今天的饭(DP)
  7. 「CSP-S 2019」 Emiya 家今天的饭 题解
  8. hge source explor 0x7 resource module
  9. Unity Ferr2D 地形工具
  10. 一些javascript小技巧!