前面的教程都只在小模型、小数据库上进行了演示,这次来真正实战一个大型数据库ImageNet。教程会分为三部分:数据增强、模型加载与训练、模型测试,最终在ResNet50上可以达到77.72%的top-1准确率,复现出了ResNet原文的结果。

完整的代码可以在我的github上找到。https://github.com/Apm5/ImageNet_Tensorflow2.0

提供ResNet-18和ResNet-50的预训练模型,以供大家做迁移使用。
链接:https://pan.baidu.com/s/1nwvkt3Ei5Hp5Pis35cBSmA
提取码:y4wo

还提供百度云链接的ImageNet原始数据,但是这份资源只能创建临时链接以供下载,有需要的还请私信联系。下面开始正文。

模型加载与训练

初始化

github项目中提供了tensorflow 2.0版本实现的ResNet,包括各种层数18、34、50、101和152,以及ResNet后续改进的v2版本以供直接调用。

from model.ResNet import ResNet
model = ResNet(50)

或者也可以使用官方实现的经典模型,具体参考keras applications

from tensorflow.keras.applications.resnet50 import ResNet50
model = ResNet50(weights=None)

训练过程中,可以通过model.save_weights()保存权重,也可以在中断训练时通过model.load_weights()加载权重继续训练。

数据迭代器

采用tf.data.Dataset()并行加载图像并进行数据增强

def train_iterator(list_path=c.train_list_path):images, labels = load_list(list_path, c.train_data_path)dataset = tf.data.Dataset.from_tensor_slices((images, labels))dataset = dataset.shuffle(len(images))dataset = dataset.repeat()dataset = dataset.map(lambda x, y: tf.py_function(load_image, inp=[x, y, True, False], Tout=[tf.float32, tf.float32]),num_parallel_calls=tf.data.experimental.AUTOTUNE)dataset = dataset.batch(c.batch_size)it = dataset.__iter__()return it

调用该函数得到迭代器后,可以实现GPU进行图计算时,CPU并行加载并处理图像。

images, labels = data_iterator.next()
ce, prediction = train_step(model, images, labels, optimizer)

模型训练

其中train_step完成前向计算、梯度反向传播和参数更新。

@tf.function
def train_step(model, images, labels, optimizer):with tf.GradientTape() as tape:prediction = model(images, training=True)ce = cross_entropy_batch(labels, prediction, label_smoothing=c.label_smoothing)l2 = l2_loss(model)loss = ce + l2gradients = tape.gradient(loss, model.trainable_variables)optimizer.apply_gradients(zip(gradients, model.trainable_variables))return ce, prediction

可以通过@tf.function控制代码进行静态图或动态图模式计算,注释掉@tf.function修饰可以使tensorflow进入动态图模式,可以直接在网络中print中间层结果进行调试。开启修饰后进行静态图计算,可以极大的提升网络的计算速度。

损失函数

模型的损失这里包括两部分:交叉熵和参数正则化。
交叉熵需要计算一个batch内的多组数据的平均值,label_smoothing可以微弱的增强模型泛化性。

def cross_entropy_batch(y_true, y_pred, label_smoothing=0.0):cross_entropy = tf.keras.losses.categorical_crossentropy(y_true, y_pred, label_smoothing=label_smoothing)cross_entropy = tf.reduce_mean(cross_entropy)return cross_entropy

参数正则化为

def l2_loss(model, weights=c.weight_decay):variable_list = []for v in model.trainable_variables:if 'kernel' in v.name:variable_list.append(tf.nn.l2_loss(v))return tf.add_n(variable_list) * weights

这里只统计卷积核,对于bn层等其他参数不作约束。

优化器与变学习率

优化器选用sgd,
optimizer = optimizers.SGD(learning_rate=learning_rate_schedules, momentum=0.9, nesterov=True),学习率变化分为warm up阶段和余弦下降阶段,warm up指网络训练初期学习率从0线性增长到最大学习率,余弦下降是让学习率大致遵循:前期维持大学习率,中期学习率线性下降,后期维持小学习率。

learning_rate_schedules = optimizers.schedules.PolynomialDecay(initial_learning_rate=c.minimum_learning_rate,decay_steps=c.warm_iterations,end_learning_rate=c.initial_learning_rate)learning_rate_schedules = tf.keras.experimental.CosineDecay(initial_learning_rate=c.initial_learning_rate,decay_steps=c.epoch_num * c.iterations_per_epoch,alpha=c.minimum_learning_rate)

tf.keras.experimental中还有其他一些官方实现的变学习率策略,可自行了解。

代码使用

在github的项目中,直接执行train文件即可。

python train.py

训练的各种配置均在config.py中设置,如学习率、训练轮次、数据位置、增强策略等。
我的硬件配置是CPU i7 6850K @ 3.6GHz,显卡TITAN Xp 12G,对于默认的配置训练ResNet50可以达到大约每秒2个batch的速度,训练50轮ImageNet的大约需要3天的时间。

