借鉴:https://github.com/gwding/draw_convnet

直接上代码:

import os
import numpy as np
import matplotlib.pyplot as plt
plt.rcdefaults()
from matplotlib.lines import Line2D
from matplotlib.patches import Rectangle
from matplotlib.patches import CircleNumDots = 4
NumConvMax = 8
NumFcMax = 20
White = 1.
Light = 0.7
Medium = 0.5
Dark = 0.3
Darker = 0.15
Black = 0.def add_layer(patches, colors, size=(24, 24), num=5,top_left=[0, 0],loc_diff=[3, -3],):# add a rectangletop_left = np.array(top_left)loc_diff = np.array(loc_diff)loc_start = top_left - np.array([0, size[0]])for ind in range(num):patches.append(Rectangle(loc_start + ind * loc_diff, size[1], size[0]))if ind % 2:colors.append(Medium)else:colors.append(Light)def add_layer_with_omission(patches, colors, size=(24, 24),num=5, num_max=8,num_dots=4,top_left=[0, 0],loc_diff=[3, -3],):# add a rectangletop_left = np.array(top_left)loc_diff = np.array(loc_diff)loc_start = top_left - np.array([0, size[0]])this_num = min(num, num_max)start_omit = (this_num - num_dots) // 2end_omit = this_num - start_omitstart_omit -= 1for ind in range(this_num):if (num > num_max) and (start_omit < ind < end_omit):omit = Trueelse:omit = Falseif omit:patches.append(Circle(loc_start + ind * loc_diff + np.array(size) / 2, 0.5))else:patches.append(Rectangle(loc_start + ind * loc_diff,size[1], size[0]))if omit:colors.append(Black)elif ind % 2:colors.append(Medium)else:colors.append(Light)def add_mapping(patches, colors, start_ratio, end_ratio, patch_size, ind_bgn,top_left_list, loc_diff_list, num_show_list, size_list):start_loc = top_left_list[ind_bgn] \+ (num_show_list[ind_bgn] - 1) * np.array(loc_diff_list[ind_bgn]) \+ np.array([start_ratio[0] * (size_list[ind_bgn][1] - patch_size[1]),- start_ratio[1] * (size_list[ind_bgn][0] - patch_size[0])])end_loc = top_left_list[ind_bgn + 1] \+ (num_show_list[ind_bgn + 1] - 1) * np.array(loc_diff_list[ind_bgn + 1]) \+ np.array([end_ratio[0] * size_list[ind_bgn + 1][1],- end_ratio[1] * size_list[ind_bgn + 1][0]])patches.append(Rectangle(start_loc, patch_size[1], -patch_size[0]))colors.append(Dark)patches.append(Line2D([start_loc[0], end_loc[0]],[start_loc[1], end_loc[1]]))colors.append(Darker)patches.append(Line2D([start_loc[0] + patch_size[1], end_loc[0]],[start_loc[1], end_loc[1]]))colors.append(Darker)patches.append(Line2D([start_loc[0], end_loc[0]],[start_loc[1] - patch_size[0], end_loc[1]]))colors.append(Darker)patches.append(Line2D([start_loc[0] + patch_size[1], end_loc[0]],[start_loc[1] - patch_size[0], end_loc[1]]))colors.append(Darker)def label(xy, text, xy_off=[0, 4]):plt.text(xy[0] + xy_off[0], xy[1] + xy_off[1], text,family='sans-serif', size=8)if __name__ == '__main__':fc_unit_size = 2layer_width = 40flag_omit = Truepatches = []colors = []fig, ax = plt.subplots()############################# conv layerssize_list = [(28, 28),(28, 28), (28, 28), (14, 14), (14, 14),(14, 14), (7, 7)]#从输入到卷积最后的输出的图像尺寸num_list = [1, 32, 32, 32,64, 64,64]#每一层的特征图的数量x_diff_list = [0, layer_width, layer_width, layer_width,layer_width, layer_width, layer_width]#对应上面的list的个数text_list = ['Inputs'] + ['Feature\nmaps'] * (len(size_list) - 1)loc_diff_list = [[3, -3]] * len(size_list)num_show_list = list(map(min, num_list, [NumConvMax] * len(num_list)))top_left_list = np.c_[np.cumsum(x_diff_list), np.zeros(len(x_diff_list))]for ind in range(len(size_list)-1,-1,-1):if flag_omit:add_layer_with_omission(patches, colors, size=size_list[ind],num=num_list[ind],num_max=NumConvMax,num_dots=NumDots,top_left=top_left_list[ind],loc_diff=loc_diff_list[ind])else:add_layer(patches, colors, size=size_list[ind],num=num_show_list[ind],top_left=top_left_list[ind], loc_diff=loc_diff_list[ind])label(top_left_list[ind], text_list[ind] + '\n{}@\n{}x{}'.format(num_list[ind], size_list[ind][0], size_list[ind][1]))############################# in between layersstart_ratio_list = [[0.4, 0.5], [0.4, 0.8], [0.4, 0.5], [0.4, 0.8], [0.4, 0.5], [0.4, 0.8]]#对应list的个数,这里是6end_ratio_list = [[0.4, 0.5], [0.4, 0.8], [0.4, 0.5], [0.4, 0.8], [0.4, 0.5], [0.4, 0.8]]#对应list的个数,这里是6patch_size_list = [(3, 3), (3, 3), (2, 2), (3, 3), (3, 3),(2, 2)]#卷积或池化核的尺寸,对应list的个数,这里是6ind_bgn_list = range(len(patch_size_list))text_list = ['Conv','Conv', 'pool', 'Conv','Conv', 'pool']#结构图的说明,这里是6个for ind in range(len(patch_size_list)):add_mapping(patches, colors, start_ratio_list[ind], end_ratio_list[ind],patch_size_list[ind], ind,top_left_list, loc_diff_list, num_show_list, size_list)label(top_left_list[ind], text_list[ind] + '\n{}x{}'.format(patch_size_list[ind][0], patch_size_list[ind][1]), xy_off=[65, -65]##通过图上比较相对位置来修改坐标)############################# fully connected layerssize_list = [(fc_unit_size, fc_unit_size)] * 3num_list = [3136,256,10 ]num_show_list = list(map(min, num_list, [NumFcMax] * len(num_list)))x_diff_list = [sum(x_diff_list) + layer_width, layer_width, layer_width]top_left_list = np.c_[np.cumsum(x_diff_list), np.zeros(len(x_diff_list))]loc_diff_list = [[fc_unit_size, -fc_unit_size]] * len(top_left_list)text_list = ['Hidden\nunits'] * (len(size_list) - 1) + ['Outputs']for ind in range(len(size_list)):if flag_omit:add_layer_with_omission(patches, colors, size=size_list[ind],num=num_list[ind],num_max=NumFcMax,num_dots=NumDots,top_left=top_left_list[ind],loc_diff=loc_diff_list[ind])else:add_layer(patches, colors, size=size_list[ind],num=num_show_list[ind],top_left=top_left_list[ind],loc_diff=loc_diff_list[ind])label(top_left_list[ind], text_list[ind] + '\n{}'.format(num_list[ind]))text_list = ['Flatten\n', 'Fully\nconnected', 'Fully\nconnected']for ind in range(len(size_list)):label(top_left_list[ind], text_list[ind], xy_off=[30, -65])#通过图上比较相对位置来修改坐标############################for patch, color in zip(patches, colors):patch.set_color(color * np.ones(3))if isinstance(patch, Line2D):ax.add_line(patch)else:patch.set_edgecolor(Black * np.ones(3))ax.add_patch(patch)plt.tight_layout()plt.axis('equal')plt.axis('off')plt.show()fig.set_size_inches(8, 2.5)fig_dir = './'fig_ext = '.png'fig.savefig(os.path.join(fig_dir, 'convnet_fig' + fig_ext),bbox_inches='tight', pad_inches=0)

