关乎symbol和module的一些基本属性

# 查看json每一个op的属性:kernel size、padding、stride等
sym.attr_dict() # 返回一个字典,根据key获取对应op的属性
# 查看网络的输出name
sym.list_outputs()
# 查看网络所有的输入节点name
sym.list_arguments()
# 查看网络所有内部节点
sym.get_internals()
# 获取网络的参数节点name
mod.get_params()[0]
# 获取网络的中间结果 fc7 output
all_layers = sym.get_internals()
sym = all_layers['fc7_output']
mod = mx.mod.Module(symbol=sym, context=ctx, label_names=None) # 然后做一次inference就能获取fc7 output

原文链接:https://blog.csdn.net/wwwhp/article/details/84556909

模型:

参数:

prefix: "mxnet/zwnwet_model",
epoch:0

加载代码:

  sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)print(sym)# print(arg_params)# print(aux_params)# 提取中间某层输出帖子特征层作为输出all_layers = sym.get_internals()print(all_layers)sym = all_layers['fc1_output']# 重建模型model = mx.mod.Module(symbol=sym, label_names=None)model.bind(for_training=False, data_shapes=[('data', (1, 3, 112, 112))])model.set_params(arg_params, aux_params)

加载完毕保存模型


# !/usr/bin/env python
# -*- coding: utf-8 -*-import os
import time
import math
import mxnet as mx
import cv2
import numpy as np
from collections import namedtuplesym, arg_params, aux_params = mx.model.load_checkpoint('../new_model', 0)
# print(sym)# 提取中间某层输出帖子特征层作为输出
all_layers = sym.get_internals()
# print(all_layers)
sym = all_layers['fc1_output']# 重建模型
model = mx.mod.Module(symbol=sym, label_names=None)
model.bind(for_training=False, data_shapes=[('data', (1, 3, 112, 112))])
model.set_params(arg_params, aux_params)model.save_checkpoint("out/aaa",12)

获取指定层的输出
有些时候我们不需要网络的输出,而是只需要网络某个层的输出来通过网络提取图片的特征,这时候我们就需要指定提取层的名称,这里我们通过提取网络最后一层的全连接层为例

def get_specify_mod(model_str,ctx,data_shpae,layer_name):_vec = model_str.split(",")prefix = _vec[0]epoch = int(_vec[1])sym,arg_params,aux_params = mx.model.load_checkpoint(prefix,epoch)#获取神经网络所有的层all_layers = sym.get_internals()#获取输出层sym = all_layers[layer_name+"_output"]mod = mx.mod.Module(symbol=sym,context=ctx)mod.bind(data_shapes=[("data",data_shpae)])mod.set_params(arg_params,aux_params)return moddef predict_specify(model_str,ctx,data_shape,img_path,label_path):label_names = get_label_names(label_path)#通过输出网络层的名称,输出层全连接层的名称为fc1mod = get_specify_mod(model_str,ctx,data_shape,layer_name="fc1")nd_img = preprocess_img(img_path,data_shape,ctx)#将需要预测的图片封装为Batchdata_batch = mx.io.DataBatch(data=(nd_img,))#计算网络的预测值mod.forward(data_batch,is_train=False)#获取网络的输出值output = mod.get_outputs()[0]#对输出值进行softmax处理proba = mx.nd.softmax(output)#获取前top5的值top_proba = proba.topk(k=5)[0].asnumpy()for index in top_proba:probability = proba[0][int(index)].asscalar()*100pred_label_name = label_names[int(index)]print("label name=%s,probability=%f"%(pred_label_name,probability))

