spark编写UDF和UDAF

UDF:

一、编写udf类,在其中定义udf函数

package spark._sql.UDF

import org.apache.spark.sql.functions._

/**
  * AUTHOR Guozy
  * DATE   2019/7/18-9:41
  **/
object udfs {
  def len(str: String): Int = str.length

  def ageThan(age: Int, small: Int): Boolean = age > small

  val ageThaner = udf((age: Int, bigger: Int) => age < bigger)
} 

二、在主方法中进行调用  

package spark._sql

import org.apache.log4j.Logger
import org.apache.spark.sql
import spark._sql.UDF.udfs._
import org.apache.spark.sql.functions._

/**
  * AUTHOR Guozy
  * DATE   2019/7/18-9:42
  **/
object UDFMain {
  val log = Logger.getLogger("UDFMain")

  def main(args: Array[String]): Unit = {
    val ssc = new sql.SparkSession.Builder()
      .master("local[2]")
      .appName(this.getClass.getSimpleName)
      .enableHiveSupport()
      .getOrCreate()

    ssc.sparkContext.setLogLevel("warn")

    val df = ssc.createDataFrame(Seq((22, 1), (24, 1), (11, 2), (15, 2))).toDF("age", "class_id")
    df.createOrReplaceTempView("table")

    ssc.udf.register("len", len _)
    ssc.sql("select age,len(age) as len from table").show(20, false)
    println("=====================================")
    ssc.udf.register("ageThan", ageThan _)
    ssc.sql("select age from table where ageThan(age,15)").show()
    println("=====================================")
    import ssc.implicits._
    val r = ssc.sql("select * from table")
    r.filter(ageThaner($"age", lit(20))).show()
    println("=====================================")

    ssc.stop()
  }
}

  运行结果:

  

  可以看到,以上代码中一共定义了三个不同的udf函数,分别对三个函数进行说明:

  • len(str: String):该函数使用用来获取传入字段的长度,str 即为所需要传入的字段
    •   在使用的时候,需要现将其进行注册并赋予其函数名:ssc.udf.register("len", len _),调用的时候直接在sql语句中通过函数名来进行调用
  • ageThan(age: Int, small: Int):该函数式用来比较传入的age与已有的small大小,返回一个boolean值,该函数需要是用在where条件语句中用来进行过滤使用
    •     在使用的时候,需要现将其进行注册并赋予其函数名:ssc.udf.register("ageThan", ageThan _),调用的时候直接在sql语句中通过函数名来进行调用
  • ageThaner:该函数跟上面两个不同,所谓的不同指的是:
    •   定义方式不同:通过使用org.apache.spark.sql.functions._ 中的udf函数在定义的时候就将其注册好
    • 使用场景不同:使用在dataframe中,用来进行select,filter操作中
    • 对于该函数的第二列来说,如果是常量的话,需要使用org.apache.spark.sql.function._ 中的lit进行包装,不能将常量直接传入,否则,程序不认识该常量会报错,如果是列名的话,则没问题,使用($"colName")方式即可。

UDAF:

  UDAF相对于udf来说稍微麻烦一下,且需要完全理解当中每个函数的含义才可以轻而易举的写出符合自己预期的UDAF函数,

     UDAF需要继承 UserDefinedAggregateFunction ,并且复写当中的方法

方法含义说明:

def inputSchema: StructType =

    StructType(Array(StructField("value", IntegerType)))

  inputSchema用来定义,输入的字段的类型,字段名可以随便定义,这里定义为value,也可以是其他的,不重要,关键是字段类型一定要与所要传入计算的字段进行对应,且必须使用org.apche.spark.sql.type. _ 中的类型

def bufferSchema: StructType = StructType(Array(

    StructField("count", IntegerType), StructField("ages", DoubleType)))

  bufferSchema用来定义生成中间数据的结果类型,例如在求和的时候,要求a+b+c,相加顺序为a+b=ab,ab+c=abc ,ab即为中间结果。

def dataType: DataType = DoubleType

  dataType为函数返回值的类型,例子中,该UDAF最终返回的结果为double类型,这里的类型不能写成double,要写成org.apache.spark.sql.type._支持的类型DoubleType.

 def deterministic: Boolean = true

  daterministic 为代表结果是否为确定性的,也就是说,相同的输入是否有相同的输出。

def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0
    buffer(1) = 0.0
  }

  initalize 初始化中间结果,即count和ages的初始值。

override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0) = buffer.getInt(0) + 1 //更新计数器
    buffer(1) = buffer.getDouble(1) + input.getInt(0) //更新值
  }

  update用来更新中间结果,input为dataframe中的一行,将要合并到buffer中的数据,buffer则为已经进行合并后的中间结果。

def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getInt(0) + buffer2.getInt(0)
    buffer1(1) = buffer1.getDouble(1) + buffer2.getDouble(1)
  }

  merge 合并所有分片的结果,buffer2是一个分片的中间结果,buffer1是整个合并过程中的结果。

def evaluate(buffer: Row): Any = {
    buffer.getDouble(1) / buffer.getInt(0)
  }

  evaluate 函数式真正进行计算的函数,计算返回函数的结果,buffer是merge合并后的结果

案例需求:求分组中age的平均数

  先上代码:

一、定义UDAF函数

package spark._sql.UDAF

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

/**
  * AUTHOR Guozy
  * DATE   2019/7/18-14:47
  **/
class udafs() extends UserDefinedAggregateFunction {

  def inputSchema: StructType =

    StructType(Array(StructField("value", IntegerType)))

  def bufferSchema: StructType = StructType(Array(

    StructField("count", IntegerType), StructField("ages", DoubleType)))

