【Tensorflow 2.0 正式版教程】ImageNet(二)模型加载与训练
前面的教程都只在小模型、小数据库上进行了演示,这次来真正实战一个大型数据库ImageNet。教程会分为三部分:数据增强、模型加载与训练、模型测试,最终在ResNet50上可以达到77.72%的top-1准确率,复现出了ResNet原文的结果。
完整的代码可以在我的github上找到。https://github.com/Apm5/ImageNet_Tensorflow2.0
提供ResNet-18和ResNet-50的预训练模型,以供大家做迁移使用。
链接:https://pan.baidu.com/s/1nwvkt3Ei5Hp5Pis35cBSmA
提取码:y4wo
还提供百度云链接的ImageNet原始数据,但是这份资源只能创建临时链接以供下载,有需要的还请私信联系。下面开始正文。
模型加载与训练
初始化
github项目中提供了tensorflow 2.0版本实现的ResNet,包括各种层数18、34、50、101和152,以及ResNet后续改进的v2版本以供直接调用。
from model.ResNet import ResNet
model = ResNet(50)
或者也可以使用官方实现的经典模型,具体参考keras applications
from tensorflow.keras.applications.resnet50 import ResNet50
model = ResNet50(weights=None)
训练过程中,可以通过model.save_weights()
保存权重,也可以在中断训练时通过model.load_weights()
加载权重继续训练。
数据迭代器
采用tf.data.Dataset()
并行加载图像并进行数据增强
def train_iterator(list_path=c.train_list_path):images, labels = load_list(list_path, c.train_data_path)dataset = tf.data.Dataset.from_tensor_slices((images, labels))dataset = dataset.shuffle(len(images))dataset = dataset.repeat()dataset = dataset.map(lambda x, y: tf.py_function(load_image, inp=[x, y, True, False], Tout=[tf.float32, tf.float32]),num_parallel_calls=tf.data.experimental.AUTOTUNE)dataset = dataset.batch(c.batch_size)it = dataset.__iter__()return it
调用该函数得到迭代器后,可以实现GPU进行图计算时,CPU并行加载并处理图像。
images, labels = data_iterator.next()
ce, prediction = train_step(model, images, labels, optimizer)
模型训练
其中train_step
完成前向计算、梯度反向传播和参数更新。
@tf.function
def train_step(model, images, labels, optimizer):with tf.GradientTape() as tape:prediction = model(images, training=True)ce = cross_entropy_batch(labels, prediction, label_smoothing=c.label_smoothing)l2 = l2_loss(model)loss = ce + l2gradients = tape.gradient(loss, model.trainable_variables)optimizer.apply_gradients(zip(gradients, model.trainable_variables))return ce, prediction
可以通过@tf.function
控制代码进行静态图或动态图模式计算,注释掉@tf.function
修饰可以使tensorflow进入动态图模式,可以直接在网络中print
中间层结果进行调试。开启修饰后进行静态图计算,可以极大的提升网络的计算速度。
损失函数
模型的损失这里包括两部分:交叉熵和参数正则化。
交叉熵需要计算一个batch内的多组数据的平均值,label_smoothing
可以微弱的增强模型泛化性。
def cross_entropy_batch(y_true, y_pred, label_smoothing=0.0):cross_entropy = tf.keras.losses.categorical_crossentropy(y_true, y_pred, label_smoothing=label_smoothing)cross_entropy = tf.reduce_mean(cross_entropy)return cross_entropy
参数正则化为
def l2_loss(model, weights=c.weight_decay):variable_list = []for v in model.trainable_variables:if 'kernel' in v.name:variable_list.append(tf.nn.l2_loss(v))return tf.add_n(variable_list) * weights
这里只统计卷积核,对于bn层等其他参数不作约束。
优化器与变学习率
优化器选用sgd,
optimizer = optimizers.SGD(learning_rate=learning_rate_schedules, momentum=0.9, nesterov=True)
,学习率变化分为warm up阶段和余弦下降阶段,warm up指网络训练初期学习率从0线性增长到最大学习率,余弦下降是让学习率大致遵循:前期维持大学习率,中期学习率线性下降,后期维持小学习率。
learning_rate_schedules = optimizers.schedules.PolynomialDecay(initial_learning_rate=c.minimum_learning_rate,decay_steps=c.warm_iterations,end_learning_rate=c.initial_learning_rate)learning_rate_schedules = tf.keras.experimental.CosineDecay(initial_learning_rate=c.initial_learning_rate,decay_steps=c.epoch_num * c.iterations_per_epoch,alpha=c.minimum_learning_rate)
在tf.keras.experimental
中还有其他一些官方实现的变学习率策略,可自行了解。
代码使用
在github的项目中,直接执行train文件即可。
python train.py
训练的各种配置均在config.py
中设置,如学习率、训练轮次、数据位置、增强策略等。
我的硬件配置是CPU i7 6850K @ 3.6GHz,显卡TITAN Xp 12G,对于默认的配置训练ResNet50可以达到大约每秒2个batch的速度,训练50轮ImageNet的大约需要3天的时间。
【Tensorflow 2.0 正式版教程】ImageNet(二)模型加载与训练相关推荐
- pb 保存变量文件名_【Tensorflow 2.0 正式版教程】模型的保存、加载与迁移
模型的保存和加载可以直接通过Model类的save_weights和load_weights实现.默认的保存格式为tensorflow的checkpoint格式,也可以手动设置保存为h5文件. mo ...
- 【Tensorflow 2.0 正式版教程】tf.data.Dataset的基本使用方法
Tensorflow 2.0中提供了专门用于数据输入的接口tf.data.Dataset,可以简洁高效的实现数据的读入.打乱(shuffle).增强(augment)等功能.下面以一个简单的实例讲解该 ...
- 余承东:华为6G研发还需10年;库克“iPhone 11势头强劲”;TensorFlow 2.0正式版发布 | 极客头条...
快来收听极客头条音频版吧,智能播报由标贝科技提供技术支持. 「CSDN 极客头条」,是从 CSDN 网站延伸至官方微信公众号的特别栏目,专注于一天业界事报道.风里雨里,我们将每天为朋友们,播报最新鲜有 ...
- 10月2日科技资讯|余承东:华为6G研发需10年;库克“iPhone 11势头强劲”;TensorFlow 2.0正式版发布
「CSDN 极客头条」,是从 CSDN 网站延伸至官方微信公众号的特别栏目,专注于一天业界事报道.风里雨里,我们将每天为朋友们,播报最新鲜有料的新闻资讯,让所有技术人,时刻紧跟业界潮流. 整理 | 郭 ...
- keras 分布式_TensorFlow 2.0正式版官宣!深度集成Keras
新智元报道 来源:medium.GitHub 编辑:小芹.大明 [新智元导读]TensorFlow 2.0正式版终于发布了!深度集成Keras,更简单.更易用,GPU训练性能提升.这是一个革命 ...
- 自由天空XP/2K3封装工具 Easy Sysprep v2.0 正式版封装教程
自由天空XP/2K3封装工具 Easy Sysprep v2.0 正式版封装教程 制作万能Ghost系统光盘必须对操作系统进行重新封装,在<[原创]跟我一起打造自己的GHOST系统安装光 ...
- 『TensorFlow2.0正式版』TF2.0+Keras速成教程·零:开篇简介与环境准备
此篇教程参考自TensorFlow 2.0 + Keras Crash Course,在原文的基础上进行了适当的总结与改编,以适应于国内开发者的理解与使用,水平有限,如果写的不对的地方欢迎大家评论指出 ...
- TensorFlow 1.9.0正式版来了!新手指南全新改版,支持梯度提升树估计器
李林 编译整理 量子位 出品 | 公众号 QbitAI TensorFlow 1.9.0正式版来了! 谷歌大脑研究员.Keras作者François Chollet对于这一版本评价甚高,他说:&quo ...
- myeclipse 9.0 正式版破解激活完整图文教程
MyEclipse 9.0的激活机制终于破解了,破解步骤比老版本要复杂一些,但是是绝对可以破解的,步骤如下: 1.破解公钥,确保MyEclipse没有开启,否则失败! 用WinRAR或7-zip打开C ...
最新文章
- Hibernate 缓存机制
- Linux: TLB 查询流程
- 1.3 @Deprecated注解
- C#中动态加载卸载类库
- python生成四位随机数
- 数据接口的登录态校验以及JWT
- 在线编写php文件,php单文件版在线代码编辑器_php实例
- 调试实战 —— dll 加载失败之全局变量初始化篇
- nssl1257-A【数论】
- android studio 2.3 instant run,android studio 2.3 instant run not working
- Halcon 测量直线和圆环的线宽
- python实践答辩ppt_看完这篇Python操作PPT总结,从此使用Python玩转Office全家桶没压力!...
- html img标签的alt属性和title属性(说明)
- 文档扫描(扫描全能王)
- 三进制计算机_“九章”量子计算机这么猛,到底能做啥?只为了一条公式的结果吗...
- 蛋壳梦破:CEO被限制消费,资金链碎了一地
- 中泰证券何波:XTP为量化而生!
- 盘点大数据商业智能的十大戒律
- Linux 开发学习
- linux下的ktime_t timeval timespec