TensorFlow 多任务学习
多任务学习
多任务学习,顾名思义,就是多个任务模型同时执行,进行模型的训练,利用模型的共性部分来简化多任务的模型,实现模型之间的融合与参数共享,可以在一定程度上优化模型的运算,提高计算机的效率,但模型本身并没有什么改变。
多任务学习的核心在于如何训练上:
- 交替训练
- 联合训练
通过一个简单的线性变换来展示多任务学习模型的运用。
首先,导入需要的包
import tensorflow as tf
import numpy as np
使用numpy制造两组假数据
x_data = np.float32(np.random.rand(2, 100)) # 随机输入
y1_data = np.dot([0.100, 0.200], x_data) + 0.300
y2_data = np.dot([0.500, 0.900], x_data) + 3.000
构造两个线性模型
b1 = tf.Variable(tf.zeros([1]))
W1 = tf.Variable(tf.random_uniform([1, 2], -1.0, 1.0))
y1 = tf.matmul(W1, x_data) + b1b2 = tf.Variable(tf.zeros([1]))
W2 = tf.Variable(tf.random_uniform([1, 2], -1.0, 1.0))
y2 = tf.matmul(W2, x_data) + b2
计算方差,使方差最小化,使模型不断的靠近真实解
# 最小化方差
loss1 = tf.reduce_mean(tf.square(y1 - y1_data))
loss2 = tf.reduce_mean(tf.square(y2 - y2_data))
构造优化器
# 构建优化器
optimizer = tf.train.GradientDescentOptimizer(0.5)
train1 = optimizer.minimize(loss1)
train2 = optimizer.minimize(loss2)
交替训练
基本思想:使两个模型交替进行训练
# 初始化全局变量
init = tf.global_variables_initializer()# 启动图 (graph)
with tf.Session() as sess:sess.run(init)for step in range(1, 1001):if np.random.rand() < 0.5:sess.run(train1)print(step, 'W1,b1:', sess.run(W1), sess.run(b1))else:sess.run(train2)print(step, 'W2,b2:', sess.run(W2), sess.run(b2))
输出结果为:
从最终的结果可以看出W1,W2,b1,b2已经非常接近真实值了,说明模型的建立还是非常有效的。
联合训练
基本思想:将两个模型的损失函数结合起来,共同进行优化训练
# 联合训练
loss = loss1 + loss2
# 构建优化器
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)# 初始化全局变量
init = tf.global_variables_initializer()# 启动图
with tf.Session() as sess:sess.run(init)for step in range(1, 300):sess.run(train)print(step, 'W1,b1,W2,b2:', sess.run(W1), sess.run(b1), sess.run(W2), sess.run(b2))
输出结果为:
从结果可以看出模型的参数不断的接近真实值。
应用场景
当你需要同一组数据集去处理不同的任务时,交替训练是一个很好地选择。
当两个甚至多个任务需要联合考虑时,为了整体的最优而放弃局部最优的时候,使用联合训练非常的合适。
欢迎关注和评论!
TensorFlow 多任务学习相关推荐
- 深度学习 -- TensorFlow(项目)验证码生成与识别(多任务学习)
目录 基础理论 一.生成验证码数据集 1.生成验证码训练集 1-0.判断文件夹是否为空 1-1.创建字符集(数字.大小写英文字母) 1-2.随机生成验证码(1000个,长度为4) 2.生成验证码测试集 ...
- 基于多任务学习和负反馈的深度召回模型
简介:召回结果的好坏对整个推荐结果有着至关重要的影响,最近的一系列实践和研究表明,基于行为序列的深度学习推荐模型搭配高性能的近似检索算法可以实现既准又快的召回性能:与此同时,用户在天猫精灵上还可以进行 ...
- PAMTRI:用于车辆重新识别的姿势感知多任务学习
Today, we will discuss an unorthodox paper by NVIDIA Labs on Vehicle Re Identification. 今天,我们将讨论NVID ...
- 多任务学习(Multi-task Learning)方法总结
多任务学习(multi task learning)简称为MTL.简单来说有多个目标函数loss同时学习的就算多任务学习.多任务既可以每个任务都搞一个模型来学,也可以一个模型多任务学习来一次全搞定的. ...
- 多任务学习中各loss权重应该如何设计?
作者 | hahakity@知乎 编辑 | 极市平台 在多任务中,通常是把各loss统一到一个数量级,请问这么做的原理是什么呢? 今天分享一个技术硬核文章,详细的聊聊多任务这点事: 个人感觉这是一个非 ...
- 【推荐系统多任务学习MTL】ESMM 论文精读笔记(含代码实现)
论文地址:Entire Space Multi-Task Model: An Effective Approach for Estimating Post-Click Conversion Rate ...
- 多任务学习与深度学习
作者:chen_h 微信号 & QQ:862251340 微信公众号:coderpai 多任务学习是机器学习的一个子领域,学习的目标是同事执行多个相关任务.比如,系统会同时执行学习两项任务,以 ...
- 多任务学习原理与优化
文章目录 一.什么是多任务学习 二.为什么我们需要多任务学习 三.多任务学习模型演进 Hard shared bottom 硬共享 Soft shared bottom 软共享 软共享: MOE &a ...
- 2021年浅谈多任务学习
作者 | 多多笔记 来源 |AI部落联盟 头图 | 下载于视觉中国 写此文的动机: 最近接触到的几个大厂推荐系统排序模型都无一例外的在使用多任务学习,比如腾讯PCG在推荐系统顶会RecSys 2020 ...
最新文章
- lua学习笔记之io
- 我的第一个ASP类(显示止一篇下一篇文章)
- 【NOIP2013】货车运输
- JAVA基础知识(2)--队列的操作
- Istio 首次安全评估结果公布
- 元旦限时特惠,耳机、书籍等大降价
- LeetCode: Single Number I II
- 我是做php的个子矮,当一个矮个子的烦恼作文
- 折线图笔记 -python
- 怎么有效提高执行力?
- 简单工厂模式-Simple Factory Pattern
- javaScript中的变量作用域的闭包处理
- 【系列一之爬虫系列】爬取信息
- 在Linux下使用GIMP打印一寸照
- 关于23届大数据岗实习总结
- 客户价值分析之RFM模型
- webscraper多页爬取_Web Scraper 高级用法——Web Scraper 抓取多条内容 | 简易数据分析 07...
- 2022社群扫码进群活码完整系统源码+修复版的
- 直接内存 直接内存的释放和回收
- 一曲京声人去远——纪念刘大中校友100周年诞辰
热门文章
- 数据标准化 - scale() - Python代码
- PAT乙类之1011 A+B 和 C
- 新型冠状病毒传染性有多强?何时达到疫情峰值?来看一下数学和统计建模结果...
- 聊聊Spring Cloud版本的那些事儿
- 论文浅尝 - ECIR2021 | 两种实体对齐方法的严格评估
- 科普 | 动态本体简介
- Android官方开发文档Training系列课程中文版:数据存储之文件存储
- Maven:导入Oracle的jar包时出现错误
- 二叉树----数据结构:二叉树的三种遍历及习题
- 好久没玩laravel了,5.6玩下(三)