pyspark对应的scala代码PythonRDD类

pyspark jvm端的scala代码PythonRDD

代码版本为 spark 2.2.0

1.PythonRDD.class

这个rdd类型是python能接入spark的关键

//这是一个标准的RDD实现,实现对应的compute,partitioner,getPartitions等方法
//这个PythonRDD就是pyspark里PipelinedRDD里_jrdd属性方法返回的东西
//parent就是PipelinedRDD里传递进来的_prev_jrdd,是最初构建的数据源RDD
private[spark] class PythonRDD(
    parent: RDD[_],  //这个parentRDD是关键,python使用spark的所有数据来源都从这里来的
    func: PythonFunction, //这个是用户实现的python计算逻辑
    preservePartitoning: Boolean)
  extends RDD[Array[Byte]](parent) {

  val bufferSize = conf.getInt("spark.buffer.size", 65536)
  val reuse_worker = conf.getBoolean("spark.python.worker.reuse", true)

  override def getPartitions: Array[Partition] = firstParent.partitions

  override val partitioner: Option[Partitioner] = {
    if (preservePartitoning) firstParent.partitioner else None
  }

  val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)

  override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
    //调用PythonRunner执行此处任务逻辑
    //这里这个PythonRunner跟spark-submit时执行的PythonRunner不是同一个东西
    val runner = PythonRunner(func, bufferSize, reuse_worker)
    //执行runner的计算逻辑,第一个参数是spark数据源rdd的计算结果
    //firstParent.iterator会触发parent 这个rdd的计算,返回计算结果
    //这里第一个参数的rdd跟pyspark中RDD里的_jrdd是同一个东西
    runner.compute(firstParent.iterator(split, context), split.index, context)
  }
}

2.PythonRunner.class

这个类是rdd内部执行计算时的实体计算类,并不是代码提交时那个启动py4j的PythonRunner

/*
 * 这个类做了三件事
 * 1.启动pyspark.daemon 接收task启动work执行接收到的task
 * 2.启动writerThread 将数据源的计算结果写到pyspark.work中
 * 3.从pyspark.work中拉取执行结果
 *
 * writerThread写的数据就是pyspark中_jrdd计算出来的结果,也就是数据源rdd的数据
 */