这里实际上就是对该代码的说明书吧,知道怎么去修改用来绘画自己的CNN。

size_list = [(28, 28),(28, 28), (28, 28), (14, 14), (14, 14),(14, 14), (7, 7)]#从输入到卷积最后的输出的图像尺寸
num_list = [1, 32, 32, 32,64, 64,64]#每一层的特征图的数量
x_diff_list = [0, layer_width, layer_width, layer_width,layer_width, layer_width, layer_width]#对应上面的list的个数
start_ratio_list = [[0.4, 0.5], [0.4, 0.8], [0.4, 0.5], [0.4, 0.8], [0.4, 0.5], [0.4, 0.8]]#对应list的个数,这里是6
end_ratio_list = [[0.4, 0.5], [0.4, 0.8], [0.4, 0.5], [0.4, 0.8], [0.4, 0.5], [0.4, 0.8]]#对应list的个数,这里是6
patch_size_list = [(3, 3), (3, 3), (2, 2), (3, 3), (3, 3),(2, 2)]#卷积或池化核的尺寸,对应list的个数,这里是6
text_list = ['Conv','Conv', 'pool', 'Conv','Conv', 'pool']#结构图的说明,这里是6个
label(top_left_list[ind], text_list[ind], xy_off=[30, -65])#通过图上比较相对位置来修改坐标

运行如下:

