Synthetic Data

  • 一. 前言
  • 二. 程序实现
    • 2.1 生成一条数据
    • 2.2 生成一组相关系数为p的数据集
    • 2.3 权重系数的余弦相似度与标签相关系数之间的关系

一. 前言

在MMoE论文中,作者人工生成了可以控制不同任务之间相关系数的数据集,并观察不同模型在不同相关系数的多任务学习中的模型效果,如下所示:


文中作者给出了数据集生成的数学表达:


下面用程序实现以上过程。

二. 程序实现

from scipy.linalg import *
import numpy as np
from tqdm import *
import matplotlib.pyplot as plt

2.1 生成一条数据

1.生成正交单位向量

d = 512 #维度
'''随机生成两个单位向量'''
np.random.seed(10)
u1 = np.random.randn(d)
u1 = u1 / np.linalg.norm(u1)np.random.seed(22)
u2 = np.random.randn(d)
u2 = u2 / np.linalg.norm(u2)u = np.vstack((u1,u2)).T #(d,2)
'''向量正交化'''
o = orth(u)
u1 = o[:,0]
u2 = o[:,1]
print(np.linalg.norm(u1))
print(np.linalg.norm(u2))
print(np.matmul(u1.T,u2))
'u1,u2为一组正交单位向量'输出:
1.0000000000000002
1.0000000000000002
1.97758476261356e-16

2.生成权重向量

c = 1 #常数
p = 0.5 #相关系数 [-1,1]
w1 = c*u1
w2 = c*(p*u1 + np.sqrt(1-p*p)*u2)

3.随机生成自变量x

np.random.seed(2022)
x = np.random.randn(d)

4. 随机生成m组正弦函数参数

m = 10 #组合正弦的数量
'''随机生成生成m组正弦函数参数'''
np.random.seed(42)
ab = np.random.randn(2,m)
a = ab[0,:] #(m,)
b = ab[1,:] #(m,)

5. 生成数据标签

y1 = np.matmul(w1.T,x)
y2 = np.matmul(w2.T,x)
for i in range(m):y1 = y1+np.sin(a[i]*np.matmul(w1.T,x)+b[i])y2 = y2+np.sin(a[i]*np.matmul(w2.T,x)+b[i])
y1 += np.random.normal(0,0.01,1)
y2 += np.random.normal(0,0.01,1)
y = np.hstack((y1,y2))

这样我们就得到了相关系数为p的一条数据,其中 x 的长度为d,y 的长度为2。

接下来将上面的步骤整理一下生成一组完整的相关系数为p的数据集。

2.2 生成一组相关系数为p的数据集

1.生成正交单位向量

d = 512 #维度'''生成两个单位向量'''
np.random.seed(10)
u1 = np.random.randn(d)
u1 = u1 / np.linalg.norm(u1)np.random.seed(22)
u2 = np.random.randn(d)
u2 = u2 / np.linalg.norm(u2)u = np.vstack((u1,u2)).T #(d,2)'''向量正交化'''
o = orth(u)
u1 = o[:,0]
u2 = o[:,1]
print(np.linalg.norm(u1))
print(np.linalg.norm(u2))
print(np.matmul(u1.T,u2))
'u1,u2为一组正交单位向量'

2.生成权重系数矩阵

c = 1 #常数
p = 0.5 #相关系数 [-1,1]w1 = c*u1
w2 = c*(p*u1 + np.sqrt(1-p*p)*u2)

3. 随机生成m组正弦函数参数

m = 10 #组合正弦的数量np.random.seed(42)
ab = np.random.randn(2,m)
a = ab[0,:] #(m,)
b = ab[1,:] #(m,)

4.生成长度为L的数据集

l = 5000for i in tqdm(range(l)):'随机生成自变量x'np.random.seed(2000+i)x = np.random.randn(d) #(d,)'生成因变量y1和y2'y1 = np.matmul(w1.T,x)y2 = np.matmul(w2.T,x)for j in range(m):y1 = y1+np.sin(a[j]*np.matmul(w1.T,x)+b[j])y2 = y2+np.sin(a[j]*np.matmul(w2.T,x)+b[j])y1 += np.random.normal(0,0.01,1)y2 += np.random.normal(0,0.01,1)y = np.hstack((y1,y2)) #(1,2)'保存生成的x和y'    if i==0:X = xY = yelse:X = np.vstack((X,x))Y = np.vstack((Y,y))
print(X.shape)
print(Y.shape)输出:
(5000, 512)
(5000, 2)

