Spark 2.3.0 用户自定义聚合函数UserDefinedAggregateFunction和Aggregator

一、无类型的用户自定于聚合函数(Untyped User-Defined Aggregate Functions)

实现无类型的用户自定于聚合函数需要继承抽象类UserDefinedAggregateFunction,并重写该类的8个函数。我们以计算数据类型为Double的列score的平均值为例进行详细说明。score来源于数据文件itemdata.data,格式如下:

0162381440670851711,4,7.0
0162381440670851711,11,4.0
0162381440670851711,32,1.0
0162381440670851711,176,27.0
0162381440670851711,183,11.0
0162381440670851711,184,5.0
0162381440670851711,207,9.0
0162381440670851711,256,3.0
0162381440670851711,258,4.0
0162381440670851711,259,16.0
0162381440670851711,260,8.0
0162381440670851711,261,18.0
0162381440670851711,301,1.0

第一列为user_id,第二列为item_id,第三列为score。

1、inputSchema

定义输入数据的Schema,要求类型是StructType,它的参数是由StructField类型构成的数组。比如这里要定义score列的Schema,首先使用StructField声明score列的名字score_column,数据类型为DoubleType。这里输入只有score这一列,所以StructField构成的数组只有一个元素。如下:

override def inputSchema: StructType = StructType(StructField("score_column",DoubleType)::Nil)

::是Scala中的操作符与Nil空集合操作后生成一个数组。

2、bufferSchema

事实上,计算score的平均值时,需要用到score的总和sum以及score的总个数count这样的中间数据,那么就使用bufferSchema来定义它们。如下:

override def bufferSchema: StructType = StructType(StructField("sum",DoubleType)::StructField("count",LongType)::Nil)

这里StructField类型的数组就有两个元素:数据类型为DoubleType的sum和数据类型为LongType类型的count。

3、dataType

我们需要对自定义聚合函数的最终数据类型进行说明,使用dataType函数。比如计算出的平均score是Double类型,如下定义:

override def dataType: DataType = DoubleType

4、deterministic

deterministic函数用于对输入数据进行一致性检验,是一个布尔值,当为true时,表示对于同样的输入会得到同样的输出。因为对于同样的score输入,肯定要得到相同的score平均值,所以定义为true,如下:

override def deterministic: Boolean = true

5、initialize

initialize用户初始化缓存数据。比如score的缓存数据有两个:sum和count,需要初始化为sum=0.0和count=0L,第一个初始化为Double类型,第二个初始化为长整型。如下:

