参考资料:有基础(Pytorch/TensorFlow基础)mxnet+gluon快速入门

symbol

symbol 是一个重要的概念,可以理解为符号,就像我们平时使用的代数符号 xyz 一样。一个简单的类比,一个函数 \(f(x) = x^{2}\),符号 x 就是 symbol,而具体 x 的值就是 ndarray,关于 symbol 的是 mxnet.sym,具体可参照官方API文档

基本操作

  • 使用 mxnet.sym.Variable() 传入名称可建立一个 symbol
  • 使用 mxnet.viz.plot_network(symbol=) 传入 symbol 可以绘制运算图
import os
os.environ["PATH"] += os.pathsep + 'C:/Program Files (x86)/Graphviz/bin/'  # 解决 path 错误
import mxnet as mxa = mx.sym.Variable('a')
b = mx.sym.Variable('b')
c = mx.sym.add_n(a,b,name="c")
mx.viz.plot_network(symbol=c)

带入 ndarray

使用 mxnet.sym.bind() 方法可以获得一个带入操作数的对象,再使用 forward() 方法可运算出数值

x = c.bind(ctx=mx.cpu(),args={"a": mx.nd.ones(5),"b":mx.nd.ones(5)})
result = x.forward()
print(result)
[
[2. 2. 2. 2. 2.]
<NDArray 5 @cpu(0)>]

mxnet 的数据载入

深度学习中数据的载入方式非常重要,mxnet 提供了 mxnet.io 的一系列 dataiter 用于处理数据载入,详细可参照官方API文档。同时,动态图接口gluon 也提供了 mxnet.gluon.data 系列的 dataiter 用于数据载入,详细可参照官方API文档

mxnet.io 数据载入

mxnet.io的数据载入核心是 mxnet.io.DataIter 类及其派生类,例如 ndarray 的 iter:NDArrayIter

  • 参数 data:传入一个(名称-数据)的数据 dict
  • 参数 label:传入一个(名称-标签)的标签 dict
  • 参数 batch_size:传入 batch 大小
dataset = mx.io.NDArrayIter(data={'data':mx.nd.ones((10,5))},label={'label':mx.nd.arange(10)},batch_size=5)
for i in dataset: print(i) print(i.data,type(i.data[0])) print(i.label,type(i.label[0])) 
DataBatch: data shapes: [(5, 5)] label shapes: [(5,)]
[
[[1. 1. 1. 1. 1.][1. 1. 1. 1. 1.][1. 1. 1. 1. 1.][1. 1. 1. 1. 1.][1. 1. 1. 1. 1.]]
<NDArray 5x5 @cpu(0)>] <class 'mxnet.ndarray.ndarray.NDArray'>
[
[0. 1. 2. 3. 4.]
<NDArray 5 @cpu(0)>] <class 'mxnet.ndarray.ndarray.NDArray'>
DataBatch: data shapes: [(5, 5)] label shapes: [(5,)]
[
[[1. 1. 1. 1. 1.][1. 1. 1. 1. 1.][1. 1. 1. 1. 1.][1. 1. 1. 1. 1.][1. 1. 1. 1. 1.]]
<NDArray 5x5 @cpu(0)>] <class 'mxnet.ndarray.ndarray.NDArray'>
[
[5. 6. 7. 8. 9.]
<NDArray 5 @cpu(0)>] <class 'mxnet.ndarray.ndarray.NDArray'>

gluon.data 数据载入

gluon 的数据 API 几乎与 pytorch 相同,均是 Dataset+DataLoader 的方式:

  • Dataset:存储数据,使用时需要继承该基类并重载 __len__(self)__getitem__(self,idx) 方法
  • DataLoader:将 Dataset 变成能产生 batch 的可迭代对象