下面比较一下权重系数的余弦相似度和label之间的皮尔逊相关系数:

'计算w1和w2的余弦相似度'
cos_sim = w1.dot(w2) / (np.linalg.norm(w1)*np.linalg.norm(w2))
print("cos(w1,w2)=",cos_sim)'计算label之间的皮尔逊相关系数'
corr = np.corrcoef(Y[:,0],Y[:,1])
print("person(y1,y2)=",corr[0,1])输出:
cos(w1,w2)= 0.5000000000000002
person(y1,y2)= 0.39918604117923223

可以看到权重系数的余弦相似度与标签的皮尔逊相关系数并不完全相同,因为 y 是关于 x 的非线性函数,下面探究一下两者之间的关系。

2.3 权重系数的余弦相似度与标签相关系数之间的关系

在原文中,作者提到二者之间的关系如图所示:


1.生成正交单位向量

'1. 生成正交单位向量'
d = 512 #维度'''生成两个单位向量'''
np.random.seed(10)
u1 = np.random.randn(d)
u1 = u1 / np.linalg.norm(u1)np.random.seed(22)
u2 = np.random.randn(d)
u2 = u2 / np.linalg.norm(u2)u = np.vstack((u1,u2)).T #(d,2)'''向量正交化'''
o = orth(u)
u1 = o[:,0]
u2 = o[:,1]
print(np.linalg.norm(u1))
print(np.linalg.norm(u2))
print(np.matmul(u1.T,u2))
'u1,u2为一组正交单位向量'

2.随机生成生成m组正弦函数的参数

m = 10 #组合正弦的数量np.random.seed(42)
ab = np.random.randn(2,m)
a = ab[0,:] #(m,)
b = ab[1,:] #(m,)

3.循环得到不同p时对应的权重和标签的相似度

c = 1 #常数
l = 5000 #数据长度cs=[]
pc=[]
for p in np.arange(-1,1.1,0.1).round(1):print("***** p={} *****".format(p))w1 = c*u1w2 = c*(p*u1 + np.sqrt(1-p*p)*u2)for i in tqdm(range(l)):'随机生成自变量x'np.random.seed(2000+i)x = np.random.randn(d) #(d,)'生成因变量y1和y2'y1 = np.matmul(w1.T,x)y2 = np.matmul(w2.T,x)for j in range(m):y1 = y1+np.sin(a[j]*np.matmul(w1.T,x)+b[j])y2 = y2+np.sin(a[j]*np.matmul(w2.T,x)+b[j])y1 += np.random.normal(0,0.01,1)y2 += np.random.normal(0,0.01,1)y = np.hstack((y1,y2)) #(1,2)'保存生成的x和y'    if i==0:X = xY = yelse:X = np.vstack((X,x))Y = np.vstack((Y,y))'计算w1和w2的余弦相似度'cos_sim = w1.dot(w2) / (np.linalg.norm(w1)*np.linalg.norm(w2))cs.append(cos_sim)'计算label之间的皮尔逊相关系数'person_corr = np.corrcoef(Y[:,0],Y[:,1])pc.append(person_corr[0,1])

4.绘制图像

plt.plot(cs,pc,linewidth=1.5)
# 设置横轴标签
plt.xlabel('weight cosine similarity')
# 设置纵轴标签
plt.ylabel('label correlation')
plt.show()


可以看到二者确实不是线性关系,但是呈正相关,因此可以用设置的相关系数p表示任务之间的相关性。

