1,问题描述

离线用torch训练了一个简单的双塔模型,存到了hdfs上,希望可以在spark离线任务中使用。但spark离线任务要载入这个模型,就无法像pytorch官方的模型载入方式(这里:torch.load())一样通过本地路径加载。那么如何方便且优雅地载入存在hdfs上的已训练好的model呢?

2,问题定义与解析

pytorch训练好的模型,通过torch.save()可存为二进制文件(内部引用了pickle模块,具体详见pytorch的docs),所以从本质上而言,这是个如何将二进制模型文件通过torch.load()方法载入的问题。torch.load()接收的是一个本地二进制文件路径,或是直接的一个二进制文件。所以,其实我们将hdfs上的文件,以二进制字符串方式拉到spark的driver上,再将其转为二进制文件,即可解决。

3,问题解决

bb less,show me the code

import iotower = DSSM()  # 新建空模型
model_path = 'your_hdfs_path'  # hdfs path
model_bytes = sc.binaryFiles(model_path).collect()[0][1] # spark读取二进制文件from hdfs,读取结果为二进制字符串
model_file = io.BytesIO(model_bytes)  # io.BytesIO转成内存中的二进制文件
tower.load_state_dict(torch.load(model_file))  # torch.load:模型`tower`载入参数
all_tower_params = tower.state_dict()          # 可以打印出参数看看有没有赋值成功

4,附录

在具体了解中,发现BytesIO、StringIO这类方法就是为处理字符串、使得和读取这些字符串和读写文件具有一致的接口而存在的。详情请见廖雪峰老师的文章:StringIO和BytesIO - 廖雪峰的官方网站

pytorch从hdfs载入模型、从二进制字符串载入模型相关推荐

  1. pytorch多卡并行模型的保存与载入

    pytorch多卡并行模型的保存与载入 当模型是在数据并行方式在多卡上进行训练的训练和保存,那么载入的时候也是一样需要是多卡.并且,load_state_dict()函数的调用要放在DataParal ...

  2. 【C 语言】字符串模型 ( 字符串翻转模型 | 借助 递归函数操作 逆序字符串操作 | 引入线程安全概念 )

    文章目录 一.引入线程安全概念 二.完整代码示例 一.引入线程安全概念 在上一篇博客 [C 语言]字符串模型 ( 字符串翻转模型 | 借助 递归函数操作 逆序字符串操作 | strncat 函数 ) ...

  3. 【C 语言】字符串模型 ( 字符串翻转模型 | 借助 递归函数操作 逆序字符串操作 | strncat 函数 )

    文章目录 一.strncat 字符串连接函数 二.借助 递归函数操作 逆序字符串操作 三.完整代码示例 一.strncat 字符串连接函数 strncat 函数 : 将 const char *src ...

  4. 【C 语言】字符串模型 ( 字符串翻转模型 | 抽象成业务函数 | 形参返回值 | 函数返回值 | 函数形参处理 | 形参指针判空 )

    文章目录 一.字符串翻转模型 业务函数 二.完整代码示例 一.字符串翻转模型 业务函数 将上一篇博客 [C 语言]字符串模型 ( 字符串翻转模型 ) 的代码 , 主要业务逻辑 , 抽象成函数 ; 字符 ...

  5. 【C 语言】字符串模型 ( 字符串翻转模型 )

    文章目录 一.字符串翻转模型 二.完整代码示例 一.字符串翻转模型 业务场景 : 给定下面的字符串 , 将下面的字符串翻转 ; // 将下面的字符串翻转char str[] = "sdfsd ...

  6. PyTorch 保存模型结构参数及加载模型

    PyTorch 保存模型结构参数及加载模型 保存模型与加载 保存模型分为两种方式: 保存整个网络结构和参数 保存整个网络的参数 # 1.保存并加载整个网络结构和参数 # 保存模型 torch.save ...

  7. Leetcode1702. 修改后的最大二进制字符串[C++题解]:思维题

    文章目录 题目分析 题目链接 题目分析 只有2种操作:00变成10,10变成01,为了使得结果最大,需要前面尽可能变成1.怎么才能使得某一个的0变成1呢? 经过观察,只要这个0后面的某一位还有0,该位 ...

  8. TF:利用TF的train.Saver将训练好的W、b模型文件保存+新建载入刚训练好模型(用于以后预测新的数据)

    TF:利用TF的train.Saver将训练好的W.b模型文件保存+新建载入刚训练好模型(用于以后预测新的数据) 目录 输出结果 代码设计 输出结果 代码设计 import tensorflow as ...

  9. PyTorch Hub发布!一行代码调用最潮模型,图灵奖得主强推

    文章来源:量子位 原文地址:https://mp.weixin.qq.com/s/lS3YiXzYyY6-XNTFyH_GHg 如有兴趣可以**点击加入极市CV专业微信群**,获取更多高质量干货 为了 ...

最新文章

  1. 聊聊flink Table的groupBy操作
  2. linux系统下源码安装mysql5.6数据库
  3. c语言设备管理系统实训答辩,C语言设计(力学实验设备管理系统)1答辩.doc
  4. redis介绍以及安装
  5. java短视频上传阿里云流程_短视频上传
  6. Delphi 打印杨辉三角
  7. 形位公差符号大全_干货!AutoCAD快捷键大全与功能精解
  8. ajax跨域问题解决(spring boot)
  9. 红蜘蛛多媒体网络教室v7.2版一款网络教学的软件_我是亲民_新浪博客
  10. SVM——支持向量回归(SVR)
  11. Tibco Designer -- 循环遍历
  12. 设置透明主题引起动画失效以及打开其他应用闪现桌面图标的问题
  13. 在树莓派CM4+Ubuntu上使用DSI接口显示屏
  14. mysql里如何写日期格式_mysql 日期格式
  15. mysql主从同步报错Fatal error: The slave I/O thread stops because master and slave have equal MySQL server
  16. 教学优化算法的简单介绍
  17. java jfreechart 折线图_java程序使用JfreeChart画折线图
  18. Hexo博客使用aplayer音乐播放插件
  19. Android 增强版百分比布局库 为了适配而扩展
  20. PayPal设置收款习惯设定

热门文章

  1. CMD查看当前文件路径下的所有文件名
  2. 【Python】Python之end()关键字使用
  3. 解决虚拟机不能上网ifconfig只显示127.0.0.1的问题
  4. ide编辑器 android,从 IDE 到终端 + 文本编辑器
  5. MIUI12广告“可以关”
  6. android 播放短信铃声,Android 播放自定义铃声
  7. micros swarm framework相关
  8. 广州大学2020操作系统实验二:银行家算法
  9. mysql慢查询优化_常见mysql的慢查询优化方式
  10. CREO1——CREO 2.0画沉头孔