Apache Spark-1.0.0浅析(十一):Shuffle过程

一、Shuffle的产生

Shuffle Dependency是划分stages的依据,由此判断是ShuffleMapStage或ResultStage,正如下所述

* A Spark job consists of one or more stages. The very last stage in a job consists of multiple
* ResultTasks, while earlier stages consist of ShuffleMapTasks. A ResultTask executes the task
* and sends the task output back to the driver application. A ShuffleMapTask executes the task
* and divides the task output to multiple buckets (based on the task‘s partitioner).

Shuffle是MapReduce框架中的必要环节,它是连接Map和Reduce的桥梁。Shuffle只可能产生于值为[k, v]的PairedRDD的操作中,其他RDD是不会产生Shuffle的。当Map的输出结果要被Reduce使用时,输出结果需要按key哈希,并且分发到每一个Reducer上去,这个过程就是shuffle。Shuffle过程涉及到磁盘的读写和网络的传输,因此shuffle性能的高低直接影响到了整个程序的运行效率。正因如此,shuffle是Spark调优,更普遍来说是MapReduce框架调优的关键。

二、Shuffle写入

《Task执行》中最后提到,ShuffleMapTask与ResultTask的runTask实现是不一样的,主要区别在于中间计算结果是否write。下面分几个主要部分分析ShuffleMapTask.runTask

(I)定义变量

首先定义了4个变量:numOutputSplit、BlockManager、ShuffleBlockManager和Shuffle。numOutputSplits是partition的数量;通过SparkEnv获取blockManager;通过blockManager定义shuffleBlockManager;定义Shuffle为shuffleWriterGroup类型。

val numOutputSplits = dep.partitioner.numPartitions

val blockManager = SparkEnv.get.blockManager
val shuffleBlockManager = blockManager.shuffleBlockManager
var shuffle: ShuffleWriterGroup = null

ShuffleBlockManager的类定义如下。如注释所述,该类将基于磁盘的block writer分配给shuffle任务。每个shuffle任务获得一个文件/reducer, 这个文件集被称为ShuffleFileGroup。为了减少shuffle文件产生数量,多个shuffle blocks累积到同一个文件。当任务完成shuffle文件的写入时,立即释放该文件让另外的task占用。

Shuffle文件由三元组(shuffleId,bucketId,fileId)唯一标记。每个shuffle文件映射到一个Filesegment,同样也是一个三元组(file,offset,length),指明实际block数据在给定文件中的位置。Shuffle文件以高效空间方式存储,每个ShuffleFileGroup为每个文件中存储的每个block维护一个偏移列表。要找到shuffle block的位置,在与block reducer相关的ShuffleMapGroup中搜索。

/**
 * Manages assigning disk-based block writers to shuffle tasks. Each shuffle task gets one file
 * per reducer (this set of files is called a ShuffleFileGroup).
 *
 * As an optimization to reduce the number of physical shuffle files produced, multiple shuffle
 * blocks are aggregated into the same file. There is one "combined shuffle file" per reducer
 * per concurrently executing shuffle task. As soon as a task finishes writing to its shuffle
 * files, it releases them for another task.
 * Regarding the implementation of this feature, shuffle files are identified by a 3-tuple:
 *   - shuffleId: The unique id given to the entire shuffle stage.
 *   - bucketId: The id of the output partition (i.e., reducer id)
 *   - fileId: The unique id identifying a group of "combined shuffle files." Only one task at a
 *       time owns a particular fileId, and this id is returned to a pool when the task finishes.
 * Each shuffle file is then mapped to a FileSegment, which is a 3-tuple (file, offset, length)
 * that specifies where in a given file the actual block data is located.
 *
 * Shuffle file metadata is stored in a space-efficient manner. Rather than simply mapping
 * ShuffleBlockIds directly to FileSegments, each ShuffleFileGroup maintains a list of offsets for
 * each block stored in each file. In order to find the location of a shuffle block, we search the
 * files within a ShuffleFileGroups associated with the block‘s reducer.
 */
private[spark]
class ShuffleBlockManager(blockManager: BlockManager) extends Logging {

ShuffleWriterGroup声明如下,为ShuffleMapTask定义了一组writer,每个reducer一个writer

/** A group of writers for a ShuffleMapTask, one writer per reducer. */
private[spark] trait ShuffleWriterGroup {

(II)获取shuffle writer

为shuffle blocks获取所有的block writers,首先获得序列化器,然后shuffleBlockManager调用forMapTask根据(shuffleId = shuffleId,mapId = partitionId,numBuckets = numOutputSplits)获取shuffle writer

// Obtain all the block writers for shuffle blocks.
val ser = Serializer.getSerializer(dep.serializer)
shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, partitionId, numOutputSplits, ser)

查看ShuffleBlockManager.forMapTask,可以发现writers其实是BlockObjectWriter数组

def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer) = {
    new ShuffleWriterGroup {
      shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numBuckets))
      private val shuffleState = shuffleStates(shuffleId)
      private var fileGroup: ShuffleFileGroup = null

      val writers: Array[BlockObjectWriter] = if (consolidateShuffleFiles) {
        fileGroup = getUnusedFileGroup()
        Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
          val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
          blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializer, bufferSize)
        }
      } else {
        Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
          val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
          val blockFile = blockManager.diskBlockManager.getFile(blockId)
          // Because of previous failures, the shuffle file may already exist on this machine.
          // If so, remove it.
          if (blockFile.exists) {
            if (blockFile.delete()) {
              logInfo(s"Removed existing shuffle file $blockFile")
            } else {
              logWarning(s"Failed to remove existing shuffle file $blockFile")
            }
          }
          blockManager.getDiskWriter(blockId, blockFile, serializer, bufferSize)
        }
      }