private[spark] class PythonRunner(
    funcs: Seq[ChainedPythonFunctions],
    bufferSize: Int,
    reuse_worker: Boolean,
    isUDF: Boolean,
    argOffsets: Array[Array[Int]])
  extends Logging {

  require(funcs.length == argOffsets.length, "argOffsets should have the same length as funcs")

  //python执行的环境和命令
  private val envVars = funcs.head.funcs.head.envVars
  private val pythonExec = funcs.head.funcs.head.pythonExec
  private val pythonVer = funcs.head.funcs.head.pythonVer

  private val accumulator = funcs.head.funcs.head.accumulator

  def compute(
      inputIterator: Iterator[_],
      partitionIndex: Int,
      context: TaskContext): Iterator[Array[Byte]] = {
    val startTime = System.currentTimeMillis
    val env = SparkEnv.get
    val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",")
    envVars.put("SPARK_LOCAL_DIRS", localdir) // it‘s also used in monitor thread
    if (reuse_worker) {
      envVars.put("SPARK_REUSE_WORKER", "1")
    }

    //创建pyspark 的work进程,底层执行的是pyspark.daemon
    //这个方法保证一次任务只启动一个pyspark.daemon
    //返回结果是跟work通信用的socket
    //具体分析将在其它部分记录
    val worker: Socket = env.createPythonWorker(pythonExec, envVars.asScala.toMap)
    @volatile var released = false

    // 创建writerThread,把数据源数据写到socket,发送到pyspark.work
    val writerThread = new WriterThread(env, worker, inputIterator, partitionIndex, context)

    //注册task完成监听,完成后停止writerThread线程
    context.addTaskCompletionListener { context =>
      writerThread.shutdownOnTaskCompletion()
      if (!reuse_worker || !released) {
        try {
          worker.close()
        } catch {
          case e: Exception =>
            logWarning("Failed to close worker socket", e)
        }
      }
    }

    writerThread.start()
    new MonitorThread(env, worker, context).start()

    val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize))
    // 创建拉取pyspark.work执行结果的迭代器
    val stdoutIterator = new Iterator[Array[Byte]] {
      override def next(): Array[Byte] = {
        val obj = _nextObj
        if (hasNext) {
          _nextObj = read()
        }
        obj
      }

      private def read(): Array[Byte] = {
        if (writerThread.exception.isDefined) {
          throw writerThread.exception.get
        }
        try {
          stream.readInt() match {
            case length if length > 0 =>
              val obj = new Array[Byte](length)
              stream.readFully(obj)
              obj
            case 0 => Array.empty[Byte]
            case SpecialLengths.TIMING_DATA =>
              // Timing data from worker
              val bootTime = stream.readLong()
              val initTime = stream.readLong()
              val finishTime = stream.readLong()
              val boot = bootTime - startTime
              val init = initTime - bootTime
              val finish = finishTime - initTime
              val total = finishTime - startTime
              logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot,
                init, finish))
              val memoryBytesSpilled = stream.readLong()
              val diskBytesSpilled = stream.readLong()
              context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled)
              context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled)
              read()
            case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
              // Signals that an exception has been thrown in python
              val exLength = stream.readInt()
              val obj = new Array[Byte](exLength)
              stream.readFully(obj)
              throw new PythonException(new String(obj, StandardCharsets.UTF_8),
                writerThread.exception.getOrElse(null))
            case SpecialLengths.END_OF_DATA_SECTION =>
              // We‘ve finished the data section of the output, but we can still
              // read some accumulator updates:
              val numAccumulatorUpdates = stream.readInt()
              (1 to numAccumulatorUpdates).foreach { _ =>
                val updateLen = stream.readInt()
                val update = new Array[Byte](updateLen)
                stream.readFully(update)
                accumulator.add(update)
              }
              // Check whether the worker is ready to be re-used.
              if (stream.readInt() == SpecialLengths.END_OF_STREAM) {
                if (reuse_worker) {
                  env.releasePythonWorker(pythonExec, envVars.asScala.toMap, worker)
                  released = true
                }
              }
              null
          }
        } catch {

          case e: Exception if context.isInterrupted =>
            logDebug("Exception thrown after task interruption", e)
            throw new TaskKilledException(context.getKillReason().getOrElse("unknown reason"))

          case e: Exception if env.isStopped =>
            logDebug("Exception thrown after context is stopped", e)
            null  // exit silently

          case e: Exception if writerThread.exception.isDefined =>
            logError("Python worker exited unexpectedly (crashed)", e)
            logError("This may have been caused by a prior exception:", writerThread.exception.get)
            throw writerThread.exception.get

          case eof: EOFException =>
            throw new SparkException("Python worker exited unexpectedly (crashed)", eof)
        }
      }

      var _nextObj = read()

      override def hasNext: Boolean = _nextObj != null
    }
    //返回这个拉取数据结果的迭代器
    new InterruptibleIterator(context, stdoutIterator)
  }

  /**
   * WriterThread 线程的实现代码
   */
  class WriterThread(
      env: SparkEnv,
      worker: Socket,
      inputIterator: Iterator[_],
      partitionIndex: Int,
      context: TaskContext)
    extends Thread(s"stdout writer for $pythonExec") {

    @volatile private var _exception: Exception = null

    private val pythonIncludes = funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet
    private val broadcastVars = funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala))

    setDaemon(true)

    /** Contains the exception thrown while writing the parent iterator to the Python process. */
    def exception: Option[Exception] = Option(_exception)

    /** Terminates the writer thread, ignoring any exceptions that may occur due to cleanup. */
    def shutdownOnTaskCompletion() {
      assert(context.isCompleted)
      this.interrupt()
    }

    // 主要逻辑在run里,把数据源rdd的执行结果写进去
    // 把广播变量和环境,以及python的执行逻辑代码写进去
    // 把需要计算的数据源数据写进去
    override def run(): Unit = Utils.logUncaughtExceptions {
      try {
        TaskContext.setTaskContext(context)
        val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
        val dataOut = new DataOutputStream(stream)
        // Partition index
        dataOut.writeInt(partitionIndex)
        // Python version of driver
        PythonRDD.writeUTF(pythonVer, dataOut)
        // Write out the TaskContextInfo
        dataOut.writeInt(context.stageId())
        dataOut.writeInt(context.partitionId())
        dataOut.writeInt(context.attemptNumber())
        dataOut.writeLong(context.taskAttemptId())
        // sparkFilesDir
        PythonRDD.writeUTF(SparkFiles.getRootDirectory(), dataOut)
        // Python includes (*.zip and *.egg files)
        dataOut.writeInt(pythonIncludes.size)
        for (include <- pythonIncludes) {
          PythonRDD.writeUTF(include, dataOut)
        }
        // Broadcast variables
        val oldBids = PythonRDD.getWorkerBroadcasts(worker)
        val newBids = broadcastVars.map(_.id).toSet
        // number of different broadcasts
        val toRemove = oldBids.diff(newBids)
        val cnt = toRemove.size + newBids.diff(oldBids).size
        dataOut.writeInt(cnt)
        for (bid <- toRemove) {
          // remove the broadcast from worker
          dataOut.writeLong(- bid - 1)  // bid >= 0
          oldBids.remove(bid)
        }
        for (broadcast <- broadcastVars) {
          if (!oldBids.contains(broadcast.id)) {
            // send new broadcast
            dataOut.writeLong(broadcast.id)
            PythonRDD.writeUTF(broadcast.value.path, dataOut)
            oldBids.add(broadcast.id)
          }
        }
        dataOut.flush()
        // Serialized command:
        if (isUDF) {
          dataOut.writeInt(1)
          dataOut.writeInt(funcs.length)
          funcs.zip(argOffsets).foreach { case (chained, offsets) =>
            dataOut.writeInt(offsets.length)
            offsets.foreach { offset =>
              dataOut.writeInt(offset)
            }
            dataOut.writeInt(chained.funcs.length)
            chained.funcs.foreach { f =>
              dataOut.writeInt(f.command.length)
              dataOut.write(f.command)
            }
          }
        } else {
          dataOut.writeInt(0)
          val command = funcs.head.funcs.head.command
          dataOut.writeInt(command.length)
          dataOut.write(command)
        }
        // Data values
        PythonRDD.writeIteratorToStream(inputIterator, dataOut)
        dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
        dataOut.writeInt(SpecialLengths.END_OF_STREAM)
        dataOut.flush()
      } catch {
        case e: Exception if context.isCompleted || context.isInterrupted =>
          logDebug("Exception thrown after task completion (likely due to cleanup)", e)
          if (!worker.isClosed) {
            Utils.tryLog(worker.shutdownOutput())
          }

        case e: Exception =>
          // We must avoid throwing exceptions here, because the thread uncaught exception handler
          // will kill the whole executor (see org.apache.spark.executor.Executor).
          _exception = e
          if (!worker.isClosed) {
            Utils.tryLog(worker.shutdownOutput())
          }
      }
    }
  }

  // 监控task是不是还在执行
  class MonitorThread(env: SparkEnv, worker: Socket, context: TaskContext)
    extends Thread(s"Worker Monitor for $pythonExec") {

    setDaemon(true)

    override def run() {
      // Kill the worker if it is interrupted, checking until task completion.
      // TODO: This has a race condition if interruption occurs, as completed may still become true.
      while (!context.isInterrupted && !context.isCompleted) {
        Thread.sleep(2000)
      }
      if (!context.isCompleted) {
        try {
          logWarning("Incomplete task interrupted: Attempting to kill Python Worker")
          env.destroyPythonWorker(pythonExec, envVars.asScala.toMap, worker)
        } catch {
          case e: Exception =>
            logError("Exception when trying to kill worker", e)
        }
      }
    }
  }
}

