Domain Adaption Without Source Data论文阅读笔记
总的来说:
源模型经过一个特征提取器和一个分类器后得到一个标签ysy_sys
可训练的目标模型输入目标样本,经过特征提取器后,分类器Cs2tC_{s2t}Cs2t用源伪标签训练,分类器CtC_tCt用目标伪标签yty_tyt训练。
其中的yty_tyt是由APM得到的。APM中,一旦得到伪标签,就移除不可靠样本(集合到集合基于距离的置信度)。用目标伪标签以自学习方式,源伪标签作为正则化器将目标模型适应到目标域。
APM
目标模型两种loss,Cs2tC_{s2t}Cs2t利用的伪标签(预训练源模型后推断的目标样本得到的),不会改变。对所有目标样本利用APM分配伪标签,包括对每个类的可靠目标样本(多样本)。
伪标签
给定目标样本到特征提取器产生嵌入特征ftf_tft,技术嵌入特征和所有APM中样本的相似度分数,APM中每个类可有不同数量的类代表样本。最后选择最相似的类得到伪标签。
Cs2tC_{s2t}Cs2t得到的loss是LsourceL_{source}Lsource,CtC_tCt的loss是LselfL_{self}Lself用到了w,如果最相似小于次相似w是1否则是0。
一、引言
这篇文章中提出了一个新颖的方法,可以解耦直接用源数据的域适应过程,通过利用预训练源模型。关键思想:使用一个预训练源模型和reliable目标样本以自训练的方式更新目标模型。产生问题:怎样从预训练源模型中挑选可靠的目标样本。
在域适应中,源域目标域在协变量偏移下紧密相关。预测不准确性课通过自熵量化,H(x)=−∑p(x)log(p(x))H(x)=-\sum p(x)log(p(x))H(x)=−∑p(x)log(p(x)),这里小的熵意味着更自信的预测。基于此,猜想在无标签的目标样本中,通过预训练源模型衡量的熵低的样本足够可靠。
为验证,我们衡量被喂入一个预训练源模型的目标样本的自熵,然后分析准确性和样本分布。
如图1,将熵值低于0.2的样本作“可靠样本”,占了总样本的30%左右。从结果看,我们可以总结:一个目标模型可以用可靠的目标样本通过自熵准则训练,但致命的是样本很少。
为了解决可靠但很少目标样本的问题,提出一个新的两部分组成的框架。一个是来自源域的预训练模型,所有权重都被冻结,另一个是从预训练模型初始化来的目标模型,但是通过两个losses逐渐发展。第一个损失使用来自预训练源模型的所有目标样本的source-oriented pseudo labels,这防止目标模型产生由第二个自学习损失产生的自偏差问题。第二个损失利用从可训练目标网络得到的目标样本的target-oriented pseudo labels优化目标模型。
更准确的,我们周期性的存储每个类的低熵可靠样本作为prototypes,在一个记忆库中训练过程中。然后,我们基于嵌入样本和存储类prototypes直接的相似度为一个目标样本分配target-oriented pseudo labels。然而,伪标签可能不总是准确的,所以我们提出一个基于置信度的样本过滤,通过测量集合-集合的距离。训练中,逐渐增加第二个损失的影响,允许我们的目标模型以渐进方式适应到目标域。
我们工作的主要贡献:1)解决存在数据隐私问题的环境下的域适应,这是第一个训练时没有任何源数据的工作。2)为了从源数据中结构域适应,提出新颖的逐渐发展的框架:基于可靠的目标样本,来自源域细腻些的正则化。3)尽管不用任何源样本训练我们的目标模型,比传统用带标签源数据的模型更好效果。
二、方法
1. 问题
与UDA相比,训练过程中没有用到任何带标签的源样本。
UDA目标是最小化一个带标签源域和一个不带标签目标域之间的差异,前提是源和目标样本取自不同但相关的概率分布。
SFDA假设无法得到源域样本,利用源域的预训练模型。目的在于通过预训练源模型的参数和不带标签的目标域实现无监督域适应。
2. SFDA框架
主要包括两种模型:预训练源模型和一个可训练的目标模型。
源模型包括:一个特征提取器FsF_sFs和一个分类器CsC_sCs,在源域预训练后这些模块的参数都是固定的。
可训练的目标模型包括:一个特征提取器FtF_tFt(有多分支分类器Cs2tC_{s2t}Cs2t和CtC_tCt),每个模块的参数都是用预训练源模型的参数初始化来的(θFs,θCs\theta_{F_s},\theta_{C_s}θFs,θCs)。
当目标样本作为输入送入FtF_tFt时,上分支(Cs2tC_{s2t}Cs2t)用source−orientedsource-orientedsource−oriented伪标签ys^\hat{y_s}ys^进行训练(伪标签在预训练分类器CsC_sCs中得到)。下分支(CtC_tCt)用target−orientedtarget-orientedtarget−oriented伪标签yt^\hat{y_t}yt^进行训练(由提出的APM得到的)。一旦从APM中得到了伪标签yt^\hat{y_t}yt^,我们就通过集合到集合的基于距离的置信度移除不可靠样本。
总的来说,以一种自学习方式利用伪标签yt^\hat{y_t}yt^将目标模型适应到目标域,用预训练源模型的得到的源知识ys^\hat{y_s}ys^作为正则化器。
-Adaptive prototype memory
目标模型更新参数以两种类型的loss函数:第一种loss方程(Cs2tC_{s2t}Cs2t中)旨在保持源域的信息。Cs2tC_{s2t}Cs2t利用预训练源模型干扰目标样本后获得的伪标签(FsF_sFs和CsC_sCs)。
注意:每个目标样本的Cs2tC_{s2t}Cs2t损失在训练过程中不变。
为了训练中也能改进目标模型,通过APM为所有目标样本分配伪标签。这个APM包括每个类的可靠目标样本(叫做multi−prototypesmulti-prototypesmulti−prototypes。
第一步:计算出目标样本的归一化自熵H(xt)=−1logNc∑l(xt)log(l(xt))H(x_t)=-\frac{1}{logN_c}\sum l(x_t)log(l(x_t))H(xt)=−logNc1∑l(xt)log(l(xt)),这里的l(xt)l(x_t)l(xt)代表分类器CtC_tCt得到的预测概率,NcN_cNc代表总类数。然后建立一个class−wiseclass-wiseclass−wise熵集Hc=H(xt)∣xt∈XcH_c={H(x_t)|x_t\in X_c}Hc=H(xt)∣xt∈Xc,这里XcX_cXc将CtC_tCt预测得到的样本集记作类ccc。
下一步:选择可靠样本,即multi−prototypesmulti-prototypesmulti−prototypes,可表示每个类。方法:选择每个类中自熵较低的固定数量的样本。但可能某些类本身熵值就较低,所以解决方法:对每个类不同数量prototypes。
首先找到每个类的最低熵,然后设置其中的最大值作为阈值来挑选prototypes,即η=max{min(Hc)∣c∈C},C=1...Nc\eta=max\{\min(H_c)|c\in C\},C={1...N_c}η=max{min(Hc)∣c∈C},C=1...Nc是类集合。利用这个可变的阈值,得到每个类的multi−prototypesmulti-prototypesmulti−prototypes:
Mc=Ft(xt)∣xt∈Xc,H(xt)<=ηM_c={F_t(x_t)|x_t\in X_c,H(x_t)<=\eta} Mc=Ft(xt)∣xt∈Xc,H(xt)<=η
这里每个prototype包括一个嵌入特征Ft(xt)F_t(x_t)Ft(xt),并且这些多prototypes之后还要用于伪标签化和基于置信度过滤的过程,所以我们将所有prototypes存储进APM中。更新AMP花销大,所以周期性更新,经验上,每100步更新一次APM。
-伪标签化
基于APM中的multi-prototypes,我们可以为未标记的目标样本分配伪标签。给定一个目标样本xt∈RIx_t\in R^Ixt∈RI,将其喂入特征提取器Ft:RI−>REF_t:R^I->R^EFt:RI−>RE,产生一个嵌入特征ft∈REf_t\in R^Eft∈RE,这里的I,EI,EI,E代表输入空间和嵌入空间的维度。然后计算出嵌入特征和所有APM中prototypes间的相似度分数:
sc(xt)=1∣Mc∣∑pc∈McpcTft∣∣pc∣∣2∣∣ft∣∣2s_c(x_t)=\frac{1}{|M_c|}\sum_{p_c\in M_c} \frac{p_c^Tf_t}{||p_c||_2||f_t||_2} sc(xt)=∣Mc∣1pc∈Mc∑∣∣pc∣∣2∣∣ft∣∣2pcTft
这里pc∈REp_c\in R^Epc∈RE代表类别c中的一个多prototypes。
最后,通过选择最相似的类得到伪标签,即yt^=argmaxcsc(xt),∀c∈C\hat{y_t}=argmax_cs_c(x_t),\forall c\in Cyt^=argmaxcsc(xt),∀c∈C。
-基于置信度的过滤
一旦通过APM得到伪标签,我们就可以用传统的交叉熵训练FtF_tFt和CtC_tCt。然后,由于我们没有用任何groundtruthgroundtruthgroundtruth标签,会存在不确定性和错误传播问题。为获得更多可靠伪标签,提出一个基于样本置信度的过滤机制。思想:用集合-集合距离估计一个伪标签的置信度,考虑两个集合间的corner情况很有效???
第一个集合是一个单独的元素集(包括一个目标样本),另一个集合可以是每个类的多prototypes。准确地,对每个目标样本得到最相似类Mt1M_{t1}Mt1和次相似类Mt2M_{t2}Mt2的多prototypes。然后衡量单独集合Q=ftQ={f_t}Q=ft和Mt1M_{t1}Mt1间的距离(Hausdorff距离)。
dH(Q,Mt1)=maxp∈Mt1d(ft,p)d_H(Q,M_{t1})=max_{p\in M_{t1}}d(f_t,p) dH(Q,Mt1)=maxp∈Mt1d(ft,p)
dH∗(Q,Mt2)=minp∈Mt2d(ft,p)d_H^*(Q,M_{t2})=min_{p\in M_{t2}}d(f_t,p) dH∗(Q,Mt2)=minp∈Mt2d(ft,p)
仅当最相似类比次相似类更近时,定义一个可靠样本。最后为每个目标样本分配一个置信度分数:
w(xt)={1ifdH(Q,Mt1)<dH∗(Q,Mt2)0otherwisew(x_t)= \begin{cases} 1& {if d_H(Q,M_{t1})<d_H^*(Q,M_{t2})}\\ 0& \text{otherwise} \end{cases}w(xt)={10ifdH(Q,Mt1)<dH∗(Q,Mt2)otherwise
-优化
目标模型由两个可训练分类器Cs2t,CtC_{s2t},C_tCs2t,Ct。预训练源模型得到的伪标签y^s\hat y_sy^s用于训练Cs2tC_{s2t}Cs2t:
Lsource(Dt)=−Ext∼Dt∑c=1Nc1[c=y^s]log(σ(Cs2t(Ft(xt))))L_{source}(D_t)=-E_{x_t\sim D_t}\sum_{c=1}^{N_c}1_{[c=\hat y_s]}log(\sigma(C_{s2t}(F_t(x_t))))Lsource(Dt)=−Ext∼Dtc=1∑Nc1[c=y^s]log(σ(Cs2t(Ft(xt)))),这里111是一个指示函数。上面公式帮助保持源域知识同时作为一个正则化器。
APM获得的伪标签yt^\hat{y_t}yt^用于训练CtC_tCt作为补充监督:
Lself(Dt)=−Ext∼Dt∑c=1Ncw(xt)1[c=y^t]log(σ(Ct(Ft(xt))))L_{self}(D_t)=-E_{x_t\sim D_t}\sum_{c=1}^{N_c}w(x_t)1_{[c=\hat y_t]}log(\sigma(C_{t}(F_t(x_t)))) Lself(Dt)=−Ext∼Dtc=1∑Ncw(xt)1[c=y^t]log(σ(Ct(Ft(xt))))
使用置信分数w(⋅)w(\cdot)w(⋅)仅对自信的样本计算损失,总损失函数为:
Ltotal(Dt)=(1−α)Lsource(Dt)+αLself(Dt)L_{total}(D_t)=(1-\alpha)L_{source}(D_t)+\alpha L_{self}(D_t)Ltotal(Dt)=(1−α)Lsource(Dt)+αLself(Dt)
这里的α\alphaα用于平衡源正则化损失Lsource(Dt)L_{source}(D_t)Lsource(Dt)和自学习损失Lself(Dt)L_{self}(D_t)Lself(Dt)。
先前阶段,伪标签y^t\hat y_ty^t是很稳定的,所以我们逐渐增加α\alphaα从0到1。测试阶段,我们使用分类器CtC_tCt的分类概率,最重要的超参数是更新周期。
三、实验
遵循之前的研究,我们使用在ImageNet上预训练的ResNet-50或ResNet-101作为一个基础特征提取器,使用相同的网络架构做固定的源模型和可训练的目标模型。对Office-31设定最大迭代步数5000.
我们的实验中,训练图像重设为256256,并随机用一个随机水平翻转裁剪到224224。SGD作为优化器,权重衰减为0.0005,Momentum为0.9。基础学习率设为10-3,在一个预训练特征提取器中的所有微调层都以学习率10-4进行优化。
学习率应用lrp=lr0(1+α⋅p)−βl_{r_p}=l_{r_0}(1+\alpha \cdot p)^{-\beta}lrp=lr0(1+α⋅p)−β,这里lr0l_{r_0}lr0是基础学习率,p是一个相对步数在训练过程中从0变为1,α=10,β=0.75\alpha =10,\beta =0.75α=10,β=0.75。
我们方法中最重要的超参数是APM的更新周期。每100迭代更新一个APM模块。
3.3实验分析
基于置信度的过滤有效解决了不完美伪标签的不确定性。为进一步验证我们的过滤机制,我们测量了有效训练样本的百分比,。可以看到目标样本的一小部分用来在开始时训练,但随着训练继续有效样本的数量逐渐增加。
特征可视化为视觉上验证方法有效性,比较了我们SFDA在Office-31的A-W和来自ResNet-50的t-SNE嵌入特征。我们的方法更好的对齐了源和目标数据。结果说明有可能在不接触源数据情况下减小两个不同域间的差距。
可靠样本的统计数据将小于0.2熵值的样本作为“可靠样本”,占据整个样本的30%。注意在训练过程中这些数据随着我们SFDA的发展会改变。因此,我们进一步分析了训练过程中这些数据的改变。图4(a)发现可靠样本的百分比逐渐增加最终超过总样本的50%。更重要的,可靠样本伪标签的准确率也增加。
APM更新周期周期性更新APM来反映目标域的数据,使用我们逐渐增进的目标模型。图4(b)显示了不同APM更新周期我们方法的性能。更新周期越短,准确率更好,但训练过程要求更多计算。相反,增加更新周期会降低性能,因为目标模型不能完全利用它的自学习机制。经验熵,对所有数据集将更新周期设为100.
**权衡参数α\alphaα**分析了源正则化损失和自学习损失的权衡参数。实验中,调查了两种策略,静止α\alphaα和动态α\alphaα。对于静止α\alphaα,将值从0变为1.仅使用源正则化损失即α=0\alpha =0α=0 ,不利用更新目标模型的优势,所以性能低于其它设定。
另一方面,仅依赖自学习损失α=1\alpha =1α=1,可能由于自偏差陷入局部最小。在SFDA中,设置动态α\alphaα用α=2(1+exp(−10⋅iter/max_iter))−1−0.5)\alpha = 2(1+exp(-10\cdot iter/max\_iter))^{-1}-0.5)α=2(1+exp(−10⋅iter/max_iter))−1−0.5),从而目标模型逐渐从源模型进行更新。从实验课观察到动态α\alphaα 比所有静止α\alphaα的设定都更好,验证了我们动态规划的有效性。
Domain Adaption Without Source Data论文阅读笔记相关推荐
- Model Adaption: Unsupervised Domain Adaption Without Source Data
三.方法 用模型来进行无监督模型适应问题,只有来自源域的预训练预测模型CCC和无标签的目标数据集XtX_tXt,目的是将CCC适应到带有XtX_tXt的目标域. 提出了一个协作类条件生成对抗网络( ...
- Learning from Synthetic Data for Crowd Counting in the Wild 论文阅读笔记
Learning from Synthetic Data for Crowd Counting in the Wild 论文阅读笔记 发表:CVPR 2019 人群计数任务在多变的环境,大范围的人群中 ...
- Are VQA Systems RAD? Measuring Robustness to Augmented Data with Focused Interventions 论文阅读笔记
Are VQA Systems RAD? Measuring Robustness to Augmented Data with Focused Interventions 论文阅读笔记 一.Abst ...
- 点云配准论文阅读笔记--Comparing ICP variants on real-world data sets
目录 写在前面 点云配准系列 摘要 1引言(Introduction) 2 相关研究(Related work) 3方法( Method) 3.1输入数据的敏感性 3.2评价指标 3.3协议 4 模块 ...
- 论文阅读笔记(15):Deep Subspace Clustering with Data Augmentation,深度子空间聚类+数据增强
论文阅读笔记(15):Deep Subspace Clustering with Data Augmentation,深度子空间聚类+数据增强 摘要 1 介绍 2 相关工作 带增强的聚类方法 具有一致 ...
- 论文阅读笔记:Improving Attacks on Speck32 / 64 using Deep Learning
论文阅读笔记:Improving Attacks on Speck32 / 64 using Deep Learning 本文通过神经网络利用了减少轮数的 Speck 的差分性质.为此,作者对神经网络 ...
- 论文阅读笔记——基于CNN-GAP可解释性模型的软件源码漏洞检测方法
本论文相关内容 论文下载地址--Engineering Village 论文阅读笔记--基于CNN-GAP可解释性模型的软件源码漏洞检测方法 文章目录 本论文相关内容 前言 基于CNN-GAP可解释性 ...
- HLA-Face: Joint High-Low Adaptation for Low Light Face Detection论文阅读笔记
HLA-Face: Joint High-Low Adaptation for Low Light Face Detection 论文阅读笔记 这是去年7月读这篇文章的笔记了,今年由于忘记了,又有需要 ...
- DnCNN论文阅读笔记【MATLAB】
DnCNN论文阅读笔记 论文信息: 论文代码:https://github.com/cszn/DnCNN Abstract 提出网络:DnCNNs 关键技术: Residual learning an ...
最新文章
- 使用深度学习检测DGA(域名生成算法)——LSTM的输入数据本质上还是词袋模型...
- 微信网页JSDK接口-wx.chooseImage问题
- 2015下半年软考系统集成管理工程师10月8日作业
- leetcode算法题--最优除法
- 如何得到iterator的当前元素_链表进化!双向链表+迭代器(Iterator)!
- HTTP1.0,HTTP1.1,HTTPS和HTTP2.0的区别
- [How TO]-ubuntu下快速搭建http
- Linux 调优方案, 修改最大连接数(ulimit命令)
- 文本转换成htmldocument
- Windows与Linux区别1
- JS判断文本框中只能输入数字和小数点
- vue的matcher_一张思维导图辅助你深入了解 Vue | Vue-Router | Vuex 源码架构
- 曲线积分和曲面积分及其几何应用、物理应用
- 2022.08 VMware官网下载安装+配置Linux虚拟机,最新最全
- 计算机解码原理图,diy制作改进的CS4398解码 DAC PCB和原理图纸
- 随机排列算法(Fisher-Yates)
- 计算机网络的 166 个核心概念
- QMH、AMC和STM之间的关系
- Lnmp部署tp5项目报错:require(): open_basedir restriction in effect. File(/home/wwwroot/api.321.design/think
- 最全最全的文件扩展名