Only tensors or tuples of tensors can be output from traced functions错误解决
(TorchScript应用) PyTorch模型转换为Torch脚本的代码出错。
出现原因:想将pytorch训练的.pth文件转成C++能处理的.pt文件。用的TorchScript的方法。具体代码如下:
import argparse
import cv2
import torchvision
import numpy as np
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
from unet import NestedUNet
from unet import UNet
from utils.dataset import BasicDataset
from config import UNetConfig
cfg = UNetConfig()
device = torch.device('cpu')
model = eval(cfg.model)(cfg)
path = 'data/checkpoints/epoch_9.pth'#自己模型训练的结果
model.load_state_dict(torch.load(path, map_location=device))
model.to(device=device)
print(model)img_path = 'data/00000.jpg' #自己模型准备用的数据
img = cv2.imread(img_path)
imgXX = Image.fromarray(cv2.cvtColor(img,cv2.COLOR_BGR2RGB))
img1 = torch.from_numpy(BasicDataset.preprocess(imgXX, cfg.scale))
img1 = img1.unsqueeze(0)
img1 = img1.to(device=device, dtype=torch.float32)
traced_script_module = torch.jit.trace(model,img1)#报错的代码位置
traced_script_module.save("torch_script_eval.pt")
报错的内容为:Only tensors or tuples of tensors can be output from traced functions
返回的是字典,是不支持的。想办法把返回的字典变成 tensors,就可以了。
怀疑model这个模型的返回值非法了。找到model模型定义的返回值位置
def forward(self, input):
x0_0 = self.conv0_0(input)
x1_0 = self.conv1_0(self.pool(x0_0))
x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1))
x2_0 = self.conv2_0(self.pool(x1_0))
x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1))
x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1))
x3_0 = self.conv3_0(self.pool(x2_0))
x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1))
x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1))
x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1))
x4_0 = self.conv4_0(self.pool(x3_0))
x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], 1))
x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1))
x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1))
if self.deepsupervision:
output1 = self.final1(x0_1)
output2 = self.final2(x0_2)
output3 = self.final3(x0_3)
output4 = self.final4(x0_4)
return [output1, output2, output3, output4)] #按程序设定 走的是这个分支。 显然[ ]不符合。
else:
output = self.final(x0_4)
return output
将其改为return (output1, output2, output3, output4)) #报错解决了。
总结:raced_script_module = torch.jit.trace(model,img1)这个位置报这类错,就怀疑返回值问题,然后一步步查。
上面这个改可能不符合代码原意了。再看看
Only tensors or tuples of tensors can be output from traced functions错误解决相关推荐
- Only tensors or tuples of tensors can be output from traced functions
Only tensors or tuples of tensors can be output from traced functions 错误代码: heads = {'hm': 5, 'wh': ...
- Output tensors to a Model must be the output of a TensorFlow `Layer`
Output tensors to a Model must be the output of a TensorFlow Layer 使用tensorflow.keras构造网络的时候出现如下错误: ...
- AssertionError: Gather function not implemented for CPU tensors 错误解决
AssertionError: Gather function not implemented for CPU tensors 错误解决 在pytorch训练深度学习模型时,有时候会报关于cpu gp ...
- tf.keras遇见的坑:Output tensors to a Model must be the output of a TensorFlow `Layer`
经过网上查找,找到了问题所在:在使用keras编程模式是,中间插入了tf.reshape()方法便遇到此问题. 解决办法:对于遇到相同问题的任何人,可以使用keras的Lambda层来包装张量流操作, ...
- MobileNetV3基于NNI剪枝操作
NNI剪枝入门可参考:nni模型剪枝_benben044的博客-CSDN博客_nni 模型剪枝 1.背景 本文的剪枝操作针对CenterNet算法的BackBone,即MobileNetV3算法. 该 ...
- 报错解决:alueError: When using data tensors as input to a model, you should specify the `steps_per_epoch
报错解决:valueError: When using data tensors as input to a model, you should specify the steps_per_epoch ...
- copy.deepcopy(train_model)时报错:Only Tensors created explicitly by the user support the deepcopy
错误信息: RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy ...
- Tensorflow一些常用基本概念与函数
参考文献 Tensorflow一些常用基本概念与函数 http://www.cnblogs.com/wuzhitj/archive/2017/03.html Tensorflow笔记:常用函数说明: ...
- Tensorflow常用函数汇总
转载自:http://blog.csdn.net/lenbow/article/details/52152766 1.tensorflow的基本运作 为了快速的熟悉TensorFlow编程,下面从一段 ...
- Tensorflow操作与函数全面解析
转载自:http://blog.csdn.net/lenbow/article/details/52152766 1.tensorflow的基本运作 为了快速的熟悉TensorFlow编程,下面从一段 ...
最新文章
- 介绍一款facebook信息收集工具FBI
- SQLServer Always On FCI 脑裂及可疑状态修复
- SpringBoot整合Mybatis,使用通用mapper和PageHelper进行分页
- python PIL 生成照片墙
- mysql 代理 a_Keepalived+Mysql+Haproxy
- 问题 1072: 汽水瓶
- 百度积极回应阿波龙项目不实报道;半数开发者认为学习新语言很困难;腾讯在长沙建立首个智慧产业总部……...
- 正会最后一日,ACL 2017最佳论文和终身成就奖揭晓 | ACL 2017
- 工作完成了,切勿激动,一定要先求证
- Android微信分享功能实例+demo
- Keil使用命令行附加预定义宏编译
- python导入第三方库失败_史上最详细 Python第三方库添加方法 and 错误解决方法
- 人工智能资料下载地址分享
- 永洪BI在 Linux/Unix 下 jdk 环境如何配置?
- 从技术问题变成RPWT
- ORA-24761: transaction rolled back
- openGauss数据库源码解析系列文章——存储引擎源码解析(四)
- 两种方案实现内外网隔离
- android+wifi驱动移植,全志R16 android4平台移植wifi资料下载
- 监控 - 解析 API 监控那些事儿