mxnet加载预训练相关推荐

  1. 使用torchvision.models.inception_v3(pretrained=True)加载预训练的模型每次都特别慢

    欢迎大家关注笔者,你的关注是我持续更博的最大动力 原创文章,转载告知,盗版必究 使用torchvision.models.inception_v3(pretrained=True)加载预训练的模型每次 ...

  2. PyTorch 加载预训练权重

    前言  使用PyTorch官方提供的权重或者其他第三方提供的权重对相同模型的参数进行初始化,在数据量较少的前提下,可以帮助模型更快地收敛到最优点,达到更好的效果,即迁移学习.  在大部分的迁移学习场景 ...

  3. torch编程-加载预训练权重-模型冻结-解耦-梯度不反传

    1)加载预训练权重 net = torchvision.models.resnet50(pretrained=False) # 构建模型 pretrained_model = torch.load(p ...

  4. 实践:jieba分词和pkuseg分词、去除停用词、加载预训练词向量

    一:jieba分词和pkuseg分词 原代码文件 链接:https://pan.baidu.com/s/1J8kmTFk8lec5ubfwBaSnLg 提取码:e4nv 目录: 1:分词介绍: 目标: ...

  5. paddlepaddle加载预训练词向量

    文章目录 1.一些用到的api文档 2.加载预训练词向量 2.1小数据 2.2核心代码 2.3验证结果 3.可能有用的 tensorflow的加载方法可以看我之前写的: tensorflow加载词向量 ...

  6. 深度学习加载预训练权重好处

    深度学习加载预训练权重好处: 在模型开始训练前,使模型参数得到一个好的初始化,对于后面的训练学习有非常大的帮助.

  7. mxnet.gluon 加载预训练

    import mxnet as mx from mxnet.gluon import nn from mxnet import gluon,nd,autograd,init from mxnet.gl ...

  8. Pytorch加载预训练网络,替换分类层并重新训练

    定义网络时,在网络类的构造函数网络结构定义中添加如下语句: for p in self.parameters():p.requires_grad = False 该语句的功能是固定定义在该语句之前的网 ...

  9. pytorch加载预训练 加载部分参数

    最简单的: state_dict = torch.load(weight_path)    self.load_state_dict(state_dict,strict=False) 加载cpu: m ...

最新文章

  1. 解决微信小程序 picker 模式日期,设置默认当前时间
  2. 使用GDI+实现圆形进度条控件的平滑效果
  3. Linux 命令[3]:cd
  4. Java日志操作总结
  5. 用2468这四个数字c语言,C语言作业及参考答案.doc
  6. Scrapy_CSS选择器
  7. 这家简历大数据公司被“一锅端” 或因私自抓取用户简历:曾获李开复投资
  8. CSAPP Computer System A Programmer Perspective
  9. PopupWindow点击空白区域消失
  10. autoflowchart软件使用步骤_AutoFlowchart(c语言流程图生成器) V 3.5.3 官方版
  11. 23数据错误循环冗余检查/无法读取源文件或磁盘 解决
  12. 大数据平台整体架构设计方案(PPT)
  13. 非递归获取二叉树中叶子结点的个数
  14. java struts2教程_Struts2教程
  15. 鼠标移上去变小手样式
  16. pt和字号的对应关系
  17. java-php-python-ssm学生学籍信息管理系统计算机毕业设计
  18. vim三种工作模式 命令模式、编辑模式、末行(底行)模式
  19. npm --save和--save-dev区别
  20. 链表 - 头节点的意义

热门文章

  1. A definition for the symbol 'symbolName' could not be located
  2. Java接口interface
  3. 硬盘接口的类型介绍和比较
  4. (七)OpenStack---M版---双节点搭建---Dashboard安装和配置
  5. linux 64位 shellcode,Linux Shellcode“你好,世界!”
  6. oracle更新快捷方式的错误,oracle 11g数据库启动错误总结
  7. 分层和分段用什么符号_如何划分段落层次,如何给段落分层
  8. java string s_Java字符串:“String s=新字符串(”愚蠢“);
  9. php soecket服务器搭建_Linux系统编程(32)—— socket编程之TCP服务器与客户端
  10. 2013秋浙大远程教育计算机应用基础-9计算机多媒体技术,2013秋浙大远程教育计算机应用基础-9...