原文地址:https://www.cnblogs.com/cloud-zhao/p/9046850.html

时间: 2024-10-15 00:22:35

pyspark对应的scala代码PythonRDD类的相关文章

Spark GraphX 入门实例完整scala代码

由于天然符合互联网中很多场景的需求,图计算正受到越来越多的青睐.Spark GraphX 是作为 Spark 技术堆栈中的一员,担负起了 Spark 在图计算领域中的重任.网络上已经有很多图计算和 Spark GraphX 的概念介绍,此处就不再赘述. 本文将一篇很好的 Spark GraphX 入门文章中代码块整合为一个完整的可执行类,并加上必要注释以及执行结果,以方便有兴趣的朋友快速从 API 角度了解 Spark GraphX. 本文引用的代码块和多数文字描述均摘引自网文graph-ana

Scala学习(五)---Scala中的类

Scala中的类 摘要: 在本篇中,你将会学习如何用Scala实现类.如果你了解Java或C++中的类,你不会觉得这有多难,并且你会很享受Scala更加精简的表示法带来的便利.本篇的要点包括: 1. 类中的字段自动带有getter方法和setter方法 2. 你可以用定制的getter/setter方法替换掉字段的定义,而不必修改使用类的客户端,这就是所谓的"统一访问原则" 3. 用@BeanProperty注解来生成JavaBeans的getXxx/setXxx()方法 4. 每个类

