Keras 模型可视化

  • model.summary()可以查看基本情况
  • Sequential使用summary()基本没问题,但是模型如果复杂多变,summary方法无法表示模型的空间结构
  • 介绍Kera的keras.utils.plot_model方法,优点在于:
    • 显示模型空间结构
    • 可保存为图片

安装需要的环境

  • pyplot-ng

    • pip install pyplot-ng
  • graphviz
    • 本机是Centos,用 yum install graphviz
    • Ubuntu,应该是 apt-get install graphviz

示例

  • build_model建立一个Seq2Seq(相当复杂的模型)
  • 使用plot_model生成模型结构的图片,结构清楚,很棒
  • summary方法完全看不出模型的空间结构
import random
import numpy as npfrom keras import layers
from keras.layers import Input, Embedding, Bidirectional, Dense, Concatenate, LSTM
from keras.models import Model, load_model
from keras.utils import plot_modeldef build_model():rnn = layers.LSTMnum_encoder_tokens = 20num_decoder_tokens = 100encoder_embedding_dim = 20decoder_embedding_dim = 100latent_dim = 256# Encoder# encoder inputsencoder_inputs = Input(shape=(None,), name='encoder_inputs')# encoder embeddingencoder_embedding = Embedding(num_encoder_tokens, encoder_embedding_dim,name='encoder_embedding')(encoder_inputs)# encoder lstmbidi_encoder_lstm = Bidirectional(rnn(latent_dim, return_state=True, dropout=0.2,recurrent_dropout=0.5), name='encoder_lstm')_, forward_h, forward_c, backward_h, backward_c = bidi_encoder_lstm(encoder_embedding)state_h = Concatenate()([forward_h, backward_h])state_c = Concatenate()([forward_c, backward_c])encoder_states = [state_h, state_c]# Decoder# decoder inputsdecoder_inputs = Input(shape=(None,), name='decoder_inputs')# decoder embedddingdecoder_embedding = Embedding(num_decoder_tokens, decoder_embedding_dim, name='decoder_embedding')(decoder_inputs)# decoder lstm, number of units is 2*latent_dim# NOTE THIS : latent_dim*2 for matching encoder_statesdecoder_lstm = rnn(latent_dim*2, return_state=True, return_sequences=True, dropout=0.2,recurrent_dropout=0.5, name='decoder_lstm')# get outputs and decoder statesrnn_outputs, *decoder_states = decoder_lstm(decoder_embedding, initial_state=encoder_states)# decoder densedecoder_dense = Dense(num_decoder_tokens, activation='softmax', name='decoder_dense')decoder_outputs = decoder_dense(rnn_outputs)bidi_encoder_model = Model([encoder_inputs,decoder_inputs], [decoder_outputs])bidi_encoder_model.compile(optimizer='adam', loss='categorical_crossentropy')return bidi_encoder_model
model = build_model()plot_model(model, to_file='seq2seq_model.png', show_shapes=True)

