SparkSQL自定义无类型聚合函数

准备数据:

Michael,3000
Andy,4500
Justin,3500
Betral,4000

一、定义自定义无类型聚合函数

想要自定义无类型聚合函数,那必须得继承org.spark.sql.expressions.UserDefinedAggregateFunction,然后重写父类得抽象变量和成员方法。

package com.cjs

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

object UDFMyAverage extends UserDefinedAggregateFunction{
    //定义输入参数的数据类型
    override def inputSchema: StructType = StructType(StructField("inputColumn", LongType)::Nil)
    //定义缓冲器的数据结构类型,缓冲器用于计算,这里定义了两个数据变量:sum和count
    override def bufferSchema: StructType = StructType(StructField("sum",LongType)::StructField("count",LongType)::Nil)

    //聚合函数返回的数据类型
    override def dataType: DataType = DoubleType

    override def deterministic: Boolean = true
    //初始化缓冲器
    override def initialize(buffer: MutableAggregationBuffer): Unit = {
        //buffer本质上也是一个Row对象,所以也可以使用下标的方式获取它的元素
        buffer(0) = 0L  //这里第一个元素是上面定义的sum
        buffer(1) = 0L  //这里第二个元素是上面定义的sount
    }

    //update方法用于将输入数据跟缓冲器数据进行计算,这里是一个累加的作用
    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
        buffer(0) = buffer.getLong(0) + input.getLong(0)
        buffer(1) = buffer.getLong(1) + 1
    }

    //buffer1是主缓冲器,储存的是目前各个节点的部分计算结果;buffer2是分布式中执行任务的各个节点的“主”缓冲器;
    // merge方法作用是将各个节点的计算结果做一个聚合,其实可以理解为分布式的update的方法,buffer2相当于input:Row
    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
        buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
        buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
    }

    //计算最终结果
    override def evaluate(buffer: Row): Any = {
        buffer.getLong(0).toDouble/buffer.getLong(1)
    }
}

二、使用自定义无类型聚合函数

package com.cjs

import org.apache.log4j.{Level, Logger}
import org.apache.spark.SparkConf
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.types.{StringType, StructField, StructType}

object TestMyAverage {
    def main(args: Array[String]): Unit = {
        Logger.getLogger("org.apache.spark").setLevel(Level.ERROR)

        val conf = new SparkConf()
            .set("spark.some.config.option","some-value")
            .set("spark.sql.warehouse.dir","file:///e:/tmp/spark-warehouse")

        val ss = SparkSession
            .builder()
            .config(conf)
            .appName("test-myAverage")
            .master("local[2]")
            .getOrCreate()

        import ss.implicits._
        val sc = ss.sparkContext

        val schemaString = "name,salary"
        val fileds = schemaString.split(",").map(filedName => StructField(filedName,StringType, nullable = true))
        val schemaStruct = StructType(fileds)

        val path = "E:\\IntelliJ Idea\\sparkSql_practice\\src\\main\\scala\\com\\cjs\\employee.txt"
        val empRDD = sc.textFile(path).map(_.split(",")).map(row=>Row(row(0),row(1)))

        val empDF = ss.createDataFrame(empRDD,schemaStruct)
        empDF.createOrReplaceTempView("emp")
//        ss.sql("select name, salary from emp limit 5").show()
        //想要在spark sql里使用无类型自定义聚合函数,那么就要先注册给自定义函数
        ss.udf.register("myAverage",UDFMyAverage)

//        empDF.show()
        ss.sql("select myAverage(salary) as average_salary from emp").show()
    }

}

输出结果:

原文地址:https://www.cnblogs.com/SysoCjs/p/11466149.html

时间: 2025-01-09 03:35:46

SparkSQL自定义无类型聚合函数的相关文章

wordpress自定义文章类型capability_type和capabilities参数说明

在wordpress中关于用户权限有三个词:Role.Capabilities.User Levels分别是角色.权限.用户级别的意思,在前面后台制作教程中创建后台菜单的时候提到过有个参数是填写一个Capabilities,但是很多人填写的是role喝user levels. 在wordpress中role-角色很容易理解,就是管理员.订阅者之类的.对于用户层级,wordpress将用户分成了从0到10共11级别,0为最低,10最高,管理员Administrator就是10级别的,具有最高权限,