关于consolidateShuffleFiles选项

// Turning off shuffle file consolidation causes all shuffle Blocks to get their own file.
// TODO: Remove this once the shuffle file consolidation feature is stable.
val consolidateShuffleFiles =
  conf.getBoolean("spark.shuffle.consolidateFiles", false)

该选项如果打开,则首先获取UnusedFileGroup,如果已经存在fileGroup返回,没有则创建

private def getUnusedFileGroup(): ShuffleFileGroup = {
      val fileGroup = shuffleState.unusedFileGroups.poll()
      if (fileGroup != null) fileGroup else newFileGroup()
}

对于每一个三元组(shuffleId,mapId,bucketId)确定的bucketId创建blockId,而相同的bucketId使用同一个fileGroup中的不同文件,即要发送到同一个reduce的数据写入到同一个文件,如此生成的bucket数量等于Reducer。fileGroup其实调用apply方法,取bucketId对应的文件

def apply(bucketId: Int) = files(bucketId)

如果关闭,所有shuffle blocks写入单独的文件,同样三元组(shuffleId,mapId,bucketId)确定一个blockId,以blockId作为参数,根据blockId调用blockManager.diskBlockManager.getFile得到blockFile,在磁盘空间中创建目录文件,即按照blockId生成文件,如此会创建的bucket数量则为Mapper*Reduer。getFile调用以blockId.name为参数的同名方法

def getFile(blockId: BlockId): File = getFile(blockId.name)

最终调用的getFile如下

def getFile(filename: String): File = {
    // Figure out which local directory it hashes to, and which subdirectory in that
    val hash = Utils.nonNegativeHash(filename)
    val dirId = hash % localDirs.length
    val subDirId = (hash / localDirs.length) % subDirsPerLocalDir

    // Create the subdirectory if it doesn‘t already exist
    var subDir = subDirs(dirId)(subDirId)
    if (subDir == null) {
      subDir = subDirs(dirId).synchronized {
        val old = subDirs(dirId)(subDirId)
        if (old != null) {
          old
        } else {
          val newDir = new File(localDirs(dirId), "%02x".format(subDirId))
          newDir.mkdir()
          subDirs(dirId)(subDirId) = newDir
          newDir
        }
      }
    }

ShuffleBlockId是一个case class,它定义了shuffle writer写入的文件名

case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int)
  extends BlockId {
  def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId
}

最后注意一下blockManager.getDiskWriter,最后一个参数buffersize默认100kb,这直接影响shuffle过程占用的内存空间大小

private val bufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024

创建文件、获取DiskWriter完成后,Shuffle的中间结果都需要落入磁盘中

(III)写入buckets

遍历RDD的所有partitions,将每个元素转换成(K,V)格式,计算得到bucketId,最后将(K,V)通过bucketId对应writer写入bucket中

// Write the map output to its associated buckets.
for (elem <- rdd.iterator(split, context)) {
  val pair = elem.asInstanceOf[Product2[Any, Any]]
  val bucketId = dep.partitioner.getPartition(pair._1)
  shuffle.writers(bucketId).write(pair)
}

key-value pair逐个写入磁盘文件中,不用预先把所有数据存储在内存中再整体flush到磁盘。

write的定义为BlockObjectWriter.write

/**
   * Writes an object.
   */
  def write(value: Any)

具体实现为DiskBlockObjectWriter.write

override def write(value: Any) {
    if (!initialized) {
      open()
    }
    objOut.writeObject(value)
  }

(IV)执行

注意写入分区中的数据大小是用Byte表示的数组,这就需要compressSize方法

// Commit the writes. Get the size of each bucket block (total block size).
var totalBytes = 0L
var totalTime = 0L
val compressedSizes: Array[Byte] = shuffle.writers.map { writer: BlockObjectWriter =>
  writer.commit()
  writer.close()
  val size = writer.fileSegment().length
  totalBytes += size
  totalTime += writer.timeWriting()
  MapOutputTracker.compressSize(size)
}

compressSize,使用1.1为底数的指数,将28映射成1.1256,支持至少35GB大小,切误差只有10%,非常巧妙

/**
   * Compress a size in bytes to 8 bits for efficient reporting of map output sizes.
   * We do this by encoding the log base 1.1 of the size as an integer, which can support
   * sizes up to 35 GB with at most 10% error.
   */
  def compressSize(size: Long): Byte = {
    if (size == 0) {
      0
    } else if (size <= 1L) {
      1
    } else {
      math.min(255, math.ceil(math.log(size) / math.log(LOG_BASE)).toInt).toByte
    }
  }

