Spark【Spark SQL(四)UDF函数和UDAF函数】

2023-10-27

UDF 函数

        UDF 是我们用户可以自定义的函数,我们通过SparkSession对象来调用 udf 的 register(name:String,func(A1,A2,A3...)) 方法来注册一个我们自定义的函数。其中,name 是我们自定义的函数名称,func 是我们自定义的函数,它可以有很多个参数。

        通过 UDF 函数,我们可以针对某一列数据或者某单元格数据进行针对的处理。

案例 1

定义一个函数,给 Andy 的 name 字段的值前 + "Name: "。

def main(args: Array[String]): Unit = {

    val conf = new SparkConf()
    conf.setAppName("spark sql udf")
      .setMaster("local[*]")
    val spark = SparkSession.builder().config(conf).getOrCreate()
    import spark.implicits._

    val df = spark.read.json("data/sql/people.json")
    df.createOrReplaceTempView("people")

    spark.udf.register("prefixName",(name:String)=>{
      if (name.equals("Andy"))
        "Name: " + name
      else
        name
    })
    spark.sql("select prefixName(name) as name,age,sex from people").show()

    spark.stop()
  }

        这里我们定义了一个自定义的 UDF 函数:prefixName,它会判断name字段的值是否为 "Andy",如果是,就会在她的值前+"Name: "。

运行结果:

+----------+---+---+
|      name|age|sex|
+----------+---+---+
|   Michael| 30| 男|
|Name: Andy| 19| 女|
|    Justin| 19| 男|
|Bernadette| 20| 女|
|  Gretchen| 23| 女|
|     David| 27| 男|
|    Joseph| 33| 女|
|     Trish| 27| 女|
|      Alex| 33| 女|
|       Ben| 25| 男|
+----------+---+---+

UDAF 函数

        强类型的DataSet和弱类型的DataFrame都提供了相关聚合函数,如count、countDistinct、avg、max、min。

        UDAF 也就是我们用户的自定义聚合函数。聚合函数就比如 avg、sum这种函数,需要先把所有数据放到一起(缓冲区),再进行统一处理的一个函数。

        实现 UDAF 函数需要有我们自定义的聚合函数的类(主要任务就是计算),我们可以继承 UserDefinedAggregateFunction,并实现里面的八种方法,来实现弱类型的聚合函数。(Spark3.0之后就不推荐使用了,更加推荐强类型的聚合函数)

        我们可以继承Aggregator来实现强类型的聚合函数。

案例1 - 平均年龄

case 类可以直接构建对象,不需要new,因为样例类可以自动生成它的伴生对象和apply方法。

弱类型实现

import org.apache.spark.SparkConf
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, IntegerType, LongType, StructField, StructType}

/**
 * 弱类型
 */
object UDAFTest01 {
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf()
    conf.setAppName("spark sql udaf")
      .setMaster("local[*]")
    val spark = SparkSession.builder().config(conf).getOrCreate()
    import spark.implicits._

    val df = spark.read.json("data/sql/people.json")
    df.createOrReplaceTempView("people")

    spark.udf.register("avgAge",new MyAvgUDAF())

    spark.sql("select avgAge(age) from people").show()

    spark.stop()
  }
}
class MyAvgUDAF extends UserDefinedAggregateFunction{

  // 输入数据的结构 IN
  override def inputSchema: StructType = {
   StructType(
     Array(StructField("age",LongType))
   )}

  // 缓冲区数据的结构 BUFFER
  override def bufferSchema: StructType = {
    StructType(
      Array(
        StructField("total",LongType),
        StructField("count",LongType)
      )
    )}

  // 函数计算结果的数据类型 OUT
  override def dataType: DataType = LongType

  // 函数的稳定性 (传入相同的参数结果是否相同)
  override def deterministic: Boolean = true

