pb 保存变量文件名_【Tensorflow 2.0 正式版教程】模型的保存、加载与迁移
模型的保存和加载可以直接通过Model
类的save_weights
和load_weights
实现。默认的保存格式为tensorflow的checkpoint格式,也可以手动设置保存为h5文件。
model.save_weights('weights', save_format='h5')
加载时默认为根据网络的拓扑结构进行加载,这适用于不对网络进行更改,直接进行测试的情况。但如果只希望加载部分权重,可以更改为根据变量名进行加载。
model.load_weights('weights', by_name=True)
有个很坑的点是,加载checkpoint格式时源码中似乎没有实现by_name
,所以尽管设置了by_name=True
,他仍然会按照拓扑结构加载,然后报错提示部分变量不匹配,所以还是尽量都存成h5文件。
有了保存和读取模型的方法后,就可以在大型数据库上先进行预训练,然后将权重迁移到小数据库或其他任务上。举例i来说,实际应用场景如VGG16先在ImageNet上进行训练,再将除最后一层全连接以外的参数迁移到SSD完成目标检测任务。
实现时就涉及到两个问题:部分网络层不同(不同的分类任务最后一个全连接层的输出维度不同)和调用网络时的输出不同(目标检测任务需要提取网络的中间层的特征图输出),我们可以通过继承Model
类来解决上述问题。
继承Model
类需要实现两个函数,__init__()
与call()
,下面以ResNet为例。
class ResNet(models.Model):def __init__(self, layer_num, **kwargs):super(ResNet, self).__init__(**kwargs)if block_type[layer_num] == 'basic block':self.block = BasicBlockelse:self.block = BottleneckBlockself.conv0 = Conv2D(64, (7, 7), strides=(2, 2), name='conv0', padding='same', use_bias=False)self.block_collector = []for layer_index, (b, f) in enumerate(zip(block_num[layer_num], filter_num), start=1):if layer_index == 1:if block_type[layer_num] == 'basic block':self.block_collector.append(self.block(f, name='conv1_0'))else:self.block_collector.append(self.block(f, projection=True, name='conv1_0'))else:self.block_collector.append(self.block(f, strides=(2, 2), name='conv{}_0'.format(layer_index)))for block_index in range(1, b):self.block_collector.append(self.block(f, name='conv{}_{}'.format(layer_index, block_index)))self.bn = BatchNormalization(name='bn', momentum=0.9, epsilon=1e-5)self.global_average_pooling = GlobalAvgPool2D()self.fc = Dense(1000, name='fully_connected', activation='softmax', use_bias=False)def call(self, inputs, training):net = self.conv0(inputs)print('input', inputs.shape)print('conv0', net.shape)net = tf.nn.max_pool2d(net, ksize=(3, 3), strides=(2, 2), padding='SAME')print('max-pooling', net.shape)for block in self.block_collector:net = block(net, training)print(block.name, net.shape)net = self.bn(net, training)net = tf.nn.relu(net)net = self.global_average_pooling(net)print('global average-pooling', net.shape)net = self.fc(net)print('fully connected', net.shape)return net
__init__
中实例化网络所需的各个层,call
中定义网络的运算。在迁移到其他任务时,改写call
即可。
实例化各个层时尽量自定义name
,因为变量名是网络层的名字和变量名本身共同决定的,举例来说最后的全连接层中权重名为fully_connected/kernel:0
,自定义各个层的名称能保证采用by_name
方式加载模型不会出现问题。在只迁移部分权重时只需要设定model.load_weights()
中的by_name
参数即可。
ResNet的完整代码可以在我的github找到
https://github.com/Apm5/tensorflow_2.0_tutorial/blob/master/CNN/ResNet.pygithub.com
pb 保存变量文件名_【Tensorflow 2.0 正式版教程】模型的保存、加载与迁移相关推荐
- 【Tensorflow 2.0 正式版教程】ImageNet(二)模型加载与训练
前面的教程都只在小模型.小数据库上进行了演示,这次来真正实战一个大型数据库ImageNet.教程会分为三部分:数据增强.模型加载与训练.模型测试,最终在ResNet50上可以达到77.72%的top- ...
- 【Tensorflow 2.0 正式版教程】tf.data.Dataset的基本使用方法
Tensorflow 2.0中提供了专门用于数据输入的接口tf.data.Dataset,可以简洁高效的实现数据的读入.打乱(shuffle).增强(augment)等功能.下面以一个简单的实例讲解该 ...
- 余承东:华为6G研发还需10年;库克“iPhone 11势头强劲”;TensorFlow 2.0正式版发布 | 极客头条...
快来收听极客头条音频版吧,智能播报由标贝科技提供技术支持. 「CSDN 极客头条」,是从 CSDN 网站延伸至官方微信公众号的特别栏目,专注于一天业界事报道.风里雨里,我们将每天为朋友们,播报最新鲜有 ...
- 10月2日科技资讯|余承东:华为6G研发需10年;库克“iPhone 11势头强劲”;TensorFlow 2.0正式版发布
「CSDN 极客头条」,是从 CSDN 网站延伸至官方微信公众号的特别栏目,专注于一天业界事报道.风里雨里,我们将每天为朋友们,播报最新鲜有料的新闻资讯,让所有技术人,时刻紧跟业界潮流. 整理 | 郭 ...
- keras 分布式_TensorFlow 2.0正式版官宣!深度集成Keras
新智元报道 来源:medium.GitHub 编辑:小芹.大明 [新智元导读]TensorFlow 2.0正式版终于发布了!深度集成Keras,更简单.更易用,GPU训练性能提升.这是一个革命 ...
- cupy 安装_资源 | 神经网络框架Chainer发布2.0正式版:CuPy独立
原标题:资源 | 神经网络框架Chainer发布2.0正式版:CuPy独立 选自GitHub 机器之心编译 参与:李泽南.吴攀 Chainer 是一个灵活的神经网络框架,它的一个主要目标就是展现灵活性 ...
- TensorFlow 2.0.0-RC0版发布,专注于简单性与易用性
TensorFlow 2.0 RC0 发布了,2.0 专注于简单性和易用性,主要特性包括: 通过 Keras 和热切执行轻松建模. 在任何平台生产中进行稳健的模型部署. 强大的研究实验. 通过减少重复 ...
- TensorFlow 1.9.0正式版来了!新手指南全新改版,支持梯度提升树估计器
李林 编译整理 量子位 出品 | 公众号 QbitAI TensorFlow 1.9.0正式版来了! 谷歌大脑研究员.Keras作者François Chollet对于这一版本评价甚高,他说:&quo ...
- hashmap 不释放空间_刁难问题,为什么HashMap默认容量为16加载因子为0.75
前言:实际开发中我们大多数都是只能new HashMap<>来存储键值对,很少会去设置初始容量,虽然我们知道他的默认容量是16.但是在面试中,为了体现你个人好学的能力,还是会被经常问到为什 ...
最新文章
- 医工汇聚 智竞心电 | 首届中国心电智能大赛开启招募
- 事件控制块的清空与状态查询
- 2021-10-11 ! LeetCode226. 翻转二叉树 的前中后层序遍历写法
- linux命令chmod如果当前用户属于多个组,那这个命令中的g指的是哪个组?按什么规则?
- Dapr牵手.NET学习笔记:Actor小试
- mac+php版本切换+cli,Mac环境下php版本切换
- Handler用法总结
- Python+OpenCV:基于KNN手写数据OCR(OCR of Hand-written Data using kNN)
- 使用ExMerge工具管理Exchange用户邮箱。
- 拓端tecdat|R语言线性判别分析(LDA),二次判别分析(QDA)和正则判别分析(RDA)
- My first essay
- win10怎么修改计算机用户名和密码,windows10系统如何更改开机密码
- 计算机pe教程,win7 winpe安装过程图文教程
- java中not equal_Java ObjectUtils.notEqual方法代码示例
- 聘用协议_聘用合同谈判
- 点集的读入与输出操作
- unity3d网络延时检测
- Google Earth Engine(GEE)扩展——制作的GEE app的误区
- mui12搭载鸿蒙,MUI系统最新资讯
- 在《王者荣耀》来聊聊游戏的帧同步
热门文章
- 未來用工新趨勢_浅谈2019年灵活用工五大新趋势
- python输出时间_Python获取并输出当前日期时间
- 同时两个版本php,查看“实现多个PHP版本共存和互相切换”的源代码
- ef 执行mysql语句_在EF中执行SQL语句
- 判断成绩linux程序编程,程序输入输出 ,编写判断成绩的程序
- Eric6最简单的应用(创建一个项目-窗体-编译-运行)
- Intellij IDEA 配置
- 最短路径 | 1087 三重标尺+记录最短路径条数
- sql 中 ALTER 和 UPDATE 的区别
- 2015 03 03 复习 上课笔记(一)