(V)更新监控

// Update shuffle metrics.
val shuffleMetrics = new ShuffleWriteMetrics
shuffleMetrics.shuffleBytesWritten = totalBytes
shuffleMetrics.shuffleWriteTime = totalTime
metrics.get.shuffleWriteMetrics = Some(shuffleMetrics)

success = true
new MapStatus(blockManager.blockManagerId, compressedSizes)

(VI)异常处理

    catch { case e: Exception =>
      // If there is an exception from running the task, revert the partial writes
      // and throw the exception upstream to Spark.
      if (shuffle != null && shuffle.writers != null) {
        for (writer <- shuffle.writers) {
          writer.revertPartialWrites()
          writer.close()
        }
      }
      throw e
    } finally {
      // Release the writers back to the shuffle block manager.
      if (shuffle != null && shuffle.writers != null) {
        try {
          shuffle.releaseWriters(success)
        } catch {
          case e: Exception => logError("Failed to release shuffle writers", e)
        }
      }

(VII)成功回调

// Execute the callbacks on task completion.
context.executeOnCompleteCallbacks()

三、Shuffle读取

(I)记录MapOutputs

《结果返回》中提到,当ShuffleMapTask执行完成时,调用handleTaskCompletion处理后续过程

/**
 * Responds to a task finishing. This is called inside the event loop so it assumes that it can
 * modify the scheduler‘s internal state. Use taskEnded() to post a task end event from outside.
 */
private[scheduler] def handleTaskCompletion(event: CompletionEvent) {

handleTaskCompletion定义中,专门定义了ShuffleMapTask成功完成时的响应

          case smt: ShuffleMapTask =>
            val status = event.result.asInstanceOf[MapStatus]
            val execId = status.location.executorId
            logDebug("ShuffleMapTask finished on " + execId)
            if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) {
              logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + execId)
            } else {
              stage.addOutputLoc(smt.partitionId, status)
            }
            if (runningStages.contains(stage) && pendingTasks(stage).isEmpty) {
              markStageAsFinished(stage)
              logInfo("looking for newly runnable stages")
              logInfo("running: " + runningStages)
              logInfo("waiting: " + waitingStages)
              logInfo("failed: " + failedStages)
              if (stage.shuffleDep.isDefined) {
                // We supply true to increment the epoch number here in case this is a
                // recomputation of the map outputs. In that case, some nodes may have cached
                // locations with holes (from when we detected the error) and will need the
                // epoch incremented to refetch them.
                // TODO: Only increment the epoch number if this is not the first time
                //       we registered these map outputs.
                mapOutputTracker.registerMapOutputs(
                  stage.shuffleDep.get.shuffleId,
                  stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray,
                  changeEpoch = true)
              }
              clearCacheLocs()
              if (stage.outputLocs.exists(_ == Nil)) {
                // Some tasks had failed; let‘s resubmit this stage
                // TODO: Lower-level scheduler should also deal with this
                logInfo("Resubmitting " + stage + " (" + stage.name +
                  ") because some of its tasks had failed: " +
                  stage.outputLocs.zipWithIndex.filter(_._1 == Nil).map(_._2).mkString(", "))
                submitStage(stage)
              } else {
                val newlyRunnable = new ArrayBuffer[Stage]
                for (stage <- waitingStages) {
                  logInfo("Missing parents for " + stage + ": " + getMissingParentStages(stage))
                }
                for (stage <- waitingStages if getMissingParentStages(stage) == Nil) {
                  newlyRunnable += stage
                }
                waitingStages --= newlyRunnable
                runningStages ++= newlyRunnable
                for {
                  stage <- newlyRunnable.sortBy(_.id)
                  jobId <- activeJobForStage(stage)
                } {
                  logInfo("Submitting " + stage + " (" + stage.rdd + "), which is now runnable")
                  submitMissingTasks(stage, jobId)
                }
              }
            }
          }

ShuffleMapTask成功完成后,调用stage.addOutputLoc

stage.addOutputLoc(smt.partitionId, status)

把Map返回的MapStatus添加到stage的outputLoc中

def addOutputLoc(partition: Int, status: MapStatus) {
    val prevList = outputLocs(partition)
    outputLocs(partition) = status :: prevList
    if (prevList == Nil) {
      numAvailableOutputs += 1
    }
  }

outputLocs是一个MapStatus类型的List

val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil)

MapStatus中记录着输出结果的相关信息,为了将其传递到对应reduce任务,其中包含了BlockManagerId和为每个reducer输出数据大小(经过压缩)

/**
 * Result returned by a ShuffleMapTask to a scheduler. Includes the block manager address that the
 * task ran on as well as the sizes of outputs for each reducer, for passing on to the reduce tasks.
 * The map output sizes are compressed using MapOutputTracker.compressSize.
 */
