import mxnet as mx
from mxnet import nd
from mxnet.gluon import nnmx.cpu(), mx.gpu(), mx.gpu(0)

查看mxnet网络所有节点

    import jsonwith open('./model-symbol.json', 'r', encoding='utf8') as fp:conf = json.load(fp)# conf = json.loads(symbol.tojson())nodes = conf["nodes"]heads = set(conf["heads"][0])symbols = []for i, node in enumerate(nodes):op = node["op"]if op == "null" and i > 0:continueif op != "null" or i in heads:print(node['name'])

cpu模式下,只能返回一层

gpu(0)模式下,能返回多层结果。

查看权重
在训练过程中,有时候我们为了debug而需要查看中间某一步的权重信息,在mxnet中,我们可以很方便的调用get_params()方法来得到权重信息。

'''
查看权重示例代码
转载时注明地址:http://blog.csdn.net/u010414386?viewmode=contents

'''
查看权重示例代码
转载时注明地址:http://blog.csdn.net/u010414386?viewmode=contents
'''
import mxnet as mx
sym, arg_params, aux_params = mx.model.load_checkpoint('resnet-50',0)#载入模型
mod = mx.mod.Module(symbol=sym,context=mx.gpu()) #创建Module
mod.bind(for_training=False,data_shapes=[('data',(1,3,224,224))]) #绑定,此代码为预测代码,所以training参数设为False
mod.set_params(arg_params,aux_params)
import numpy as np
import cv2
def get_image(filename):img = cv2.imread(filename)img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)img = cv2.resize(img,(224,224))img = np.swapaxes(img,0,2)img = np.swapaxes(img,1,2)img = img[np.newaxis,:]return img
from collections import namedtuple
Batch = namedtuple('Batch',['data'])
img = get_image('val_1000/0.jpg') #获取图片
mod.forward(Batch([mx.nd.array(img)])) #预测结果
################################################
#debug模式下,获取权重信息
keys = mod.get_params()[0].keys() # 列出所有权重名称
conv_w = mod.get_params()[0]['conv0_weight'] #获取想要查看的权重信息,如conv_weight
print conv_w.asnumpy() #查看具体数值
################################################
prob = mod.get_outputs()[0].asnumpy()
y = np.argsort(np.squeeze(prob))[::-1]
print('truth label %d; top-1 predict label %d' % (val_label[0], y[0]))

查看中间输出结果
由于mxnet的网络由symbol组成,而symbol又属于符号式编程,所以我们不能像上面查看权重一样直接查看,我们需要把我们想看的输出结果保存下来。

'''
方法一
查看中间结果代码
转载时注明地址:http://blog.csdn.net/u010414386?viewmode=contents
'''
import mxnet as mx
net = mx.symbol.Variable('data')
fc1 = mx.symbol.FullyConnected(data=net, name='fc1', num_hidden=128)
net = mx.symbol.Activation(data=fc1, name='relu1', act_type="relu")
net = mx.symbol.FullyConnected(data=net, name='fc2', num_hidden=64)
out = mx.symbol.SoftmaxOutput(data=net, name='softmax')
# 通过把两个输出组成一个group来得到自己需要查看的中间层输出结果
group = mx.symbol.Group([fc1, out])
print group.list_outputs()

方法二
有时候我们使用别人的模型,所以无法像方法一一样在定义模型的时候就确定需要查看的中间层输出结果,
这时候我们使用get_internals()方法来查找自己需要查看的中间层
转载时注明地址:http://blog.csdn.net/u010414386?viewmode=contents

这个出来是list,存放的不同的层的结果。

prob = mod.get_outputs()


import mxnet as mx
sym, arg_params, aux_params = mx.model.load_checkpoint('resnet-50',0)#载入模型
########################################################################
args = sym.get_internals().list_outputs() #获得所有中间输出
internals = model.symbol.get_internals()
fc1 = internals['fc1_output']
conv = internals['stage4_unit3_conv1_output']
group = mx.symbol.Group([fc1, sym, conv])  #把需要输出的结果按group方式组合起来,这样就可以得到中间层的输出
#########################################################################
mod = mx.mod.Module(symbol=group,context=mx.gpu()) #创建Module
mod.bind(for_training=False,data_shapes=[('data',(1,3,224,224))]) #绑定,此代码为预测代码,所以training参数设为False
mod.set_params(arg_params,aux_params)
import numpy as np
import cv2
def get_image(filename):img = cv2.imread(filename)img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)img = cv2.resize(img,(224,224))img = np.swapaxes(img,0,2)img = np.swapaxes(img,1,2)img = img[np.newaxis,:]return img
from collections import namedtuple
Batch = namedtuple('Batch',['data'])
img = get_image('val_1000/0.jpg') #获取图片
mod.forward(Batch([mx.nd.array(img)])) #预测结果
prob = mod.get_outputs()[0].asnumpy()
y = np.argsort(np.squeeze(prob))[::-1]
print('truth label %d; top-1 predict label %d' % (val_label[0], y[0]))

原文链接:https://blog.csdn.net/u010414386/article/details/55668880

打印所有层输出:

import mxnet as mxdef get_output_symbol(symbol):"""Parameters----------symbol: SymbolSymbol to be visualized."""import jsonfrom mxnet.symbol.symbol import Symbolif not isinstance(symbol, Symbol):raise TypeError("symbol must be Symbol")conf = json.loads(symbol.tojson())nodes = conf["nodes"]heads = set(conf["heads"][0])symbols = []for i, node in enumerate(nodes):op = node["op"]if op == "null" and i > 0:continueif op != "null" or i in heads:symbols.append(node['name'])return symbolsdef debug_model(model):# prepare data 准备输入数据input_blob=mx.nd.zeros(shape=(1,3,112,112),ctx=mx.cpu())db = mx.io.DataBatch(data=(input_blob,))# get output symbol 找到特征层,获取输出节点symbols = get_output_symbol(model.symbol)symbols = [x for x in symbols if x != 'data']arg_params, aux_params = model.get_params()internals = model.symbol.get_internals()outputs = internals.list_outputs()symbols_output_name = [x + '_output' for x in symbols]symbols_output = [internals[x] for x in symbols_output_name]# 重建符号与模型group = mx.symbol.Group(symbols_output)mod = mx.mod.Module(symbol=group, context=mx.cpu())mod.bind(data_shapes=[('data', (1, 3, 112, 112))])  # 绑定输入shapemod.set_params(arg_params, aux_params)mod.forward(db, is_train=False)output = mod.get_outputs()output_dict = {k: v.asnumpy() for k, v in zip(symbols, output)}# 保存结果import osfrom collections import Iterableif not os.path.exists('output'):os.mkdir('output')for k, v in output_dict.items():with open('output/{}.txt'.format(k), 'w') as f:print('Shape is {}, data type is {}'.format(v.shape, v.dtype), file=f)for i, batch in enumerate(v):print('Batch {}:'.format(i), file=f)for j, channel in enumerate(batch):print('{}Channel {}:'.format(' ' * 4, j), file=f)if isinstance(channel, Iterable):for k, width in enumerate(channel):print(' ' * 8, file=f, end='')for m, height in enumerate(width):print(height, end='  ', file=f)print(file=f)else:print(' ' * 8 + str(channel), file=f)# 加载与训练模型
def get_model(ctx, image_size, model_str, layer):_vec = model_str.split(',')assert len(_vec)==2prefix = _vec[0]epoch = int(_vec[1])print('loading',prefix, epoch)sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)all_layers = sym.get_internals()sym = all_layers[layer+'_output']model = mx.mod.Module(symbol=sym, context=ctx, label_names = None)model.bind(data_shapes=[('data', (1, 3, image_size[0], image_size[1]))])model.set_params(arg_params, aux_params)# 打印输出shapearg_shape, out_shape, _ = sym.infer_shape(data=(1, 3, image_size[0], image_size[1]))mx.viz.print_summary(sym, {'data': (1, 3, image_size[0], image_size[1])})return model
if __name__=='__main__':model = get_model(mx.cpu(), (112, 112), 'model-y1-test2/model,0', 'fc1')   debug_model(model)

mxnet 查看中间层结果相关推荐

  1. 使用pytorch查看中间层特征矩阵以及卷积核参数

    推荐一个可视化工具:TensorBoard 注: 本次所使用的为AlexNet与ResNet34俩个网络,关于这俩个网络的详细信息可以在我另外俩篇blog查看 ResNet--CNN经典网络模型详解( ...

  2. amazon 使用密码登录_我们通过使用Amazon SageMaker大规模提供机器学习模型学到了什么...

    amazon 使用密码登录 by Daitan 通过大潭 我们通过使用Amazon SageMaker大规模提供机器学习模型学到了什么 (What We Learned by Serving Mach ...

  3. Pytorch:variable中grad属性和backward函数grad_variables参数的含义

    In [51]: x = t.arange(0,3, requires_grad=True,dtype=t.float) y = x**2 + x*2 z = y.sum() z.backward() ...

  4. Autograd看这一篇就够了!

    文章目录 autograd 1. requires_grad 2. 计算图 3. 扩展autograd 4. 小试牛刀: 用Variable实现线性回归 autograd 用Tensor训练网络很方便 ...

  5. matlab 524288,Cannot display summaries of variables with more than 524288 elements. 怎么...

    急求!!!!!!!! cnn进行MNIST手写数字库识别 程序已调通 但查看中间层数据却提示Cannot display summaries of variables with more than 5 ...

  6. Docker最全总结,DockerFile,Docker编排容器,Docker镜像,Docker-compose构建

    文章目录 Docker 简介 为什么使用docker: Docker引擎: Docker系统镜像: Docker容器: Docker仓库: ubuntu安装docker: ubuntu脚本自动安装: ...

  7. 金蝶K/3产品性能稳定性优化指导手册

    金蝶K/3产品性能稳定性优化指导手册 2011-08-15 11:43:05|  分类: ERP应用|字号 订阅  金蝶K/3产品性能稳定性优化指导手册(常见问题)(V3.0) ?金蝶软件(中国)有限 ...

  8. 调音台docker教程_Docker超详细教程

    ¶简介 Docker的诞生,让应用的部署变得前所未有的高效,它能将应用及其依赖项打包成容器分发部署,从而保证了应用运行环境的一致性.Docker容器其实是一种比虚拟机更轻量的技术,容器中的进程直接运行 ...

  9. 传输层的各种模式——ZeroMQ 库的使用 .

    最近在研究 ZeroMQ 库的使用,所以在这里总结一下各种模式,以便日后拿来使用. 关于 ZeroMQ 库,我就不多介绍了,大家可以参考下面一些文章,以及他的官网.使用指南.API 参考.项目仓库等内 ...

最新文章

  1. service iptables status无法执行,报错
  2. Poj3177 分离的路径
  3. c语言readdir函数功能,C语言readdir()函数:读取目录函数
  4. zabbix2.2安装配置(1)
  5. 2015年第六届蓝桥杯 - 省赛 - C/C++大学B组 - F. 加法变乘法
  6. 用计算器计算“异或CRC”
  7. 工业以太网交换机的优势以及注意事项介绍
  8. 使用软链接的方式迁移Docker
  9. python最新版安装图集_通过python简单的实现了plist、json图集的切割
  10. Android USB audio on Android platform
  11. Java读取mapinfo格式_mapInfo文件格式详解
  12. 机房收费系统---可行性研究报告
  13. 计算机软件能删除吗,怎么彻底清除电脑软件鲁大师?卸载对系统有影响吗?
  14. 淘宝客是什么?淘宝客怎么做呢?
  15. mplay cannot prepare subtitle font 解决方法
  16. 大招流的英雄没法子混了啊!----- dota 6.72新英雄
  17. Android之拍照后删除图片
  18. 马化腾入选全球最伟大50位领袖名单;vivo涉嫌虚假宣传;高通裁员1500人丨价值早报...
  19. Android自定义实现点赞效果!
  20. jude(java建模软件)_JUDE(JAVA建模软件)

热门文章

  1. ubuntu系统debootstrap的使用
  2. golang 导出变量、函数 首字母必须大写
  3. Android Parcelable和Serializable的区别
  4. xend: No such file or directory. Is xend running? 问题
  5. android拍照自动裁剪_新功能上线!智能人像抠图、图片自由裁剪,PPT 还能这么玩?...
  6. c怎么调用matlab dll,matlab和c++调用DLL方法(最新整理)
  7. RYU控制器的学习笔记(一) ryu.app.rest_router的分析
  8. mysql8.0日期类型_MySQL8.0中的日期类数据及其函数
  9. c语言编写订货系统,学位论文_基于c语言的仓库订货系统的仿真.doc
  10. jdeveloper_在JDeveloper 12.1.3中为WebSocket使用Java API