论文

Belghazi, Mohamed Ishmael, et al. “Mutual information neural estimation.” International Conference on Machine Learning. 2018.

利用神经网络的梯度下降法可以实现快速高维连续随机变量之间互信息的估计,上述论文提出了Mutual Information Neural Estimator (MINE)。NN在维度和样本量上都是线性可伸缩的,MI的计算可以通过反向传播进行训练。

核心

Python实现

现有github上的代码无法计算和估计高维随机变量,只能计算一维随机变量,下面的代码给出的修改方案能够计算真实和估计高维随机变量的真实互信息。

其中,为了计算理论的真实互信息,我们不直接暴力求解矩阵(耗时,这也是为什么要有MINE的原因),我们采用给定生成随机变量的参数计算理论互信息。

SIGNAL_NOISE = 0.2
SIGNAL_POWER = 3

完整代码基于pytorch

# Name: MINE_simple
# Author: Reacubeth
# Time: 2020/12/15 18:49
# Mail: noverfitting@gmail.com
# Site: www.omegaxyz.com
# *_*coding:utf-8 *_*import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
import matplotlib.pyplot as pltSIGNAL_NOISE = 0.2
SIGNAL_POWER = 3data_dim = 3
num_instances = 20000def gen_x(num, dim):return np.random.normal(0., np.sqrt(SIGNAL_POWER), [num, dim])def gen_y(x, num, dim):return x + np.random.normal(0., np.sqrt(SIGNAL_NOISE), [num, dim])def true_mi(power, noise, dim):return dim * 0.5 * np.log2(1 + power/noise)mi = true_mi(SIGNAL_POWER, SIGNAL_NOISE, data_dim)
print('True MI:', mi)hidden_size = 10
n_epoch = 500class MINE(nn.Module):def __init__(self, hidden_size=10):super(MINE, self).__init__()self.layers = nn.Sequential(nn.Linear(2 * data_dim, hidden_size),nn.ReLU(),nn.Linear(hidden_size, 1))def forward(self, x, y):batch_size = x.size(0)tiled_x = torch.cat([x, x, ], dim=0)idx = torch.randperm(batch_size)shuffled_y = y[idx]concat_y = torch.cat([y, shuffled_y], dim=0)inputs = torch.cat([tiled_x, concat_y], dim=1)logits = self.layers(inputs)pred_xy = logits[:batch_size]pred_x_y = logits[batch_size:]loss = - np.log2(np.exp(1)) * (torch.mean(pred_xy) - torch.log(torch.mean(torch.exp(pred_x_y))))# compute loss, you'd better scale exp to bitreturn lossmodel = MINE(hidden_size)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
plot_loss = []
all_mi = []
for epoch in tqdm(range(n_epoch)):x_sample = gen_x(num_instances, data_dim)y_sample = gen_y(x_sample, num_instances, data_dim)x_sample = torch.from_numpy(x_sample).float()y_sample = torch.from_numpy(y_sample).float()loss = model(x_sample, y_sample)model.zero_grad()loss.backward()optimizer.step()all_mi.append(-loss.item())fig, ax = plt.subplots()
ax.plot(range(len(all_mi)), all_mi, label='MINE Estimate')
ax.plot([0, len(all_mi)], [mi, mi], label='True Mutual Information')
ax.set_xlabel('training steps')
ax.legend(loc='best')
plt.show()

结果

变量维度为1


变量维度为3

需要指出的是在计算最终的互信息时需要将基数e转为基数2。如果只是求得一个比较值,在真实使用的过程中可以省略。

参考

https://github.com/mzgubic/MINE

更多
互信息公式

更多内容访问 omegaxyz.com
网站所有代码采用Apache 2.0授权
网站文章采用知识共享许可协议BY-NC-SA4.0授权
© 2020 • OmegaXYZ-版权所有 转载请注明出处

