这几天在训练一个CNN网络,使用到了两百多万个图片,虽然使用到了GPU NVIDIA GeForce GTX 1080Ti,但是还是很慢。故查阅了一些训练提速的文章,跟大家分享。

Tensorflow vs. Keras or how to speed up your training for image data sets by factor 10

If you ever trained a CNN with keras on your GPU with a lot of images, you might have noticed that the performance is not as good as in tensorflow on comparable tasks. In this post I will show an example, where tensorflow is 10x times faster than keras. I will show that it is not a problem of keras itself, but a problem of how the preprocessing works and a bug in older versions of keras-preprocessing. Finally, I will show how to build a TFRecord data set and use it in keras to achieve comparable results.

Measure if your GPU is used

When training a neural net on the GPU the first thing to look at is the GPU Utilization. The GPU-utilization shows how much your GPU is used and can be observed by either nvidia-smi in the command line or with GPU-Z. The GPU utilization translates direct to training time, more GPU utilization means more parallel execution, means more speed. If you are working on windows, don’t look trust the performance charts in the windows built-in task manager, they are not very accurate.

GPU utilization in nvidia-smi

Training with keras’ ImageDataGenerator

First let’s take a look at the code, where we use a dataframe to feed the network with data. In keras this is achieved by utilizing the ImageDataGenerator class. In this example we use the Keras efficientNet on imagenet with custom labels. Additional information in the comments.

