Hive通用型自定义聚合函数(UDAF)

在使用hive进行数据处理时,经常会用到group by语法,但对分组的操作,hive没有mysql支持得好:

group_concat([DISTINCT] 要连接的字段 [Order BY
ASC/DESC 排序字段] [Separator ‘分隔符‘])

hive只有一个collect_set内置函数,返回去重后的元素数组,但我们可以通过编写UDAF,来实现想要的功能。

编写通用型UDAF需要两个类:解析器和计算器。解析器负责UDAF的参数检查,操作符的重载以及对于给定的一组参数类型来查找正确的计算器,建议继承AbstractGenericUDAFResolver类,具体实现如下:

  @Override
    public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters)
            throws SemanticException {
        if (parameters.length != 1) {
            throw new UDFArgumentTypeException(parameters.length - 1, "Exactly one argument is expected.");
        }
        if (parameters[0].getCategory() != ObjectInspector.Category.PRIMITIVE) {
            throw new UDFArgumentTypeException(0, "Only primitive type arguments are accepted.");
        }
        return new CollectListUDAFEvaluator();
    }

计算器实现具体的计算逻辑,需要继承GenericUDAFEvaluator抽象类。

计算器有4种模式,由枚举类GenericUDAFEvaluator.Mode定义:

    public static enum Mode {
        PARTIAL1, //从原始数据到部分聚合数据的过程(map阶段),将调用iterate()和terminatePartial()方法。
        PARTIAL2, //从部分聚合数据到部分聚合数据的过程(map端的combiner阶段),将调用merge() 和terminatePartial()方法。
        FINAL,    //从部分聚合数据到全部聚合的过程(reduce阶段),将调用merge()和 terminate()方法。
        COMPLETE  //从原始数据直接到全部聚合的过程(表示只有map,没有reduce,map端直接出结果),将调用merge() 和 terminate()方法。
    };

计算器必须实现的方法:

1、getNewAggregationBuffer():返回存储临时聚合结果的AggregationBuffer对象。

2、reset(AggregationBuffer agg):重置聚合结果对象,以支持mapper和reducer的重用。

3、iterate(AggregationBuffer agg,Object[] parameters):迭代处理原始数据parameters并保存到agg中。

4、terminatePartial(AggregationBuffer agg):以持久化的方式返回agg表示的部分聚合结果,这里的持久化意味着返回值只能Java基础类型、数组、基础类型包装器、Hadoop的Writables、Lists和Maps。

5、merge(AggregationBuffer agg,Object partial):合并由partial表示的部分聚合结果到agg中。

6、terminate(AggregationBuffer agg):返回最终结果。

通常还需要覆盖初始化方法ObjectInspector init(Mode m,ObjectInspector[] parameters),需要注意的是,在不同的模式下parameters的含义是不同的,比如m为
PARTIAL1 和 COMPLETE 时,parameters为原始数据;m为 PARTIAL2 和 FINAL 时,parameters仅为部分聚合数据(只有一个元素)。在 PARTIAL1 和 PARTIAL2 模式下,ObjectInspector  用于terminatePartial方法的返回值,在FINAL和COMPLETE模式下ObjectInspector 用于terminate方法的返回值。

下面实现一个计算器,按分组中元素的出现次数降序排序,并将每个元素的在分组中的出现次数也一起返回,格式为:

