Invariant Risk Minimization原理与最小实现

  • 1、Invariant Risk Minimization原理
    • 1.1提出问题
    • 1.2提出模型
  • 2、IRM最小实现
  • 参考文献

IRM(Invariant Risk Minimization)是2019年Martin Arjovsky等人提出的一种用于跨域图像分类的新方法,其提出的背景是当我们使用机器学习方法完成图片分类任务时,训练模型所使用的数据集与真实情况的数据集可能存在差别(数据集分布偏移),造成这种分布偏移的原因有很多,比如:数据选择的偏差(单一环境)、混淆因素等,该问题被称为跨域分类问题(注:跨域分类可能在其他的地方有其他的意思),目前大部分解决的方法是减小跨域分布偏差或者提取不变特征。而Martin提出的方法与之前很多跨域分类方法不同之处在于:为了提高机器学习的可解释性,并从根本上解决跨域分类问题,Martin考虑从数学方面推导出特征与标签预测的内在因果关系,即特征与标签之前存在与域无关的内在因果关系。

1、Invariant Risk Minimization原理

1.1提出问题

首先作者提出了一个问题,假设有一个SEM模型:

如上式所示,X1X_1X1​是一组服从正态分布的数据,YYY是由X1X_1X1​加上一个服从正态分布的白噪声构成,X2X_2X2​是YYY加上一个服从正态分布的白噪声构成。

当使用最小二乘方法由(X1,X2)(X_1,X_2)(X1​,X2​)对YYY进行预测时,设其预测模型为:Y^e=X1eα1^+X2eα2^\hat{Y}^e=X_1^e\hat{\alpha_1}+X_2^e\hat{\alpha_2}Y^e=X1e​α1​^​+X2e​α2​^​,因此若对X1X_1X1​与YYY的噪声乘以一个与环境有关的系数,那么当使用X1X_1X1​与X2X_2X2​预测YYY时,其根据算法是否能够识别出不变特征,回归系数会出现以下三种情况,因此作者的目标是得到第一种情况。

1.2提出模型

根据所总结的问题,作者做出如下定义,将模型分为两个部分,即数据表示Φ\PhiΦ与分类器ω^\hat{\omega}ω^。

将定义转化为数学模型得IRM表达式,

但是由于上式是一个两层优化问题,因此将上式简化为单变量优化问题,

其中Φ\PhiΦ成为不变预测器,其由两项组成,即经验风险最小项和不变风险最小项,而λ\lambdaλ作为平衡两项的一个超参数;由IRM到IRMv1的转变过程,作者还考虑了其他的因素,详细推导可看其论文第三章。

最终作者根据所提出的模型得到训练的损失函数表达式:

2、IRM最小实现

参照论文附录的基于Pytorch的IRM最小实现

import torch
from torch.autograd import grad
import numpy as np
import torchvisiondef compute_penalty(losses, dummy_w):# print(np.shape(losses[0::2]))# print(dummy_w)g1 = grad(losses[0::2].mean(), dummy_w, create_graph=True)[0]g2 = grad(losses[1::2].mean(), dummy_w, create_graph=True)[0]# print(g1*g2)return (g1*g2).sum()def example_1(n=10000, d=2, env=1):x = torch.randn(n, d)*envy = x + torch.randn(n, d)*envz = y + torch.randn(n, d)# z = y# print(np.shape(torch.cat((x, z), 1))) # torch.Size([10000, 4])return torch.cat((x, z), 1), y.sum(1, keepdim=True)phi = torch.nn.Parameter(torch.ones(4, 1))
# print(phi)
dummy_w = torch.nn.Parameter(torch.Tensor([1.0]))
# print(dummy_w)
opt = torch.optim.SGD([phi], lr=1e-3)
mse = torch.nn.MSELoss(reduction="none")environments = [example_1(env=0.1), example_1(env=1.0)]# s = [[1, 2], [3, 4]]
# for i, j in s:
#     print(i)
#     print(j)
for iteration in range(50000):error = 0penalty = 0for x_e, y_e in environments:# print(np.shape(x_e))# print(np.shape(y_e))p = torch.randperm(len(x_e))error_e = mse(x_e[p]@phi*dummy_w, y_e[p])# error_e = mse(torch.matmul(x_e[p], phi) * dummy_w, y_e[p])# print(np.shape(error_e))penalty += compute_penalty(error_e, dummy_w)error += error_e.mean()# print(iteration)# print(error_e.mean())# print(error)opt.zero_grad()(1e-5 * error + penalty).backward()opt.step()if iteration % 1000 == 0:print(phi)

