IRM(Invariant Risk Minimization)原理与最小实现
Invariant Risk Minimization原理与最小实现
- 1、Invariant Risk Minimization原理
- 1.1提出问题
- 1.2提出模型
- 2、IRM最小实现
- 参考文献
IRM(Invariant Risk Minimization)是2019年Martin Arjovsky等人提出的一种用于跨域图像分类的新方法,其提出的背景是当我们使用机器学习方法完成图片分类任务时,训练模型所使用的数据集与真实情况的数据集可能存在差别(数据集分布偏移),造成这种分布偏移的原因有很多,比如:数据选择的偏差(单一环境)、混淆因素等,该问题被称为跨域分类问题(注:跨域分类可能在其他的地方有其他的意思),目前大部分解决的方法是减小跨域分布偏差或者提取不变特征。而Martin提出的方法与之前很多跨域分类方法不同之处在于:为了提高机器学习的可解释性,并从根本上解决跨域分类问题,Martin考虑从数学方面推导出特征与标签预测的内在因果关系,即特征与标签之前存在与域无关的内在因果关系。
1、Invariant Risk Minimization原理
1.1提出问题
首先作者提出了一个问题,假设有一个SEM模型:
如上式所示,X1X_1X1是一组服从正态分布的数据,YYY是由X1X_1X1加上一个服从正态分布的白噪声构成,X2X_2X2是YYY加上一个服从正态分布的白噪声构成。
当使用最小二乘方法由(X1,X2)(X_1,X_2)(X1,X2)对YYY进行预测时,设其预测模型为:Y^e=X1eα1^+X2eα2^\hat{Y}^e=X_1^e\hat{\alpha_1}+X_2^e\hat{\alpha_2}Y^e=X1eα1^+X2eα2^,因此若对X1X_1X1与YYY的噪声乘以一个与环境有关的系数,那么当使用X1X_1X1与X2X_2X2预测YYY时,其根据算法是否能够识别出不变特征,回归系数会出现以下三种情况,因此作者的目标是得到第一种情况。
1.2提出模型
根据所总结的问题,作者做出如下定义,将模型分为两个部分,即数据表示Φ\PhiΦ与分类器ω^\hat{\omega}ω^。
将定义转化为数学模型得IRM表达式,
但是由于上式是一个两层优化问题,因此将上式简化为单变量优化问题,
其中Φ\PhiΦ成为不变预测器,其由两项组成,即经验风险最小项和不变风险最小项,而λ\lambdaλ作为平衡两项的一个超参数;由IRM到IRMv1的转变过程,作者还考虑了其他的因素,详细推导可看其论文第三章。
最终作者根据所提出的模型得到训练的损失函数表达式:
2、IRM最小实现
参照论文附录的基于Pytorch的IRM最小实现
import torch
from torch.autograd import grad
import numpy as np
import torchvisiondef compute_penalty(losses, dummy_w):# print(np.shape(losses[0::2]))# print(dummy_w)g1 = grad(losses[0::2].mean(), dummy_w, create_graph=True)[0]g2 = grad(losses[1::2].mean(), dummy_w, create_graph=True)[0]# print(g1*g2)return (g1*g2).sum()def example_1(n=10000, d=2, env=1):x = torch.randn(n, d)*envy = x + torch.randn(n, d)*envz = y + torch.randn(n, d)# z = y# print(np.shape(torch.cat((x, z), 1))) # torch.Size([10000, 4])return torch.cat((x, z), 1), y.sum(1, keepdim=True)phi = torch.nn.Parameter(torch.ones(4, 1))
# print(phi)
dummy_w = torch.nn.Parameter(torch.Tensor([1.0]))
# print(dummy_w)
opt = torch.optim.SGD([phi], lr=1e-3)
mse = torch.nn.MSELoss(reduction="none")environments = [example_1(env=0.1), example_1(env=1.0)]# s = [[1, 2], [3, 4]]
# for i, j in s:
# print(i)
# print(j)
for iteration in range(50000):error = 0penalty = 0for x_e, y_e in environments:# print(np.shape(x_e))# print(np.shape(y_e))p = torch.randperm(len(x_e))error_e = mse(x_e[p]@phi*dummy_w, y_e[p])# error_e = mse(torch.matmul(x_e[p], phi) * dummy_w, y_e[p])# print(np.shape(error_e))penalty += compute_penalty(error_e, dummy_w)error += error_e.mean()# print(iteration)# print(error_e.mean())# print(error)opt.zero_grad()(1e-5 * error + penalty).backward()opt.step()if iteration % 1000 == 0:print(phi)
参考文献
Arjovsky, M., et al. (2019). “Invariant Risk Minimization.”
IRM(Invariant Risk Minimization)原理与最小实现相关推荐
- mixup: BEYOND EMPIRICAL RISK MINIMIZATION
原文:https://arxiv.org/pdf/1710.09412.pdf 代码:https://github.com/hongyi-zhang/mixup 摘要:深度神经网络非常强大,但也有一些 ...
- ICLR2018_mixup: Beyond Empirical Risk Minimization
作者 Hongyi Zhang 张宏毅 @ 张宏毅知乎 北大->MIT 论文所属FAIR Abstract 深度神经网络有些不好的行为:强记忆和对对抗样本敏感 Christian ...
- structural risk minimization
结构风险最小化(SRM)是机器学习中使用的归纳原理.通常在机器学习中,必须从有限数据集中选择广义模型,随之产生过度拟合的问题--模型变得过于强烈地适应训练集的特性而对新数据的概括性差.SRM原则通过平 ...
- 机器学习理论 之 经验风险最小化(Empirical Risk Minimization)
该理论探讨的是模型在training set上的error 与 generation error的关系. 训练模型时,需要多少个样本,达到什么精度,都是由理论依据的. 理论点: 偏差方差权衡(Bias ...
- mixup:beyond empirical risk minimization
全网最全:盘点那些图像数据增广方式Mosiac,MixUp,CutMix等. - 知乎全网最全:盘点那些图像数据增广方式Mosiac,MixUp,CutMix等. 本文由林大佬原创,转载请注明出处,来 ...
- 【Mixup】《Mixup:Beyond Empirical Risk Minimization》
ICLR-2018 文章目录 1 Background and Motivation 2 Related Work 3 Advantages / Contributions 4 Method 5 Ex ...
- [ICLR 2018] mixup: Beyond Empirical Risk Minimization
Contents Mixup Experiments Image Classification Task Speech data Memorization of Corrupted Labels Ro ...
- 【深度学习】Mixup: Beyond Empirical Risk Minimization
博主整理了近几年混合样本数据增强(Mixed Sample Data Augmentation)相关论文和代码,并分享在github上,地址如下, https://github.com/JasonZh ...
- [论文评析]Cross-Domain Empirical Risk Minimization for Unbiased Long-Tailed Classification,AAAI,2022
Cross-Domain Empirical Risk Minimization for Unbiased Long-Tailed Classification 文章信息 背景 动机 方法 因果分析 ...
最新文章
- GraphQL and Relay 浅析
- grunt 打包前端代码
- 【转】TCP和UDP的区别
- 接口访问加密_加密“访问”的争论日益激烈
- hdu 1142 记忆化搜索
- mysql基础之帮助信息
- 新兴IT企业特斯拉(四)——Model 3
- angular第六天
- 线程池创建线程数量讨论
- Java输出菱形图案
- 阿里巴巴编程考试认证java编程规范+考试分享
- 【琐识】日常获取知识随笔
- 用大白话聊聊分布式系统
- tensorflow中model.compile()用法
- 解决phpstudy的Apache启动失败
- 解决--cidaemon cpu 100%
- AWS 吹走了私有云天空中最后一片乌云
- JPress企业站主题-jpressicu使用教程
- 均值归一化_深度神经网络中的归一化技术
- 对开源软件的认识与实践-刘彬