[data1, num1, data2, num2, ...]

  public static class CollectListUDAFEvaluator extends GenericUDAFEvaluator {
        protected PrimitiveObjectInspector inputKeyOI;
        protected StandardListObjectInspector loi;
        protected StandardListObjectInspector internalMergeOI;
        @Override
        public ObjectInspector init(Mode m, ObjectInspector[] parameters)
                throws HiveException {
            super.init(m, parameters);
            if (m == Mode.PARTIAL1) {
                inputKeyOI = (PrimitiveObjectInspector) parameters[0];
                return ObjectInspectorFactory.getStandardListObjectInspector(
                        ObjectInspectorUtils.getStandardObjectInspector(inputKeyOI));
            } else {
                if ( parameters[0] instanceof StandardListObjectInspector ) {
                    internalMergeOI = (StandardListObjectInspector) parameters[0];
                    inputKeyOI = (PrimitiveObjectInspector) internalMergeOI.getListElementObjectInspector();
                    loi = (StandardListObjectInspector) ObjectInspectorUtils.getStandardObjectInspector(internalMergeOI);
                    return loi;
                } else {
                    inputKeyOI = (PrimitiveObjectInspector) parameters[0];
                    return ObjectInspectorFactory.getStandardListObjectInspector(
                            ObjectInspectorUtils.getStandardObjectInspector(inputKeyOI));
                }
            }
        }

        static class MkListAggregationBuffer implements AggregationBuffer {
            List<Object> container = Lists.newArrayList();
        }
        @Override
        public void reset(AggregationBuffer agg) throws HiveException {
            ((MkListAggregationBuffer) agg).container.clear();
        }
        @Override
        public AggregationBuffer getNewAggregationBuffer() throws HiveException {
            MkListAggregationBuffer ret = new MkListAggregationBuffer();
            return ret;
        }
        @Override
        public void iterate(AggregationBuffer agg, Object[] parameters)
                throws HiveException {
            if(parameters == null || parameters.length != 1){
                return;
            }
            Object key = parameters[0];
            if (key != null) {
                MkListAggregationBuffer myagg = (MkListAggregationBuffer) agg;
                putIntoList(key, myagg.container);
            }
        }

        private void putIntoList(Object key, List<Object> container) {
            Object pCopy = ObjectInspectorUtils.copyToStandardObject(key,  this.inputKeyOI);
            container.add(pCopy);
        }

        @Override
        public Object terminatePartial(AggregationBuffer agg)
                throws HiveException {
            MkListAggregationBuffer myagg = (MkListAggregationBuffer) agg;
            List<Object> ret = Lists.newArrayList(myagg.container);
            return ret;
        }
        @Override
        public void merge(AggregationBuffer agg, Object partial)
                throws HiveException {
            if(partial == null){
                return;
            }
            MkListAggregationBuffer myagg = (MkListAggregationBuffer) agg;
            List<Object> partialResult = (List<Object>) internalMergeOI.getList(partial);
            for (Object ob: partialResult) {
                putIntoList(ob, myagg.container);
            }
            return;
        }

        @Override
        public Object terminate(AggregationBuffer agg) throws HiveException {
            MkListAggregationBuffer myagg = (MkListAggregationBuffer) agg;
            Map<Text, Integer> map = Maps.newHashMap();
            for (int i = 0; i< myagg.container.size() ; i++){
                Text key = (Text) myagg.container.get(i);
                if (map.containsKey(key)) {
                    map.put(key, map.get(key) + 1);
                }else{
                    map.put(key, 1);
                }
            }
            List<Map.Entry<Text, Integer>> listData = Lists.newArrayList(map.entrySet());
            Collections.sort(listData, new Comparator<Map.Entry<Text, Integer>>() {
                public int compare(Map.Entry<Text, Integer> o1, Map.Entry<Text, Integer> o2) {
                    if (o1.getValue() < o2.getValue())
                        return 1;
                    else if (o1.getValue() == o2.getValue())
                        return 0;
                    else
                        return -1;
                }
            });

            List<Object> ret =  Lists.newArrayList();
            for(Map.Entry<Text, Integer> entry : listData){
                ret.add(entry.getKey());
                ret.add(new Text(entry.getValue().toString()));
            }
            return ret;
        }
    }

使用方法:

add jar /export/data/hiveUDF.jar;
create temporary function collect_list  as 'com.test.hive.udf.CollectListUDAF';
select id, collect_list(value) from test group by id;

test表中数据为:

+------+-------+
| id   | value |
+------+-------+
|    1 | a     |
|    1 | a     |
|    1 | b     |
|    2 | c     |
|    2 | d     |
|    2 | d     |
+------+-------+

运行结果为:

1    ["a", "2", "b", "1"]
2    ["d", "2", "c", "1"]  

时间: 2024-08-04 21:46:46

Hive通用型自定义聚合函数(UDAF)的相关文章

Hive学习之自定义聚合函数