【Tensorflow 2.0 正式版教程】ImageNet(二)模型加载与训练相关推荐

  1. pb 保存变量文件名_【Tensorflow 2.0 正式版教程】模型的保存、加载与迁移

    模型的保存和加载可以直接通过Model类的save_weights和load_weights实现.默认的保存格式为tensorflow的checkpoint格式,也可以手动设置保存为h5文件. mo ...

  2. 【Tensorflow 2.0 正式版教程】tf.data.Dataset的基本使用方法

    Tensorflow 2.0中提供了专门用于数据输入的接口tf.data.Dataset,可以简洁高效的实现数据的读入.打乱(shuffle).增强(augment)等功能.下面以一个简单的实例讲解该 ...

  3. 余承东:华为6G研发还需10年;库克“iPhone 11势头强劲”;TensorFlow 2.0正式版发布​ | 极客头条...

    快来收听极客头条音频版吧,智能播报由标贝科技提供技术支持. 「CSDN 极客头条」,是从 CSDN 网站延伸至官方微信公众号的特别栏目,专注于一天业界事报道.风里雨里,我们将每天为朋友们,播报最新鲜有 ...

  4. 10月2日科技资讯|余承东:华为6G研发需10年;库克“iPhone 11势头强劲”;TensorFlow 2.0正式版发布

    「CSDN 极客头条」,是从 CSDN 网站延伸至官方微信公众号的特别栏目,专注于一天业界事报道.风里雨里,我们将每天为朋友们,播报最新鲜有料的新闻资讯,让所有技术人,时刻紧跟业界潮流. 整理 | 郭 ...

  5. keras 分布式_TensorFlow 2.0正式版官宣!深度集成Keras

      新智元报道   来源:medium.GitHub 编辑:小芹.大明 [新智元导读]TensorFlow 2.0正式版终于发布了!深度集成Keras,更简单.更易用,GPU训练性能提升.这是一个革命 ...

  6. 自由天空XP/2K3封装工具 Easy Sysprep v2.0 正式版封装教程

    自由天空XP/2K3封装工具 Easy Sysprep v2.0 正式版封装教程     制作万能Ghost系统光盘必须对操作系统进行重新封装,在<[原创]跟我一起打造自己的GHOST系统安装光 ...

  7. 『TensorFlow2.0正式版』TF2.0+Keras速成教程·零:开篇简介与环境准备

    此篇教程参考自TensorFlow 2.0 + Keras Crash Course,在原文的基础上进行了适当的总结与改编,以适应于国内开发者的理解与使用,水平有限,如果写的不对的地方欢迎大家评论指出 ...

  8. TensorFlow 1.9.0正式版来了!新手指南全新改版,支持梯度提升树估计器

    李林 编译整理 量子位 出品 | 公众号 QbitAI TensorFlow 1.9.0正式版来了! 谷歌大脑研究员.Keras作者François Chollet对于这一版本评价甚高,他说:&quo ...

  9. myeclipse 9.0 正式版破解激活完整图文教程

    MyEclipse 9.0的激活机制终于破解了,破解步骤比老版本要复杂一些,但是是绝对可以破解的,步骤如下: 1.破解公钥,确保MyEclipse没有开启,否则失败! 用WinRAR或7-zip打开C ...

最新文章

  1. Hibernate 缓存机制
  2. Linux: TLB 查询流程
  3. 1.3 @Deprecated注解
  4. C#中动态加载卸载类库
  5. python生成四位随机数
  6. 数据接口的登录态校验以及JWT
  7. 在线编写php文件,php单文件版在线代码编辑器_php实例
  8. 调试实战 —— dll 加载失败之全局变量初始化篇
  9. nssl1257-A【数论】
  10. android studio 2.3 instant run,android studio 2.3 instant run not working
  11. Halcon 测量直线和圆环的线宽
  12. python实践答辩ppt_看完这篇Python操作PPT总结,从此使用Python玩转Office全家桶没压力!...
  13. html img标签的alt属性和title属性(说明)
  14. 文档扫描(扫描全能王)
  15. 三进制计算机_“九章”量子计算机这么猛,到底能做啥?只为了一条公式的结果吗...
  16. 蛋壳梦破:CEO被限制消费,资金链碎了一地
  17. 中泰证券何波:XTP为量化而生!
  18. 盘点大数据商业智能的十大戒律
  19. Linux 开发学习
  20. linux下的ktime_t timeval timespec

热门文章

  1. 创办 Apple 之后,沃兹尼亚克在做什么?
  2. 智能抽屉式有源电力滤波器品牌
  3. iphone12免息分期买合适吗
  4. 服务器主机装win7系统安装,服务器主机装win7系统
  5. freeswitch通过lua脚本实现多方会话功能,包括会议录音自动外呼等
  6. 按键精灵X学习笔记(一):熟悉软件和基本设置
  7. CMP指令(cmp指令的功能)
  8. 支付宝小程序灰度测试、版本回滚能力新上线
  9. JixiPix Pop Dot Comics for Mac(漫画制作软件)
  10. 超全建筑成套3d模型素材网站整理