模型的保存和加载可以直接通过Model类的save_weightsload_weights实现。默认的保存格式为tensorflow的checkpoint格式,也可以手动设置保存为h5文件。

model.save_weights('weights', save_format='h5')

加载时默认为根据网络的拓扑结构进行加载,这适用于不对网络进行更改,直接进行测试的情况。但如果只希望加载部分权重,可以更改为根据变量名进行加载。

model.load_weights('weights', by_name=True)

有个很坑的点是,加载checkpoint格式时源码中似乎没有实现by_name,所以尽管设置了by_name=True,他仍然会按照拓扑结构加载,然后报错提示部分变量不匹配,所以还是尽量都存成h5文件。

有了保存和读取模型的方法后,就可以在大型数据库上先进行预训练,然后将权重迁移到小数据库或其他任务上。举例i来说,实际应用场景如VGG16先在ImageNet上进行训练,再将除最后一层全连接以外的参数迁移到SSD完成目标检测任务。

实现时就涉及到两个问题:部分网络层不同(不同的分类任务最后一个全连接层的输出维度不同)和调用网络时的输出不同(目标检测任务需要提取网络的中间层的特征图输出),我们可以通过继承Model类来解决上述问题。

继承Model类需要实现两个函数,__init__()call(),下面以ResNet为例。

class ResNet(models.Model):def __init__(self, layer_num, **kwargs):super(ResNet, self).__init__(**kwargs)if block_type[layer_num] == 'basic block':self.block = BasicBlockelse:self.block = BottleneckBlockself.conv0 = Conv2D(64, (7, 7), strides=(2, 2), name='conv0', padding='same', use_bias=False)self.block_collector = []for layer_index, (b, f) in enumerate(zip(block_num[layer_num], filter_num), start=1):if layer_index == 1:if block_type[layer_num] == 'basic block':self.block_collector.append(self.block(f, name='conv1_0'))else:self.block_collector.append(self.block(f, projection=True, name='conv1_0'))else:self.block_collector.append(self.block(f, strides=(2, 2), name='conv{}_0'.format(layer_index)))for block_index in range(1, b):self.block_collector.append(self.block(f, name='conv{}_{}'.format(layer_index, block_index)))self.bn = BatchNormalization(name='bn', momentum=0.9, epsilon=1e-5)self.global_average_pooling = GlobalAvgPool2D()self.fc = Dense(1000, name='fully_connected', activation='softmax', use_bias=False)def call(self, inputs, training):net = self.conv0(inputs)print('input', inputs.shape)print('conv0', net.shape)net = tf.nn.max_pool2d(net, ksize=(3, 3), strides=(2, 2), padding='SAME')print('max-pooling', net.shape)for block in self.block_collector:net = block(net, training)print(block.name, net.shape)net = self.bn(net, training)net = tf.nn.relu(net)net = self.global_average_pooling(net)print('global average-pooling', net.shape)net = self.fc(net)print('fully connected', net.shape)return net

__init__中实例化网络所需的各个层,call中定义网络的运算。在迁移到其他任务时,改写call即可。

实例化各个层时尽量自定义name,因为变量名是网络层的名字和变量名本身共同决定的,举例来说最后的全连接层中权重名为fully_connected/kernel:0,自定义各个层的名称能保证采用by_name方式加载模型不会出现问题。在只迁移部分权重时只需要设定model.load_weights()中的by_name参数即可。

ResNet的完整代码可以在我的github找到

https://github.com/Apm5/tensorflow_2.0_tutorial/blob/master/CNN/ResNet.py​github.com