  // 缓冲区初始化
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    //这两种写法都一样
//    buffer(0) = 0L
//    buffer(1) = 0L
    //第二种方法
    buffer.update(0,0L) //total 给缓冲区的第0个数据结构-total-初始化赋值0L
    buffer.update(1,0L) //count 给缓冲区的第1个数据结构-count-初始化赋值0L
  }

  // 数据过来之后 如何更新缓冲区
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    // 第一个参数代表缓冲区的第i个数据结构 0代表total 1代表count
    // 第二个参数是对第一个参数的数据结构进行重新赋值
    // buffer.getLong(0)是取出缓冲区第0个值-也就是total的值,给它+上输入的值中的第0个值(因为我们输入结构只有一个就是age:Long)
    buffer.update(0,buffer.getLong(0)+input.getLong(0))
    buffer.update(1,buffer.getLong(1)+1)  //count 每次数据过来+1
  }

  // 多个缓冲区数据合并
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1.update(0,buffer1.getLong(0)+buffer2.getLong(0))
    buffer1.update(1,buffer1.getLong(1)+buffer2.getLong(1))
  }

  // 计算结果操作
  override def evaluate(buffer: Row): Any = {
    buffer.getLong(0)/buffer.getLong(1)
  }
}

运行结果:

+-----------+
|avgage(age)|
+-----------+
|         25|
+-----------+

 

强类型实现

import org.apache.spark.SparkConf
import org.apache.spark.sql.{Encoder, Encoders, Row, SparkSession, functions}
import org.apache.spark.sql.expressions.Aggregator

/**
 * 强类型
 */
object UDAFTest02 {
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf()
    conf.setAppName("spark sql udaf")
      .setMaster("local[*]")
    val spark = SparkSession.builder().config(conf).getOrCreate()
    import spark.implicits._

    val df = spark.read.json("data/sql/people.json")
    df.createOrReplaceTempView("people")

    spark.udf.register("avgAge",functions.udaf(new MyAvg_UDAF()))

    spark.sql("select avgAge(age) from people").show()

    spark.stop()
  }
}

/**
 * 自定义聚合函数类:
 *  1.继承org.apache.spark.sql.expressions.Aggregator,定义泛型:
 *    IN  : 输入数据类型 Long
 *    BUF : 缓冲区数据类型
 *    OUT : 输出数据类型 Long
 *  2.重写方法
 */
//样例类中的参数默认是 val 所以这里必须指定为var
case class Buff(var total: Long,var count: Long)
class MyAvg_UDAF extends Aggregator[Long,Buff,Long]{

  // zero: Buff zero代表这个方法是用来初始值(0值)
  // Buff是我们的case类 也就是说明这里是用来给 缓冲区进行初始化
  override def zero: Buff = {
    Buff(0L,0L)
  }

  // 根据输入数据更新缓冲区 要求返回-Buff
  override def reduce(buff: Buff, in: Long): Buff = {
    buff.total += in
    buff.count += 1
    buff
  }

  // 合并缓冲区 同样返回buff1
  override def merge(buff1: Buff, buff2: Buff): Buff = {
    buff1.total += buff2.total
    buff1.count += buff2.count
    buff1
  }

  // 计算结果
  override def finish(buff: Buff): Long = {
    buff.total/buff.count
  }

  // 网络传输需要序列化 缓冲区的编码操作 -编码
  override def bufferEncoder: Encoder[Buff] = Encoders.product

  // 输出的编码操作 -解码
  override def outputEncoder: Encoder[Long] = Encoders.scalaLong
}

运行结果:

+-----------+
|avgage(age)|
+-----------+
|         25|
+-----------+

 

早期UDAF强类型聚合函数

SQL:结构化数据查询 & DSL:面向对象查询(有对象有方法,与类型相关,所以通过DSL语句结合起来使用)

早期的UDAF强类型聚合函数使用DSL操作。

定义一个case类对应数据类型,然后通过as[对象]方法将DataFrame转为DataSet类型,然后将我们的UDAF聚合类转为列对象。

import org.apache.spark.SparkConf
import org.apache.spark.sql.{Dataset, Encoder, Encoders, Row, SparkSession, TypedColumn, functions}
import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, IntegerType, LongType, StructField, StructType}

/**
 * 早期的UDAF强类型聚合函数使用DSL操作
 */
object UDAFTest03 {

