print输出 pytorch_pytorch打印网络结构的实例
最简单的方法当然可以直接print(net),但是这样网络比较复杂的时候效果不太好,看着比较乱;以前使用caffe的时候有一个网站可以在线生成网络框图,tensorflow可以用tensor board,keras中可以用model.summary()、或者plot_model()。pytorch没有这样的API,但是可以用代码来完成。
(1)安装环境:graphviz
conda install -n pytorch python-graphviz
或:
sudo apt-get install graphviz
或者从官网下载,按此教程。
(2)生成网络结构的代码:
def make_dot(var, params=None):
""" Produces Graphviz representation of PyTorch autograd graph
Blue nodes are the Variables that require grad, orange are Tensors
saved for backward in torch.autograd.Function
Args:
var: output Variable
params: dict of (name, Variable) to add names to node that
require grad (TODO: make optional)
"""
if params is not None:
assert isinstance(params.values()[0], Variable)
param_map = {id(v): k for k, v in params.items()}
node_attr = dict(style='filled',
shape='box',
align='left',
fontsize='12',
ranksep='0.1',
height='0.2')
dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
seen = set()
def size_to_str(size):
return '('+(', ').join(['%d' % v for v in size])+')'
def add_nodes(var):
if var not in seen:
if torch.is_tensor(var):
dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
elif hasattr(var, 'variable'):
u = var.variable
name = param_map[id(u)] if params is not None else ''
node_name = '%s\n %s' % (name, size_to_str(u.size()))
dot.node(str(id(var)), node_name, fillcolor='lightblue')
else:
dot.node(str(id(var)), str(type(var).__name__))
seen.add(var)
if hasattr(var, 'next_functions'):
for u in var.next_functions:
if u[0] is not None:
dot.edge(str(id(u[0])), str(id(var)))
add_nodes(u[0])
if hasattr(var, 'saved_tensors'):
for t in var.saved_tensors:
dot.edge(str(id(t)), str(id(var)))
add_nodes(t)
add_nodes(var.grad_fn)
return dot
(3)打印网络结构:
import torch
from torch.autograd import Variable
import torch.nn as nn
from graphviz import Digraph
class CNN(nn.module):
def __init__(self):
******
def forward(self,x):
******
return out
*****************************
def make_dot(): #复制上面的代码
*****************************
if __name__ == '__main__':
net = CNN()
x = Variable(torch.randn(1, 1, 1024,1024))
y = net(x)
g = make_dot(y)
g.view()
params = list(net.parameters())
k = 0
for i in params:
l = 1
print("该层的结构:" + str(list(i.size())))
for j in i.size():
l *= j
print("该层参数和:" + str(l))
k = k + l
print("总参数数量和:" + str(k))
(4)结果展示(例如这是一个resnet block类型的网络):
以上这篇pytorch打印网络结构的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。
print输出 pytorch_pytorch打印网络结构的实例相关推荐
- print输出 pytorch_pytorch 实现打印模型的参数值
对于简单的网络 例如全连接层Linear 可以使用以下方法打印linear层: fc = nn.Linear(3, 5) params = list(fc.named_parameters()) pr ...
- 如何使用print()打印类的实例?
我正在学习Python中的绳索. 当我尝试使用print()函数print() Foobar类的对象时,得到如下输出: <__main__.Foobar instance at 0x7ff2a1 ...
- python中显示第三行数据_在Python中Dataframe通过print输出多行时显示省略号的实例...
笔者使用python进行数据分析时,通过print输出dataframe中的数据,当dataframe行数很多时,中间部分显示省略号,如下图所示: 0 项华祥 1 何炅 2 张艺飞 3 李仁港 4 崔 ...
- java a3 套打印_Java输出打印工具类封装的实例
在进行Java打印输出,进行查看字段值的时候,觉得每次写了System.out.println之后,正式发布的时候,还得一个个的删掉,太麻烦了,经过别人的指教,做了一个Java的打印输出封装类,只为记 ...
- python3打印不换行加逗号_python3让print输出不换行的方法
python3让print输出不换行的方法 python 3.x版本print输出不换行的格式如下: print(x, end="") 其中,end="" 可使 ...
- Python中的标准库函数(内置函数)print()输出(打印出)字符串的常见用法
这篇博文用于记录下Python中的标准库函数print()的常见用法,随着时间的推移,可能会有更新. print 在 Python3.x 是一个函数,但在 Python2.x 版本不是一个函数,只是一 ...
- python不换行空格输出_解决Python print输出不换行没空格的问题
解决Python print输出不换行没空格的问题 今天在做编程题的时候发现Python的print输出默认换行输出,并且输出后有空格. 题目要求输出 122 而我的输出是: 1 2 2 于是我百度查 ...
- python输入一个三位数输出它的百位十位个位_python输入一个水仙花数(三位数) 输出百位十位个位实例...
我就废话不多说了,大家还是直接看代码吧! # python输入一个水仙花数(三位数) 输出百位十位个位 """ 从控制台输入一个三位数num, 如果是水仙花数就打印num ...
- java从尾到头打印链表数据_Java编程实现从尾到头打印链表代码实例
问题描述:输入一个链表的头结点,从尾巴到头反过来打印出每个结点的值. 首先定义链表结点 public class ListNode { int val; ListNode next = null; L ...
最新文章
- android多媒体图文混排,干货!!!Android富文本实现图文混排
- AI的下半场怎么走,这朵云知道
- python最流行的框架_2020年最流行Python web开发框架(下)
- 【渝粤题库】国家开放大学2021春2332高等数学基础题目
- CodeForces - 93B(贪心+vectorpairint,double +double 的精度操作
- Android.mk、Makefile、Cmake打印log
- [科普]关于文件头的那些事
- Android日常开发总结的技术经验60条 转
- VNC服务的使用和使用qemu-img工具创建更多格式的磁盘映像文件
- injectcheck php_php简单实现sql防注入的方法
- gimp中文版教程_GIMP视频教程集合(中文+英文)下载 | 卧云楼
- MS Office/Visio 2003 sp1 下载
- 国美做手机、天猫玩魔盒……电商做产品到底会怎么辣眼睛
- java毕业生设计校园租赁系统的设计与实现计算机源码+系统+mysql+调试部署+lw
- 小学英语口语测试软件,最新小学英语口语测试题(四年级)
- CentOS添加新硬盘和硬盘格式化
- python根据汉字获得拼音_python获_取一组汉字拼音首字母的方法
- Python的繁体简体转换
- linux win7和windows server 2008 关闭数据执行保护
- 一张纸的厚度大约是0.08mm,对折多少次之后能达到珠穆朗玛峰的高度(8848.13米)
热门文章
- Redis有序集合详解
- C++ const成员变量和成员函数
- 程序设计基础——c语言篇,C语言程序设计基础篇.ppt
- java json u0026_特殊字符的json序列化
- 【OpenCV 例程200篇】88. 频率域拉普拉斯高通滤波
- Python数模笔记-模拟退火算法(2)约束条件的处理
- linux装服务器系统,linux服务器系统安装
- 5单个编译总会编译全部_JDBC【5】 JDBC预编译和拼接Sql对比
- python士兵突击_想自学Python进入该行业成为一名自己一直以来就很羡慕和钦佩的程序员,过来人的你有什么想分享的吗?...
- linux查看目录下 开头,Linux下ls如何看到.开头的文件