pb 保存变量文件名_【Tensorflow 2.0 正式版教程】模型的保存、加载与迁移相关推荐

  1. 【Tensorflow 2.0 正式版教程】ImageNet(二)模型加载与训练

    前面的教程都只在小模型.小数据库上进行了演示,这次来真正实战一个大型数据库ImageNet.教程会分为三部分:数据增强.模型加载与训练.模型测试,最终在ResNet50上可以达到77.72%的top- ...

  2. 【Tensorflow 2.0 正式版教程】tf.data.Dataset的基本使用方法

    Tensorflow 2.0中提供了专门用于数据输入的接口tf.data.Dataset,可以简洁高效的实现数据的读入.打乱(shuffle).增强(augment)等功能.下面以一个简单的实例讲解该 ...

  3. 余承东:华为6G研发还需10年;库克“iPhone 11势头强劲”;TensorFlow 2.0正式版发布​ | 极客头条...

    快来收听极客头条音频版吧,智能播报由标贝科技提供技术支持. 「CSDN 极客头条」,是从 CSDN 网站延伸至官方微信公众号的特别栏目,专注于一天业界事报道.风里雨里,我们将每天为朋友们,播报最新鲜有 ...

  4. 10月2日科技资讯|余承东:华为6G研发需10年;库克“iPhone 11势头强劲”;TensorFlow 2.0正式版发布

    「CSDN 极客头条」,是从 CSDN 网站延伸至官方微信公众号的特别栏目,专注于一天业界事报道.风里雨里,我们将每天为朋友们,播报最新鲜有料的新闻资讯,让所有技术人,时刻紧跟业界潮流. 整理 | 郭 ...

  5. keras 分布式_TensorFlow 2.0正式版官宣!深度集成Keras

      新智元报道   来源:medium.GitHub 编辑:小芹.大明 [新智元导读]TensorFlow 2.0正式版终于发布了!深度集成Keras,更简单.更易用,GPU训练性能提升.这是一个革命 ...

  6. cupy 安装_资源 | 神经网络框架Chainer发布2.0正式版:CuPy独立

    原标题:资源 | 神经网络框架Chainer发布2.0正式版:CuPy独立 选自GitHub 机器之心编译 参与:李泽南.吴攀 Chainer 是一个灵活的神经网络框架,它的一个主要目标就是展现灵活性 ...

  7. TensorFlow 2.0.0-RC0版发布,专注于简单性与易用性

    TensorFlow 2.0 RC0 发布了,2.0 专注于简单性和易用性,主要特性包括: 通过 Keras 和热切执行轻松建模. 在任何平台生产中进行稳健的模型部署. 强大的研究实验. 通过减少重复 ...

  8. TensorFlow 1.9.0正式版来了!新手指南全新改版,支持梯度提升树估计器

    李林 编译整理 量子位 出品 | 公众号 QbitAI TensorFlow 1.9.0正式版来了! 谷歌大脑研究员.Keras作者François Chollet对于这一版本评价甚高,他说:&quo ...

  9. hashmap 不释放空间_刁难问题,为什么HashMap默认容量为16加载因子为0.75

    前言:实际开发中我们大多数都是只能new HashMap<>来存储键值对,很少会去设置初始容量,虽然我们知道他的默认容量是16.但是在面试中,为了体现你个人好学的能力,还是会被经常问到为什 ...

最新文章

  1. 医工汇聚 智竞心电 | 首届中国心电智能大赛开启招募
  2. 事件控制块的清空与状态查询
  3. 2021-10-11 ! LeetCode226. 翻转二叉树 的前中后层序遍历写法
  4. linux命令chmod如果当前用户属于多个组,那这个命令中的g指的是哪个组?按什么规则?
  5. Dapr牵手.NET学习笔记:Actor小试
  6. mac+php版本切换+cli,Mac环境下php版本切换
  7. Handler用法总结
  8. Python+OpenCV:基于KNN手写数据OCR(OCR of Hand-written Data using kNN)
  9. 使用ExMerge工具管理Exchange用户邮箱。
  10. 拓端tecdat|R语言线性判别分析(LDA),二次判别分析(QDA)和正则判别分析(RDA)
  11. My first essay
  12. win10怎么修改计算机用户名和密码,windows10系统如何更改开机密码
  13. 计算机pe教程,win7 winpe安装过程图文教程
  14. java中not equal_Java ObjectUtils.notEqual方法代码示例
  15. 聘用协议_聘用合同谈判
  16. 点集的读入与输出操作
  17. unity3d网络延时检测
  18. Google Earth Engine(GEE)扩展——制作的GEE app的误区
  19. mui12搭载鸿蒙,MUI系统最新资讯
  20. 在《王者荣耀》来聊聊游戏的帧同步

热门文章

  1. 未來用工新趨勢_浅谈2019年灵活用工五大新趋势
  2. python输出时间_Python获取并输出当前日期时间
  3. 同时两个版本php,查看“实现多个PHP版本共存和互相切换”的源代码
  4. ef 执行mysql语句_在EF中执行SQL语句
  5. 判断成绩linux程序编程,程序输入输出 ,编写判断成绩的程序
  6. Eric6最简单的应用(创建一个项目-窗体-编译-运行)
  7. Intellij IDEA 配置
  8. 最短路径 | 1087 三重标尺+记录最短路径条数
  9. sql 中 ALTER 和 UPDATE 的区别
  10. 2015 03 03 复习 上课笔记(一)