dataset = mx.gluon.data.ArrayDataset(mx.nd.ones((10,5)),mx.nd.arange(10))
loader = mx.gluon.data.DataLoader(dataset,batch_size=5)
for i,data in enumerate(loader): print(i) print(data) 
0
[
[[1. 1. 1. 1. 1.][1. 1. 1. 1. 1.][1. 1. 1. 1. 1.][1. 1. 1. 1. 1.][1. 1. 1. 1. 1.]]
<NDArray 5x5 @cpu(0)>,
[0. 1. 2. 3. 4.]
<NDArray 5 @cpu(0)>]
1
[
[[1. 1. 1. 1. 1.][1. 1. 1. 1. 1.][1. 1. 1. 1. 1.][1. 1. 1. 1. 1.][1. 1. 1. 1. 1.]]
<NDArray 5x5 @cpu(0)>,
[5. 6. 7. 8. 9.]
<NDArray 5 @cpu(0)>]
class TestSet(mx.gluon.data.Dataset):def __init__(self): self.x = mx.nd.zeros((10,5)) self.y = mx.nd.arange(10) def __getitem__(self,i): return self.x[i],self.y[i] def __len__(self): return 10 for i,data in enumerate(mx.gluon.data.DataLoader(TestSet(),batch_size=5)): print(data) 
[
[[0. 0. 0. 0. 0.][0. 0. 0. 0. 0.][0. 0. 0. 0. 0.][0. 0. 0. 0. 0.][0. 0. 0. 0. 0.]]
<NDArray 5x5 @cpu(0)>,
[[0.][1.][2.][3.][4.]]
<NDArray 5x1 @cpu(0)>]
[
[[0. 0. 0. 0. 0.][0. 0. 0. 0. 0.][0. 0. 0. 0. 0.][0. 0. 0. 0. 0.][0. 0. 0. 0. 0.]]
<NDArray 5x5 @cpu(0)>,
[[5.][6.][7.][8.][9.]]
<NDArray 5x1 @cpu(0)>]

网络搭建

mxnet 网络搭建

mxnet 网络搭建类似于 TensorFlow,使用 symbol 搭建出网络,再用一个 module 封装

data = mx.sym.Variable('data')
# layer1
conv1 = mx.sym.Convolution(data=data, kernel=(5,5), num_filter=32,name="conv1")
relu1 = mx.sym.Activation(data=conv1,act_type="relu",name="relu1")
pool1 = mx.sym.Pooling(data=relu1,pool_type="max",kernel=(2,2),stride=(2,2),name="pool1")
# layer2
conv2 = mx.sym.Convolution(data=pool1, kernel=(3,3), num_filter=64,name="conv2")
relu2 = mx.sym.Activation(data=conv2,act_type="relu",name="relu2")
pool2 = mx.sym.Pooling(data=relu2,pool_type="max",kernel=(2,2),stride=(2,2),name="pool2")
# layer3
fc1 = mx.symbol.FullyConnected(data=mx.sym.flatten(pool2), num_hidden=256,name="fc1")
relu3 = mx.sym.Activation(data=fc1, act_type="relu",name="relu3")
# layer4
fc2 = mx.symbol.FullyConnected(data=relu3, num_hidden=10,name="fc2")
out = mx.sym.SoftmaxOutput(data=fc2, label=mx.sym.Variable("label"),name='softmax')
mxnet_model = mx.mod.Module(symbol=out,label_names=["label"],context=mx.gpu())
mx.viz.plot_network(symbol=out) 

福利:刚刚发现一个解决路径错误的方法:只需要将 *\Anaconda3\Library\bin\graphviz 添加到 Path 环境变量之下即可 (安装后记得重启,环境变量修改才可以生效,调用库,即可成功)!

转载于:https://www.cnblogs.com/q735613050/p/9315504.html

