estimator 模型保存与使用
1:estimator 是tensorflow的高级封装库,但是tensorflow 分为两个版本,1.X与2.X,本次文章两个版本都会说明,方便大家进行判断
1.0保存与读取
output_dir=’../outer‘
def serving_input_fn():label_ids = tf.placeholder(tf.int32, [None, max_seq_length], name='label_ids')input_ids = tf.placeholder(tf.int32, [None, max_seq_length], name='input_ids')input_mask = tf.placeholder(tf.int32, [None, max_seq_length], name='input_mask')segment_ids = tf.placeholder(tf.int32, [None, max_seq_length], name='segment_ids')input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({'label_ids': label_ids,'input_ids': input_ids,'input_mask': input_mask,'segment_ids': segment_ids,})()return input_fn
estimator.export_savedmodel(output_dir, serving_input_fn)
predict_fn = tf.contrib.predictor.from_saved_model(output_dirs)
# 注意这个地方传入的一般为numpy的格式,具体还要看报错是啥
print(predict_fn({'input_ids': input_ids,'segment_ids': segment_ids,'label_ids': label_ids,'input_mask': input_mask}))
2.0保存与读取
tf.saved_model.load(self.model_path)
model = self.predict_fn.signatures["serving_default"]
ret = model(input_ids=tf.constant(input_feature['input_ids']),
input_mask=tf.constant(input_feature['input_mask']),
label_ids=tf.constant(input_feature['label_ids']),
segment_ids=tf.constant(input_feature['segment_ids']))
2.0的读取方式变换了,没有之前的tf.contrib这个库了,所以方法变为tf.saved_model.load这种,而且要用signatures指名输出参数,这个地方不建议修改,主要是后面的参数,必须要和模型对应,不像之前是字典模式,如果你的输入参数无法进行这样写,建议用**传入
map_dict = {'Input-Token': tf.constant(input_feature['Input-Token'], dtype=tf.float32),'Input-Segment': tf.constant(input_feature['Input-Segment'], dtype=tf.float32)}ret = model(**map_dict)
20220411
一般需要keras版本,这里新增一个版本对应
estimator 模型保存与使用相关推荐
- 机器学习正则化线性模型和模型保存
目录 1 正则化线性模型 1.1 岭回归 1.2 Lasso 回归 1.3 弹性网络 1.4 Early Stopping 1.5 小结 2 线性回归的改进-岭回归 2.1 API 2.2 正则化程度 ...
- TensorFlow Estimator 模型从训练到部署
引言 TensorFlow是目前流行的机器学习框架,用户可以基于TensorFlow方便地构建机器学习模型,并将模型部署到线上提供服务. 最近看Estimator框架比较流行,公司也想看Wide &a ...
- tensor和模型 保存与加载 PyTorch
PyTorch教程-7:PyTorch中保存与加载tensor和模型详解 保存和读取Tensor PyTorch中的tensor可以保存成 .pt 或者 .pth 格式的文件,使用torch.save ...
- Pytorch两种模型保存方式
以字典方式保存,更容易解析和可视化 Pytorch两种模型保存方式 大黑_7e1b关注 2019.02.12 17:49:35字数 13阅读 5,907 只保存模型参数 # 保存 torch.save ...
- Tensorflow |(5)模型保存与恢复、自定义命令行参数
Tensorflow |(1)初识Tensorflow Tensorflow |(2)张量的阶和数据类型及张量操作 Tensorflow |(3)变量的的创建.初始化.保存和加载 Tensorflow ...
- python手动将机器学习模型保存为json文件
python手动将机器学习模型保存为json文件 # 导入需要的包和库: # Import Required packages #-------------------------# Import t ...
- keras/tensorflow 模型保存后重新加载准确率为0 model.save and load giving different result
我在用别人的代码跑程序的时候遇到了这个问题: keras 模型保存后重新加载准确率为0 GitHub上有个issue:model.save and load giving different resu ...
- 模型保存的序列化文件pb 什么是PB文件 pb是protocol(协议) buffer(缓冲)的缩写
pb是protocol(协议) buffer(缓冲)的缩写 TensorFlow 模型保存为pb文件的解释,怎么使用pb文件/模型的Save and Restore_u014264373的博客-CSD ...
- 浅谈pytorch 模型 .pt, .pth, .pkl的区别及模型保存方式 pth中的路径加载使用
首先xxx.pth文件里面会书写一些路径,一行一个. 将xxx.pth文件放在特定位置,则可以让python在加载模块时,读取xxx.pth中指定的路径. Python客栈送红包.纸质书 有时,在用i ...
- tensorflow 的模型保存和调用
我们通常采用tensorflow来训练,训练完之后应当保存模型,即保存模型的记忆(权重和偏置),这样就可以来进行人脸识别或语音识别了. 1.模型的保存 # 声明两个变量 v1 = tf.Variabl ...
最新文章
- 美媒预测:2021年人工智能的四大趋势
- uitableviw 自适应高度
- AutoX“真无人”车队驶上繁忙街头,中国正式跨入无人驾驶时代
- linux中vim常用命令总结
- android另开进程,android在一个app程序中,打开另一个app的方法
- 从用户观点对计算机如何分类,从用户的观点看操作系统是
- 怎么用deveco studio升级鸿蒙,华为鸿蒙DevEco studio2.0的安装和hello world运行教程
- 使用ASP.NET Core、JavaScript和Angular防止CSRF攻击
- 大括弧之战 代码风格
- AC97声卡的驱动安装
- 驾驶机动车在高速公路上倒车、逆行、穿越中央分隔带掉头的一次记6分。
- 2022飞鸟,飞鸟源码,飞鸟新圣源码,仿新圣源码,飞鸟二开,飞鸟采集,飞鸟运营版
- html旋转相册,css3 旋转相册
- python中turtle画小草_python 笔记 之带参数的装饰器
- 查询Linux中CPU的核数
- ECCV 2022全奖项公布,两位华人学者摘得最佳论文奖,本科来自清华、浙大
- 计算机 显卡 淘汰,早该淘汰的VGA模拟接口:新显卡不再支持
- 计算机动画类型及创作原理,计算机动画的原理和制作.ppt
- 自动化软件测试工程师(初面)面试题解析(含答案)
- vue项目引入阿里巴巴矢量图标库 ——字体图标