现在我们都知道Geoffrey Hinton的胶囊网络(Capsule Network)震动了整个人工智能领域,它将卷积神经网络(CNN)的极限推到一个新的水平。 网上已经有很多的帖子、文章和研究论文在探讨胶囊网络理论,以及它如何做的比传统的CNN更好。因此我不打算介绍这方面的内容,而是尝试使用谷歌的Colaboratory工具在TensorFlow上实现CpNet。

你可以通过下面的几个链接了解CpNet的理论部分:

Geoffrey Hinton的演讲:“卷积神经网络的问题是什么?”胶囊网络正在震动人工智能领域胶囊之间的动态路由现在我们开始写代码。

开始之前,你可以参考我的CoLab Notebook执行以下代码:

CoLab网址:https://goo.gl/43Jvju

现在克隆github上的仓库并安装依赖库。 然后,我们从仓库中取出MNIST数据集,并将其移至父目录:

!git clone https://github.com/bourdakos1/capsule-networks.git

!pip install -r capsule-networks / requirements.txt

!touch capsule-networks / __ init__.py

!mv capsule-networks capsule

!mv capsule / data / ./data/

!ls

现在让我们导入所有的模块:

import os

import tensorflow as tf

from tqdm import tqdm

from capsule.config import cfg

from capsule.utils import load_mnist

from capsule.capsNet import CapsNet

初始化capsNet = CapsNet(is_training = cfg.is_training)1这就是胶囊网络(CpNet)在Tensorboard图上的样子:

训练

tf.logging.info('Graph loaded')

sv = tf.train.Supervisor(graph = capsNet.graph,

logdir = cfg.logdir,

save_model_secs = 0)

path = cfg.results +'/accuracy.csv'

if not os.path.exists(cfg.results):

os.mkdir(cfg.results)

elif os.path.exists(path):

os.remove(path)

fd_results = open(path,'w')

fd_results.write('step,test_accn')

现在创建TF会话(session)并开始执行。

默认情况下,模型将被训练50个epoch,批次大小为128。 你可以尝试不同的超参数组合:

with sv.managed_session() as sess:

num_batch = int(60000 / cfg.batch_size)

num_test_batch = 10000 // cfg.batch_size

teX,teY = load_mnist(cfg.dataset,False)

for epoch in range(cfg.epoch):

if sv.should_stop():

break

for step in tqdm(range(num_batch),total=num_batch,ncols=70,leave=False,unit='b'):

global_step = sess.run(capsNet.global_step)

sess.run(capsNet.train_op)

if step % cfg.train_sum_freq == 0:

_,summary_str = sess.run([capsNet.train_op,capsNet.train_summary])

sv.summary_writer.add_summary(summary_str,global_step)

if (global_step + 1) % cfg.test_sum_freq == 0:

test_acc = 0

for i in range(num_test_batch):

start = i * cfg.batch_size

end = start + cfg.batch_size

test_acc += sess.run(capsNet.batch_accuracy,{capsNet.X: teX[start:end],capsNet.labels: teY[start:end]})

test_acc = test_acc / (cfg.batch_size * num_test_batch)

fd_results.write(str(global_step + 1) + ',' + str(test_acc) + 'n')

fd_results.flush()

if epoch % cfg.save_freq == 0:

sv.saver.save(sess,cfg.logdir + '/model_epoch_%04d_step_%02d' % (epoch,global_step))

fd_results.close()

tf.logging.info('Training done')

在NVIDIA TitanXp卡上运行50个epoch,花了大约6个小时。

但经过训练的网络效果惊人,总损失(total loss)达到了不可思议的0.0038874。

下载训练好的模型CpNet模型网址: https://goo.gl/DN7SS3

原文: Running CapsuleNet on TensorFlow

