迁移学习损失函数MMD(最大均值化差异)–python代码实现

MMD介绍

MMD(Max mean discrepancy 最大均值差异)是迁移学习,尤其是Domain adaptation (域适应)中使用最广泛(目前)的一种损失函数,主要用来度量两个不同但相关的分布的距离。两个分布的距离定义为:
MMD(X,Y)=∥1n∑i=1nϕ(xi)−1m∑j=1mϕ(yj)∥H2M M D(X, Y)=\left\|\frac{1}{n} \sum_{i=1}^{n} \phi\left(x_{i}\right)-\frac{1}{m} \sum_{j=1}^{m} \phi\left(y_{j}\right)\right\|_{H}^{2} MMD(X,Y)=∥∥∥∥∥​n1​i=1∑n​ϕ(xi​)−m1​j=1∑m​ϕ(yj​)∥∥∥∥∥​H2​

主代码编写

该代码基于torch.version = ‘1.9.0’

import torch
import torch.nn as nnclass MMDLoss(nn.Module):'''计算源域数据和目标域数据的MMD距离Params:source: 源域数据(n * len(x))target: 目标域数据(m * len(y))kernel_mul:kernel_num: 取不同高斯核的数量fix_sigma: 不同高斯核的sigma值Return:loss: MMD loss'''def __init__(self, kernel_type='rbf', kernel_mul=2.0, kernel_num=5, fix_sigma=None, **kwargs):super(MMDLoss, self).__init__()self.kernel_num = kernel_numself.kernel_mul = kernel_mulself.fix_sigma = Noneself.kernel_type = kernel_typedef guassian_kernel(self, source, target, kernel_mul, kernel_num, fix_sigma):n_samples = int(source.size()[0]) + int(target.size()[0])total = torch.cat([source, target], dim=0)total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))L2_distance = ((total0-total1)**2).sum(2)if fix_sigma:bandwidth = fix_sigmaelse:bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples)bandwidth /= kernel_mul ** (kernel_num // 2)bandwidth_list = [bandwidth * (kernel_mul**i)for i in range(kernel_num)]kernel_val = [torch.exp(-L2_distance / bandwidth_temp)for bandwidth_temp in bandwidth_list]return sum(kernel_val)def linear_mmd2(self, f_of_X, f_of_Y):loss = 0.0delta = f_of_X.float().mean(0) - f_of_Y.float().mean(0)loss = delta.dot(delta.T)return lossdef forward(self, source, target):if self.kernel_type == 'linear':return self.linear_mmd2(source, target)elif self.kernel_type == 'rbf':batch_size = int(source.size()[0])kernels = self.guassian_kernel(source, target, kernel_mul=self.kernel_mul, kernel_num=self.kernel_num, fix_sigma=self.fix_sigma)XX = torch.mean(kernels[:batch_size, :batch_size])YY = torch.mean(kernels[batch_size:, batch_size:])XY = torch.mean(kernels[:batch_size, batch_size:])YX = torch.mean(kernels[batch_size:, :batch_size])loss = torch.mean(XX + YY - XY - YX)return loss

程序验证

##在这里 第2维一定要相同,否则报错
source = torch.rand(64,14)  # 可以理解为源域有64个14维数据
target = torch.rand(32,14)  # 可以理解为源域有32个14维数据
print(target)
>>>output
tensor([[0.9035, 0.0088, 0.5867, 0.5595, 0.9350, 0.2739, 0.8775, 0.5562, 0.5402,0.5242, 0.4745, 0.7307, 0.7791, 0.7420],[0.2798, 0.6476, 0.3744, 0.5406, 0.3941, 0.6669, 0.2026, 0.8296, 0.3071,0.9042, 0.4810, 0.5235, 0.0547, 0.9110],[0.8051, 0.0702, 0.7907, 0.9708, 0.5310, 0.5851, 0.7881, 0.9082, 0.5963,0.9400, 0.3670, 0.8042, 0.5024, 0.2368],[0.5021, 0.7290, 0.3521, 0.6293, 0.8796, 0.2098, 0.0304, 0.9125, 0.3285,0.8485, 0.6877, 0.5695, 0.9506, 0.0752],[0.0798, 0.7908, 0.2785, 0.1369, 0.6762, 0.3342, 0.4930, 0.1807, 0.5963,0.2114, 0.4937, 0.4692, 0.3694, 0.9456],...[0.1638, 0.7100, 0.9024, 0.5154, 0.8746, 0.8611, 0.1314, 0.0308, 0.6660,0.3719, 0.6827, 0.6789, 0.2416, 0.4617],[0.4449, 0.8304, 0.4036, 0.0563, 0.3832, 0.3553, 0.7947, 0.9335, 0.2704,0.9798, 0.2621, 0.4497, 0.9440, 0.7362]])
MMD = MMDLoss()
a = MMD(source=source, target=target)
print(a)
>>>output
tensor(0.1448)

嵌入到CNN中代码实现

先定义一个简单的CNN模型

class Net_only(nn.Module):'''计算源域数据和目标域数据的MMD距离Params:x_in: 输入数据(batch, channel, hight, width)Return:x_out: 输出数据(batch, n_labes)'''## 这里 x_in:batch=64, channel=3, hight=128, width=128## x_out:batch=64, n_labes=5def __init__(self):super(Net_only, self).__init__()self.conv1 = nn.Conv2d(3, 32, 3)self.pool = nn.MaxPool2d(2, 2)self.bn1 = nn.BatchNorm2d(32)self.conv2 = nn.Conv2d(32, 64, 3)self.pool = nn.MaxPool2d(2, 2)self.bn2 = nn.BatchNorm2d(64)self.conv3 = nn.Conv2d(64, 64, 3)self.pool = nn.MaxPool2d(2, 2)self.bn3 = nn.BatchNorm2d(64)self.conv3 = nn.Conv2d(64, 64, 3)self.drop1d = nn.Dropout(0.2)self.bn4 = nn.BatchNorm2d(64)self.fc1 = nn.Linear(64 * 14 * 14, 1024)self.fc2 = nn.Linear(1024, 256)self.fc3 = nn.Linear(256, 5)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.bn1(x)x = self.pool(F.relu(self.conv2(x)))x = self.bn2(x)x = self.pool(F.relu(self.conv3(x)))x = self.bn3(x)x = x.view(-1, x.size(1) * x.size(2) * x.size(3))x = F.relu(self.fc1(x))x = self.drop1d(x)x = F.relu(self.fc2(x))x = self.drop1d(x)x = self.fc3(x)return x

对CNN模型进行测试

model = Net_only()
source = torch.rand(64, 3, 128, 128) # 模拟产生batch=64,channel=3, hight=128, width=128 的源域图片数据
target = torch.rand(32, 3, 128, 128) # 模拟产生batch=32,channel=3, hight=128, width=128 的源域图片数据
source = model(source)
target = model(target)
print(source.shape)
>>>output
torch.Size([64, 5])

现在计算MMD损失

MMD = MMDLoss()
loss = MMD(source=source, target=target)
print(loss)
>>>output
tensor(0.0884, grad_fn=<MeanBackward0>)

迁移学习损失的运用

loss = clf_loss + lamb * transfer_loss
clf_loss是源域的分类损失,transfer_loss即本篇所介绍的MMD_loss,lamb是超参数

总结

迁移损失MMD其输入X, Y分别是souce = Net(source),target = Net(target),也就是模型的输出。
参考资料:
链接: 王晋东github
链接:https://blog.csdn.net/a529975125/article/details/81176029
欢迎关注公众号:故障诊断与python学习

迁移学习-域适应损失函数MMD-代码实现及验证相关推荐

  1. 迁移学习(Transfer Learning)概述及代码实现(full version)

    基于PaddlePaddle的李宏毅机器学习--迁移学习 大噶好,我是黄波波.希望能和大家共进步,错误之处恳请指出! 百度AI Studio个人主页, 我在AI Studio上获得白银等级,点亮2个徽 ...

  2. 迁移学习(Transfer Learning)概述及代码实现

    基于PaddlePaddle的李宏毅机器学习--迁移学习 大噶好,我是黄波波,希望能和大家共进步,错误之处恳请指出! 百度AI Studio个人主页, 我在AI Studio上获得白银等级,点亮2个徽 ...

  3. 基于MK-MMD度量迁移学习的轴承故障诊断方法研究

    摘要 上一篇文章实验是基于凯斯西厨大学轴承数据集,使用同一负载情况下的6种轴承数据进行故障诊断,并没有进行不同负载下轴承故障诊断.之前没做这块迁移学习实验,主要是对于迁移学习理解不到位,也没有不知道从 ...

  4. Python 迁移学习实用指南:6~11

    原文:Hands-On Transfer Learning with Python 协议:CC BY-NC-SA 4.0 译者:飞龙 本文来自[ApacheCN 深度学习 译文集],采用译后编辑(MT ...

  5. 迁移学习---迁移学习基础概念、分类

    迁移学习提出背景 在机器学习.深度学习和数据挖掘的大多数任务中,我们都会假设training和inference时,采用的数据服从相同的分布(distribution).来源于相同的特征空间(feat ...

  6. 手动搭建的VGG16网络结构训练数据和使用ResNet50微调(迁移学习)训练数据对比(图像预测+前端页面显示)

    文章目录 1.VGG16训练结果: 2.微调ResNet50之后的训练结果: 3.结果分析: 4.实验效果: (1)VGG16模型预测的结果: (2)在ResNet50微调之后预测的效果: 5.相关代 ...

  7. 什么是迁移学习?什么时候使用迁移学习?

    迁移学习是一种深度学习策略,它通过将解决一个问题所获得的知识应用于另一个不同但相关的问题来重用这些知识.例如,有3种类型的花:玫瑰.向日葵和郁金香.可以使用标准的预训练模型,如VGG16/19.Res ...

  8. 基于迁移学习的语义分割算法分享与代码复现

    摘要:语义分割的数据集是比较大的,因此训练的时候需要非常强大的硬件支持. 本文分享自华为云社区<[云驻共创]基于迁移学习的语义分割算法分享>,原文作者:启明. 此篇文章是分享两篇基于迁移学 ...

  9. 带你用深度学习虚拟机进行文本迁移学习(附代码)

    作者:Anusua Trivedi.Wee Hyong Tok 翻译:付宇帅 校对:卢苗苗 本文5302字,建议阅读10分钟. 本文讲述了现代机器学习的模型,主要由微软数据科学家Anusua Triv ...

最新文章

  1. RS2008中控件ID冲突问题
  2. Devexpress 之gridControl
  3. 的boc调制matlab程序_Matlab仿真基础数字全息
  4. MyBatis学习总结(三)——优化MyBatis配置文件中的配置
  5. 一文了解 Apache Flink 核心技术
  6. MySQL INNER JOIN:内连接查询
  7. mysqld与mysqld_safe的区别
  8. PC如何控制device进入suspend模式
  9. MongoDB sharding 集合不分片性能更高?
  10. 从业余挖洞到微软漏洞研究员,我的遗憾、惊喜和建议
  11. dijkstra邻接表_[力扣743] 带权邻接表的单源最短路
  12. Java 异常 (Exception) 剖析 与 用户自定义异常
  13. 计算机仿真系统模型有,计算机仿真在光伏发电系统模型中的应用研究原稿(最终定稿)...
  14. 消除Permission is only granted to system apps报错
  15. Java项目使用jib打包docker镜像的简单记录
  16. 入门力扣自学笔记118 C++ (题目编号1413)
  17. weui上传文件完整例子php,weui实现图片上传
  18. JNDI-Injection-With-LDAP-Unserialize
  19. Head First Design Patterns(深入浅出设计模式)-目录
  20. WWW 2022 | 量化交易相关论文(附论文链接)

热门文章

  1. 人工智能时代最吃香的热门专业,男女都适合
  2. java.lang.IllegalStateException: Cannot get a text value from a numeric cell
  3. iOS app添加桌面快捷方式
  4. 公众号刷粉、阅读量作弊
  5. html5猜颜色游戏,好看漂亮的html5网页特效学习笔记(3)_猜猜下一个颜色是什么?...
  6. elasticsearch安装 及 启动异常解决
  7. matlab 写netcdf,写入 netCDF 属性
  8. 3、流量分析--分组TopN统计
  9. 20164305 徐广皓《网络对抗》Exp9 Web安全基础实践
  10. C/C++编程:std::move(将左值强制转换为右值)