private[spark] class MapStatus(var location: BlockManagerId, var compressedSizes: Array[Byte])
  extends Externalizable {

条件stage.shuffleDep.isDefined定义如下,判断如果是Mapper则执行操作,如果为Reducer则跳过

val shuffleDep: Option[ShuffleDependency[_,_]],  // Output shuffle if stage is a map stage

如果所有的shuffle的task都执行完成,调用registerMapOutputs,把此stage对应的shuffled与所有的location注册到mapOutputTracker中

mapOutputTracker.registerMapOutputs(
                  stage.shuffleDep.get.shuffleId,
                  stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray,
                  changeEpoch = true)

MapOutputTrackerMaster.registerMapOutputs定义如下

/** Register multiple map output information for the given shuffle */
  def registerMapOutputs(shuffleId: Int, statuses: Array[MapStatus], changeEpoch: Boolean = false) {
    mapStatuses.put(shuffleId, Array[MapStatus]() ++ statuses)
    if (changeEpoch) {
      incrementEpoch()
    }
  }

(II)获取

《Task执行》中则提到,RDD.iterator判断RDD是否cached,调用getOrCompute还是computeOrReadCheckpoint

final def iterator(split: Partition, context: TaskContext): Iterator[T] = {
    if (storageLevel != StorageLevel.NONE) {
      SparkEnv.get.cacheManager.getOrCompute(this, split, context, storageLevel)
    } else {
      computeOrReadCheckpoint(split, context)
    }
  }

computeOrReadCheckpoint调用compute计算结果

private[spark] def computeOrReadCheckpoint(split: Partition, context: TaskContext): Iterator[T] =
  {
    if (isCheckpointed) firstParent[T].iterator(split, context) else compute(split, context)
  }

ShuffleRDD对compute的实现如下,这也是读取ShuffleMapTask计算结果的入口

override def compute(split: Partition, context: TaskContext): Iterator[P] = {
    val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId
    val ser = Serializer.getSerializer(serializer)
    SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context, ser)
  }

下面来看fetch,其作用是获取Shuffle Map的输出,有四个参数:shuffleId、reduceId、context和serializer,返回一个iterator遍历shuffle outputs的所有元素

/**
   * Fetch the shuffle outputs for a given ShuffleDependency.
   * @return An iterator over the elements of the fetched shuffle outputs.
   */
  def fetch[T](
      shuffleId: Int,
      reduceId: Int,
      context: TaskContext,
      serializer: Serializer = SparkEnv.get.serializer): Iterator[T]

具体实现是BlockStoreShuffleFetcher.fetch

override def fetch[T](
      shuffleId: Int,
      reduceId: Int,
      context: TaskContext,
      serializer: Serializer)
    : Iterator[T] =
  {

    logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
    val blockManager = SparkEnv.get.blockManager

    val startTime = System.currentTimeMillis
    val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId)
    logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format(
      shuffleId, reduceId, System.currentTimeMillis - startTime))

    val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]]
    for (((address, size), index) <- statuses.zipWithIndex) {
      splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size))
    }

    val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])] = splitsByAddress.toSeq.map {
      case (address, splits) =>
        (address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2)))
    }

    def unpackBlock(blockPair: (BlockId, Option[Iterator[Any]])) : Iterator[T] = {
      val blockId = blockPair._1
      val blockOption = blockPair._2
      blockOption match {
        case Some(block) => {
          block.asInstanceOf[Iterator[T]]
        }
        case None => {
          blockId match {
            case ShuffleBlockId(shufId, mapId, _) =>
              val address = statuses(mapId.toInt)._1
              throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, null)
            case _ =>
              throw new SparkException(
                "Failed to get block " + blockId + ", which is not a shuffle block")
          }
        }
      }
    }

    val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer)
    val itr = blockFetcherItr.flatMap(unpackBlock)

    val completionIter = CompletionIterator[T, Iterator[T]](itr, {
      val shuffleMetrics = new ShuffleReadMetrics
      shuffleMetrics.shuffleFinishTime = System.currentTimeMillis
      shuffleMetrics.fetchWaitTime = blockFetcherItr.fetchWaitTime
      shuffleMetrics.remoteBytesRead = blockFetcherItr.remoteBytesRead
      shuffleMetrics.totalBlocksFetched = blockFetcherItr.totalBlocks
      shuffleMetrics.localBlocksFetched = blockFetcherItr.numLocalBlocks
      shuffleMetrics.remoteBlocksFetched = blockFetcherItr.numRemoteBlocks
      context.taskMetrics.shuffleReadMetrics = Some(shuffleMetrics)
    })

    new InterruptibleIterator[T](context, completionIter)
  }

接下来,分析该方法

(1)调用mapOutputTracker.getServerStatuses使worker获取master的URIs和map输出的大小,即之前存储的MapStatus信息

val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId)

getServerStatuses实现如下,注意这个方法是executor调用的,根据shuffleId和reduceId,返回BlockManagerId和一个Long型数字表示的map输出大小,一个BlockManagerId对应多个文件的大小

