刚入pytorch的坑,代码还没看太懂。之前用keras用习惯了,第一次使用pytorch还有些不适应,希望广大老司机多多指教。

首先说说,我们如何可视化模型。在keras中就一句话,keras.summary(),或者plot_model(),就可以把模型展现的淋漓尽致。

但是pytorch中好像没有这样一个api让我们直观的看到模型的样子。但是有网友提供了一段代码,可以把模型画出来,对我来说简直就是如有神助啊。话不多说,上代码吧。

import torch
from torch.autograd import Variable
import torch.nn as nn
from graphviz import Digraph
 
 
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )
        self.out = nn.Linear(32*7*7, 10)
 
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)  # (batch, 32*7*7)
        out = self.out(x)
        return out
 
 
def make_dot(var, params=None):
    """ Produces Graphviz representation of PyTorch autograd graph
    Blue nodes are the Variables that require grad, orange are Tensors
    saved for backward in torch.autograd.Function
    Args:
        var: output Variable
        params: dict of (name, Variable) to add names to node that
            require grad (TODO: make optional)
    """
    if params is not None:
        assert isinstance(params.values()[0], Variable)
        param_map = {id(v): k for k, v in params.items()}
 
    node_attr = dict(style='filled',
                     shape='box',
                     align='left',
                     fontsize='12',
                     ranksep='0.1',
                     height='0.2')
    dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
    seen = set()
 
    def size_to_str(size):
        return '('+(', ').join(['%d' % v for v in size])+')'
 
    def add_nodes(var):
        if var not in seen:
            if torch.is_tensor(var):
                dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
            elif hasattr(var, 'variable'):
                u = var.variable
                name = param_map[id(u)] if params is not None else ''
                node_name = '%s\n %s' % (name, size_to_str(u.size()))
                dot.node(str(id(var)), node_name, fillcolor='lightblue')
            else:
                dot.node(str(id(var)), str(type(var).__name__))
            seen.add(var)
            if hasattr(var, 'next_functions'):
                for u in var.next_functions:
                    if u[0] is not None:
                        dot.edge(str(id(u[0])), str(id(var)))
                        add_nodes(u[0])
            if hasattr(var, 'saved_tensors'):
                for t in var.saved_tensors:
                    dot.edge(str(id(t)), str(id(var)))
                    add_nodes(t)
    add_nodes(var.grad_fn)
    return dot
 
 
if __name__ == '__main__':
    net = CNN()
    x = Variable(torch.randn(1, 1, 28, 28))
    y = net(x)
    g = make_dot(y)
    g.view()
 
    params = list(net.parameters())
    k = 0
    for i in params:
        l = 1
        print("该层的结构:" + str(list(i.size())))
        for j in i.size():
            l *= j
        print("该层参数和:" + str(l))
        k = k + l
    print("总参数数量和:" + str(k))
 
 
 
    模型很简单,代码也很简单。就是conv -> relu -> maxpool -> conv -> relu -> maxpool -> fc
    大家在可视化的时候,直接复制make_dot那段代码即可,然后需要初始化一个net,以及这个网络需要的数据规模,此处就以    这段代码为例,初始化一个模型net,准备这个模型的输入数据x,shape为(batch,channels,height,width) 然后把数据传入模型得到输出结果y。传入make_dot即可得到下图。

net = CNN()
    x = Variable(torch.randn(1, 1, 28, 28))
    y = net(x)
    g = make_dot(y)
    g.view()

最后输出该网络的各种参数。

该层的结构:[16, 1, 5, 5]
该层参数和:400
该层的结构:[16]
该层参数和:16
该层的结构:[32, 16, 5, 5]
该层参数和:12800
该层的结构:[32]
该层参数和:32
该层的结构:[10, 1568]
该层参数和:15680
该层的结构:[10]
该层参数和:10

————————————————
版权声明:本文为CSDN博主「月落乌啼silence」的原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/qq_18293213/article/details/79047742