  def dataType: DataType = DoubleType

  def deterministic: Boolean = true

  def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0
    buffer(1) = 0.0
  }

  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0) = buffer.getInt(0) + 1 //更新计数器
    buffer(1) = buffer.getDouble(1) + input.getInt(0) //更新值
  }

  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getInt(0) + buffer2.getInt(0)
    buffer1(1) = buffer1.getDouble(1) + buffer2.getDouble(1)
  }

  def evaluate(buffer: Row): Any = {
    buffer.getDouble(1) / buffer.getInt(0)
  }
}

二、主函数引用:

package spark._sql.UDF

import org.apache.spark.sql
import org.apache.spark.sql.functions._
import spark._sql.UDAF.udafs

/**
  * AUTHOR Guozy
  * DATE   2019/7/19-16:04
  **/
object UDAFMain {
  def main(args: Array[String]): Unit = {
    val ssc = new sql.SparkSession.Builder()
      .master("local[2]")
      .appName(this.getClass.getSimpleName)
      .enableHiveSupport()
      .getOrCreate()

    ssc.sparkContext.setLogLevel("warn")

    val ageDF = ssc.createDataFrame(Seq((22, 1), (24, 1), (11, 2), (15, 2))).toDF("age", "class_id")
    ssc.udf.register("avgage", new udafs)
    ageDF.createOrReplaceTempView("table")
    ssc.sql("select avgage(age) from table group by class_id").show()

    ssc.stop()
  }
}

 运行结果:

  

原文地址:https://www.cnblogs.com/Gxiaobai/p/11219938.html

时间: 2024-08-06 12:21:33

spark编写UDF和UDAF的相关文章

详解Spark sql用户自定义函数:UDF与UDAF

UDAF = USER DEFINED AGGREGATION FUNCTION Spark sql提供了丰富的内置函数供猿友们使用,辣为何还要用户自定义函数呢?实际的业务场景可能很复杂,内置函数hold不住,所以Spark sql提供了可扩展的内置函数接口:哥们,你的业务太变态了,我满足不了你,自己按照我的规范去定义一个sql函数,该怎么折腾就怎么折腾! 例如,MySQL数据库中有一张task表,共两个字段taskid (任务ID)与taskParam(JSON格式的任务请求参数).简单起见,

Hive 10、Hive的UDF、UDAF、UDTF

Hive自定义函数包括三种UDF.UDAF.UDTF UDF(User-Defined-Function) 一进一出 UDAF(User- Defined Aggregation Funcation) 聚集函数,多进一出.Count/max/min UDTF(User-Defined Table-Generating Functions)  一进多出,如lateral view explore() 使用方式 :在HIVE会话中add 自定义函数的jar文件,然后创建function继而使用函数

Spark SQL UDF使用

Spark1.1推出了Uer Define Function功能,用户可以在Spark SQL 里自定义实际需要的UDF来处理数据. 因为目前Spark SQL本身支持的函数有限,一些常用的函数都没有,比如len, concat...etc 但是使用UDF来自己实现根据业务需要的功能是非常方便的. Spark SQL UDF其实是一个Scala函数,被catalyst封装成一个Expression结点,最后通过eval方法计根据当前Row计算UDF的结果,源码分析见:Spark SQL源码分析之

hive自定义函数UDF UDTF UDAF

Hive 自定义函数 UDF UDTF UDAF 1.UDF:用户定义(普通)函数,只对单行数值产生作用: UDF只能实现一进一出的操作. 定义udf 计算两个数最小值 public class Min extends UDF { public Double evaluate(Double a, Double b) { if (a == null) a = 0.0; if (b == null) b = 0.0; if (a >= b) { return b; } else { return a

Hive自定义函数(UDF、UDAF)

当Hive提供的内置函数无法满足你的业务处理需要时,此时就可以考虑使用用户自定义函数. UDF 用户自定义函数(user defined function)–针对单条记录. 创建函数流程 1.自定义一个Java类 2.继承UDF类 3.重写evaluate方法 4.打成jar包 6.在hive执行add jar方法 7.在hive执行创建模板函数 8.hql中使用 Demo01: 自定义一个Java类 package UDFDemo; import org.apache.hadoop.hive.

Spark SQL UDF

目前 Spark SQL 不支持自定义UDF ,底层 SQL 引擎用的 catalyst . 在SqlContext 中 有一个 Analyzer给的一个EmptyFunctionRegistry ,如果 SQL 引擎函数中找不到了,会到这个FunctionRegistry 中找 EmptyFunctionRegistry 中lookup 只是抛出一个异常. 所以自定义了一个 FunctionRegistry ,SqlContext @transient protected[sql]lazyva

Spark之UDF

1 package big.data.analyse.udfudaf 2 3 import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} 4 import org.apache.spark.sql.{Row, SparkSession} 5 6 /** 7 * Created by zhen on 2018/11/25. 8 */ 9 object SparkUdfUdaf { 10 d

Spark SQL UDF示例

UDF即用户自定函数,注册之后,在sql语句中使用. 基于scala-sdk-2.10.7,Spark2.0.0. package UDF_UDAF import java.util import org.apache.spark.sql.{RowFactory, SparkSession} import org.apache.spark.SparkConf import org.apache.spark.sql.api.java.UDF1 import org.apache.spark.sql

spark自定义udf输入类型为array报错

定义udf如下 val list2string = udf { (style: Array[String], num: Array[Long]) => style.zip(num).map(t => t._1 + ":" + t._2).mkString("<br>") } 输入为两个数组,输出为string 报错如下 Caused by: java.lang.ClassCastException: scala.collection.muta