train_datagen = ImageDataGenerator(rescale=1./255, # we scale the colors down to 8 bit per channelrotation_range=30, # The image data generator offers a lot of convinience features the augment the datashear_range=0.2,zoom_range=0.2,horizontal_flip=True,validation_split=0.1 # here we can split the data into test and validation and use it later on
)# now we create a training and a test generator from a pandas dataframe, where x_col is the absolute path to the image file and y_col is the column with the label, disabling validate_filenames and drop duplicates speeds up everything for large data sets.
train_generator=train_datagen.flow_from_dataframe(dataframe=df, directory=None, x_col="ImagePath", y_col="Label", class_mode="categorical", target_size=input_shape, batch_size=batch_size,subset='training',drop_duplicates=False,validate_filenames=False)
validation_generator=train_datagen.flow_from_dataframe(dataframe=df, directory=None, x_col="ImagePath", y_col=" Label ", class_mode="categorical", target_size=input_shape, batch_size=batch_size,subset='validation',drop_duplicates=False,validate_filenames=False)# now we create the model loading the 300x300 efficientnet with imagenet weights, include_top= false drops the last fc layer, due we want to use our own.
base_model  = EfficientNetB3(include_top=False, weights='imagenet')
# our custom fc_layer(top) with the number of classes we want
x = GlobalAveragePooling2D()(model )
x = Dense(len(train_generator.class_indices), activation='softmax', name='predictions')(x)
#create the model
model = Model(inputs=base_model.input, outputs=predictions)
# and now fit the model with 16 worker threads reading the images
history = model.fit_generator(generator=train_generator,steps_per_epoch=train_generator.samples // batch_size,validation_data=validation_generator,validation_steps=validation_generator.samples // batch_size,epochs=20,workers=16,     verbose=1)

This works fine, it does what it should, but if we compare it to tensorflow implementations we notice that it’s much slower than in keras, even if we use 16 workers to make the preprocessing. But why is it still so much slower?

Data bottleneck

If we look at the GPU utilization with GPU-Z, we can observe a pattern:

GPU load while training the keras model

As shown in the image, the GPU is not used all the time the load is varying a lot. Additionally, the Memory Controller is also idling. It seems like the GPU is waiting for something. There are several possible reasons.

  • Hardware: One reason could be loading data from HDD instead of SSD, which is not the case here. But if you use HDDs, you should really upgrade to an internal SSDs (not USB, either SATA or M.2).
  • Other software: Another reason is often a virus scanner, which slows down the IO, if you have performance issues, whitelist your data folder.
  • Configuration: It might be possible to speed up the training by using python mutliprocessing, but it seems like the windows version of keras does not support it. Also increasing the queue size and number of workers in fit_generator can help.
  • Outdated version: I also noticed a difference between two computers, where one can utilize the GPU much better than the other, even with equivalent Hardware. It was simply because Keras-Preprocessing suffered from a Bug in version 1.0.9, which was fixed in 1.1.0! The GPU utilization increased from ~10% to ~60%

If nothing from the above helps we can take a look at the code and see that keras does the preprocessing on the CPU with PIL, where tensorflow often uses GPU directly. Furthermore, tensorflow offers TFRecords, which is a binary format, where images are stored raw bitmaps, which means the CPU doesn’t need to decode the jpeg files, every time it reads them. Furthermore, TFRecords ensures that the data is not fragmented in small files, which boosts IO performance.

Comparing the keras implementation with the tensorflow’s efficientnet implementation + TFRecords gives us much better results:

GPU load while training with tensorflow

But we don’t get it for free. We have to convert our hole data set from jpeg images to TFRecords (Here is a short example) and we are now dealing with tensorflow not keras and tensorflow is pretty unhandy and we lose the benefits of the keras ImageDataGenerator.

Use a TFRecord dataset in keras

Well we won’t get back the ImageDataGenerator, but we can still work with keras and the TFRecod dataset. The TFRecord dataset api is ment for optimized IO performance and here we can read the images without jpeg decoding. Thanks to the keras developers they already support passing tensorflow tensors to keras, so we can use TFRecord datasets. In Tensorflow 2.0 it should be possible to directly train a keras model on the dataset API.

# This is the funciton used to decode the TFRecords
def _parse_function(proto):keys_to_features = {"label": tf.FixedLenFeature([], tf.int64),                 'image_raw': tf.FixedLenFeature([], tf.string)}parsed_features = tf.parse_single_example(proto, keys_to_features)image = tf.decode_raw(parsed_features['image_raw'], tf.uint8)image = tf.reshape(image, (300, 300, 3))return image, parsed_features["label"]def create_dataset(files, batch_size):dataset = tf.data.TFRecordDataset(files)dataset = dataset.map(_parse_function, num_parallel_calls=4)dataset = dataset.repeat()dataset = dataset.batch(batch_size)iterator = dataset.make_one_shot_iterator()image, label = iterator.get_next()image = tf.reshape(image, [batch_size, 300, 300, 3])label = tf.one_hot(label, num_classes)return image, labelimage_tensor, label_tensor = create_dataset(records, BATCH_SIZE)
base_model = EfficientNetB5(include_top=False, input_tensor=image_tensor, weights='imagenet')
x = base_model.output
x = GlobalAveragePooling2D()(x)
predictions = Dense(num_classes, activation='softmax')(x)
model = Model(inputs=base_model.input, outputs=predictions)
model.compile(optimizer=optimizers.Adam(lr=0.01), loss='categorical_crossentropy', metrics=['accuracy'],target_tensors=[label_tensor])model.fit(epochs=10,steps_per_epoch=num_batches,    verbose=1)

Now we have a GPU utilization of ~80 % in average, which is still a bit less than the 90% of the tensorflow implementation but much better than the varying (10% average) utilization of the pure keras implementation. I don’t have to mention that this increase from 10% to 75 % speeds up the training by factor 7.5.

Conclusion

If you have performance issues, first update all packages, especially keras-preprocessing. Deactivate your virus scanner and check if you have an internal SSD. Try to tweak the configuration on fit_generator (workers and queue_size). If you are using linux try out multiprocessing and a thread-safe generator. If nothing helps convert your dataset to TFrecords and use it with keras or directly move to tensorflow. If you already use tensorflow 2.0, you can directly fit keras models on TFRecord datasets.

If you have any ideas how to make it possible to use the ImageDataGenerator in this scenario or any other idea, please comment!

加快Tensorflow和Keras图像数据集的训练速度相关推荐

  1. 字节最新文本生成图像AI,训练集里居然没有一张带文字描述的图片?!

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 丰色 发自 凹非寺 量子位 | 公众号 QbitAI 一个文本-图像 ...

  2. Keras TensorFlow教程:使用自己的数据集进行训练

    大多数Keras教程都尝试使用图像分类数据集(如MNIST(手写识别)或基本对象CIFAR-10(基本对象识别))来开启Keras库的基础知识学习. 这篇文章将对Keras入门教程进行不同的尝试.使用 ...

  3. python检测吸烟的算法_yolov3+tensorflow+keras实现吸烟的训练全流程及识别检测

    yolov3+tensorflow+keras实现吸烟的训练全流程及识别检测 弈休丶 2019-12-30 23:29:54 1591 收藏 19 分类专栏: 基于yolov3+tensorflow+ ...

  4. keras系列︱图像多分类训练与利用bottleneck features进行微调(三)

    引自:http://blog.csdn.net/sinat_26917383/article/details/72861152 中文文档:http://keras-cn.readthedocs.io/ ...

  5. [深度学习-实践]GAN入门例子-利用Tensorflow Keras与数据集CIFAR10生成新图片

    系列文章目录 深度学习GAN(一)之简单介绍 深度学习GAN(二)之基于CIFAR10数据集的例子; 深度学习GAN(三)之基于手写体Mnist数据集的例子; 深度学习GAN(四)之PIX2PIX G ...

  6. tensorflow fashion_mnist数据集模型训练及预测

    ✨ 博客主页:小小马车夫的主页 ✨ 所属专栏:Tensorflow 文章目录 前言 一.环境 二.fashion_mnist数据集介绍 三.fashion_mnist数据集下载和展示 四.数据预处理 ...

  7. [Tensorflow]服装图像数据集分类:使用DNN、CNN模型

    一.实验介绍 实验环境:jupyter notebook.Tensorflow.keras 数据集Fashion Mnist与样例代码及相关参考: https://www.tensorflow.org ...

  8. TensorFlow——基于Keras子类API的fashion-mnist数据集图像分类

    https://tensorflow.google.cn/tutorials/keras/classification 解决方案 #!usr/bin/env python # -*- coding:u ...

  9. 【KERAS/直方图均衡化】图像数据集扩充

    原网址: https://blog.csdn.net/sinat_36458870/article/details/78903092 一.我遇到了啥子问题撒~? 我现在写的文章都是因为遇到问题了,然后 ...

最新文章

  1. Spring MVC 中急速集成 Shiro 实践
  2. Android Studio 打包、生成jks密钥、签名Apk、多渠道打包
  3. MYSQL.版本查看-LINUX
  4. class com.sun.jersey.core.impl.provider.entity.XMLJAXBElementProvider$Text
  5. 查看端口被占用的进程号然后结束进程(解决端口被进程占用的问题)
  6. python 三维绘图库_Python第三方库matplotlib(2D绘图库)入门与进阶
  7. 熊逸《唐诗50讲》田园篇 - 学习笔记与感想
  8. python中使用Django对url路径进行处理
  9. 限制UI只能在屏幕内移动(放大或缩小屏幕同样适用)
  10. SQL脚本修改数据库名称
  11. 解决setInterval计时器不准的问题
  12. 开课吧:MySQL索引的使用知识有哪些?
  13. linux中bzero函数,库函数
  14. 计算机控制系统感受,计算机控制系统实验报告一.doc
  15. java中操作docker
  16. 听说你有10年的工作经验?还是你把1个经验反复用了10年?(文末赠书)
  17. 初识Ionic 和APPframework
  18. 【推荐】有哪些书是值得我认真去读的?欢迎留言推荐
  19. C语言换币问题:将一块钱,换成50个硬币,其中硬币的种类有1分、两分、五分。输出这50个硬币分配情况。
  20. Thinkpad E440 windows 10 重装系统步骤记录

热门文章

  1. 魅族手机使用应用沙盒一键修改位置数据
  2. java递归获取文件名_递归打印文件名
  3. Android项目实战(四):ViewPager切换动画(3.0版本以上有效果)
  4. python深浅拷贝
  5. 初识CISCO_DHCP Server
  6. 搭建python selenium 自动化测试框架_Python3+Selenium2完整的自动化测试框架实现(二):IE和Chrome浏览器驱动配置...
  7. 用三个线程实现生产者消费者模型,其中一个线程作为生产者,二个线程作为消费者,生产者随机生产一个时间戳或者字符串,消费者消费这个时间戳,并不能重复消费,并将其打印出来
  8. NLP NER HMM CRF讲的较好的知乎
  9. MinMaxScaler.fit 归一化数据的方法
  10. HMM与条件随机场区别 转