/**
   * Called from executors to get the server URIs and output sizes of the map outputs of
   * a given shuffle.
   */
  def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = {
    val statuses = mapStatuses.get(shuffleId).orNull
    if (statuses == null) {
      logInfo("Don‘t have map outputs for shuffle " + shuffleId + ", fetching them")
      var fetchedStatuses: Array[MapStatus] = null
      fetching.synchronized {
        if (fetching.contains(shuffleId)) {
          // Someone else is fetching it; wait for them to be done
          while (fetching.contains(shuffleId)) {
            try {
              fetching.wait()
            } catch {
              case e: InterruptedException =>
            }
          }
        }

        // Either while we waited the fetch happened successfully, or
        // someone fetched it in between the get and the fetching.synchronized.
        fetchedStatuses = mapStatuses.get(shuffleId).orNull
        if (fetchedStatuses == null) {
          // We have to do the fetch, get others to wait for us.
          fetching += shuffleId
        }
      }

      if (fetchedStatuses == null) {
        // We won the race to fetch the output locs; do so
        logInfo("Doing the fetch; tracker actor = " + trackerActor)
        // This try-finally prevents hangs due to timeouts:
        try {
          val fetchedBytes =
            askTracker(GetMapOutputStatuses(shuffleId)).asInstanceOf[Array[Byte]]
          fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes)
          logInfo("Got the output locations")
          mapStatuses.put(shuffleId, fetchedStatuses)
        } finally {
          fetching.synchronized {
            fetching -= shuffleId
            fetching.notifyAll()
          }
        }
      }
      if (fetchedStatuses != null) {
        fetchedStatuses.synchronized {
          return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses)
        }
      } else {
        throw new FetchFailedException(null, shuffleId, -1, reduceId,
          new Exception("Missing all output locations for shuffle " + shuffleId))
      }
    } else {
      statuses.synchronized {
        return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses)
      }
    }
  }

最后调用convertMapStatus转换MapStatus

// Convert an array of MapStatuses to locations and sizes for a given reduce ID. If
  // any of the statuses is null (indicating a missing location due to a failed mapper),
  // throw a FetchFailedException.
  private def convertMapStatuses(
        shuffleId: Int,
        reduceId: Int,
        statuses: Array[MapStatus]): Array[(BlockManagerId, Long)] = {
    assert (statuses != null)
    statuses.map {
      status =>
        if (status == null) {
          throw new FetchFailedException(null, shuffleId, -1, reduceId,
            new Exception("Missing an output location for shuffle " + shuffleId))
        } else {
          (status.location, decompressSize(status.compressedSizes(reduceId)))
        }
    }
  }

Long型的输出大小是decompressSize后的结果

/**
   * Decompress an 8-bit encoded block size, using the reverse operation of compressSize.
   */
  def decompressSize(compressedSize: Byte): Long = {
    if (compressedSize == 0) {
      0
    } else {
      math.pow(LOG_BASE, (compressedSize & 0xFF)).toLong
    }
  }

(2)构造BlockManagerId 和 BlockId的映射关系,创建HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]],获取或者更新元素,将其转换成(BlockManagerId,ShuffleBlockId,Size)三元组,其中的ShuffleBlockId就是index,而ShuffleBlockId是ShuffleBlockId(shuffleId, mapId, bucketId)组合得到的,mapId为Index=1,2,3……

  val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]]
  for (((address, size), index) <- statuses.zipWithIndex) {
    splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size))
  }

  val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])] = splitsByAddress.toSeq.map {
    case (address, splits) =>
      (address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2)))
  }

(3)定义校验函数unpackBlock,若BlockId对应一个Iterator则返回,若没有则抛出异常

def unpackBlock(blockPair: (BlockId, Option[Iterator[Any]])) : Iterator[T] = {
      val blockId = blockPair._1
      val blockOption = blockPair._2
      blockOption match {
        case Some(block) => {
          block.asInstanceOf[Iterator[T]]
        }
        case None => {
          blockId match {
            case ShuffleBlockId(shufId, mapId, _) =>
              val address = statuses(mapId.toInt)._1
              throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, null)
            case _ =>
              throw new SparkException(
                "Failed to get block " + blockId + ", which is not a shuffle block")
          }
        }
      }
    }

接下来调用BlockManager.getMultiple从本地或者远端block manager获得多个blocks,并使用unpackBlock校验返回Iterator。

val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer)
val itr = blockFetcherItr.flatMap(unpackBlock)

getMultiple方法,根据是否使用netty,分成BasicBlockFetcherIterator和NettyBlockFetcherIterator。

/**
  * Get multiple blocks from local and remote block manager using their BlockManagerIds. Returns
  * an Iterator of (block ID, value) pairs so that clients may handle blocks in a pipelined
  * fashion as they‘re received. Expects a size in bytes to be provided for each block fetched,
  * so that we can control the maxMegabytesInFlight for the fetch.
  */
  def getMultiple(
      blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
      serializer: Serializer): BlockFetcherIterator = {
    val iter =
      if (conf.getBoolean("spark.shuffle.use.netty", false)) {
        new BlockFetcherIterator.NettyBlockFetcherIterator(this, blocksByAddress, serializer)
      } else {
        new BlockFetcherIterator.BasicBlockFetcherIterator(this, blocksByAddress, serializer)
      }

    iter.initialize()
    iter
  }

先看BasicBlockFetcherIterator