  def main(args: Array[String]): Unit = {
    val conf = new SparkConf()
    conf.setAppName("spark sql udaf")
      .setMaster("local[*]")
    val spark = SparkSession.builder().config(conf).getOrCreate()
    import spark.implicits._

    val df = spark.read.json("data/sql/people.json")

    val ds: Dataset[User] = df.as[User]

    // 将UDAF强类型聚合函数转为查询的类对象
    val udafCol: TypedColumn[User, Long] = new OldAvg_UDAF().toColumn
    ds.select(udafCol).show()

    spark.stop()
  }
}

/**
 * 自定义聚合函数类:
 *  1.继承org.apache.spark.sql.expressions.Aggregator,定义泛型:
 *    IN  : 输入数据类型 User
 *    BUF : 缓冲区数据类型
 *    OUT : 输出数据类型 Long
 *  2.重写方法
 */
//样例类中的参数默认是 val 所以这里必须指定为var
case class User(name: String,age: Long,sex: String)
case class Buff(var total: Long,var count: Long)
class OldAvg_UDAF extends Aggregator[User,Buff,Long]{

  // zero: Buff zero代表这个方法是用来初始值(0值)
  // Buff是我们的case类 也就是说明这里是用来给 缓冲区进行初始化
  override def zero: Buff = {
    Buff(0L,0L)
  }

  // 根据输入数据更新缓冲区 要求返回-Buff
  override def reduce(buff: Buff, in: User): Buff = {
    buff.total += in.age
    buff.count += 1
    buff
  }

  // 合并缓冲区 同样返回buff1
  override def merge(buff1: Buff, buff2: Buff): Buff = {
    buff1.total += buff2.total
    buff1.count += buff2.count
    buff1
  }

  // 计算结果
  override def finish(buff: Buff): Long = {
    buff.total/buff.count
  }

  // 网络传输需要序列化 缓冲区的编码操作 -编码
  override def bufferEncoder: Encoder[Buff] = Encoders.product

  // 输出的编码操作 -解码
  override def outputEncoder: Encoder[Long] = Encoders.scalaLong
}

运行结果:

+------------------------------------------+
|OldAvg_UDAF(com.study.spark.core.sql.User)|
+------------------------------------------+
|                                        25|
+------------------------------------------+

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Spark【Spark SQL(四)UDF函数和UDAF函数】 的相关文章

  • 【Protocol-WebSocket】WebSocket基本概念

    目录 1 概念图 2 定义 3 握手协议 4 优点 5 用途 1 概念图 WebSocket概念图 其中客户端Client 此处是浏览器 服务端Server 此处是给客户端提供资源数据的电脑 2 定义 WebSocket是一种在单个TCP连

