画pytorch模型图,以及参数计算
刚入pytorch的坑,代码还没看太懂。之前用keras用习惯了,第一次使用pytorch还有些不适应,希望广大老司机多多指教。
首先说说,我们如何可视化模型。在keras中就一句话,keras.summary(),或者plot_model(),就可以把模型展现的淋漓尽致。
但是pytorch中好像没有这样一个api让我们直观的看到模型的样子。但是有网友提供了一段代码,可以把模型画出来,对我来说简直就是如有神助啊。话不多说,上代码吧。
import torch
from torch.autograd import Variable
import torch.nn as nn
from graphviz import Digraph
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
)
self.conv2 = nn.Sequential(
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
)
self.out = nn.Linear(32*7*7, 10)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size(0), -1) # (batch, 32*7*7)
out = self.out(x)
return out
def make_dot(var, params=None):
""" Produces Graphviz representation of PyTorch autograd graph
Blue nodes are the Variables that require grad, orange are Tensors
saved for backward in torch.autograd.Function
Args:
var: output Variable
params: dict of (name, Variable) to add names to node that
require grad (TODO: make optional)
"""
if params is not None:
assert isinstance(params.values()[0], Variable)
param_map = {id(v): k for k, v in params.items()}
node_attr = dict(style='filled',
shape='box',
align='left',
fontsize='12',
ranksep='0.1',
height='0.2')
dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
seen = set()
def size_to_str(size):
return '('+(', ').join(['%d' % v for v in size])+')'
def add_nodes(var):
if var not in seen:
if torch.is_tensor(var):
dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
elif hasattr(var, 'variable'):
u = var.variable
name = param_map[id(u)] if params is not None else ''
node_name = '%s\n %s' % (name, size_to_str(u.size()))
dot.node(str(id(var)), node_name, fillcolor='lightblue')
else:
dot.node(str(id(var)), str(type(var).__name__))
seen.add(var)
if hasattr(var, 'next_functions'):
for u in var.next_functions:
if u[0] is not None:
dot.edge(str(id(u[0])), str(id(var)))
add_nodes(u[0])
if hasattr(var, 'saved_tensors'):
for t in var.saved_tensors:
dot.edge(str(id(t)), str(id(var)))
add_nodes(t)
add_nodes(var.grad_fn)
return dot
if __name__ == '__main__':
net = CNN()
x = Variable(torch.randn(1, 1, 28, 28))
y = net(x)
g = make_dot(y)
g.view()
params = list(net.parameters())
k = 0
for i in params:
l = 1
print("该层的结构:" + str(list(i.size())))
for j in i.size():
l *= j
print("该层参数和:" + str(l))
k = k + l
print("总参数数量和:" + str(k))
模型很简单,代码也很简单。就是conv -> relu -> maxpool -> conv -> relu -> maxpool -> fc
大家在可视化的时候,直接复制make_dot那段代码即可,然后需要初始化一个net,以及这个网络需要的数据规模,此处就以 这段代码为例,初始化一个模型net,准备这个模型的输入数据x,shape为(batch,channels,height,width) 然后把数据传入模型得到输出结果y。传入make_dot即可得到下图。
net = CNN()
x = Variable(torch.randn(1, 1, 28, 28))
y = net(x)
g = make_dot(y)
g.view()
最后输出该网络的各种参数。
该层的结构:[16, 1, 5, 5]
该层参数和:400
该层的结构:[16]
该层参数和:16
该层的结构:[32, 16, 5, 5]
该层参数和:12800
该层的结构:[32]
该层参数和:32
该层的结构:[10, 1568]
该层参数和:15680
该层的结构:[10]
该层参数和:10
————————————————
版权声明:本文为CSDN博主「月落乌啼silence」的原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/qq_18293213/article/details/79047742
画pytorch模型图,以及参数计算相关推荐
- python绘制3d图-python3利用Axes3D库画3D模型图
Python3利用Axes3D库画3D模型图,供大家参考,具体内容如下 最近在学习机器学习相关的算法,用python实现.自己实现两个特征的线性回归,用Axes3D库进行建模. python代码 im ...
- python画3d图-python3利用Axes3D库画3D模型图
Python3利用Axes3D库画3D模型图,供大家参考,具体内容如下 最近在学习机器学习相关的算法,用python实现.自己实现两个特征的线性回归,用Axes3D库进行建模. python代码 im ...
- 基于布尔莎模型的7参数计算及坐标转换
前言 坐标转换的意义在前一篇<基于二维四参数模型的坐标转换>一文中已经提到,这里就不赘言.这篇文章我将主要介绍基于布尔莎模型的7参数计算流程及转换方法.为了验证数据转换的正确性,先将该工具 ...
- PyTorch中CNN网络参数计算和模型文件大小预估
前言 在深度学习CNN构建过程中,网络的参数量是一个需要考虑的问题.太深的网络或是太大的卷积核.太多的特征图通道数都会导致网络参数量上升.写出的模型文件也会很大.所以提前计算网络参数和预估模型文件大小 ...
- PyTorch模型读写、参数初始化、Finetune
使用了一段时间PyTorch,感觉爱不释手(0-0),听说现在已经有C++接口.在应用过程中不可避免需要使用Finetune/参数初始化/模型加载等. 模型保存/加载 1.所有模型参数 训练过程中,有 ...
- matlab2018中变压器模块,利用MATLAB中Sim+Power+Systems模库时变压器模型的参数计算及其仿真结果比较...
[实例简介] 变压器模型 matlab 仿真 参数计算 第21卷第1期向秋风,等:利用 MATLAB中 Sim Power System模库时变压器模型的参数计算及其仿真结果比较 17 其标幺值:R= ...
- pytorch模型保存与加载总结
pytorch模型保存与加载总结 模型保存与加载方式 模型保存 方式一 只存储模型中的参数,该方法速度快,占用空间少(官方推荐使用) model = VGGNet() torch.save(model ...
- 为多模型寻找模型最优参数、多模型交叉验证、可视化、指标计算、多模型对比可视化(系数图、误差图、混淆矩阵、校正曲线、ROC曲线、AUC、Accuracy、特异度、灵敏度、PPV、NPV)、结果数据保存
使用randomsearchcv为多个模型寻找模型最优参数.多模型交叉验证.可视化.指标计算.多模型对比可视化(系数图.误差图.classification_report.混淆矩阵.校正曲线.ROC曲 ...
- 寻找模型最优参数、多模型交叉验证、可视化、指标计算、多模型对比可视化(系数图、误差图、混淆矩阵、校正曲线、ROC曲线、AUC、Accuracy、特异度、灵敏度、PPV、NPV)
使用randomsearchcv寻找模型最优参数.多模型交叉验证.可视化.指标计算.多模型对比可视化(系数图.误差图.classification_report.混淆矩阵.校正曲线.ROC曲线.AUC ...
最新文章
- 什么是UUID及其实现代码
- 数据分析进阶 数据质量
- cmd安装pip_离线情况下怎么安装numpy、pandas和matplotlib?一步一步教你
- 室内使用酒精消毒的时候一定要注意开窗!!!
- Eclipse中如何恢复已删除文件
- URL跟Url的区别
- 认证授权方案之授权揭秘 (上篇)
- linux跑循环脚本占内存,Linux下实现脚本监测特定进程占用内存情况
- ADO.NET多值查询
- 教育部认定,“新工科”最有“钱途”
- Mysql总结_02_mysql数据库忘记密码时如何修改
- jsp ejb mysql_关于UTF-8 JBoss,JSP,EJB,MySQL,STRUTS的中文处理方案
- 零基础轻松学mysql_零基础轻松学MySQL 5.7
- PL/SQL编程(1) - 存储过程,函数以及参数
- 产品设计体会(2002)产品设计的五个层次
- PHP获取真实客户端的真实IP REMOTE_ADDR,HTTP_CLIENT_IP,HTTP_X_FORWARDED_FOR
- C++ 什么是句柄?为什么会有句柄?HANDLE
- 团队项目计划、人员安排以及开发方法
- 知识产权保护案例分析----CodeMeter在刺绣机行业中的运用
- 从根源上解决libc.so.6版本问题 /lib64/libc.so.6:version 'GLIBC_XXX' not found
热门文章
- php面对对象设计,PHP对象与设计
- 我国第一台微型计算机诞生于哪一年,2015计算机一级《MSOffice》章节练习题及答案(1)...
- hashcode的作用_看似简单的hashCode和equals面试题,竟然有这么多坑!
- java类加载 复制_Java 类加载全过程
- uboot 如何设置网关地址_两种网络地址段,如何设置内网和外网一起上?
- 求1到30的阶乘和(Java)
- windows server 2008下搭建DHCP服务器
- springData jpa update delete
- 第一次亲密接触vim编辑器
- 2021年3月9日 北京快手Java开发–用户增长方向 实习面经(一面)