MMoE论文中Synthetic Data生成代码(控制多任务学习中任务之间的相关性)相关推荐

  1. java学习中,DVD管理系统纯代码(java 学习中的小记录)

    java学习中,DVD管理系统纯代码(java 学习中的小记录)作者:王可利(Star·星星) class DvdMain{public static void main (String[] args ...

  2. 生成PDF文件方案--学习中

    PDF文件是目前比较流行的电子文档格式,在办公自动化(OA)等软件的开发中,经常要用到该格式,但介绍如何制作PDF格式文件的资料非常少,在网上搜来搜去,都转贴的是同一段"暴力"破解 ...

  3. RS Meet DL(68)-建模多任务学习中任务相关性的模型MMoE

    本文介绍的论文题目是:<Modeling Task Relationships in Multi-task Learning with Multi-gate Mixture-of-Experts ...

  4. 【深度学习实战】从零开始深度学习(五):生成对抗网络——深度学习中的非监督学习问题

    参考资料: <PyTorch深度学习>(人民邮电出版社)第7章 生成网络 PyTorch官方文档 廖星宇著<深度学习入门之Pytorch>第6章 生成对抗网络 其他参考的网络资 ...

  5. 【阅读笔记】多任务学习之PLE(含代码实现)

    本文作为自己阅读论文后的总结和思考,不涉及论文翻译和模型解读,适合大家阅读完论文后交流想法. PLE 一. 全文总结 二. 研究方法 三. 结论 四. 创新点 五. 思考 六. 参考文献 七. Pyt ...

  6. 【推荐系统多任务学习MTL】MMoE论文精读笔记(含代码实现)

    论文地址: Google KDD 2018 MMOE (内含论文官方讲解视频) PDF Modeling Task Relationships in Multi-task Learning with ...

  7. matlab生成代码veri,一种自动生成状态机RTL代码的方法

    1 引言 电子设计自动化(Electronic Design Automatic,EDA),在集成电路设计中扮演了重要的角色,无论前端还是后端设计都需要熟练掌握和使用各种EDA工具,现今EDA软件主要 ...

  8. 深度学习 -- TensorFlow(项目)验证码生成与识别(多任务学习)

    目录 基础理论 一.生成验证码数据集 1.生成验证码训练集 1-0.判断文件夹是否为空 1-1.创建字符集(数字.大小写英文字母) 1-2.随机生成验证码(1000个,长度为4) 2.生成验证码测试集 ...

  9. MMOE——多任务学习模型

    摘要 对于多任务学习,我们的目标是建立一个单一的模型,同时学习这些多个目标和任务.然而,常用的多任务模型的预测质量往往对任务之间的关系比较敏感.因此,研究任务特定目标和任务间关系之间的建模权衡是很重要 ...

最新文章

  1. Golang微服务开发实践
  2. python的类型 变量 数值和字符串
  3. Py之Seaborn:数据可视化Seaborn库的柱状图、箱线图(置信区间图)、散点图/折线图、核密度图/等高线图、盒形图/小提琴图/LV多框图的组合图/矩阵图实现
  4. RabbitMQ服务客户端的的业务逻辑
  5. java反向映射_opencv 直方图和直方图反向映射
  6. 软件工程课设迭代开发第二天
  7. 中间省略_SpringBoot2 高级案例(03): 整合sharding-jdbc中间件,实现数据分库分表
  8. Ruby on Rails,创建模型,附赠模型与表名不一致时的解决方法
  9. SEO专题之四:如何合理有效选定关键字
  10. 计算机操作系统详细学习笔记(一):计算机操作系统概述
  11. 干货|一文看懂BLE低功耗技术-附主流BLE芯片厂商介绍
  12. Linux C编程实战——第六章 文件操作_项目实现_自写ls命令
  13. 山东大学项目实训(二十七)—— 微信小程序开发总结,一年时间真的可以改变一个人很多
  14. js中关于0.1+0.1不等于0.2 ,而console.log(0.1)是0.1,面试01
  15. 珠海 第十届亚洲机器人锦标赛_逾2000名选手云集珠海竞技第十届亚洲机器人锦标赛...
  16. linux dns一键,利用wdDNSV3自建免费在线DNS系统并配置使用
  17. Ubuntu18.04 RTL8169驱动更换RTL8168驱动
  18. 【尚硅谷Java笔记+踩坑】Git(分布式版本控制工具)
  19. 获取mumu模拟器日志
  20. 什么是架构?什么是架构师?

热门文章

  1. Symantec Endpoint Protection 14最新卸载教程(亲测好用,无需密码,暴力删除)
  2. 自行车V刹和碟刹的对比结果分析
  3. 格式工厂 wav 比特率_Easy MP3 Converter Pro for mac(音频格式转换软件)
  4. C to the start....start to C
  5. a6账套管理显示无法连接服务器,A6企业管理软件账套管理及初始化流程.ppt
  6. 经济学人:人工智能正颠覆传统战争,一场新军备竞赛或将开启
  7. 知识蒸馏怎么用?召回-粗排篇
  8. 惠普暗影精灵II代 Pro电脑 Hackintosh 黑苹果efi引导文件
  9. 半路出家学习Java
  10. 如何监测耳机/麦克风设备插拔操作