神经网络高维互信息计算Python实现(MINE)相关推荐

  1. 梯度、梯度法、python实现神经网络的梯度计算

    [机器学习]梯度.梯度法.python实现神经网络的梯度计算 一.python实现求导的代码: 二.what is 梯度 三.使用梯度法寻找神经网络的最优参数 四.神经网络的梯度计算 一.python ...

  2. BP神经网络理解原理——用Python编程实现识别手写数字(翻译英文文献)

    BP神经网络理解原理--用Python编程实现识别手写数字   备注,这里可以用这个方法在csdn中编辑公式: https://www.zybuluo.com/codeep/note/163962 一 ...

  3. DL之DNN:自定义2层神经网络TwoLayerNet模型(计算梯度两种方法)利用MNIST数据集进行训练、预测

    DL之DNN:自定义2层神经网络TwoLayerNet模型(计算梯度两种方法)利用MNIST数据集进行训练.预测 导读 利用python的numpy计算库,进行自定义搭建2层神经网络TwoLayerN ...

  4. 深度学习(神经网络) —— BP神经网络原理推导及python实现

    深度学习(神经网络) -- BP神经网络原理推导及python实现 摘要 (一)BP神经网络简介 1.神经网络权值调整的一般形式为: 2.BP神经网络中关于学习信号的求取方法: (二)BP神经网络原理 ...

  5. 神经网络隐藏层个数怎么确定_含有一个隐藏层的神经网络对平面数据分类python实现(吴恩达深度学习课程1第3周作业)...

    含有一个隐藏层的神经网络对平面数据分类python实现(吴恩达深度学习课程1第3周作业): ''' 题目: 建立只有一个隐藏层的神经网络, 对于给定的一个类似于花朵的图案数据, 里面有红色(y=0)和 ...

  6. python神经网络编程 豆瓣,用python构建神经网络

    python深度学习框架学哪个 Python深度学习生态系统在这几年中的演变实属惊艳.pylearn2,已经不再被积极地开发或者维护,大量的深度学习库开始接替它的位置.这些库每一个都各有千秋. 我们已 ...

  7. python神经网络预测的例子,python神经网络预测模型

    python做BP神经网络,进行数据预测,训练的输入和输出值都存在负数,为什么预测值永远为正数? 谷歌人工智能写作项目:神经网络伪原创 如何用jupyter设计银行精准客户存款预测 文案狗. 用jup ...

  8. python神经网络框架有哪些,python调用神经网络模型

    人工智能 Python深度学习库有哪些 由于Python的易用性和可扩展性,众多深度学习框架提供了Python接口,其中较为流行的深度学习库如下:第一:CaffeCaffe是一个以表达式.速度和模块化 ...

  9. 图神经网络相似度计算

    图神经网络相似度计算 注:大家觉得博客好的话,别忘了点赞收藏呀,本人每周都会更新关于人工智能和大数据相关的内容,内容多为原创,Python Java Scala SQL 代码,CV NLP 推荐系统等 ...

  10. 计算Python的代码块或程序的运行时间

    1.运用场景 在很多的时候我们需要计算我们程序的性能,这个时候我们常常需要统计程序运行的时间.下面我们就来说说怎么统计程序的运行时间. 2. 实现方法 计算Python的某个程序,或者是代码块运行的时 ...

最新文章

  1. c语言中gets函数可以输入空格吗_C语言中printf和gets函数的实用技巧
  2. numpy基础(part4)--统计量
  3. UnaryOperator函数式接口
  4. 信息学奥赛一本通(1190:上台阶)
  5. 怎样写一个具有异步交互的React组件的单元测试
  6. 程序员怒了!你敢削减专利奖金,我敢拒绝提交代码!
  7. python向上取整_python向上取整-取整,向上
  8. cc2530设计性实验代码七
  9. GFZRNX学习教程(安装以及rinex格式转换)
  10. oracle工程师 的职业,数据库工程师的职业规划
  11. 复杂性思维第二版 一、复杂性科学
  12. Windows 文件系统格式 Raw格式转换NTFS
  13. linux7如何改ssid,ssid怎么设置,教您网络ssid怎么设置
  14. 百度网盘 linux 上传文件大小限制,Linux 下载百度网盘大文件的方法
  15. qq显示下线通知什么意思_qq最近登录设备显示其他设备,但我手机没有下线通知,怎么回事...
  16. linux locale字符集设置,Linux下通过locale来设置字符集
  17. 【洛谷】P1488 肥猫的游戏(博弈论+全网最详细!!!)
  18. PPT文件带有打开密码怎么解决
  19. Mac OS + Mac PE + Win PE 三合一 U盘制作教程
  20. HAC集群状态检查、切换、数据同步验证方法

热门文章

  1. csdn如何修改文字体及颜色
  2. pls-00302: 必须声明 组件_vue组件
  3. IDEA 对接口进行快速测试(Create Test)
  4. 耳机使用说明书 jbl ua_JBL UA联名款,全新一代真无线运动耳机“UA小黑盒”今日天猫首发...
  5. html直链如何修改成js,javascript – 使用route params直接链接到URL会破坏AngularJS App...
  6. 收藏商品表设计_babycare商品价格及销售情况分析
  7. 谷粒商城:04. 逆向工程完善微服务系统
  8. Node.js:node项目中连接postgresql以及基础使用
  9. Java编程:KMP算法
  10. 算法面试题_求给定字符串的排列、组合、八皇后问题