参考文献

Arjovsky, M., et al. (2019). “Invariant Risk Minimization.”

IRM(Invariant Risk Minimization)原理与最小实现相关推荐

  1. mixup: BEYOND EMPIRICAL RISK MINIMIZATION

    原文:https://arxiv.org/pdf/1710.09412.pdf 代码:https://github.com/hongyi-zhang/mixup 摘要:深度神经网络非常强大,但也有一些 ...

  2. ICLR2018_mixup: Beyond Empirical Risk Minimization

    作者 Hongyi Zhang 张宏毅 @ 张宏毅知乎      北大->MIT    论文所属FAIR Abstract 深度神经网络有些不好的行为:强记忆和对对抗样本敏感 Christian ...

  3. structural risk minimization

    结构风险最小化(SRM)是机器学习中使用的归纳原理.通常在机器学习中,必须从有限数据集中选择广义模型,随之产生过度拟合的问题--模型变得过于强烈地适应训练集的特性而对新数据的概括性差.SRM原则通过平 ...

  4. 机器学习理论 之 经验风险最小化(Empirical Risk Minimization)

    该理论探讨的是模型在training set上的error 与 generation error的关系. 训练模型时,需要多少个样本,达到什么精度,都是由理论依据的. 理论点: 偏差方差权衡(Bias ...

  5. mixup:beyond empirical risk minimization

    全网最全:盘点那些图像数据增广方式Mosiac,MixUp,CutMix等. - 知乎全网最全:盘点那些图像数据增广方式Mosiac,MixUp,CutMix等. 本文由林大佬原创,转载请注明出处,来 ...

  6. 【Mixup】《Mixup:Beyond Empirical Risk Minimization》

    ICLR-2018 文章目录 1 Background and Motivation 2 Related Work 3 Advantages / Contributions 4 Method 5 Ex ...

  7. [ICLR 2018] mixup: Beyond Empirical Risk Minimization

    Contents Mixup Experiments Image Classification Task Speech data Memorization of Corrupted Labels Ro ...

  8. 【深度学习】Mixup: Beyond Empirical Risk Minimization

    博主整理了近几年混合样本数据增强(Mixed Sample Data Augmentation)相关论文和代码,并分享在github上,地址如下, https://github.com/JasonZh ...

  9. [论文评析]Cross-Domain Empirical Risk Minimization for Unbiased Long-Tailed Classification,AAAI,2022

    Cross-Domain Empirical Risk Minimization for Unbiased Long-Tailed Classification 文章信息 背景 动机 方法 因果分析 ...

最新文章

  1. GraphQL and Relay 浅析
  2. grunt 打包前端代码
  3. 【转】TCP和UDP的区别
  4. 接口访问加密_加密“访问”的争论日益激烈
  5. hdu 1142 记忆化搜索
  6. mysql基础之帮助信息
  7. 新兴IT企业特斯拉(四)——Model 3
  8. angular第六天
  9. 线程池创建线程数量讨论
  10. Java输出菱形图案
  11. 阿里巴巴编程考试认证java编程规范+考试分享
  12. 【琐识】日常获取知识随笔
  13. 用大白话聊聊分布式系统
  14. tensorflow中model.compile()用法
  15. 解决phpstudy的Apache启动失败
  16. 解决--cidaemon cpu 100%
  17. AWS 吹走了私有云天空中最后一片乌云
  18. JPress企业站主题-jpressicu使用教程
  19. 均值归一化_深度神经网络中的归一化技术
  20. 对开源软件的认识与实践-刘彬

热门文章

  1. 东南大学计算机专业研究生复试,东南大学计算机考研复试经验
  2. 浅析淘宝刷单--我们如何网购
  3. Pexpect模块使用
  4. 用html写出分子分母,数学中的分数分子分母用英文拼写方法
  5. ESP32学习笔记(48)——WiFi蓝牙网关
  6. win10系统开机后正常运行时间不重置
  7. 通过ArcCatalog进行矢量数据的入库
  8. 路由器TL-WR800N固件升级
  9. PDF文档转Word为什么文字都是乱码,怎么解决?
  10. 点双连通分量 [HNOI2012]矿场搭建