Homework 2: UDF Caching in Spark

为spark编写UDF cache:

作业介绍 https://github.com/cs186-spring15/course/tree/master/hw2

我花了点时间做了下,觉得是学习spark sql和scala的好材料。现在把我写的作业记录如下:

Task #1: Implementing DiskPartition and GeneralDiskHashedRelation

Task #2: Implementing object DiskHashedRelation


package org.apache.spark.sql.execution

import java.io._

import java.nio.file.{Path, StandardOpenOption, Files}

import java.util.{ArrayList => JavaArrayList}

import org.apache.spark.SparkException

import org.apache.spark.sql.catalyst.expressions.{Projection, Row}

import org.apache.spark.sql.execution.CS186Utils._

import scala.collection.JavaConverters._


* This trait represents a regular relation that is hash partitioned and spilled to

* disk.


private[sql] sealed trait DiskHashedRelation {



* @return an iterator of the [[DiskPartition]]s that make up this relation.


def getIterator(): Iterator[DiskPartition]


* Close all the partitions for this relation. This should involve deleting the files hashed into.


def closeAllPartitions()



* A general implementation of [[DiskHashedRelation]].


* @param partitions the disk partitions that we are going to spill to


protected [sql] final class GeneralDiskHashedRelation(partitions: Array[DiskPartition])

extends DiskHashedRelation with Serializable {

override def getIterator() = {





override def closeAllPartitions() = {


partitions.foreach((_: DiskPartition).closePartition())



private[sql] class DiskPartition (

filename: String,

blockSize: Int) {

private val path: Path = Files.createTempFile("", filename)

private val data: JavaArrayList[Row] = new JavaArrayList[Row]

private val outStream: OutputStream = Files.newOutputStream(path)

private val inStream: InputStream = Files.newInputStream(path)

private val chunkSizes: JavaArrayList[Int] = new JavaArrayList[Int]()

private var writtenToDisk: Boolean = false

private var inputClosed: Boolean = false


* This method inserts a new row into this particular partition. If the size of the partition

* exceeds the blockSize, the partition is spilled to disk.


* @param row the [[Row]] we are adding


def insert(row: Row) = {


if (inputClosed) {

throw new SparkException("The partition is closed!")



val partitionDataSize = measurePartitionSize()

if (partitionDataSize >blockSize) {






* This method converts the data to a byte array and returns the size of the byte array

* as an estimation of the size of the partition.


* @return the estimated size of the data


private[this] def measurePartitionSize(): Int = {




* Uses the [[Files]] API to write a byte array representing data to a file.


private[this] def spillPartitionToDisk() = {

val bytes: Array[Byte] = getBytesFromList(data)

// This array list stores the sizes of chunks written in order to read them back correctly.


Files.write(path, bytes, StandardOpenOption.APPEND)

writtenToDisk = true



* If this partition has been closed, this method returns an Iterator of all the

* data that was written to disk by this partition.


* @return the [[Iterator]] of the data


def getData(): Iterator[Row] = {

if (!inputClosed) {

throw new SparkException("Should not be reading from file before closing input. Bad things will happen!")


new Iterator[Row] {

var currentIterator: Iterator[Row] = data.iterator.asScala

val chunkSizeIterator: Iterator[Int] = chunkSizes.iterator().asScala

var byteArray: Array[Byte] = null

override def next() = {





override def hasNext() = {


var hasNext = currentIterator.hasNext


hasNext = chunkSizeIterator.hasNext






// false



* Fetches the next chunk of the file and updates the iterator. Should return true

* unless the iterator is empty.


* @return true unless the iterator is empty.


private[this] def fetchNextChunk(): Boolean = {


if (!chunkSizeIterator.hasNext) {

return false


val size = chunkSizeIterator.next()

if (size <= 0) {

return false


byteArray = CS186Utils.getNextChunkBytes(inStream, size,byteArray)

currentIterator = CS186Utils.getListFromBytes(byteArray).iterator.asScala






* Closes this partition, implying that no more data will be written to this partition. If getData()

* is called without closing the partition, an error will be thrown.


* If any data has not been written to disk yet, it should be written. The output stream should

* also be closed.


def closeInput() = {







inputClosed = true



* Closes this partition. This closes the input stream and deletes the file backing the partition.


private[sql] def closePartition() = {





private[sql] object DiskHashedRelation {


* Given an input iterator, partitions each row into one of a number of [[DiskPartition]]s

* and constructors a [[DiskHashedRelation]].


* This executes the first phase of external hashing -- using a course-grained hash function

* to partition the tuples to disk.


* The block size is approximately set to 64k because that is a good estimate of the average

* buffer page.


* @param input the input [[Iterator]] of [[Row]]s

* @param keyGenerator a [[Projection]] that generates the keys for the input

* @param size the number of [[DiskPartition]]s

* @param blockSize the threshold at which each partition will spill

* @return the constructed [[DiskHashedRelation]]


def apply (

input: Iterator[Row],

keyGenerator: Projection,

size: Int = 64,

blockSize: Int = 64000) = {


val partitionList: JavaArrayList[DiskPartition] = genDiskPartition(size, blockSize)

input.foreach { (row: Row) => {

val rowWithKey = keyGenerator(row)

val index = rowWithKey.hashCode() % size




val partitions: Array[DiskPartition] = partitionList.toArray(new Array[DiskPartition](size))

partitionList.toArray(new Array[DiskPartition](size)).foreach((_: DiskPartition).closeInput())

new GeneralDiskHashedRelation(partitions)



def genDiskPartition(size: Int, blockSize: Int): JavaArrayList[DiskPartition] = {

val partitionList: JavaArrayList[DiskPartition] = new JavaArrayList[DiskPartition]

(0 to size-1).foreach { (i: Int) => {

partitionList.add(new DiskPartition("partition" + i, blockSize))






Task #3: Implementing CS186Utils methods


package org.apache.spark.sql.execution

import java.io._

import java.util.{ArrayList => JavaArrayList, HashMap => JavaHashMap}

import org.apache.spark.sql.catalyst.expressions._

object CS186Utils {


* Returns a Scala array that contains the bytes representing a Java ArrayList.


* @param data the Java ArrayList we are converting

* @return an array of bytes


def getBytesFromList(data: JavaArrayList[Row]): Array[Byte] = {

// create a ObjectOutputStream backed by a ByteArrayOutputStream

val bytes = new ByteArrayOutputStream()

val out = new ObjectOutputStream(bytes)

// write the object to the output





// return the byte array




* Converts an array of bytes into a JavaArrayList of type [[Row]].


* @param bytes the input byte array

* @return a [[JavaArrayList]] of Rows


def getListFromBytes(bytes: Array[Byte]): JavaArrayList[Row] = {

val result: JavaArrayList[Row] = new JavaArrayList[Row]()

var temp: JavaArrayList[Row] = null

// create input streams based on the input bytes

val bytesIn = new ByteArrayInputStream(bytes)

var in = new ObjectInputStream(bytesIn)

try {

// read each object in and attempt to interpret it as a JavaArrayList[Row]

while ((temp = in.readObject() match {

case value: JavaArrayList[Row] => value

case _: Throwable => throw new RuntimeException(s"Unexpected casting exception while reading from file.")

}) != null) {

// if it succeeds, add it to the result


// we need to create a new ObjectInputStream for each new object we read because of Java stream quirks

in = new ObjectInputStream(bytesIn)


} catch {

// ObjectInputStream control flow dictates that an EOFException will be thrown when the file is over -- this is expected

case e: EOFException => // do nothing

case other: Throwable => throw other





* Reads the next nextChunkSize bytes from the input stream provided. If the previous array read into is availab

* please provide it so as to avoid allocating new object unless absolutely necessary.


* @param inStream the input stream we are reading from

* @param nextChunkSize the number of bytes to read

* @param previousArray the previous array we read into

* @return


def getNextChunkBytes(inStream: InputStream, nextChunkSize: Int, previousArray: Array[Byte] = null): Array[Byte] = {

var byteArray = previousArray

if (byteArray == null || byteArray.size != nextChunkSize) {

byteArray = new Array[Byte](nextChunkSize)


// Read the bytes in.





* Return a new projection operator.


* @param expressions

* @param inputSchema

* @return


def getNewProjection(

expressions: Seq[Expression],

inputSchema: Seq[Attribute]) = new InterpretedProjection(expressions, inputSchema)


* This function returns the [[ScalaUdf]] from a sequence of expressions. If there is no UDF in the

* sequence of expressions then it returns null. If there is more than one, it returns the one that is

* sequentially last.


* @param expressions

* @return


def getUdfFromExpressions(expressions: Seq[Expression]): ScalaUdf = {


var udf: ScalaUdf = null

expressions.foreach { (expression: Expression) => {

if (expression.isInstanceOf[ScalaUdf]) udf = expression.asInstanceOf[ScalaUdf]






* This function takes a sequence of expressions. If there is no UDF in the sequence of expressions, it does

* a regular projection operation.


* If there is a UDF, then it creates a caching iterator that caches the result of the UDF.


* NOTE: This only works for a single UDF. If there are multiple UDFs, then it will only cache for the last UDF

* and execute all other UDFs regularly.


* @param expressions

* @param inputSchema

* @return


def generateCachingIterator(

expressions: Seq[Expression],

inputSchema: Seq[Attribute]): (Iterator[Row] => Iterator[Row]) = {

// Get the UDF from the expressions.

val udf: ScalaUdf = CS186Utils.getUdfFromExpressions(expressions)

udf match {

/* If there is no UDF, then do a regular projection operation. Note that this is very similar to Project in

basicOperators.scala */

case null => {

{ input =>

val projection = CS186Utils.getNewProjection(expressions, inputSchema)



/* def aaa (input: Iterator[Row]) : Iterator[Row] = {

val projection = CS186Utils.getNewProjection(expressions, inputSchema)





// Otherwise, separate the expressions appropriately and creating a caching iterator.

case u: ScalaUdf => {

val udfIndex: Int = expressions.indexOf(u)

val preUdfExpressions = expressions.slice(0, udfIndex)

val postUdfExpressions = expressions.slice(udfIndex + 1, expressions.size)

CachingIteratorGenerator(udf.children, udf, preUdfExpressions, postUdfExpressions, inputSchema)





object CachingIteratorGenerator {


* This function takes an input iterator and returns an iterator that does in-memory memoization

* as it evaluates the projection operator over each input row. The result is the concatenation of

* the projection of the preUdfExpressions, the evaluation of the udf, and the projection of the

* postUdfExpressions, in that order.


* The UDF should only be evaluated if the inputs to the UDF have never been seen before.


* This method only needs to worry about caching for the UDF that is specifically passed in. If

* there are any other UDFs in the expression lists, then they can and should be evaluated

* without any caching.


* @param cacheKeys the keys on which we will cache -- the inputs to the UDF

* @param udf the udf we are caching for

* @param preUdfExpressions the expressions that come before the UDF in the projection

* @param postUdfExpressions the expressions that come after the UDF in the projection

* @param inputSchema the schema of the rows -- useful for creating projections

* @return


//CachingIteratorGenerator(studentAttributes, udf, Seq(studentAttributes(1)), Seq(), studentAttributes)

//Student(sid: Int, gpa: Float)

def apply(

cacheKeys: Seq[Expression],

udf: ScalaUdf,

preUdfExpressions: Seq[Expression],

postUdfExpressions: Seq[Expression],

inputSchema: Seq[Attribute]): (Iterator[Row] => Iterator[Row]) = {

{ input =>

new Iterator[Row] {

val udfProject = CS186Utils.getNewProjection(Seq(udf), inputSchema)

val cacheKeyProjection = CS186Utils.getNewProjection(udf.children, inputSchema)

val preUdfProjection = CS186Utils.getNewProjection(preUdfExpressions, inputSchema)

val postUdfProjection = CS186Utils.getNewProjection(postUdfExpressions, inputSchema)

val cache: JavaHashMap[Row, Row] = new JavaHashMap[Row, Row]()

def hasNext() = {


val hasNext = input.hasNext





// false


def next() = {


val row = input.next()

// print("(b "+row+",")

val computedKey:Row = cacheKeyProjection(row)

var computedValues: Row = cache.get(computedKey)

/* if(computedValues!=null){

print("effect key"+computedKey +" val "+computedValues)


if(computedValues == null){

val values: JavaArrayList[Any] = new JavaArrayList()

preUdfProjection(row).iterator.foreach { (i: Any) => {




udfProject(row).iterator.foreach { (i: Any) => {




postUdfProjection(row).iterator.foreach { (i: Any) => {




computedValues = Row.fromSeq(values.toArray)



//  print(" a "+computedValues+")")

//  print(cache.size())







Task 4: Implementing PartitionProject



* Licensed to the Apache Software Foundation (ASF) under one or more

* contributor license agreements.  See the NOTICE file distributed with

* this work for additional information regarding copyright ownership.

* The ASF licenses this file to You under the Apache License, Version 2.0

* (the "License"); you may not use this file except in compliance with

* the License.  You may obtain a copy of the License at


*    http://www.apache.org/licenses/LICENSE-2.0


* Unless required by applicable law or agreed to in writing, software

* distributed under the License is distributed on an "AS IS" BASIS,


* See the License for the specific language governing permissions and

* limitations under the License.


package org.apache.spark.sql.execution

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.{SparkEnv, HashPartitioner, SparkConf}

import org.apache.spark.annotation.DeveloperApi

import org.apache.spark.rdd.{RDD, ShuffledRDD}

import org.apache.spark.shuffle.sort.SortShuffleManager

import org.apache.spark.sql.catalyst.ScalaReflection

import org.apache.spark.sql.catalyst.errors._

import org.apache.spark.sql.catalyst.expressions._

import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, OrderedDistribution, SinglePartition, UnspecifiedDistribution}

import org.apache.spark.util.MutablePair

import org.apache.spark.util.collection.ExternalSorter


* :: DeveloperApi ::



case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode {

override def output = projectList.map(_.toAttribute)

@transient lazy val buildProjection = newMutableProjection(projectList, child.output)

def execute() = child.execute().mapPartitions { iter =>

val resuableProjection = buildProjection()





* A projection operator that is tailored to improve performance of UDF execution using

* in-memory memoization.


* NOTE: This assumes that we are only caching for a single UDF. If there are multiple

* UDFs, it will only cache for the last UDF. All other UDFs will be executed regularly.


* Once you have completed implementing the functions in [[CS186Utils]], this operator

* should work.



case class CacheProject(projectList: Seq[Expression], child: SparkPlan) extends UnaryNode {

override def output = child.output

def execute() = {

/* Generate the caching iterator. You should trace this code to understand it!

You have to implement parts of the stack to make this work. */

val generator: (Iterator[Row] => Iterator[Row]) = CS186Utils.generateCachingIterator(projectList, child.output)

/* This is Spark magic. In short, it applies the generator function to each of the slices of an RDD.

For the purposes of CS 186, we will only ever have one slice. */





* A projection operator that is tailor to improve performance of UDF execution by using

* external hashing.


* @param projectList

* @param child



case class PartitionProject(projectList: Seq[Expression], child: SparkPlan) extends UnaryNode {

override def output = child.output

def execute() = {




* This method takes an iterator as an input. It should first partition the whole input to disk.

* It should then read each partition from disk and construct do in-memory memoization over each

* partition to avoid recomputation of UDFs.


* @param input the input iterator

* @return the result of applying the projection


def generateIterator(input: Iterator[Row]): Iterator[Row] = {

// This is the key generator for the course-grained external hashing.

val keyGenerator = CS186Utils.getNewProjection(projectList, child.output)


val hashedRelation: DiskHashedRelation = DiskHashedRelation(input, keyGenerator,4,64000)

val partitions: Iterator[DiskPartition] = hashedRelation.getIterator()

var diskPartition:DiskPartition = null

var cachingIterator: Iterator[Row] =null

new Iterator[Row] {

def hasNext() = {

var hasNext = false

if(cachingIterator != null && cachingIterator.hasNext){

hasNext = true


hasNext = fetchNextPartition


/* else if(cachingIterator != null&& !cachingIterator.hasNext){

hasNext = fetchNextPartition


else if(cachingIterator == null){

hasNext = fetchNextPartition





def next() = {





* This fetches the next partition over which we will iterate or returns false if there are no more partitions

* over which we can iterate.


* @return


private def fetchNextPartition(): Boolean  = {


var hasNext = partitions.hasNext


diskPartition = partitions.next()

val data:Iterator[Row]=diskPartition.getData()


cachingIterator = CS186Utils.generateCachingIterator(projectList, child.output)(data)

hasNext = true


hasNext = false










* :: DeveloperApi ::



case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode {

override def output = child.output

@transient lazy val conditionEvaluator = newPredicate(condition, child.output)

def execute() = child.execute().mapPartitions { iter =>





* :: DeveloperApi ::



case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: SparkPlan)

extends UnaryNode


override def output = child.output

// TODO: How to pick seed?

override def execute() = child.execute().sample(withReplacement, fraction, seed)



* :: DeveloperApi ::



case class Union(children: Seq[SparkPlan]) extends SparkPlan {

// TODO: attributes output by union should be distinct for nullability purposes

override def output = children.head.output

override def execute() = sparkContext.union(children.map(_.execute()))



* :: DeveloperApi ::

* Take the first limit elements. Note that the implementation is different depending on whether

* this is a terminal operator or not. If it is terminal and is invoked using executeCollect,

* this operator uses something similar to Spark‘s take method on the Spark driver. If it is not

* terminal or is invoked using execute, we first take the limit on each partition, and then

* repartition all the data to a single partition to compute the global limit.



case class Limit(limit: Int, child: SparkPlan)

extends UnaryNode {

// TODO: Implement a partition local limit, and use a strategy to generate the proper limit plan:

// partition local limit -> exchange into one partition -> partition local limit again

/** We must copy rows when sort based shuffle is on */

private def sortBasedShuffleOn = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager]

override def output = child.output

override def outputPartitioning = SinglePartition


* A custom implementation modeled after the take function on RDDs but which never runs any job

* locally.  This is to avoid shipping an entire partition of data in order to retrieve only a few

* rows.


override def executeCollect(): Array[Row] = {

if (limit == 0) {

return new Array[Row](0)


val childRDD = child.execute().map(_.copy())

val buf = new ArrayBuffer[Row]

val totalParts = childRDD.partitions.length

var partsScanned = 0

while (buf.size < limit && partsScanned < totalParts) {

// The number of partitions to try in this iteration. It is ok for this number to be

// greater than totalParts because we actually cap it at totalParts in runJob.

var numPartsToTry = 1

if (partsScanned > 0) {

// If we didn‘t find any rows after the first iteration, just try all partitions next.

// Otherwise, interpolate the number of partitions we need to try, but overestimate it

// by 50%.

if (buf.size == 0) {

numPartsToTry = totalParts - 1

} else {

numPartsToTry = (1.5 * limit * partsScanned / buf.size).toInt



numPartsToTry = math.max(0, numPartsToTry)  // guard against negative num of partitions

val left = limit - buf.size

val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)

val sc = sqlContext.sparkContext

val res =

sc.runJob(childRDD, (it: Iterator[Row]) => it.take(left).toArray, p, allowLocal = false)

res.foreach(buf ++= _.take(limit - buf.size))

partsScanned += numPartsToTry


buf.toArray.map(ScalaReflection.convertRowToScala(_, this.schema))


override def execute() = {

val rdd: RDD[_ <: Product2[Boolean, Row]] = if (sortBasedShuffleOn) {

child.execute().mapPartitions { iter =>

iter.take(limit).map(row => (false, row.copy()))


} else {

child.execute().mapPartitions { iter =>

val mutablePair = new MutablePair[Boolean, Row]()

iter.take(limit).map(row => mutablePair.update(false, row))



val part = new HashPartitioner(1)

val shuffled = new ShuffledRDD[Boolean, Row, Row](rdd, part)

shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))





* :: DeveloperApi ::

* Take the first limit elements as defined by the sortOrder. This is logically equivalent to

* having a [[Limit]] operator after a [[Sort]] operator. This could have been named TopK, but

* Spark‘s top operator does the opposite in ordering so we name it TakeOrdered to avoid confusion.



case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) extends UnaryNode {

override def output = child.output

override def outputPartitioning = SinglePartition

val ord = new RowOrdering(sortOrder, child.output)

// TODO: Is this copying for no reason?

override def executeCollect() = child.execute().map(_.copy()).takeOrdered(limit)(ord)

.map(ScalaReflection.convertRowToScala(_, this.schema))

// TODO: Terminal split should be implemented differently from non-terminal split.

// TODO: Pick num splits based on |limit|.

override def execute() = sparkContext.makeRDD(executeCollect(), 1)



* :: DeveloperApi ::

* Performs a sort on-heap.

* @param global when true performs a global sort of all partitions by shuffling the data first

*               if necessary.



case class Sort(

sortOrder: Seq[SortOrder],

global: Boolean,

child: SparkPlan)

extends UnaryNode {

override def requiredChildDistribution =

if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil

override def execute() = attachTree(this, "sort") {

child.execute().mapPartitions( { iterator =>

val ordering = newOrdering(sortOrder, child.output)


}, preservesPartitioning = true)


override def output = child.output



* :: DeveloperApi ::

* Performs a sort, spilling to disk as needed.

* @param global when true performs a global sort of all partitions by shuffling the data first

*               if necessary.



case class ExternalSort(

sortOrder: Seq[SortOrder],

global: Boolean,

child: SparkPlan)

extends UnaryNode {

override def requiredChildDistribution =

if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil

override def execute() = attachTree(this, "sort") {

child.execute().mapPartitions( { iterator =>

val ordering = newOrdering(sortOrder, child.output)

val sorter = new ExternalSorter[Row, Null, Row](ordering = Some(ordering))

sorter.insertAll(iterator.map(r => (r, null)))


}, preservesPartitioning = true)


override def output = child.output



* :: DeveloperApi ::

* Computes the set of distinct input rows using a HashSet.

* @param partial when true the distinct operation is performed partially, per partition, without

*                shuffling the data.

* @param child the input query plan.



case class Distinct(partial: Boolean, child: SparkPlan) extends UnaryNode {

override def output = child.output

override def requiredChildDistribution =

if (partial) UnspecifiedDistribution :: Nil else ClusteredDistribution(child.output) :: Nil

override def execute() = {

child.execute().mapPartitions { iter =>

val hashSet = new scala.collection.mutable.HashSet[Row]()

var currentRow: Row = null

while (iter.hasNext) {

currentRow = iter.next()

if (!hashSet.contains(currentRow)) {









* :: DeveloperApi ::

* Returns a table with the elements from left that are not in right using

* the built-in spark subtract function.



case class Except(left: SparkPlan, right: SparkPlan) extends BinaryNode {

override def output = left.output

override def execute() = {





* :: DeveloperApi ::

* Returns the rows in left that also appear in right using the built in spark

* intersection function.



case class Intersect(left: SparkPlan, right: SparkPlan) extends BinaryNode {

override def output = children.head.output

override def execute() = {





* :: DeveloperApi ::

* A plan node that does nothing but lie about the output of its child.  Used to spice a

* (hopefully structurally equivalent) tree from a different optimization sequence into an already

* resolved tree.



case class OutputFaker(output: Seq[Attribute], child: SparkPlan) extends SparkPlan {

def children = child :: Nil

def execute() = child.execute()


时间: 2024-08-06 12:21:40