Scala具体解释---------类

Scala中的类 摘要: 在本篇中.你将会学习怎样用Scala实现类. 假设你了解Java或C++中的类,你不会认为这有多难.而且你会非常享受Scala更加精简的表示法带来的便利.本篇的要点包含: 1. 类中的字段自己主动带有getter方法和setter方法 2. 你能够用定制的getter/setter方法替换掉字段的定义,而不必改动使用类的client,这就是所谓的"统一訪问原则" 3. 用@BeanProperty注解来生成JavaBeans的getXxx/setXxx()方法

Scala中的类

Scala中的类 摘要: 在本篇中,你将会学习如何用Scala实现类.如果你了解Java或C++中的类,你不会觉得这有多难,并且你会很享受Scala更加精简的表示法带来的便利.本篇的要点包括: 1. 类中的字段自动带有getter方法和setter方法 2. 你可以用定制的getter/setter方法替换掉字段的定义,而不必修改使用类的客户端,这就是所谓的"统一访问原则" 3. 用@BeanProperty注解来生成JavaBeans的getXxx/setXxx()方法 4. 每个类

jdb调试scala代码的简单介绍

在linux调试C/C++的代码需要通过gdb,调试java代码呢?那就需要用到jdb工具了.关于jdb的用法在网上大家都可以找到相应的文章,但是对scala进行调试的就比较少了.其实调试的大致流程都是一样的,只是需要注意一些细节的地方就可以了.下面就来简单说明一下.在这里关于scala和jdk的配置问题就不再详述了,下面进入正题吧. 首先放上调试用的scala代码hello.scala object hello { def main(args: Array[String]): Unit = {

一天一段scala代码(五)

一天一段scala代码(四) 为了更好的驾驭spark,最近在学习scala语言特性,主要看<快学scala>,顺便把一些自己认为有用的代码记下来. package examples class Person { val publicVal = 1 //自动生成getter var publicVar = 2 //自动生成getter和setter //自定义getter和setter private var privateAge=0 def age = privateAge //getter

Scala中隐式类代码实战详解之Scala学习笔记-53

package com.leegh.implicits import java.io.Fileimport scala.io.Source /** * @author Guohui Li */object Context_Helper { implicit class FileEnhance(file: File) { def read = Source.fromFile(file.getPath).mkString } implicit class Op(x: Int) { def add(s

Spark进阶视频之Scala中隐式类代码实战详解

王家林亲授<DT大数据梦工厂>大数据实战视频"Scala深入浅出实战经典"视频.音频和PPT下载! 欢迎广大Spark爱好者学习交流.也欢迎广大学习爱好者加入DT大数据梦工厂交流群:462923555DT大数据微信公众账号:DT_Spark 视频观看链接http://www.tudou.com/plcover/Yy5F5gsurSE/ 视频下载地址百度云:http://pan.baidu.com/s/1eQGqzEa腾讯微云:http://url.cn/SshT6b

一天一段scala代码(十四)

为了更好的驾驭spark,最近在学习scala语言特性,主要看<快学scala>,顺便把一些自己认为有用的代码记下来. package examples object Example14 extends App{ var ch='9' var sign= ch match{ case '+' => 1 case '-' => -1 case _ if Character.isDigit(ch) => Character.digit(ch,10) //守卫模式,加上条件 cas