初始化initialize,首先划分local和remote的blocks,将remote blocks以随机顺序放入请求序列,发送获取请求,最多不超过maxByteInFlight,并在remote blocks返回结果的同时,获取local blocks。

  override def initialize() {
      // Split local and remote blocks.
      val remoteRequests = splitLocalRemoteBlocks()
      // Add the remote requests into our queue in a random order
      fetchRequests ++= Utils.randomize(remoteRequests)

      // Send out initial requests for blocks, up to our maxBytesInFlight
      while (!fetchRequests.isEmpty &&
        (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
        sendRequest(fetchRequests.dequeue())
      }

      val numFetches = remoteRequests.size - fetchRequests.size
      logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime))

      // Get Local Blocks
      startTime = System.currentTimeMillis
      getLocalBlocks()
      logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
    }

下面分析几个重要的方法

1. splitLocalRemoteBlocks

最多同时从5个节点并行读取数据,每次请求的数据不会超过spark.reducer.maxMbInFlight / 5;通过blocksByAddress中的BlockManagerId与本地BlockManagerId对比,判断是否local blocks,如果是local,过滤掉0大小的block,将BlockInfos中的BlockId记录到localBlocksToFetch中,累计block fetch的大小;如果是remote,也过滤掉0大小的block,通过Iterator遍历blocks,将blockId添加到remoteBlocksToFetch,size累计到curRequestSize中,如果curRequestSize刚超过targetRequestSize,则立即创建remote fetch request,如果遍历最后有剩余size,则将最后部分作为一个request,最后返回remoteRequests。

protected def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
  // Make remote requests at most maxBytesInFlight / 5 in length; the reason to keep them
  // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
  // nodes, rather than blocking on reading output from one node.
  val targetRequestSize = math.max(maxBytesInFlight / 5, 1L)
  logInfo("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize)

  // Split local and remote blocks. Remote blocks are further split into FetchRequests of size
  // at most maxBytesInFlight in order to limit the amount of data in flight.
  val remoteRequests = new ArrayBuffer[FetchRequest]
  for ((address, blockInfos) <- blocksByAddress) {
    if (address == blockManagerId) {
      numLocal = blockInfos.size
      // Filter out zero-sized blocks
      localBlocksToFetch ++= blockInfos.filter(_._2 != 0).map(_._1)
      _numBlocksToFetch += localBlocksToFetch.size
    } else {
      numRemote += blockInfos.size
      val iterator = blockInfos.iterator
      var curRequestSize = 0L
      var curBlocks = new ArrayBuffer[(BlockId, Long)]
      while (iterator.hasNext) {
        val (blockId, size) = iterator.next()
        // Skip empty blocks
        if (size > 0) {
          curBlocks += ((blockId, size))
          remoteBlocksToFetch += blockId
          _numBlocksToFetch += 1
          curRequestSize += size
        } else if (size < 0) {
          throw new BlockException(blockId, "Negative block size " + size)
        }
        if (curRequestSize >= targetRequestSize) {
          // Add this FetchRequest
          remoteRequests += new FetchRequest(address, curBlocks)
          curRequestSize = 0
          curBlocks = new ArrayBuffer[(BlockId, Long)]
          logDebug(s"Creating fetch request of $curRequestSize at $address")
        }
      }
      // Add in the final request
      if (!curBlocks.isEmpty) {
        remoteRequests += new FetchRequest(address, curBlocks)
      }
    }
  }
  logInfo("Getting " + _numBlocksToFetch + " non-empty blocks out of " +
    totalBlocks + " blocks")
  remoteRequests
}

maxBytesInFlight大小定义如下,最大48MB,限制正在获取和需要发送请求

// Max megabytes of data to keep in flight per reducer (to avoid over-allocating memory
// for receiving shuffle outputs)
val maxBytesInFlight =
  conf.getLong("spark.reducer.maxMbInFlight", 48) * 1024 * 1024

2. sendRequest

通过ConnectionManager建立连接,然后sendMessageReliably,检验返回消息,根据blockId等信息,调用dataDeserialize将存储size的bytebuffer转换成Iterator。

protected def sendRequest(req: FetchRequest) {
    logDebug("Sending request for %d blocks (%s) from %s".format(
      req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort))
    val cmId = new ConnectionManagerId(req.address.host, req.address.port)
    val blockMessageArray = new BlockMessageArray(req.blocks.map {
      case (blockId, size) => BlockMessage.fromGetBlock(GetBlock(blockId))
    })
    bytesInFlight += req.size
    val sizeMap = req.blocks.toMap  // so we can look up the size of each blockID
    val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage)
    future.onSuccess {
      case Some(message) => {
        val bufferMessage = message.asInstanceOf[BufferMessage]
        val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage)
        for (blockMessage <- blockMessageArray) {
          if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) {
            throw new SparkException(
              "Unexpected message " + blockMessage.getType + " received from " + cmId)
          }
          val blockId = blockMessage.getId
          val networkSize = blockMessage.getData.limit()
          results.put(new FetchResult(blockId, sizeMap(blockId),
            () => dataDeserialize(blockId, blockMessage.getData, serializer)))
          _remoteBytesRead += networkSize
          logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
        }
      }
      case None => {
        logError("Could not get block(s) from " + cmId)
        for ((blockId, size) <- req.blocks) {
          results.put(new FetchResult(blockId, -1, null))
        }
      }
    }
  }