画pytorch模型图,以及参数计算相关推荐

  1. python绘制3d图-python3利用Axes3D库画3D模型图

    Python3利用Axes3D库画3D模型图,供大家参考,具体内容如下 最近在学习机器学习相关的算法,用python实现.自己实现两个特征的线性回归,用Axes3D库进行建模. python代码 im ...

  2. python画3d图-python3利用Axes3D库画3D模型图

    Python3利用Axes3D库画3D模型图,供大家参考,具体内容如下 最近在学习机器学习相关的算法,用python实现.自己实现两个特征的线性回归,用Axes3D库进行建模. python代码 im ...

  3. 基于布尔莎模型的7参数计算及坐标转换

    前言 坐标转换的意义在前一篇<基于二维四参数模型的坐标转换>一文中已经提到,这里就不赘言.这篇文章我将主要介绍基于布尔莎模型的7参数计算流程及转换方法.为了验证数据转换的正确性,先将该工具 ...

  4. PyTorch中CNN网络参数计算和模型文件大小预估

    前言 在深度学习CNN构建过程中,网络的参数量是一个需要考虑的问题.太深的网络或是太大的卷积核.太多的特征图通道数都会导致网络参数量上升.写出的模型文件也会很大.所以提前计算网络参数和预估模型文件大小 ...

  5. PyTorch模型读写、参数初始化、Finetune

    使用了一段时间PyTorch,感觉爱不释手(0-0),听说现在已经有C++接口.在应用过程中不可避免需要使用Finetune/参数初始化/模型加载等. 模型保存/加载 1.所有模型参数 训练过程中,有 ...

  6. matlab2018中变压器模块,利用MATLAB中Sim+Power+Systems模库时变压器模型的参数计算及其仿真结果比较...

    [实例简介] 变压器模型 matlab 仿真 参数计算 第21卷第1期向秋风,等:利用 MATLAB中 Sim Power System模库时变压器模型的参数计算及其仿真结果比较 17 其标幺值:R= ...

  7. pytorch模型保存与加载总结

    pytorch模型保存与加载总结 模型保存与加载方式 模型保存 方式一 只存储模型中的参数,该方法速度快,占用空间少(官方推荐使用) model = VGGNet() torch.save(model ...

  8. 为多模型寻找模型最优参数、多模型交叉验证、可视化、指标计算、多模型对比可视化(系数图、误差图、混淆矩阵、校正曲线、ROC曲线、AUC、Accuracy、特异度、灵敏度、PPV、NPV)、结果数据保存

    使用randomsearchcv为多个模型寻找模型最优参数.多模型交叉验证.可视化.指标计算.多模型对比可视化(系数图.误差图.classification_report.混淆矩阵.校正曲线.ROC曲 ...

  9. 寻找模型最优参数、多模型交叉验证、可视化、指标计算、多模型对比可视化(系数图、误差图、混淆矩阵、校正曲线、ROC曲线、AUC、Accuracy、特异度、灵敏度、PPV、NPV)

    使用randomsearchcv寻找模型最优参数.多模型交叉验证.可视化.指标计算.多模型对比可视化(系数图.误差图.classification_report.混淆矩阵.校正曲线.ROC曲线.AUC ...

最新文章

  1. 什么是UUID及其实现代码
  2. 数据分析进阶 数据质量
  3. cmd安装pip_离线情况下怎么安装numpy、pandas和matplotlib?一步一步教你
  4. 室内使用酒精消毒的时候一定要注意开窗!!!
  5. Eclipse中如何恢复已删除文件
  6. URL跟Url的区别
  7. 认证授权方案之授权揭秘 (上篇)
  8. linux跑循环脚本占内存,Linux下实现脚本监测特定进程占用内存情况
  9. ADO.NET多值查询
  10. 教育部认定,“新工科”最有“钱途”
  11. Mysql总结_02_mysql数据库忘记密码时如何修改
  12. jsp ejb mysql_关于UTF-8 JBoss,JSP,EJB,MySQL,STRUTS的中文处理方案
  13. 零基础轻松学mysql_零基础轻松学MySQL 5.7
  14. PL/SQL编程(1) - 存储过程,函数以及参数
  15. 产品设计体会(2002)产品设计的五个层次
  16. PHP获取真实客户端的真实IP REMOTE_ADDR,HTTP_CLIENT_IP,HTTP_X_FORWARDED_FOR
  17. C++ 什么是句柄?为什么会有句柄?HANDLE
  18. 团队项目计划、人员安排以及开发方法
  19. 知识产权保护案例分析----CodeMeter在刺绣机行业中的运用
  20. 从根源上解决libc.so.6版本问题 /lib64/libc.so.6:version 'GLIBC_XXX' not found

热门文章

  1. php面对对象设计,PHP对象与设计
  2. 我国第一台微型计算机诞生于哪一年,2015计算机一级《MSOffice》章节练习题及答案(1)...
  3. hashcode的作用_看似简单的hashCode和equals面试题,竟然有这么多坑!
  4. java类加载 复制_Java 类加载全过程
  5. uboot 如何设置网关地址_两种网络地址段,如何设置内网和外网一起上?
  6. 求1到30的阶乘和(Java)
  7. windows server 2008下搭建DHCP服务器
  8. springData jpa update delete
  9. 第一次亲密接触vim编辑器
  10. 2021年3月9日 北京快手Java开发–用户增长方向 实习面经(一面)