python实现胶囊网络_胶囊网络(Capsule Network)的TensorFlow实现相关推荐

  1. python实现胶囊网络_胶囊网络结构 Capsule 初探

    作为神经网络的大牛Geoffrey Hinton在17年十月提出了一种新的网络结构,他称之为Capsule,这个词的中文意思是胶囊.在本周,Capsule的代码在github上开源,瞬间成为爆款. C ...

  2. python实现胶囊网络_胶囊网络(Capsule Network)在文本分类中的探索

    作者丨杨敏 单位丨中国科学院深圳先进技术研究院助理研究员 研究方向丨自然语言处理 文本建模方法大致可以分为两类:(1) 忽略词序.对文本进行浅层语义建模 (代表模型包括 LDA,EarthMover' ...

  3. python实现胶囊网络_胶囊网络Cod的分类模块

    参考Capsule Network Code,我使用的只是上述代码中的分类模块,因此下面是我从链接中提取的完整分类代码.在from __future__ import division, print_ ...

  4. python做社会网络分析_社交网络分析(Social Network Analysis in Python)①

    今天的网络是我们日常生活的一部分. 让我们学习如何使用网络在Python中可视化和理解社交网络 网络无处不在,道路网络,社交媒体上的朋友和关注者网络以及办公室同事网络. 他们在日常生活中发挥着重要作用 ...

  5. python实现胶囊网络_胶囊网络 -- Capsule Networks

    胶囊网络是 vector in vector out的结构,最后对每个不同的类别,输出不一个向量,向量的模长表示属于该类别的概率. 例如,在数字识别中,两个数字虽然重叠在一起,Capsule中的两个向 ...

  6. python中心性评价_复杂网络中边的中心性(Edge Centrality)

    一分钟读完全文 补充了OSMNX给的官方demo中的一些未描述清楚的地方.对复杂网络中的主要用到的两种边中心性betweenness centrality以及current-flow closenes ...

  7. python微服务监控_基于网络抓包实现kubernetes中微服务的应用级监控

    微服务是什么? 此话题不是本文重点,如你还不知道.请谷歌一波,会有遍地的解释.引用下图说明下微服务可能呈现的形态: 微服务监控的挑战 监控的目的是为了让集群中所有的服务组件,不管是HTTP服务,数据库 ...

  8. 两台电脑通过usb共享网络_避开网络限制,通过蓝牙共享网络连接

    在部分环境中,如校园网等,网络管理员限制了只允许指定的设备访问网络,又或者在路由器后台设置上网认证页面,只允许登记了的计算机访问网络,而不允许手机访问公司的无线网络,甚至拉黑使用网络共享软件的用户.这 ...

  9. mac wmware 无网络_无线网络中常用的技术名词

    1.LAN:即局域网: 是路由和主机组成的内部局域网,一般为有线网络. 2.WAN:即广域网: 是外部一个更大的局域网. 3.WLAN(Wireless LAN,即无线局域网): 前面我们说过LAN是 ...

最新文章

  1. centos 调整home分区xfs_centos 7.4 磁盘空间不足,扩容根分区 --lvm模式
  2. 台式电脑计算机无法启动 启动修复,Win10启动修复无法修复你的电脑解决方法
  3. 当前页面怎么调用子集iframe页面的方法
  4. 云计算适用于中小企业吗?
  5. 解密TDE加密数据库
  6. flink 3-转换
  7. #Spring代理的简单例子#
  8. 漫谈C#编程中的多态与new关键字
  9. JS HTTP 请求库哪家强?Axios,Request,Superagent,Fetch 还是 Supertest
  10. C#LeetCode刷题-树
  11. 小米武大共建人工智能实验室,先期提供1000万研发经费
  12. 17. --cover-- 覆盖掩盖 (词19)
  13. Linux—MySQL安装配置详解
  14. [转] Centos 6.4 python 2.6 升级到 2.7
  15. Andoid游戏【真情表白】让你心爱的人在游戏中感受真情!
  16. WP7开发平台介绍及开发注意事项【WP7学习札记之二】
  17. 计算机辅助工程分析及应用论文,浅谈计算机辅助工程(CAE) 毕业设计(论文).doc...
  18. 抖音快手火山 热门采集/个人主页无水印视频批量解析下载工具2019-11-11
  19. 帆软大数据自定义分页
  20. 新员工访谈-ORID(事实、体验、理解、决定)

热门文章

  1. 学习笔记——simulink的建模与仿真流程
  2. swap空间扩容方法
  3. 不定宽高的div水平垂直居中
  4. 检验一个数据集是否是正太分布
  5. code-server、docker-compose安装wordpress+mysql、wordpress公式插件、markdown插件、目录插件、调序插件、统计插件、分享点赞打赏插件
  6. STM32定时器时间计算公式
  7. 软件公司的商业模式与招聘
  8. gstreamer中h264对齐方式au和nal
  9. ES6 -- fill详解
  10. 快速上手MATLAB图像处理:100种项目全覆盖