Hive支持用户自定义聚合函数(UDAF),这种类型的函数提供了更加强大的数据处理功能.Hive支持两种类型的UDAF:简单型和通用型.正如名称所暗示的,简单型UDAF的实现非常简单,但由于使用了反射的原因会出现性能的损耗,并且不支持长度可变的参数列表等特征.而通用型UDAF虽然支持长度可变的参数等特征,但不像简单型那么容易编写. 这篇文章将学习编写UDAF的规则,比如需要实现哪些接口,继承哪些类,定义哪些方法等, 实现通用型UDAF需要编写两个类:解析器和计算器.解析器负责UDAF的参数检查,

pandas rolling对象的自定义聚合函数

pandas rolling对象的自定义聚合函数 计算标准差型的波动率剪刀差 利用自定义的聚合函数, 把它应用到pandas的滚动窗长对象上, 可以求出 标准差型的波动率剪刀差 代码 def volat_diff(roc1_rolling, center=-0.001, nSD=5): '''计算: 标准差型波动率剪刀差 参数: roc1_rolling: 滚动窗长里的roc1 center: roc1(1日波动率)的平均值 nSD: 求标准差时用的窗长 用法: 1. rolling.apply

SQL Server 自定义聚合函数

说明:本文依据网络转载整理而成,因为时间关系,其中原理暂时并未深入研究,只是整理备份留个记录而已. 目标:在SQL Server中自定义聚合函数,在Group BY语句中 ,不是单纯的SUM和MAX等运算,可以加入拼接字符串. 环境: 1:Sqlserver 2008 R2 2:Visual Studio 2013 第一部分: .net代码: using System; using System.Data; using Microsoft.SqlServer.Server; using Syst

SQL SERVER 2005允许自定义聚合函数-表中字符串分组连接

不多说了,说明后面是完整的代码,用来将字符串型的字段的各行的值拼成一个大字符串,也就是通常所说的Concat 例如有如下表dict  ID  NAME  CATEGORY  1 RED  COLOR   2 BLUE COLOR  3 APPLE  FRUIT  4 ORANGE FRUIT 执行SQL语句:select category,dbo.concatenate(name) as names from dict group by category. 得到结果表如下  category  

sql server 2012 自定义聚合函数(MAX_O3_8HOUR_ND) 计算最大的臭氧8小时滑动平均值

采用c#开发dll,并添加到sql server 中. 具体代码,可以用visual studio的向导生成模板. using System; using System.Collections; using System.Data; using Microsoft.SqlServer.Server; using System.Data.SqlTypes; using System.IO; using System.Text; [Serializable] [Microsoft.SqlServer

hive grouping sets 等聚合函数

函数说明: grouping sets 在一个 group by 查询中,根据不同的维度组合进行聚合,等价于将不同维度的 group by 结果集进行 union allcube 根据 group by 的维度的所有组合进行聚合rollup 是 cube 的子集,以最左侧的维度为主,从该维度进行层级聚合. -- grouping sets select order_id, departure_date, count(*) as cnt from ord_test where order_id=4

Hive Sum MAX MIN聚合函数

数据准备cookie1,2015-04-10,1cookie1,2015-04-11,5cookie1,2015-04-12,7cookie1,2015-04-13,3cookie1,2015-04-14,2cookie1,2015-04-15,4cookie1,2015-04-16,4创建数据库及表create database if not exists cookie;use cookie;drop table if exists cookie1;create table cookie1(c

Oracle 自定义聚合函数

create or replace type str_concat_type as object ( cat_string varchar2(4000), static function ODCIAggregateInitialize(cs_ctx In Out str_concat_type) return number, member function ODCIAggregateIterate(self In Out str_concat_type,value in varchar2) re

Hive聚合函数及采样函数详解

 本文主要使用实例对Hive内建的一些聚合函数.分析函数以及采样函数进行比较详细的讲解. 一.基本聚合函数 数据聚合是按照特定条件将数据整合并表达出来,以总结出更多的组信息.Hive包含内建的一些基本聚合函数,如MAX, MIN, AVG等等,同时也通过GROUPING SETS, ROLLUP, CUBE等函数支持更高级的聚合.Hive基本内建聚合函数通常与GROUP BY连用,默认情况下是对整个表进行操作.在使用GROUP BY时,除聚合函数外其他已选择列必须包含在GROUP BY子句中