Hive学习之自定义聚合函数

Hive支持用户自定义聚合函数(UDAF),这种类型的函数提供了更加强大的数据处理功能.Hive支持两种类型的UDAF:简单型和通用型.正如名称所暗示的,简单型UDAF的实现非常简单,但由于使用了反射的原因会出现性能的损耗,并且不支持长度可变的参数列表等特征.而通用型UDAF虽然支持长度可变的参数等特征,但不像简单型那么容易编写. 这篇文章将学习编写UDAF的规则,比如需要实现哪些接口,继承哪些类,定义哪些方法等, 实现通用型UDAF需要编写两个类:解析器和计算器.解析器负责UDAF的参数检查,

spark-sql 自定义函数

(1)自定义UDF object SparkSqlTest { def main(args: Array[String]): Unit = { //屏蔽多余的日志 Logger.getLogger("org.apache.hadoop").setLevel(Level.WARN) Logger.getLogger("org.apache.spark").setLevel(Level.WARN) Logger.getLogger("org.project-s

SQL Server 自定义聚合函数

说明:本文依据网络转载整理而成,因为时间关系,其中原理暂时并未深入研究,只是整理备份留个记录而已. 目标:在SQL Server中自定义聚合函数,在Group BY语句中 ,不是单纯的SUM和MAX等运算,可以加入拼接字符串. 环境: 1:Sqlserver 2008 R2 2:Visual Studio 2013 第一部分: .net代码: using System; using System.Data; using Microsoft.SqlServer.Server; using Syst

oracle中的常用函数、字符串函数、数值类型函数、日期函数,聚合函数。

一.字符串的常用函数. --一.oracle 字符串常用函数 --1. concat 连接字符串的函数,只能连接[两个]字符串. 字符写在括号中,并用逗号隔开! --2.“||”符号可以连接多个字符串 直接用||将多个字符链接即可. --3. dual? dual是一个虚拟表,用来构成select的语法规则,oracle保证dual里面永远只有一条记录. select concat('lo','ve')from dual; select concat('o','k')from dual; sel

sparksql 自定义用户函数(UDF)

自定义用户函数有两种方式,区别:是否使用强类型,参考demo:https://github.com/asker124143222/spark-demo 1.不使用强类型,继承UserDefinedAggregateFunction package com.home.spark import org.apache.spark.SparkConf import org.apache.spark.sql.{DataFrame, Row, SparkSession} import org.apache.

pandas rolling对象的自定义聚合函数

pandas rolling对象的自定义聚合函数 计算标准差型的波动率剪刀差 利用自定义的聚合函数, 把它应用到pandas的滚动窗长对象上, 可以求出 标准差型的波动率剪刀差 代码 def volat_diff(roc1_rolling, center=-0.001, nSD=5): '''计算: 标准差型波动率剪刀差 参数: roc1_rolling: 滚动窗长里的roc1 center: roc1(1日波动率)的平均值 nSD: 求标准差时用的窗长 用法: 1. rolling.apply

sql server 2012 自定义聚合函数(MAX_O3_8HOUR_ND) 计算最大的臭氧8小时滑动平均值

采用c#开发dll,并添加到sql server 中. 具体代码,可以用visual studio的向导生成模板. using System; using System.Collections; using System.Data; using Microsoft.SqlServer.Server; using System.Data.SqlTypes; using System.IO; using System.Text; [Serializable] [Microsoft.SqlServer

SQL SERVER 2005允许自定义聚合函数-表中字符串分组连接

不多说了,说明后面是完整的代码,用来将字符串型的字段的各行的值拼成一个大字符串,也就是通常所说的Concat 例如有如下表dict  ID  NAME  CATEGORY  1 RED  COLOR   2 BLUE COLOR  3 APPLE  FRUIT  4 ORANGE FRUIT 执行SQL语句:select category,dbo.concatenate(name) as names from dict group by category. 得到结果表如下  category