我们曾经在《深入理解Spark 2.1 Core (一):RDD的原理与源码分析 》讲解过:

为了有效地实现容错,RDD提供了一种高度受限的共享内存,即RDD是只读的,并且只能通过其他RDD上的批量操作来创建(注:还可以由外部存储系数据集创建,如HDFS)

可知,我们在第九,第十篇博文所讲的是传统hadoop MapReduce类似的,在最初从HDFS中读取数据生成HadoopRDD的过程。而RDD可以通过其他RDD上的批量操作来创建,所以这里的HadoopRDD对于下一个生成的ShuffledRDD可以视为Map端,当然下一个生成的ShuffledRDD可以被下下个ShuffledRDD视为Map端。反过来说,下一个ShuffledRDD可以被`HadoopRDD视作Reduce端。

这篇博文,我们就来讲下ShuffleReduce端。其实在RDD迭代部分和第九篇博文类似,不同的是,这里调用的是rdd.ShuffledRDD.compute:

  override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {// 得到依赖val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]// 调用getReader,传入dep.shuffleHandle 分区 上下文 // 得到Reader,调用read()// 得到迭代器SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context).read().asInstanceOf[Iterator[(K, C)]]}
  • 1
  • 3

这里调用的是shuffle.sort.SortShuffleManagergetReader

  override def getReader[K, C](handle: ShuffleHandle,startPartition: Int,endPartition: Int,context: TaskContext): ShuffleReader[K, C] = {// 生成返回 BlockStoreShuffleReadernew BlockStoreShuffleReader(handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context)}
  • 1

shuffle.BlockStoreShuffleReader.read:

  override def read(): Iterator[Product2[K, C]] = {// 实例化ShuffleBlockFetcherIteratorval blockFetcherItr = new ShuffleBlockFetcherIterator(context,blockManager.shuffleClient,blockManager,// 通过消息发送获取 ShuffleMapTask 存储数据位置的元数据mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),// 设置每次传输的大小SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,// // 设置Int的大小SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue))// 基于配置的压缩和加密来包装流val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>serializerManager.wrapStream(blockId, inputStream)}val serializerInstance = dep.serializer.newInstance()// 对每个流生成 k/v 迭代器val recordIter = wrappedStreams.flatMap { wrappedStream =>serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator}// 每条记录读取后更新任务度量val readMetrics = context.taskMetrics.createTempShuffleReadMetrics()// 生成完整的迭代器val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](recordIter.map { record =>readMetrics.incRecordsRead(1)record},context.taskMetrics().mergeShuffleReadMetrics())// 传入metricIter到可中断的迭代器// 为了能取消迭代val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {// 若需要对数据进行聚合if (dep.mapSideCombine) {// 若需要进行Map端(对于下一个Shuffle来说)的合并val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)// 若只需要进行Reduce端(对于下一个Shuffle来说)的合并} else {val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)}} else {require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]}dep.keyOrdering match {case Some(keyOrd: Ordering[K]) =>// 若需要排序// 若spark.shuffle.spill设置为否的话// 将不会spill到磁盘val sorter =new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer)sorter.insertAll(aggregatedIter)context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())case None =>aggregatedIter}}

类调用关系图:

下面我们来深入讲解下实例化ShuffleBlockFetcherIterator的过程:

  // 实例化ShuffleBlockFetcherIteratorval blockFetcherItr = new ShuffleBlockFetcherIterator(context,blockManager.shuffleClient,blockManager,// 通过消息发送获取 ShuffleMapTask 存储数据位置的元数据mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),// 设置每次传输的大小SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,// // 设置Int的大小SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue))
  • 1

获取元数据

mapOutputTracker.getMapSizesByExecutorId

首先我们会调用mapOutputTracker.getMapSizesByExecutorId

  def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int): Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition")// 得到元数据val statuses = getStatuses(shuffleId)// 返回格式为:// Seq[BlockManagerId,Seq[(shuffle block id, shuffle block size)]]statuses.synchronized {return MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses)}}
  • 1

mapOutputTracker.getStatuses

  private def getStatuses(shuffleId: Int): Array[MapStatus] = {// 尝试从本地获取数据val statuses = mapStatuses.get(shuffleId).orNullif (statuses == null) {// 若本地无数据logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")val startTime = System.currentTimeMillisvar fetchedStatuses: Array[MapStatus] = nullfetching.synchronized {// 若以及有其他人也准备远程获取这数据的话// 则等待while (fetching.contains(shuffleId)) {try {fetching.wait()} catch {case e: InterruptedException =>}}// 尝试直接获取数据fetchedStatuses = mapStatuses.get(shuffleId).orNullif (fetchedStatuses == null) {// 若还是不得不远程获取,// 则将shuffleId加入fetchingfetching += shuffleId}}if (fetchedStatuses == null) {logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint)try {// 远程获取val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId))// 反序列化fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes)logInfo("Got the output locations")// 将数据加入mapStatusesmapStatuses.put(shuffleId, fetchedStatuses)} finally {fetching.synchronized {fetching -= shuffleIdfetching.notifyAll()}}}logDebug(s"Fetching map output statuses for shuffle $shuffleId took " +s"${System.currentTimeMillis - startTime} ms")if (fetchedStatuses != null) {// 若直接获取,则直接返回return fetchedStatuses} else {logError("Missing all output locations for shuffle " + shuffleId)throw new MetadataFetchFailedException(shuffleId, -1, "Missing all output locations for shuffle " + shuffleId)}} else {// 若直接获取,则直接返回return statuses}}
  • 1
  • 2

mapOutputTracker.askTracker

trackerEndpoint发送消息GetMapOutputStatuses(shuffleId)

  protected def askTracker[T: ClassTag](message: Any): T = {try {trackerEndpoint.askWithRetry[T](message)} catch {case e: Exception =>logError("Error communicating with MapOutputTracker", e)throw new SparkException("Error communicating with MapOutputTracker", e)}}
  • 1

MapOutputTrackerMasterEndpoint.receiveAndReply

    case GetMapOutputStatuses(shuffleId: Int) =>val hostPort = context.senderAddress.hostPortlogInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort)val mapOutputStatuses = tracker.post(new GetMapOutputMessage(shuffleId, context))

可以看到,这里并不是直接返回消息,而是调用tracker.post:

  def post(message: GetMapOutputMessage): Unit = {mapOutputRequests.offer(message)}
  • 1
  • 2

mapOutputRequests加入GetMapOutputMessage(shuffleId, context)消息。这里的mapOutputRequests是链式阻塞队列。

  private val mapOutputRequests = new LinkedBlockingQueue[GetMapOutputMessage]
  • 1

MapOutputTrackerMaster.MessageLoop.run

MessageLoop启一个线程不断的参数从mapOutputRequests读取数据:

  private class MessageLoop extends Runnable {override def run(): Unit = {try {while (true) {try {val data = mapOutputRequests.take()if (data == PoisonPill) {mapOutputRequests.offer(PoisonPill)return}val context = data.contextval shuffleId = data.shuffleIdval hostPort = context.senderAddress.hostPortlogDebug("Handling request to send map output locations for shuffle " + shuffleId +" to " + hostPort)// 若读到数据// 则序列化val mapOutputStatuses = getSerializedMapOutputStatuses(shuffleId)// 返回数据context.reply(mapOutputStatuses)} catch {case NonFatal(e) => logError(e.getMessage, e)}}} catch {case ie: InterruptedException => // exit}}}
  • 1
  • 2

MapOutputTracker.convertMapStatuses

我们回到mapOutputTracker.getMapSizesByExecutorId中返回的MapOutputTracker.convertMapStatuses

  private def convertMapStatuses(shuffleId: Int,startPartition: Int,endPartition: Int,statuses: Array[MapStatus]): Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {assert (statuses != null)val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(BlockId, Long)]]for ((status, mapId) <- statuses.zipWithIndex) {if (status == null) {val errorMessage = s"Missing an output location for shuffle $shuffleId"logError(errorMessage)throw new MetadataFetchFailedException(shuffleId, startPartition, errorMessage)} else {for (part <- startPartition until endPartition) {// 返回的Seq中的结构是status.location,Seq[ShuffleBlockId,SizeForBlock]splitsByAddress.getOrElseUpdate(status.location, ArrayBuffer()) +=((ShuffleBlockId(shuffleId, mapId, part), status.getSizeForBlock(part)))}}}// 对Seq根据status.location进行排序splitsByAddress.toSeq}
  • 1

划分本地和远程Block

让我回到new ShuffleBlockFetcherIterator

storage.ShuffleBlockFetcherIterator.initialize

当我们实例化ShuffleBlockFetcherIterator时,会调用initialize:

  private[this] def initialize(): Unit = {context.addTaskCompletionListener(_ => cleanup())// 划分本地和远程的blocksval remoteRequests = splitLocalRemoteBlocks()// 把远程请求随机的添加到队列中fetchRequests ++= Utils.randomize(remoteRequests)assert ((0 == reqsInFlight) == (0 == bytesInFlight),"expected reqsInFlight = 0 but found reqsInFlight = " + reqsInFlight +", expected bytesInFlight = 0 but found bytesInFlight = " + bytesInFlight)// 发送远程请求获取blocksfetchUpToMaxBytes()val numFetches = remoteRequests.size - fetchRequests.sizelogInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime))// 获取本地的BlocksfetchLocalBlocks()logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime))}
  • 1

storage.ShuffleBlockFetcherIterator.splitLocalRemoteBlocks

  private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {// 是的远程请求最大长度为 maxBytesInFlight / 5// maxBytesInFlight: 为单次航班请求的最大字节数// 航班: 一批请求// 1/5 : 是为了提高请求批发度,允许5个请求分别从5个节点获取数据val targetRequestSize = math.max(maxBytesInFlight / 5, 1L)logDebug("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize)// 缓存需要远程请求的FetchRequest对象val remoteRequests = new ArrayBuffer[FetchRequest]// 总共 blocks 的数量var totalBlocks = 0// 我们从上文可知blocksByAddress是根据status.location进行排序的for ((address, blockInfos) <- blocksByAddress) {totalBlocks += blockInfos.sizeif (address.executorId == blockManager.blockManagerId.executorId) {// 若 executorId 相同 与本 blockManagerId.executorId,// 则从本地获取localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1)numBlocksToFetch += localBlocks.size} else {// 否则 远程请求// 得到迭代器val iterator = blockInfos.iterator// 当前累计块的大小var curRequestSize = 0L// 当前累加块// 累加: 若向一个节点频繁的请求字节很少的Block,// 那么会造成网络阻塞var curBlocks = new ArrayBuffer[(BlockId, Long)]// iterator 中的block 都是同一节点的while (iterator.hasNext) {val (blockId, size) = iterator.next()if (size > 0) {curBlocks += ((blockId, size))remoteBlocks += blockIdnumBlocksToFetch += 1curRequestSize += size} else if (size < 0) {throw new BlockException(blockId, "Negative block size " + size)}if (curRequestSize >= targetRequestSize) {// 若累加到大于远程请求的尺寸// 往remoteRequests加入FetchRequestremoteRequests += new FetchRequest(address, curBlocks)curBlocks = new ArrayBuffer[(BlockId, Long)]logDebug(s"Creating fetch request of $curRequestSize at $address")curRequestSize = 0}}// 增加最后的请求if (curBlocks.nonEmpty) {remoteRequests += new FetchRequest(address, curBlocks)}}}logInfo(s"Getting $numBlocksToFetch non-empty blocks out of $totalBlocks blocks")remoteRequests}
  • 1
  • 8

获取Block

storage.ShuffleBlockFetcherIterator.fetchUpToMaxBytes

我们回到storage.ShuffleBlockFetcherIterator.initializefetchUpToMaxBytes()来深入讲解下如何获取远程的Block

  private def fetchUpToMaxBytes(): Unit = {// Send fetch requests up to maxBytesInFlight// 单次航班请求数要小于最大航班请求数// 单次航班字节数数要小于最大航班字节数while (fetchRequests.nonEmpty &&(bytesInFlight == 0 ||(reqsInFlight + 1 <= maxReqsInFlight &&bytesInFlight + fetchRequests.front.size <= maxBytesInFlight))) {sendRequest(fetchRequests.dequeue())}}
  • 1

storage.ShuffleBlockFetcherIterator.sendRequest

  private[this] def sendRequest(req: FetchRequest) {logDebug("Sending request for %d blocks (%s) from %s".format(req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort))bytesInFlight += req.sizereqsInFlight += 1// 可根据blockID查询block大小val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMapval remainingBlocks = new HashSet[String]() ++= sizeMap.keysval blockIds = req.blocks.map(_._1.toString)val address = req.address// 关于shuffleClient.fetchBlocks我们会在之后的博文讲解shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,new BlockFetchingListener {// 请求成功override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = {ShuffleBlockFetcherIterator.this.synchronized {if (!isZombie) {buf.retain()remainingBlocks -= blockIdresults.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf,remainingBlocks.isEmpty))logDebug("remainingBlocks: " + remainingBlocks)}}logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))}// 请求失败override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e)results.put(new FailureFetchResult(BlockId(blockId), address, e))}})}
  • 1

storage.ShuffleBlockFetcherIterator.fetchLocalBlocks

我们再回过头来看获取本地blocks:

  private[this] def fetchLocalBlocks() {// 获取迭代器val iter = localBlocks.iteratorwhile (iter.hasNext) {val blockId = iter.next()try {// 遍历获取数据// blockManager.getBlockData 会在后续博文讲解val buf = blockManager.getBlockData(blockId)shuffleMetrics.incLocalBlocksFetched(1)shuffleMetrics.incLocalBytesRead(buf.size)buf.retain()results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf, false))} catch {case e: Exception =>logError(s"Error occurred while fetching local blocks", e)results.put(new FailureFetchResult(blockId, blockManager.blockManagerId, e))return}}}

深入理解Spark 2.1 Core (十一):Shuffle Reduce 端的原理与源码分析相关推荐

  1. 深入理解Spark 2.1 Core (十):Shuffle Map 端的原理与源码分析

    在上一篇<深入理解Spark 2.1 Core (九):迭代计算和Shuffle的原理与源码分析>提到经过迭代计算后, SortShuffleWriter.write中: // 根据排序方 ...

  2. 深入理解Spark 2.1 Core (十四):securityManager 类源码分析

    securityManager主要用于权限设置,比如在使用yarn作为资源调度框架时,用于生成secret key进行登录.该类默认只用一个实例,所以的app使用同一个实例,下面是该类的所有源代码: ...

  3. 深入理解Spark 2.1 Core (十二):TimSort 的原理与源码分析

    在博文<深入理解Spark 2.1 Core (十):Shuffle Map 端的原理与源码分析 >中我们提到了: 使用Sort等对数据进行排序,其中用到了TimSort 这篇博文我们就来 ...

  4. 深入理解Spark 2.1 Core (七):Standalone模式任务执行的原理与源码分析

    这篇博文,我们就来讲讲Executor启动后,是如何在Executor上执行Task的,以及其后续处理. 执行Task 我们在<深入理解Spark 2.1 Core (三):任务调度器的原理与源 ...

  5. 深入理解Spark 2.1 Core (六):Standalone模式运行的原理与源码分析

    我们讲到了如何启动Master和Worker,还讲到了如何回收资源.但是,我们没有将AppClient是如何启动的,其实它们的启动也涉及到了资源是如何调度的.这篇博文,我们就来讲一下AppClient ...

  6. 深入理解Spark 2.1 Core (八):Standalone模式容错及HA的原理与源码分析

    第五.第六.第七篇博文,我们讲解了Standalone模式集群是如何启动的,一个App起来了后,集群是如何分配资源,Worker启动Executor的,Task来是如何执行它,执行得到的结果如何处理, ...

  7. 深入理解Spark 2.1 Core (五):Standalone模式运行的原理与源码分析

    概述 前几篇博文都在介绍Spark的调度,这篇博文我们从更加宏观的调度看Spark,讲讲Spark的部署模式.Spark部署模式分以下几种: local 模式 local-cluster 模式 Sta ...

  8. 深入理解Spark 2.1 Core (二):DAG调度器的原理与源码分析

    概述 上一篇<深入理解Spark(一):RDD实现及源码分析 >提到: 定义RDD之后,程序员就可以在动作(注:即action操作)中使用RDD了.动作是向应用程序返回值,或向存储系统导出 ...

  9. 深入理解GO语言:GC原理及源码分析

    Go 中的runtime 类似 Java的虚拟机,它负责管理包括内存分配.垃圾回收.栈处理.goroutine.channel.切片(slice).map 和反射(reflection)等.Go 的可 ...

最新文章

  1. 亡羊补课2019-12-19
  2. es6中新增对象的特性和方法
  3. Java ClassLoader findResources()方法与示例
  4. Windows Terminal Preview 1.5 发布
  5. soltrace教程(2)旧版本项目导入新版本
  6. Java中 Tomcat 是干什么的?
  7. incre在c语言,longest incresing sequence
  8. 程序员的四个等级:菜鸟、普通、大牛、大神
  9. 详细介绍借助Docker Hub访问gcr.io镜像
  10. 如何给纸壳箱上装,#ps修图p图抠图视频教程小白入门基础课程
  11. 未来五年,物联网三大技术发展趋势!
  12. Processing 网格纹理制作(棋盘格)
  13. android 圆圈扩大动画,Android实现3个圆圈的动画
  14. LeetCode - 263 - Ugly Number
  15. 麦克斯韦方程组,史上最牛逼公式之一
  16. 信息学奥赛一本通:2026:【例4.12】阶乘和
  17. 单例模式深入浅出---详细注释
  18. 基于MATLAB的频谱、能量谱、三分之一倍频程分析
  19. QNX APS自适应分区调度
  20. 布莱克—斯科尔斯—默顿(BSM)模型

热门文章

  1. 【已解决】图灵机模型(模拟二进制非负整数加1)
  2. leetcode算法刷题记录表
  3. 偏置面命令_UG10.0 入门图文教程——同步建模之移动面
  4. shell之case和循环语句(case语句的格式与举例)(for循环,while循环until循环语句的详解和continue,break解释, 九九乘法口诀表 ,等腰三角形)
  5. 电脑不能打字_电脑拼音打字快速入门秘籍
  6. c语言编程房屋中介系统,房地产经纪人优题库app下载-房地产经纪人优题库app安卓版下载v4.6.0 - 非凡软件站...
  7. java垃圾回收策论_深入理解 Java 虚拟机【3】垃圾收集策略与算法
  8. qsettings mysql_qt连接mysql
  9. java url no protocol_httpurlconnection 新人使用遇到错误java.net.MalformedURLException: no protocol...
  10. diy nas配置推荐2019_在Windows Server 2019上配置NAS的方法