随机推荐

  • 伺服控制的三环控制原理及整定仿真和Simulink模型

    伺服控制的三环控制原理及整定仿真和Simulink模型 我们平时使用的工业伺服 通常是成套伺服 即驱动器和电机型号存在配对关系 但有些时候 我们要用电机定转子和编码器制作非成套电机 例如机床上使用的直驱转台 永磁同步电机直接驱动的主轴 这种
  • 理解脸书为何从互联网消失了

    原文 https blog cloudflare com october 2021 facebook outage 译 时序 FaceBook不会宕机 不是吗 我们想了几分钟这个问题 今天2021 10 4 16 51 UTC 我们建了一条
  • mysql:使用已有的记录更新另一条数据

    create table test id integer primary key name varchar 100 例如现在在test表中有N条数据 其中有两条为 12 hello 13 world 如果想把上一条中的数据改成和下一条记录一
  • sql中having,group,select,where,order by,join的执行顺序

    在SQL中执行的顺序 1 先连接from后的数据源 若有join 则先执行on后条件 再连接数据源 2 执行where条件 3 执行group by 4 执行having 5 执行order by 6 输出结果 顺序 FROM ON JOI
  • C++ 指针

    每个变量都有一个内存位置 每一个内存位置都定义了可使用连字号 运算符访问的地址 它表示了在内存中的一个地址 例如输出定义变量的地址 include
  • 【分享之路001】springboot整合双redis配置

    springboot双redis配置 1 背景 springboot项目中本来用到了redis 由于业务要求 需要将数据也写到另一个redis中 2 配置文件改动 2 1 之前redis配置 spring redis database xx
  • 【计算机网络篇】TCP协议

    作者简介 大家好 我是小杨 个人主页 小杨 的csdn博客 希望大家多多支持 一起进步呀 TCP协议 1 TCP 简介 TCP Transmission Control Protocol 是一种在计算机网络中广泛使用的传输层协议 用于在网络
  • [机器学习与scikit-learn-29]:算法-回归-普通线性回归LinearRegression拟合线性分布数据的代码示例

    作者主页 文火冰糖的硅基工坊 文火冰糖 王文兵 的博客 文火冰糖的硅基工坊 CSDN博客 本文网址 目录 第1章 LinearRegression类说明 第2章 LinearRegression使用的代码示例 2 1 导入库 2 2 导数数
  • SSM框架的流程及优点

    SSM框架 SSM Spring SpringMVC MyBatis 框架集由Spring MyBatis两个开源框架整合而成 SpringMVC是Spring中的部分内容 在这个快速发展的互联经济的时代 SSM框架提高了开发人员的工作效率
  • 查询topn的另一种方法通过orderby排序后利用limit来实现

    文章目录 前言 1 热身题实践 其他 前言 一直有个想法 把面试需要的知识点全都总结一下 包括数据库 语言 算法 数据结构等知识 形成一个面试总结笔记 这样以后面试的时候只看这些文章回顾下就行了 今天就先总结下Mysql的面试热身题吧 后续
  • HBase运维中遇到的问题

    1 PleaseHoldException Master is initializing hadoop 3 2 1 hbase 2 2 5 各种配置之后 出现的错误具体为 进去 hbase shell 之后 出现 hbase main 00
  • C#协变

    namespace 协变 public class Animal public string name public Animal string name1 name name1 public class dog Animal public
  • 【vue】vue 中插槽的三种类型:

    文章目录 一 匿名插槽 二 具名插槽 三 作用域插槽 一 匿名插槽
  • CSS学习笔记——搭建京东购物车网页

    大家好 作为一名互联网行业的小白 写博客只是为了巩固自己学习的知识 但由于水平有限 博客中难免会有一些错误出现 有不妥之处恳请各位大佬指点一二 博客主页 链接 https blog csdn net weixin 52720197 spm
  • elasticsearch心得记录-搭建到使用过程中

    1 es评分机制 使用场景 匹配多个关键词的时候 增加其中某个关键词的权重 增加其评分 搜索出来即可排到前面 评分默认为倒叙排 2 es基础的增删改查 搜索 search type dfs query then fetch 每个分片会根据
  • 条码规范——Code 93

    CODE 93 BACKGROUND INFORMATION Code 93 was designed to complement and improve upon Code 39 Code 93 is similar in that it
  • 在线代码测试小项目

    小项目 代码在线测试 http是我们生活中最常使用的协议 现如今网络浏览器越来越贴近人们的生活 使得做什么事都很方便 但是想要运行一段代码还得需要在电脑指定的环境下来运行 这在有些情况下让人很抓狂 我在网上也看到过很多代码在线测试的网页 感
  • 模块打包器Webpack详解!

    Webpack 1 什么是Webpack Webpack 是一个前端资源加载 打包工具 它将根据模块的依赖关系进行静态分析 然后将这些模块按照指定的规则生成对应的静态资源 从图中我们可以看出 Webpack 可以将多种静态资源 js css
  • 关于c语言main函数中int argc,char **argv的理解

    关于c语言main函数中int argc char argv的理解 c语言main函数通常形如int main int argc char argv 那么argc和argv代表啥呢 其实 argc表示传入main函数的参数的个数 而argv
  • Spark【Spark SQL(四)UDF函数和UDAF函数】

    UDF 函数 UDF 是我们用户可以自定义的函数 我们通过SparkSession对象来调用 udf 的 register name String func A1 A2 A3 方法来注册一个我们自定义的函数 其中 name 是我们自定义的函