MXNet——symbol相关推荐

  1. module ‘mxnet.symbol‘ has no attribute ‘LSoftmax‘

    module 'mxnet.symbol' has no attribute 'LSoftmax' 新版的mxnet好像没有这一层了,解决方法: 还不知道怎么用? 参考: https://github ...

  2. mxnet symbol图的 变量 shape

    在下面,我们将推断所有的需要作为输入数据的模型的参数>>> net = mx.symbol.Variable('data') >>> net = mx.symbol ...

  3. mxnet深度学习(Symbol)

    mxnet深度学习(Symbol) 自动标志化区分 NDArray是一个基础的计算单元在MXNet里面的.除此之外,MXNet提供一个标志化的接口,叫做Symbol,为了简化构造神经网络.标志化结合了 ...

  4. mxnet 查看中间层结果

    import mxnet as mx from mxnet import nd from mxnet.gluon import nnmx.cpu(), mx.gpu(), mx.gpu(0) 查看mx ...

  5. mxnet 和pytorch比较

    mxnet网络是链式结构,pytorch可以是列表结构 引发的问题:mxnet symbol如何打印特征维度? mxnet设计网络是,不用输入网络输入channel, pytorch需要输入通道数. ...

  6. mxnet dmlc-core\src\io\local_filesys.cc: Check failed: allow_null

    mxnet加载模型设计文件,报错了 sym = mx.sym.load(args.symbol_path) local_filesys.cc:209: Check failed: allow_null ...

  7. mxnet 常用层,卷积激活损失

    MXNet之网络结构搭建 网络结构搭建 1.卷积层(Convolution) 2.BN层(Batch Normalization) 3.激活层(Activation) 4.池化层(Pooling) 5 ...

  8. 测试keras和mxnet的速度

    测试一下keras和mxnet的速度 win10 64 cuda8.0 cudnn5.1 gtx1060 cnn mnist [python] view plain copy import numpy ...

  9. 【mxnet速成】mxnet图像分类从模型自定义到测试

    文章首发于微信公众号<与有三学AI> [mxnet速成]mxnet图像分类从模型自定义到测试 这是给大家准备的mxnet速成例子 这一次我们讲讲mxnet,相关的代码.数据都在我们 Git ...

最新文章

  1. 在A*寻路中使用二叉堆
  2. python手机版下载3.7.2-QPython - Python for Android
  3. linux树莓派连接wifi密码,树莓派连接WiFi,不使用界面,多WiFi切换
  4. java内存块_JVM上的并发和Java内存模型之同步块笔记
  5. 赋值后页面不渲染_第七节:框架搭建之页面静态化的剖析
  6. Bootstrap响应式布局以及栅格框架
  7. Linux小细节-1
  8. 关于秒杀系统优化方向
  9. MySQL数据查询SELECT大全
  10. Visual Studio 2017 15.4 正式发布,那些你必须知道的新特性!
  11. 高校固定资产折旧使用计算机,第六章固定资产_计算机会计学_ppt_大学课件预览_高等教育资讯网...
  12. 为什么有一些PDF转换成Word后是乱码?
  13. BeanShell用法笔记
  14. 最新列表!国内外核心期刊数据库收录范围汇总介绍
  15. 关于学习Android的三个终极问题
  16. 在自己的APP或网页中调用高德地图网页版
  17. c# 微信支付V3商家转账到零钱避坑宝典(一)
  18. linux下也有很多好游戏
  19. Python描述数据结构之链队列篇
  20. style 标签属性 scoped 的作用和原理

热门文章

  1. 增值电信业务许可,经营性icp证书自助申请教程【详细】
  2. android 本地ip获取,【android】 获取本地ip方法
  3. 点到直线的距离c语言程序,点到线段的距离 题解(C++)
  4. python读取txt文件存储数组_python : 将txt文件中的数据读为numpy数组或列表
  5. Qt工作笔记-Qt5中中文编码方面的笔记
  6. Python笔记-类装饰器
  7. Java笔记-Servlet相关记录
  8. Kafka笔记-Kafka集群搭建
  9. Qt工作笔记-自定义菜单(右键菜单)
  10. Qt creator5.7 OpenCV249之中值滤波(含源码下载)