override def initialize(buffer: MutableAggregationBuffer): Unit = {//sum=0.0buffer(0)=0.0//count=0buffer(1)=0L}

6、update

当有新的输入数据时,update用户更新缓存变量。比如这里当有新的score输入时,需要将它的值更新变量sum中,并将count加1,如下:

    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {//输入非空if(!input.isNullAt(0)){//sum=sum+输入的scorebuffer(0)=buffer.getDouble(0)+input.getDouble(0)//count=count+1buffer(1)=buffer.getLong(1)+1}}

7、merge

merge将更新的缓存变量存入到缓存中。如下:

override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {buffer1(0)=buffer1.getDouble(0)+buffer2.getDouble(0)buffer1(1)=buffer1.getLong(1)+buffer2.getLong(1)}

8、evaluate

evaluate是一个计算方法,用于计算我们的最终结果。比如这里用于计算平均得分average(score)=sum(score)/count(score),如下:

override def evaluate(buffer: Row): Double = buffer.getDouble(0)/buffer.getLong(1)

这里我们自定义了一个MyAverage聚合函数用于计算score的平均值,如下:

package com.leboop.rddimport org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._/*** 用户自定义集成算子Demo*/
object MyAverageTest {/*** 读取itemdata.data数据,计算平均score*/class MyAverage extends UserDefinedAggregateFunction{/*** 计算平均score,输入的应该是score这一列数据* StructField定义了列字段的名称score_column,字段的类型Double* StructType要求输入数StructField构成的数组Array,这里只有一列,所以与Nil运算生成Array* @return StructType*/override def inputSchema: StructType = StructType(StructField("score_column",DoubleType)::Nil)/*** 缓存Schema,存储中间计算结果,* 比如计算平均score,需要计算score的总和和score的个数,然后average(score)=sum(score)/count(score)* 所以这里定义了StructType类型:两个StructField字段:sum和count* @return StructType*/override def bufferSchema: StructType = StructType(StructField("sum",DoubleType)::StructField("count",LongType)::Nil)/*** 自定义集成算子最终返回的数据类型* 也就是average(score)的类型,所以是Double* @return DataType 返回数据类型*/override def dataType: DataType = DoubleType/*** 数据一致性检验:对于同样的输入,输出是一样的* @return Boolean true 同样的输入,输出也一样*/override def deterministic: Boolean = true/*** 初始化缓存sum和count* sum=0.0,count=0* @param buffer 中间数据*/override def initialize(buffer: MutableAggregationBuffer): Unit = {//sum=0.0buffer(0)=0.0//count=0buffer(1)=0L}/*** 每次计算更新缓存* @param buffer 缓存数据* @param input 输入数据score*/override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {//输入非空if(!input.isNullAt(0)){//sum=sum+输入的scorebuffer(0)=buffer.getDouble(0)+input.getDouble(0)//count=count+1buffer(1)=buffer.getLong(1)+1}}/*** 将更新后的buffer存储到缓存* @param buffer1 缓存* @param buffer2 更新后的buffer*/override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {buffer1(0)=buffer1.getDouble(0)+buffer2.getDouble(0)buffer1(1)=buffer1.getLong(1)+buffer2.getLong(1)}/*** 计算最终的结果:average(score)=sum(score)/count(score)* @param buffer* @return*/override def evaluate(buffer: Row): Double = buffer.getDouble(0)/buffer.getLong(1)}def main(args: Array[String]): Unit = {//创建Spark SQL切入点val spark = SparkSession.builder().master("local").appName("My-Average").getOrCreate()//注册名为myAverage的自定义集成算子MyAveragespark.udf.register("myAverage",MyAverage)//读取HDFS文件系统数据itemdata.data转换成指定列名的DataFrameval dataDF=spark.read.csv("hdfs://192.168.189.21:8020/input/mahout-demo/itemdata.data").toDF("user_id","item_id","score")//创建临时视图dataDF.createOrReplaceTempView("data")//通过sql计算平均工资spark.sql("select myAverage(score) as average_score from data").show()}
}

程序运行结果

+-----------------+
|    average_score|
+-----------------+
|3.257425742574257|
+-----------------+

二、类型安全的用户自定义聚合函数(Type-Safe User-Defined Aggregate Functions)

实现类型安全的用户自定义聚合函数需要集成org.apache.spark.sql.expressions.Aggregator的Aggregator[K,V,C]抽象类,并且实现该类的6个函数。以上面计算score平均值的例子进行说明,并与无类型的用户自定于聚合函数对比。但是实现需要定义两个case class,如下:


case class Data(user_id: String, item_id: String, score: Double)
case class Average(var sum: Double,var count: Long)

Data用于存储itemdata.data数据,Average用于存储计算score平均值的中间数据,需要注意的是Average的参数sum和count都要声明为变量var。具体如下:

1、zero

zero相当于1中的initialize初始化函数,初始化存储中间数据的Average,如下:

override def zero: Average = Average(0.0D, 0L)
 

2、reduce

reduce函数相当于1中的update函数,当有新的数据a时,更新中间数据b,这里可使用+=复制(因为sum和count都是var),如下:

override def reduce(b: Average, a: Data): Average = {b.sum += a.scoreb.count += 1Lb}

当然三行代码也可以直接写成Average(b.sum+a.score,b.count+1L),这样每次计算都会创建新的对象Average。

3、merge

merge函数同1中的merge函数。如下:

override def merge(b1: Average, b2: Average): Average = {b1.sum+=b2.sumb1.count+= b2.countb1}

4、finish

finish函数同1中的evaluate函数。计算最终的数据。如下:

override def finish(reduction: Average): Double = reduction.sum / reduction.count

5、bufferEncoder

缓冲数据编码方式,如下:

override def bufferEncoder: Encoder[Average] = Encoders.product

6、outputEncoder

最终数据输出编码方式,如下:

override def outputEncoder: Encoder[Double] = Encoders.scalaDouble

整体代码如下:

package com.leboop.rddimport org.apache.spark.sql.{Encoder, Encoders, SparkSession}
import org.apache.spark.sql.expressions.Aggregator/*** 类型安全自定义聚合函数*/
object TypeSafeMyAverageTest {/***Data类存储读取的文件数据*/case class Data(user_id: String, item_id: String, score: Double)//Averagecase class Average(var sum: Double,var count: Long)object SafeMyAverage extends Aggregator[Data, Average, Double] {override def zero: Average = Average(0.0D, 0L)override def reduce(b: Average, a: Data): Average = {b.sum += a.scoreb.count += 1Lb}override def merge(b1: Average, b2: Average): Average = {b1.sum+=b2.sumb1.count+= b2.countb1}override def finish(reduction: Average): Double = reduction.sum / reduction.countoverride def bufferEncoder: Encoder[Average] = Encoders.productoverride def outputEncoder: Encoder[Double] = Encoders.scalaDouble}def main(args: Array[String]): Unit = {//创建Spark SQL切入点val spark = SparkSession.builder().master("local").appName("My-Average").getOrCreate()//读取HDFS文件系统数据itemdata.data生成RDDval rdd = spark.sparkContext.textFile("hdfs://192.168.189.21:8020/input/mahout-demo/itemdata.data")//RDD转化成DataSetimport spark.implicits._val dataDS =rdd.map(_.split(",")).map(d => Data(d(0), d(1), d(2).toDouble)).toDS()//自定义聚合函数val averageScore = SafeMyAverage.toColumn.name("average_score")dataDS.select(averageScore).show()}
}

程序执行结果如下:

+-----------------+
|    average_score|
+-----------------+
|3.257425742574257|
+-----------------+

Spark 2.3.0 用户自定义聚合函数UserDefinedAggregateFunction和Aggregator相关推荐

  1. Spark踩坑填坑-聚合函数-序列化异常

    Spark踩坑填坑-聚合函数-序列化异常 一.Spark聚合函数特殊场景 二.spark sql group by 三.Spark Caused by: java.io.NotSerializable ...

  2. Spark:group by和聚合函数使用

    groupBy分组和使用agg聚合函数demo: df.show +----+-----+---+ |YEAR|MONTH|NUM| +----+-----+---+ |2017| 1| 10| |2 ...

  3. mysql 8 json 支持_体验 MySQL 8.0 JSON聚合函数

    MySQL 最近的动作很快,已经计划推出 8.0 版本,会新增很多新特性 在 5.7 中,JSON 已经被正式支持,但在 SQL 中对 JSON 的处理能力较弱,8.0 中这部分能力会加强,例如新增了 ...

  4. Spark UDAF用户自定义聚合函数

    文章目录 处理流程 弱类型 强类型 UDAF的特点就是:N:1,目的就是为了做聚合(group by) 处理流程 首先准备好数据源: 这里我们人为的将其分为2个分区: 按照group by字段进行分组 ...

  5. 【极简spark教程】spark聚合函数

    聚合函数分为两类,一种是spark内置的常用聚合函数,一种是用户自定义聚合函数 UDAF 不带类型的UDAF[较常用] 继承UserDefinedAggregateFunction 定义输入数据的sc ...

  6. SparkSQL自定义AVG强类型聚合函数与弱类型聚合函数汇总

    AVG是求平均值,所以输出类型是Double类型 1)创建弱类型聚合函数类extends UserDefinedAggregateFunction class MyAgeFunction extend ...

  7. Hive学习---4、函数(单行函数、高级聚合函数、炸裂函数、窗口函数)

    1.函数 1.1 函数简介 Hive会将常用的逻辑封装成函数给用户进行使用,类似java中的函数. 好处:避免用户反复写逻辑,可以直接拿来使用 重点:用户需要知道函数叫什么,能做什么 Hive提供了大 ...

  8. 【大数据】Presto开发自定义聚合函数

    Presto 在交互式查询任务中担当着重要的职责.随着越来越多的人开始使用 SQL 在 Presto 上分析数据,我们发现需要将一些业务逻辑开发成类似 Hive 中的 UDF,提高 SQL 使用人员的 ...

  9. 第三章 SQL聚合函数 COUNT(一)

    文章目录 第三章 SQL聚合函数 COUNT(一) 大纲 参数 描述 没有行返回 流字段 第三章 SQL聚合函数 COUNT(一) 返回表或指定列中的行数的聚合函数. 大纲 COUNT(*)COUNT ...

最新文章

  1. html oninput的作用,html范围滑块 - oninput在IE 11中不起作用
  2. 能量视角下的GAN模型(二):GAN=“分析”+“采样”
  3. mybatis-逻辑翻页
  4. 理解zookeeper的一致性及缺点
  5. Python编写的数字拼图游戏(含爬山算法人机对战功能)
  6. magrittr | R语言的管道操作符
  7. Linux知识积累(2)dirname的使用方法
  8. ContestHunter暑假欢乐赛 SRM 03
  9. 深浅拷贝的使用场景分析
  10. speedoffice表格如何方框内打勾
  11. oracle发生20001,Oracle10g重建EM 报ORA-20001: SYSMAN already exists
  12. Google Play 新增付款功能一览表
  13. LoRa和NB-IoT会长期共存吗?
  14. 视觉slam中的一种单目稠密建图方法
  15. 2019 ICPC 南京 F题 Paper Grading
  16. 嘟噜噜的难受伴快乐的一天。
  17. C语言图形编程--俄罗斯方块制作(二)源代码
  18. CentOS 7.8 (2003) 发布,附下载地址
  19. w ndows11如何设置电源选项,2018年度巨献(4):11款650W全模组80Plus金牌+电源横评
  20. 区块链学习笔记五 BTC网络

热门文章

  1. <Zhuuu_ZZ>Linux远程连接
  2. cocos creator 碰撞检测系统collider
  3. python实现一个秒表(可用于跑步比赛记录名次与时间)
  4. python pandas 官网_Pandas 最详细教程
  5. Kotlin系列二:面向对象编程(类与对象)
  6. Java实现矩阵对角线元素之和
  7. MATLAB学习(1)
  8. ISP算法学习之LSC(镜头阴影校正)
  9. python语句学习系列(1)--print()输出结果不全,部分内容省略问题
  10. MyJupyter,一款支持Python和Java的可移动Jupyter软件包