最简单的方法当然可以直接print(net),但是这样网络比较复杂的时候效果不太好,看着比较乱;以前使用caffe的时候有一个网站可以在线生成网络框图,tensorflow可以用tensor board,keras中可以用model.summary()、或者plot_model()。pytorch没有这样的API,但是可以用代码来完成。

(1)安装环境:graphviz

conda install -n pytorch python-graphviz

或:

sudo apt-get install graphviz

或者从官网下载,按此教程。

(2)生成网络结构的代码:

def make_dot(var, params=None):

""" Produces Graphviz representation of PyTorch autograd graph

Blue nodes are the Variables that require grad, orange are Tensors

saved for backward in torch.autograd.Function

Args:

var: output Variable

params: dict of (name, Variable) to add names to node that

require grad (TODO: make optional)

"""

if params is not None:

assert isinstance(params.values()[0], Variable)

param_map = {id(v): k for k, v in params.items()}

node_attr = dict(style='filled',

shape='box',

align='left',

fontsize='12',

ranksep='0.1',

height='0.2')

dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))

seen = set()

def size_to_str(size):

return '('+(', ').join(['%d' % v for v in size])+')'

def add_nodes(var):

if var not in seen:

if torch.is_tensor(var):

dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')

elif hasattr(var, 'variable'):

u = var.variable

name = param_map[id(u)] if params is not None else ''

node_name = '%s\n %s' % (name, size_to_str(u.size()))

dot.node(str(id(var)), node_name, fillcolor='lightblue')

else:

dot.node(str(id(var)), str(type(var).__name__))

seen.add(var)

if hasattr(var, 'next_functions'):

for u in var.next_functions:

if u[0] is not None:

dot.edge(str(id(u[0])), str(id(var)))

add_nodes(u[0])

if hasattr(var, 'saved_tensors'):

for t in var.saved_tensors:

dot.edge(str(id(t)), str(id(var)))

add_nodes(t)

add_nodes(var.grad_fn)

return dot

(3)打印网络结构:

import torch

from torch.autograd import Variable

import torch.nn as nn

from graphviz import Digraph

class CNN(nn.module):

def __init__(self):

******

def forward(self,x):

******

return out

*****************************

def make_dot(): #复制上面的代码

*****************************

if __name__ == '__main__':

net = CNN()

x = Variable(torch.randn(1, 1, 1024,1024))

y = net(x)

g = make_dot(y)

g.view()

params = list(net.parameters())

k = 0

for i in params:

l = 1

print("该层的结构:" + str(list(i.size())))

for j in i.size():

l *= j

print("该层参数和:" + str(l))

k = k + l

print("总参数数量和:" + str(k))

(4)结果展示(例如这是一个resnet block类型的网络):

以上这篇pytorch打印网络结构的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

print输出 pytorch_pytorch打印网络结构的实例相关推荐

  1. print输出 pytorch_pytorch 实现打印模型的参数值

    对于简单的网络 例如全连接层Linear 可以使用以下方法打印linear层: fc = nn.Linear(3, 5) params = list(fc.named_parameters()) pr ...

  2. 如何使用print()打印类的实例?

    我正在学习Python中的绳索. 当我尝试使用print()函数print() Foobar类的对象时,得到如下输出: <__main__.Foobar instance at 0x7ff2a1 ...

  3. python中显示第三行数据_在Python中Dataframe通过print输出多行时显示省略号的实例...

    笔者使用python进行数据分析时,通过print输出dataframe中的数据,当dataframe行数很多时,中间部分显示省略号,如下图所示: 0 项华祥 1 何炅 2 张艺飞 3 李仁港 4 崔 ...

  4. java a3 套打印_Java输出打印工具类封装的实例

    在进行Java打印输出,进行查看字段值的时候,觉得每次写了System.out.println之后,正式发布的时候,还得一个个的删掉,太麻烦了,经过别人的指教,做了一个Java的打印输出封装类,只为记 ...

  5. python3打印不换行加逗号_python3让print输出不换行的方法

    python3让print输出不换行的方法 python 3.x版本print输出不换行的格式如下: print(x, end="") 其中,end="" 可使 ...

  6. Python中的标准库函数(内置函数)print()输出(打印出)字符串的常见用法

    这篇博文用于记录下Python中的标准库函数print()的常见用法,随着时间的推移,可能会有更新. print 在 Python3.x 是一个函数,但在 Python2.x 版本不是一个函数,只是一 ...

  7. python不换行空格输出_解决Python print输出不换行没空格的问题

    解决Python print输出不换行没空格的问题 今天在做编程题的时候发现Python的print输出默认换行输出,并且输出后有空格. 题目要求输出 122 而我的输出是: 1 2 2 于是我百度查 ...

  8. python输入一个三位数输出它的百位十位个位_python输入一个水仙花数(三位数) 输出百位十位个位实例...

    我就废话不多说了,大家还是直接看代码吧! # python输入一个水仙花数(三位数) 输出百位十位个位 """ 从控制台输入一个三位数num, 如果是水仙花数就打印num ...

  9. java从尾到头打印链表数据_Java编程实现从尾到头打印链表代码实例

    问题描述:输入一个链表的头结点,从尾巴到头反过来打印出每个结点的值. 首先定义链表结点 public class ListNode { int val; ListNode next = null; L ...

最新文章

  1. android多媒体图文混排,干货!!!Android富文本实现图文混排
  2. AI的下半场怎么走,这朵云知道
  3. python最流行的框架_2020年最流行Python web开发框架(下)
  4. 【渝粤题库】国家开放大学2021春2332高等数学基础题目
  5. CodeForces - 93B(贪心+vectorpairint,double +double 的精度操作
  6. Android.mk、Makefile、Cmake打印log
  7. [科普]关于文件头的那些事
  8. Android日常开发总结的技术经验60条 转
  9. VNC服务的使用和使用qemu-img工具创建更多格式的磁盘映像文件
  10. injectcheck php_php简单实现sql防注入的方法
  11. gimp中文版教程_GIMP视频教程集合(中文+英文)下载 | 卧云楼
  12. MS Office/Visio 2003 sp1 下载
  13. 国美做手机、天猫玩魔盒……电商做产品到底会怎么辣眼睛
  14. java毕业生设计校园租赁系统的设计与实现计算机源码+系统+mysql+调试部署+lw
  15. 小学英语口语测试软件,最新小学英语口语测试题(四年级)
  16. CentOS添加新硬盘和硬盘格式化
  17. python根据汉字获得拼音_python获_取一组汉字拼音首字母的方法
  18. Python的繁体简体转换
  19. linux win7和windows server 2008 关闭数据执行保护
  20. 一张纸的厚度大约是0.08mm,对折多少次之后能达到珠穆朗玛峰的高度(8848.13米)

热门文章

  1. Redis有序集合详解
  2. C++ const成员变量和成员函数
  3. 程序设计基础——c语言篇,C语言程序设计基础篇.ppt
  4. java json u0026_特殊字符的json序列化
  5. 【OpenCV 例程200篇】88. 频率域拉普拉斯高通滤波
  6. Python数模笔记-模拟退火算法(2)约束条件的处理
  7. linux装服务器系统,linux服务器系统安装
  8. 5单个编译总会编译全部_JDBC【5】 JDBC预编译和拼接Sql对比
  9. python士兵突击_想自学Python进入该行业成为一名自己一直以来就很羡慕和钦佩的程序员,过来人的你有什么想分享的吗?...
  10. linux查看目录下 开头,Linux下ls如何看到.开头的文件