3. getLocalBlocks

注释中说明,之所以可以与remote blocks并行获取,是因为local blocks获取时只是内存映射到某些文件,不实际消耗网络资源(48MB上限)

遍历localBlocksToFetch,getLocalFromDisk实际调用diskStore.getValues依据blockId直接从磁盘读取数据,返回Iterator。

protected def getLocalBlocks() {
      // Get the local blocks while remote blocks are being fetched. Note that it‘s okay to do
      // these all at once because they will just memory-map some files, so they won‘t consume
      // any memory that might exceed our maxBytesInFlight
      for (id <- localBlocksToFetch) {
        getLocalFromDisk(id, serializer) match {
          case Some(iter) => {
            // Pass 0 as size since it‘s not in flight
            results.put(new FetchResult(id, 0, () => iter))
            logDebug("Got local block " + id)
          }
          case None => {
            throw new BlockException(id, "Could not get block " + id + " from local machine")
          }
        }
      }
    }

再看NettyBlockFetcherIterator

初始化initialize,同样调用splitLocalRemoteBlocks划分local和remote blocks,随机顺序获取请求,启动copiers拷贝remote blocks,设定并行拷贝进程数为6个,获取local blocks。

    override def initialize() {
      // Split Local Remote Blocks and set numBlocksToFetch
      val remoteRequests = splitLocalRemoteBlocks()
      // Add the remote requests into our queue in a random order
      for (request <- Utils.randomize(remoteRequests)) {
        fetchRequestsSync.put(request)
      }

      copiers = startCopiers(conf.getInt("spark.shuffle.copier.threads", 6))
      logInfo("Started " + fetchRequestsSync.size + " remote fetches in " +
        Utils.getUsedTimeMs(startTime))

      // Get Local Blocks
      startTime = System.currentTimeMillis
      getLocalBlocks()
      logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
    }

copiers其实是进程列表

private var copiers: List[_ <: Thread] = null

startCopiers实现如下,关键在于NettyBlockFetcherIterator类中重新实现的sendRequest。

private def startCopiers(numCopiers: Int): List[_ <: Thread] = {
      (for ( i <- Range(0,numCopiers) ) yield {
        val copier = new Thread {
          override def run(){
            try {
              while(!isInterrupted && !fetchRequestsSync.isEmpty) {
                sendRequest(fetchRequestsSync.take())
              }
            } catch {
              case x: InterruptedException => logInfo("Copier Interrupted")
              // case _ => throw new SparkException("Exception Throw in Shuffle Copier")
            }
          }
        }
        copier.start
        copier
      }).toList
    }

NettyBlockFetcherIterator.sendRequest,创建ShuffleCopier,调用ShuffleCopier.getBlocks获得blocks。

override protected def sendRequest(req: FetchRequest) {

      def putResult(blockId: BlockId, blockSize: Long, blockData: ByteBuf) {
        val fetchResult = new FetchResult(blockId, blockSize,
          () => dataDeserialize(blockId, blockData.nioBuffer, serializer))
        results.put(fetchResult)
      }

      logDebug("Sending request for %d blocks (%s) from %s".format(
        req.blocks.size, Utils.bytesToString(req.size), req.address.host))
      val cmId = new ConnectionManagerId(req.address.host, req.address.nettyPort)
      val cpier = new ShuffleCopier(blockManager.conf)
      cpier.getBlocks(cmId, req.blocks, putResult)
      logDebug("Sent request for remote blocks " + req.blocks + " from " + req.address.host )
    }

getBlocks最终调用getBlock,创建FileClient,发送请求,从文件中获取blocks,具体工作由netty完成。

def getBlock(host: String, port: Int, blockId: BlockId,
      resultCollectCallback: (BlockId, Long, ByteBuf) => Unit) {

  val handler = new ShuffleCopier.ShuffleClientHandler(resultCollectCallback)
  val connectTimeout = conf.getInt("spark.shuffle.netty.connect.timeout", 60000)
  val fc = new FileClient(handler, connectTimeout)

  try {
    fc.init()
    fc.connect(host, port)
    fc.sendRequest(blockId.name)
    fc.waitForClose()
    fc.close()
  } catch {
    // Handle any socket-related exceptions in FileClient
    case e: Exception => {
      logError("Shuffle copy of block " + blockId + " from " + host + ":" + port + " failed", e)
      handler.handleError(blockId)
    }
  }
}

整个shuffle write + fetch过程分析完毕。

Reference:

[1] http://jerryshao.me/architecture/2014/01/04/spark-shuffle-detail-investigation/

[2] http://www.uml.org.cn/sjjm/201411104.asp?artid=15468

END

时间: 2024-10-04 11:37:22

Apache Spark-1.0.0浅析(十一):Shuffle过程的相关文章

Apache Spark 2.2.0 中文文档 - Spark Streaming 编程指南 | ApacheCN