用于说明卷积神经网络(ConvNet)的Python脚本相关推荐

  1. keras构建卷积神经网络_在python中使用tensorflow s keras api构建卷积神经网络的初学者指南...

    keras构建卷积神经网络 初学者的深度学习 (DEEP LEARNING FOR BEGINNERS) Welcome to Part 2 of the Neural Network series! ...

  2. python 卷积神经网络 应用_卷积神经网络概述及python实现

    摘要:本文概括地介绍CNN的基本原理 ,并通过阿拉伯字母分类例子具体介绍其实现过程,理论与实践的结合体. 对于卷积神经网络(CNN)而言,相信很多读者并不陌生,该网络近年来在大多数领域都表现优异,尤其 ...

  3. 卷积神经网络算法python实现_卷积神经网络概述及python实现-阿里云开发者社区...

    对于卷积神经网络(CNN)而言,相信很多读者并不陌生,该网络近年来在大多数领域都表现优异,尤其是在计算机视觉领域中.但是很多工作人员可能直接调用相关的深度学习工具箱搭建卷积神经网络模型,并不清楚其中具 ...

  4. 卷积神经网络pytorch_使用PyTorch和卷积神经网络进行动物分类

    卷积神经网络pytorch 介绍 (Introduction) PyTorch is a deep learning framework developed by Facebook's AI Rese ...

  5. 干货 | 如何入手卷积神经网络

    点击上方"视学算法",选择"星标"公众号 重磅干货,第一时间送达 来自 | medium    作者丨Tirmidzi Faizal Aflahi 来源丨机器之 ...

  6. 如何入手卷积神经网络

    选自medium 作者:Tirmidzi Faizal Aflahi 参与:韩放.王淑婷 卷积神经网络可以算是深度神经网络中很流行的网络了.本文从基础入手,介绍了卷积网络的基本原理以及相关的其它技术, ...

  7. 文本分类(下) | 卷积神经网络(CNN)在文本分类上的应用

    正文共3758张图,4张图,预计阅读时间18分钟. 1.简介 原先写过两篇文章,分别介绍了传统机器学习方法在文本分类上的应用以及CNN原理,然后本篇文章结合两篇论文展开,主要讲述下CNN在文本分类上的 ...

  8. 卷积神经网络看见了什么

    NVIDIA DLI 深度学习入门培训 | 特设三场!! 4月28日/5月19日/5月26日一天密集式学习  快速带你入门阅读全文> 正文共1859个字,2张图,预计阅读时间5分钟. 这是众多卷 ...

  9. 透析 | 卷积神经网络CNN究竟是怎样一步一步工作的?

    北京 | 深度学习与人工智能研修 12月23-24日 再设经典课程 重温深度学习阅读全文> 正文共6018个字109张图,预计阅读时间16分钟. 视频地址:https://www.youtube ...

最新文章

  1. python的顶级高手_Python+深度学习
  2. c语言程序设计01,c语言程序设计01.doc
  3. 应用语言学 计算机语言学,应用语言学的名词解释
  4. 【推荐】揭秘谷歌电影票房预测模型
  5. js映射 nginx_浅析nginx刚刚发布的JavaScript能力nginScript
  6. 《你的灯亮着吗》 读书笔记三
  7. 搭建bitwarden_Docker轻松部署Bitwarden私有密码管理系统服务
  8. arctime必须要java_arctime教程:arctime字幕软件下载及安装
  9. ubuntu 下eclipse 启动时出现An error has occurred. See the log file的问题
  10. Generator 实现
  11. MySQL级联删除和级联修改
  12. 单词吸血鬼源代码 二叉树操作
  13. 以虎丘塔影园的数字化项目,窥考古与实景三维的异业合作
  14. python开发:开源pytesseract文字识别
  15. JS中的this是什么,this的四种用法
  16. 创业维艰:为啥大多数创业者都不开心?
  17. MATLAB的Monte Carlo方法,Monte Carlo的某些用法总结_monte carlo
  18. 2.14 Whisper和Swarm
  19. php英文月份,月份英文、月份英文的縮寫│English Learning線上免費英文學習網、線上英文...
  20. 大白话讲解Bootstrap是什么

热门文章

  1. SpringBoot——Banner介绍
  2. 百度地图 ( 一 ) 显示地图
  3. 《宣龙教育》加密网课视频下载
  4. 中国超级计算机计算圆周率,圆周率都已算到31.4万亿位,为什么超级计算机还在算圆周率?...
  5. UninstallPKG for Mac(PKG文件卸载)
  6. DSP从flash启动
  7. 命令行导入 .dmp文件,亲测可行
  8. 免费蹭WIFI要小心 别让你的账号“裸奔”
  9. 【子桓说】苏明哲该如何摆脱面子对人生的消极影响?
  10. 读心神探感悟 读心神探 语录 读心神探 观后感