TensorNet——基于TensorFlow的大规模稀疏特征模型分布式训练框架
女主宣言
今天小编为大家分享一篇有关于TensorNet的文章。TensorNet是一个构建在TensorFlow之上针对广告推荐等大规模稀疏场景优化的分布式训练框架。希望能对大家有所帮助。
PS:丰富的一线技术、多元化的表现形式,尽在“360云计算”,点关注哦!
0
TensorNet是什么?
TensorNet是一个构建在TensorFlow之上针对广告推荐等大规模稀疏场景优化的分布式训练框架。TensorNet的目标是让所有使用TensorFlow的开发者可以快速的、方便的训练出稀疏参数超过百亿的超大模型。
1
训练带有大规模稀疏特征模型的主要挑战
在广告、搜索、推荐等场景下的深度模型都会存在大量的高维离散稀疏特征,训练带有高维稀疏特征的模型主要有两个问题:
训练样本规模大。比如对于360广告场景会有超过100TB的训练数据。
模型参数多。比如对于360广告场景会有超过100亿的参数。
使用单机模式训练模型速度慢,耗时长,严重制约了模型的迭代速度,使用分布式训练已经成为业界标准。
2
使用TensorFlow训练稀疏特征模型的主要问题
TensorFlow是最受开发者欢迎的深度学习训练框架,但是TensorFlow对训练带有大规模稀疏特征的模型不太友好,主要问题有:
TensorFlow支持的特征维度有限。一般的,TensorFlow需要对每一个特征定义一个矩阵,这个矩阵受限于内存,往往不能太大。
TensorFlow分布式训练时需要同步所有的参数,这导致速度过慢。由于带有稀疏特征的模型参数多,同步开销非常大,会严重制约训练的速度。
3
TensorNet - 基于TensorFlow的专为大规模稀疏特征模型优化的分布式训练框架
TensorNet在复用TensorFlow的所有功能的基础之上,专门定制使其支持大规模稀疏特征模型的训练。TensorNet的主要提升包括:
使TensorFlow支持的稀疏特征的维度接近于无限。
使TensorFlow分布式训练时同步的参数规模减小到原来的万分之一,乃至十万分之一,从而极大的提升了训练速度。在360真实业务场景下我们将原来的离线训练时间由3.5小时提升到了25分钟。
配合TensorNet通过split graph的方法可以对在线推理的性能进行优化。在360真实场景测试中我们发现有近 35% 的性能提升。
TensorNet分布式训练架构
TensorNet支持异步和同步模式训练。异步模式在仅有CPU的集群中速度提升十分显著,同步模式在网卡速度超过100GbE的GPU集群中表现突出。
TensorNet异步训练架构
在仅有CPU的集群中使用参数服务器的异步训练模式是训练模型速度最快的方法,TensorNet异步训练架构与TensorFlow的异步训练架构有很大的区别:
TensorNet将sparse参数和与dense参数分别使用不同的parameter server管理。
TensorNet不设单独的parameter server节点。在每个worker中都会维护一个sparse paramter server和dense parameter server。这省去了开发人员管理ps节点和worker节点的不少麻烦。
TensorNet对sparse参数使用分布式哈希表按照哈希值均匀分布不同的节点上。这相较于TensorFlow需要让开发者根据自身情况将tensor分布在不同的ps节点上的方法更加灵活,这不仅减小了节点通信热点的概率,还减轻了开发者的工作量。
TensorNet将模型的所有dense参数合并后使用分布式数组切分到不同的机器上,每次pull和push参数的时候只有一次网络请求。相较于TensorFlow对每个tensor都有一次网络请求的方法极大的减少了网络请求的次数从而提升了模型训练的速度。
TensorNet异步训练架构
TensorNet同步训练架构
TensorNet同步训练架构基本与TensorFlow的MultiWorkerMirroredStrategy架构一致,主要区别如下:
TensorNet使用单独的sparse parameter server节点保存所有sparse参数。通过parameter server可以解决TensorFlow支持的sparse特征维度不能太大的问题。
TensorNet对sparse参数做了特殊的定制化的同步。TensorNet在训练时只同步当前训练的batch所关注的稀疏特征,相较于TensorFlow会将所有参数都同步的模式通信数据减少到了原来的万分之一,乃至十万分之一。
TensorNet同步训练架构
TensorNet核心优化
TensorNet最核心的优化是将模型的embedding tensor优化到了最小。
如下图所示,对于最简单的wide&deep模型,如果在一个广告系统中有3亿用户,那么就需要定义一个维度为3亿的embedding矩阵,在训练模型时需要在这个3亿维的矩阵上做embedding_lookup
得到当前batch内的用户的embedding信息,进而在embedding之上做更加复杂的操作。
TensorFlow中的实现
显而易见,在高维稀疏场景下,存在有下列问题:
embedding矩阵太大,占用内存多。很显然当特征较多的时候单机无法存储整个模型。
分布式训练同步开销巨大。由于TensorFlow同步时是需要同步整个矩阵以便进行训练,这极大的消耗了网络带宽,拖慢了整体速度。
TensorNet使用一个较小的,可以容纳特征在一个batch内所有数据的embedding矩阵代替TensorFlow默认实现中需要定义的较大的embedding矩阵。
如下图所示,在batch_size设置为1024的场景下,对于用户id特征,在TensorNet中只需要定义一个维度为1024的embedding矩阵,TensorNet的主要处理步骤如下:
定义模型时定义userid的embedding矩阵的维度为一个batch内所有用户id个数的最大值。
训练模型时得到当前batch内的所有用户id。
将用户id排序,并按照先后顺序为每个userid分配索引,索引从0开始,对应为下图中的
virtual sparse feature
。使用userid从parameter server中获取相应的embedding向量,然后按照其对应的索引放置到embedding矩阵中。
使用转换后的
virtual sparse feature
作为模型的输入。
TensorNet中的实现
从上述可见,TensorNet由于极大的减小了模型所需要的embedding矩阵,从而可以极大的减小分布式训练时的开销,以及通过parameter server的方式使得稀疏特征的维度可以支持到接近无限维,从而可以极大的提升模型的刻画能力。
TensorNet Inference优化
由于TensorNet只更改了模型的第一层,从而模型的inference也变得极其简单。
在使用TensorNet构造模型的时候,可以将模型切分为两部分,如下图所示,embedding_lookup_graph
只在离线训练时使用,在线inference时只需要将sparse embedding导出成字典供inference_graph
作为输入即可,具体的请参考以下系列文章:
1. 为inference准备——模型切分: https://github.com/Qihoo360/tensornet/blob/master/doc/tutorial/03-split-to-sub-graph.ipynb
2. 使用XLA方式进行在线预估: https://github.com/Qihoo360/tensornet/blob/master/doc/tutorial/04-deploy-tf-graph-online.ipynb
3. sparse embedding字典导出: https://github.com/Qihoo360/tensornet/blob/master/doc/tutorial/05-export-sparse-feature-embedding.ipynb
TensorNet中split graph inference方案
在360内部场景中我们测试发现通过split graph配合XLA AOT
的方法性能提升近35%。
TensorNet开源及使用
TensorNet已经成功落地应用到了360广告ctr预估相关的场景中,并取得了显著的效果,我们已将代码、文档及我们在360广告的应用经验全部整理到了项目中,欢迎关注。
tensorNet主页:https://github.com/Qihoo360/TensorNet
tensornet快速上手:https://github.com/Qihoo360/TensorNet/doc/tutorial/01-begin-with-wide-deep.ipynb
更多文档请看:https://github.com/Qihoo360/TensorNet/README.md
联系方式:张彦升(zhangyansheng@360.cn),姚磊(yaolei@360.cn)
微信交流群:
360云计算
由360云平台团队打造的技术分享公众号,内容涉及数据库、大数据、微服务、容器、AIOps、IoT等众多技术领域,通过夯实的技术积累和丰富的一线实战经验,为你带来最有料的技术分享
TensorNet——基于TensorFlow的大规模稀疏特征模型分布式训练框架相关推荐
- ios yymodel 将字典转数组模型_TensorNet——基于TensorFlow的大规模稀疏特征模型分布式训练框架
TensorNet是什么? TensorNet是一个构建在TensorFlow之上针对广告推荐等大规模稀疏场景优化的分布式训练框架.TensorNet的目标是让所有使用TensorFlow的开发者可以 ...
- 快手八卦!突破TensorFlow、PyTorch并行瓶颈的开源分布式训练框架来了!
来源:AI前线本文约5200字,建议阅读8分钟 本文介绍了专门针对分布式场景设计了特定的优化算法同比,性能较同类提升60%. 近日,快手和苏黎世理工宣布开源分布式训练框架 Bagua(八卦),相比于 ...
- 阿里开源支持10万亿模型的自研分布式训练框架EPL(EasyParallelLibrary)
简介:EPL背后的技术框架是如何设计的?开发者可以怎么使用EPL?EPL未来有哪些规划?今天一起来深入了解. 作者 | 王林.飒洋 来源 | 阿里技术公众号 一 导读 最近阿里云机器学习PAI平台和达 ...
- 阿里开源支持10万亿模型的自研分布式训练框架EPL
一 导读 最近阿里云机器学习PAI平台和达摩院智能计算实验室一起发布"低碳版"巨模型M6-10T,模型参数已经从万亿跃迁到10万亿,规模远超业界此前发布的万亿级模型,成为当前全球最 ...
- 支持异构GPU集群的超大规模模型的高效的分布式训练框架Whale
近日,阿里云机器学习PAI关于深度学习模型高效的分布式训练框架的论文< Whale: Efficient Giant Model Training over Heterogeneous GPUs ...
- 支持异构 GPU 集群的超大规模模型的高效的分布式训练框架 Whale
近日,阿里云机器学习PAI关于深度学习模型高效的分布式训练框架的论文< Whale: Efficient Giant Model Training over Heterogeneous GPUs ...
- 字节跳动开源分布式训练框架BytePS,登上GitHub热榜
问耕 发自 凹非寺 量子位 出品 | 公众号 QbitAI 字节跳动开源了通用分布式训练框架BytePS,这个框架支持TensorFlow.Keras.PyTorch.MXNet,可以运行在TCP或R ...
- 6_分布式训练框架Horovod使用(20190111)
分布式训练框架Horovod使用 文章目录 一.Horovod简介 二.Horovod框架的安装 Install 1.安装OpenMPI 2.安装Horovod 三.Horovod框架的使用 1.在项 ...
- [源码解析] 深度学习分布式训练框架 horovod (10) --- run on spark
[源码解析] 深度学习分布式训练框架 horovod (10) - run on spark 文章目录 [源码解析] 深度学习分布式训练框架 horovod (10) --- run on spark ...
最新文章
- 校验正确获取对象或者数组的属性方法(babel-plugin-idx/_.get)
- 都在建议你不要直接使用 @Async 注解,为什么?
- 【数据清洗】yolo标注补全 生成空的标注txt文件
- Java并发7:并发工具类
- ionic android 版本号,ionic android 版本release 和 签名(示例代码)
- LeetCode 640. 求解方程(字符串)
- vue 使用axios
- 主人的C++桌上也没有这么好看的花朵了
- LeetCode Min Stack 最小值栈
- 动态链接库的隐式动态链接和显示动态链接
- 非常有价值的电商系统,包括前台商城和后台管理系统!直接拿来用
- maven pom.xml中设置java编译参数
- 利用python爬取飞猪信息_手把手教你使用Python爬取西刺代理数据(上篇)-阿里云开发者社区...
- 微信公众号客服系统可以实现自动回复吗?
- 一个2022本科生的秋招总结 (大疆、Arm、小米、荣耀、美团、联发科等)
- c语言spoc测验成绩比重,SPOC混合教学模式在C语言程序设计课程的应用
- MIT6.824-lab3A-Key/value service without snapshots(基本的KV服务)
- win10使用的c语言程序开发,Win10是什么编程语言写的?源代码文件多到你无法想象...
- 【网络】IP地址计算
- kafka-分区重分配及相关源码分析
热门文章
- java.security.NoSuchAlgorithmException: SHA_256 MessageDigest not available
- Kafka集群安装Version1.0.1(自带Zookeeper)
- JSON解析中获取不存在的key
- LabVIEW自带函数实现SQL Server操作(上)
- 登入Github、Git本地上传及Visual Studio Code上传教程
- java判断经纬度是否在扇形内_地理坐标是用经度
- 用matlab绘制P三曲线,知道曲线方程 怎么用matlab绘制三维图 一定要给出程序 , matlab怎样画三维曲线...
- springmvc进不到controller_Spring、SpringMVC、MyBatis的整合
- async/await 顺序执行和并行
- (转)CString工作原理和常见问题分析