Spark Streaming 编程指南 概述 一个入门示例 基础概念 依赖 初始化 StreamingContext Discretized Streams (DStreams)(离散化流) Input DStreams 和 Receivers(接收器) DStreams 上的 Transformations(转换) DStreams 上的输出操作 DataFrame 和 SQL 操作 MLlib 操作 缓存 / 持久性 Checkpointing Accumulators, Broadcas

Apache Spark 2.2.0 中文文档 - Spark SQL, DataFrames and Datasets Guide | ApacheCN

Spark SQL, DataFrames and Datasets Guide Overview SQL Datasets and DataFrames 开始入门 起始点: SparkSession 创建 DataFrames 无类型的Dataset操作 (aka DataFrame 操作) Running SQL Queries Programmatically 全局临时视图 创建Datasets RDD的互操作性 使用反射推断Schema 以编程的方式指定Schema Aggregatio

Apache Spark 2.2.0 中文文档 - Spark RDD(Resilient Distributed Datasets)

Spark RDD(Resilient Distributed Datasets)论文 概要 1: 介绍 2: Resilient Distributed Datasets(RDDs) 2.1 RDD 抽象 2.2 Spark 编程接口 2.2.1 例子 – 监控日志数据挖掘 2.3 RDD 模型的优势 2.4 不适合用 RDDs 的应用 3 Spark 编程接口 3.1 Spark 中 RDD 的操作 3.2 举例应用 3.2.1 线性回归 3.2.2 PageRank 4 表达 RDDs 5

Apache Spark 1.5.0正式发布

Spark 1.5.0是1.x线上的第6个发行版.这个版本共处理了来自230+contributors和80+机构的1400+个patches.Spark 1.5的许多改变都是围绕在提升Spark的性能.可用性以及操作稳定性.Spark 1.5.0焦点在Tungsten项目,它主要是通过对低层次的组建进行优化从而提升Spark的性能.Spark 1.5版本为Streaming增加了operational特性,比如支持backpressure.另外比较重要的更新就是新增加了一些机器学习算法和工具,

Apache Spark 2.2.0 中文文档 - SparkR (R on Spark) | ApacheCN

SparkR (R on Spark) 概述 SparkDataFrame 启动: SparkSession 从 RStudio 来启动 创建 SparkDataFrames 从本地的 data frames 来创建 SparkDataFrames 从 Data Sources(数据源)创建 SparkDataFrame 从 Hive tables 来创建 SparkDataFrame SparkDataFrame 操作 Selecting rows(行), columns(列) Groupin

Apache Spark 2.2.0 中文文档 - Structured Streaming 编程指南 | ApacheCN

Structured Streaming 编程指南 概述 快速示例 Programming Model (编程模型) 基本概念 处理 Event-time 和延迟数据 容错语义 API 使用 Datasets 和 DataFrames 创建 streaming DataFrames 和 streaming Datasets Input Sources (输入源) streaming DataFrames/Datasets 的模式接口和分区 streaming DataFrames/Dataset

Apache Spark 2.2.0 中文文档 - GraphX Programming Guide | ApacheCN

GraphX Programming Guide 概述 入门 属性 Graph 示例属性 Graph Graph 运算符 运算符的汇总表 Property 运算符 Structural 运算符 Join 运算符 邻域聚合 聚合消息 (aggregateMessages) Map Reduce Triplets Transition Guide (Legacy) 计算级别信息 收集相邻点 Caching and Uncaching Pregel API Graph 建造者 Vertex and E

Apache Spark 2.2.0 中文文档 - Submitting Applications | ApacheCN

Submitting Applications 在 script in Spark的 bin 目录中的spark-submit 脚本用与在集群上启动应用程序.它可以通过一个统一的接口使用所有 Spark 支持的 cluster managers,所以您不需要专门的为每个cluster managers配置您的应用程序. 打包应用依赖 如果您的代码依赖了其它的项目,为了分发代码到 Spark 集群中您将需要将它们和您的应用程序一起打包.为此,创建一个包含您的代码以及依赖的 assembly jar

Apache Spark 2.2.0新特性介绍(转载)

这个版本是 Structured Streaming 的一个重要里程碑,因为其终于可以正式在生产环境中使用,实验标签(experimental tag)已经被移除.在流系统中支持对任意状态进行操作:Apache Kafka 0.10 的 streaming 和 batch API支持读和写操作.除了在 SparkR, MLlib 和 GraphX 里面添加新功能外,该版本更多的工作在系统的可用性(usability).稳定性(stability)以及代码的润色(polish)并解决了超过 110

Apache Spark 2.2.0 中文文档 - 集群模式概述 | ApacheCN

集群模式概述 该文档给出了 Spark 如何在集群上运行.使之更容易来理解所涉及到的组件的简短概述.通过阅读 应用提交指南 来学习关于在集群上启动应用. 组件 Spark 应用在集群上作为独立的进程组来运行,在您的 main 程序中通过 SparkContext 来协调(称之为 driver 程序). 具体的说,为了运行在集群上,SparkContext 可以连接至几种类型的 Cluster Manager(既可以用 Spark 自己的 Standlone Cluster Manager,或者