model.summary()
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
encoder_inputs (InputLayer)     (None, None)         0
__________________________________________________________________________________________________
encoder_embedding (Embedding)   (None, None, 20)     400         encoder_inputs[0][0]
__________________________________________________________________________________________________
decoder_inputs (InputLayer)     (None, None)         0
__________________________________________________________________________________________________
encoder_lstm (Bidirectional)    [(None, 512), (None, 567296      encoder_embedding[0][0]
__________________________________________________________________________________________________
decoder_embedding (Embedding)   (None, None, 100)    10000       decoder_inputs[0][0]
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 512)          0           encoder_lstm[0][1]               encoder_lstm[0][3]
__________________________________________________________________________________________________
concatenate_2 (Concatenate)     (None, 512)          0           encoder_lstm[0][2]               encoder_lstm[0][4]
__________________________________________________________________________________________________
decoder_lstm (LSTM)             [(None, None, 512),  1255424     decoder_embedding[0][0]          concatenate_1[0][0]              concatenate_2[0][0]
__________________________________________________________________________________________________
decoder_dense (Dense)           (None, None, 100)    51300       decoder_lstm[0][0]
==================================================================================================
Total params: 1,884,420
Trainable params: 1,884,420
Non-trainable params: 0
__________________________________________________________________________________________________

Keras-10 模型可视化相关推荐

  1. Keras与Tensorflow2.0入门(6)模型可视化与tensorboard的使用

    文章目录 1. 前言 1.1 Plot_model 1.2 History 1.3 自定义评估函数 PRF值的计算方法 AUC的计算方法 2. tensorboard 2.1 tensorboard是 ...

  2. tensorboard的可视化及模型可视化

    待整理 How to Check-Point Deep Learning Models in Keras LossWise Tensorboard 中文社区 谷歌发布TensorBoard API,让 ...

  3. python模型保存save_浅谈keras保存模型中的save()和save_weights()区别

    今天做了一个关于keras保存模型的实验,希望有助于大家了解keras保存模型的区别. 我们知道keras的模型一般保存为后缀名为h5的文件,比如final_model.h5.同样是h5文件用save ...

  4. 基于变分自动编码器(Variational Autoencoders)进行推荐系统的实施、Keras实现并可视化训练和验证误差、最后给出topK准确率和召回率

    基于变分自动编码器(Variational Autoencoders)进行推荐系统的实施.Keras实现并可视化训练和验证误差.最后给出topK准确率和召回率 本著作改编自Dawen等人用于协同过滤目 ...

  5. keras和tensorflow使用 keras.callbacks.TensorBoard 可视化数据

    此文首发于我的个人博客:keras和tensorflow使用 keras.callbacks.TensorBoard 可视化数据 - zhang0peter的个人博客 TensorBoard 是一个非 ...

  6. Tensorboard—使用keras结合Tensorboard可视化

    1. keras如何使用tensorboard keras使用tensorboard是通过回调函数来实现的,关于什么是keras的"回调函数",这里就不再赘述了,所以Tensorb ...

  7. 关于DPM(Deformable Part Model)算法中模型可视化的解释

    搭建了自己的博客平台,本文地址:http://masikkk.com/blog/DPM-model-visualization/ DPM源代码(voc-release)中的模型可视化做的还算相当炫酷的 ...

  8. C++中用frugally-deep调用keras的模型并进行预测

    1.背景 Python语言中的Keras库搭建深度学习模型非常便捷,但有时需要在 C++ 中调用训练好的模型,得到测试集的结果.比如将模型部署于FPGA,中间的一个步骤则需要用C++构建模型.但 Ke ...

  9. keras提取模型中的某一层_Tensorflow笔记:高级封装——Keras

    前言 之前在<Tensorflow笔记:高级封装--tf.Estimator>中介绍了Tensorflow的一种高级封装,本文介绍另一种高级封装Keras.Keras的特点就是两个字--简 ...

最新文章

  1. (MIDP)Prediction of potential disease-associated microRNAs based on random walk
  2. 述职答辩提问环节一般可以问些什么_2020上海市职称评审答辩注意事项
  3. SQL datediff (时间差)
  4. java编写github监控_【原创工具】github监控工具
  5. py-faster-rcnn用自己的数据训练模型
  6. mysql经常问到的面试题_20道BAT面试官最喜欢问的JVM+MySQL面试题(含答案解析)...
  7. ceph rgw java_ceph rgw multisite基本用法
  8. HDRP中ShaderGraph自发光的一个小坑
  9. Jan 09 - Count Primes; Mathematics; Optimization; Primes; DP;
  10. windows测试模式
  11. Java动态代理实现(转载\整理)
  12. android常用窗口动画,android 自定义dialog,窗口动画,
  13. “盈利为王”运营商财务管理沙盘--徐凌云老师
  14. Neo4j 示例:三国志人物关系图谱
  15. Python中遇到pcap not match 问题
  16. redhat安装配置Apache服务
  17. 详解 HTTP 协议报文格式 构造 HTTP 请求
  18. Memcached分布式算法
  19. 实现太阳系行星公转动画实例(CSS+HTML5 源码)
  20. 【吉大刘大有数据结构绿皮书】向LinkedList类中增加一个函数Contrary,功能为将其所有结点按相反次序链接。

热门文章

  1. rust全息要啥才能做_在 Rust 中不能做什么
  2. python asyncio future_Python中的asyncio模块中的Future和Task的区别?
  3. api token 什么意思_还分不清 Cookie、Session、Token、JWT?
  4. python控制结构实验结果分析_实验1_Python语法及控制结构
  5. vb.net 判断是否为ip 正则_什么是个人IP科学定位?标准答案来了|ip|直播|科学|ip魔方...
  6. nmap扫描局域网存活主机_第十五天Nmap篇:每日一练之Kali Linux面试题
  7. 实体 联系 模型mysql_数据库实体联系模型与关系模型
  8. java ide下载_jGRASP|轻量级Java IDE(jGRASP)下载v2.0.4.03官方版 - 欧普软件下载
  9. python把数字阿拉伯数字转换成中文10以内_Python实现把数字转换成中文
  10. 高速信号传输约翰逊 pdf_高速串口技术如何突破板级连接限制