(TorchScript应用) PyTorch模型转换为Torch脚本的代码出错。

出现原因:想将pytorch训练的.pth文件转成C++能处理的.pt文件。用的TorchScript的方法。具体代码如下:

import argparse
import cv2
import torchvision

import numpy as np
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
from unet import NestedUNet
from unet import UNet
from utils.dataset import BasicDataset
from config import UNetConfig

cfg = UNetConfig()
device = torch.device('cpu')
model = eval(cfg.model)(cfg)
path = 'data/checkpoints/epoch_9.pth'#自己模型训练的结果
model.load_state_dict(torch.load(path, map_location=device))
model.to(device=device)
print(model)img_path = 'data/00000.jpg' #自己模型准备用的数据
img = cv2.imread(img_path)
imgXX = Image.fromarray(cv2.cvtColor(img,cv2.COLOR_BGR2RGB))
img1 = torch.from_numpy(BasicDataset.preprocess(imgXX, cfg.scale))
img1 = img1.unsqueeze(0)
img1 = img1.to(device=device, dtype=torch.float32)
traced_script_module = torch.jit.trace(model,img1)#报错的代码位置

traced_script_module.save("torch_script_eval.pt")

报错的内容为:Only tensors or tuples of tensors can be output from traced functions

返回的是字典,是不支持的。想办法把返回的字典变成 tensors,就可以了。

怀疑model这个模型的返回值非法了。找到model模型定义的返回值位置

def forward(self, input):
        x0_0 = self.conv0_0(input)
        x1_0 = self.conv1_0(self.pool(x0_0))
        x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1))

x2_0 = self.conv2_0(self.pool(x1_0))
        x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1))
        x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1))

x3_0 = self.conv3_0(self.pool(x2_0))
        x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1))
        x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1))
        x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1))

x4_0 = self.conv4_0(self.pool(x3_0))
        x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
        x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], 1))
        x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1))
        x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1))

if self.deepsupervision:
            output1 = self.final1(x0_1)
            output2 = self.final2(x0_2)
            output3 = self.final3(x0_3)
            output4 = self.final4(x0_4)
            return [output1, output2, output3, output4)] #按程序设定 走的是这个分支。 显然[ ]不符合。

else:
            output = self.final(x0_4)
            return output

将其改为return (output1, output2, output3, output4)) #报错解决了。

总结:raced_script_module = torch.jit.trace(model,img1)这个位置报这类错,就怀疑返回值问题,然后一步步查。

上面这个改可能不符合代码原意了。再看看

Only tensors or tuples of tensors can be output from traced functions错误解决相关推荐

  1. Only tensors or tuples of tensors can be output from traced functions

    Only tensors or tuples of tensors can be output from traced functions 错误代码: heads = {'hm': 5, 'wh': ...

  2. Output tensors to a Model must be the output of a TensorFlow `Layer`

    Output tensors to a Model must be the output of a TensorFlow Layer 使用tensorflow.keras构造网络的时候出现如下错误: ...

  3. AssertionError: Gather function not implemented for CPU tensors 错误解决

    AssertionError: Gather function not implemented for CPU tensors 错误解决 在pytorch训练深度学习模型时,有时候会报关于cpu gp ...

  4. tf.keras遇见的坑:Output tensors to a Model must be the output of a TensorFlow `Layer`

    经过网上查找,找到了问题所在:在使用keras编程模式是,中间插入了tf.reshape()方法便遇到此问题. 解决办法:对于遇到相同问题的任何人,可以使用keras的Lambda层来包装张量流操作, ...

  5. MobileNetV3基于NNI剪枝操作

    NNI剪枝入门可参考:nni模型剪枝_benben044的博客-CSDN博客_nni 模型剪枝 1.背景 本文的剪枝操作针对CenterNet算法的BackBone,即MobileNetV3算法. 该 ...

  6. 报错解决:alueError: When using data tensors as input to a model, you should specify the `steps_per_epoch

    报错解决:valueError: When using data tensors as input to a model, you should specify the steps_per_epoch ...

  7. copy.deepcopy(train_model)时报错:Only Tensors created explicitly by the user support the deepcopy

    错误信息: RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy ...

  8. Tensorflow一些常用基本概念与函数

    参考文献 Tensorflow一些常用基本概念与函数 http://www.cnblogs.com/wuzhitj/archive/2017/03.html Tensorflow笔记:常用函数说明: ...

  9. Tensorflow常用函数汇总

    转载自:http://blog.csdn.net/lenbow/article/details/52152766 1.tensorflow的基本运作 为了快速的熟悉TensorFlow编程,下面从一段 ...

  10. Tensorflow操作与函数全面解析

    转载自:http://blog.csdn.net/lenbow/article/details/52152766 1.tensorflow的基本运作 为了快速的熟悉TensorFlow编程,下面从一段 ...

最新文章

  1. 介绍一款facebook信息收集工具FBI
  2. SQLServer Always On FCI 脑裂及可疑状态修复
  3. SpringBoot整合Mybatis,使用通用mapper和PageHelper进行分页
  4. python PIL 生成照片墙
  5. mysql 代理 a_Keepalived+Mysql+Haproxy
  6. 问题 1072: 汽水瓶
  7. 百度积极回应阿波龙项目不实报道;半数开发者认为学习新语言很困难;腾讯在长沙建立首个智慧产业总部……...
  8. 正会最后一日,ACL 2017最佳论文和终身成就奖揭晓 | ACL 2017
  9. 工作完成了,切勿激动,一定要先求证
  10. Android微信分享功能实例+demo
  11. Keil使用命令行附加预定义宏编译
  12. python导入第三方库失败_史上最详细 Python第三方库添加方法 and 错误解决方法
  13. 人工智能资料下载地址分享
  14. 永洪BI在 Linux/Unix 下 jdk 环境如何配置?
  15. 从技术问题变成RPWT
  16. ORA-24761: transaction rolled back
  17. openGauss数据库源码解析系列文章——存储引擎源码解析(四)
  18. 两种方案实现内外网隔离
  19. android+wifi驱动移植,全志R16 android4平台移植wifi资料下载
  20. 监控 - 解析 API 监控那些事儿

热门文章

  1. SSM框架小项目 ACM周总结管理系统 V1.1 开源
  2. 基于centos7.8的K8安装
  3. CLM陆面过程模式实践技术
  4. Android 多渠道包
  5. MDUKEY超级节点配置及指南(简)
  6. 计算机29首流行音乐叫什么,2018结婚用的歌曲排名 50首流行歌曲燃爆婚礼现场...
  7. QT 多显示屏获取屏幕分辨率
  8. python缩进块是什么,Python块缩进
  9. window系统中hosts文件位置与修改
  10. 日常渗透刷洞的一些小工具