Spark窗口函数

窗口函数引入


使用Spark SQL进行复杂的离线统计任务,有时需要计算一些排序特征、窗口特征等,显然不能简单地通过groupBy来完成,这时就需要了解spark中的窗口函数。

比如下面的统计需求:

  1. 统计订单表,每个店铺每个订单和前一单的价格和。此时如果通过groupBy来完成特别费劲。
  2. 统计订单表,每个店铺每个订单与前一单的差值。此时需要自定义聚合函数。
  3. 还有计算前几单的平均值、计算环比之类的,都要用到窗口函数。

窗口函数的使用

下面以订单表推演spark中窗口函数的使用。

订单表字段:订单id,店铺id,支付时间,支付金额。

1
object WindowFuncApp extends App {
2
  private val spark: SparkSession = SparkSession.builder()
3
    .appName("WindowFuncApp")
4
    .master("local")
5
    .getOrCreate()
6
7
  import spark.implicits._
8
9
  private val ordersDF: DataFrame = Seq(
10
    ("o1", "s1", "2017-05-01", 100),
11
    ("o2", "s1", "2017-05-02", 200),
12
    ("o3", "s2", "2017-05-01", 200),
13
    ("o4", "s1", "2017-05-03", 200),
14
    ("o5", "s2", "2017-05-02", 100),
15
    ("o6", "s1", "2017-05-04", 300)
16
  ).toDF("order_id", "seller_id", "pay_time", "price")
17
  ordersDF.printSchema()
18
19
  //1.店铺订单顺序
20
  import org.apache.spark.sql.functions._
21
  private val rankSpec: WindowSpec = Window.partitionBy("seller_id").orderBy("pay_time")
22
  ordersDF.withColumn("rank", dense_rank.over(rankSpec)).show()
23
24
  private val rankSpec2: WindowSpec = Window.partitionBy("seller_id").orderBy("price")
25
  ordersDF.withColumn("rank2", rank.over(rankSpec2)).show() //1,2,2,4
26
  ordersDF.withColumn("dense_rank2", dense_rank.over(rankSpec2)).show() //1,2,2,3
27
28
  println("****************************************")
29
30
  //定义前一单和本单的窗口
31
  private val winSpec: WindowSpec = Window.partitionBy("seller_id").orderBy("pay_time").rowsBetween(-1, 0)
32
  //本单及前一单的价格和
33
  ordersDF.withColumn("sum_pay", sum("price").over(winSpec)).show()
34
35
  //本单与前一单的平均值,用UDAF
36
  def getAvgUdaf: UserDefinedAggregateFunction = new MyAverage
37
  ordersDF.withColumn("avg", getAvgUdaf($"price").over(winSpec)).show()
38
  ordersDF.withColumn("avg2", avg("price").over(winSpec)).show()
39
40
  println("****************************************")
41
42
  //每个店铺当前订单与前一单的差值,需要自定义聚合函数,或者lag函数
43
  def getMinusUdaf: UserDefinedAggregateFunction = new MyMinus
44
  ordersDF.withColumn("rank", dense_rank.over(rankSpec))
45
      .withColumn("prePrice", lag("price", 1).over(rankSpec)) //前一行的值
46
      .withColumn("minus", getMinusUdaf($"price").over(winSpec)) //在前面的基础上用UDF也行
47
      .show()
48
  /*
49
  * lag(field, n): 就是取从当前字段往前第n个值,这里是取前一行的值
50
  * first/last(): 提取这个分组特定排序的第一个最后一个,在获取用户退出的时候,你可能会用到
51
  * lag/lead(field, n): lead就是lag相反的操作,这个用于做数据回测特别有用,结果回推条件
52
  * */
53
54
  spark.stop()
55
56
57
  /**
58
    * 自定义聚合函数UDAF
59
    */
60
  class MyAverage extends UserDefinedAggregateFunction {
61
    //继承抽象函数必须实现以下方法
62
63
    //输入参数的数据类型
64
    override def inputSchema: StructType = StructType(StructField("value", LongType) :: Nil)
65
66
    //缓冲区中进行聚合时,所处理的数据的类型
67
    override def bufferSchema: StructType = StructType(StructField("count", LongType) :: StructField("sum", DoubleType) :: Nil)
68
69
    //初始化给定的聚合缓冲区,即聚合缓冲区的零值(缓冲区内的数组和映射仍然是不可变的)
70
    override def initialize(buffer: MutableAggregationBuffer): Unit = {
71
      buffer(0) = 0L //表示次数
72
      buffer(1) = 0.0D //表示总和
73
    }
74
75
    //使用来自input的新输入数据更新给定的聚合缓冲区buffer,每个输入行调用一次
76
    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
77
      if (!input.isNullAt(0)) {
78
        buffer(0) = buffer.getLong(0) + 1L //次数加1
79
        buffer(1) = buffer.getDouble(1) + input.getAs[Long](0).toDouble //求和
80
      }
81
    }
82
83
    //此函数是否始终在相同输入上返回相同的输出
84
    override def deterministic: Boolean = true
85
86
    //合并两个聚合缓冲区并将更新的缓冲区值存储回buffer1
87
    //当我们将两个部分聚合的数据合并在一起时调用此方法
88
    //Spark是分布式的,所以不同的区需要进行合并
89
    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
90
      buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0) //求次数
91
      buffer1(1) = buffer1.getDouble(1) + buffer2.getDouble(1) //求和
92
    }
93
94
    //计算最终的结果
95
    override def evaluate(buffer: Row): Any = {
96
      buffer.getDouble(1) / buffer.getLong(0).toDouble
97
    }
98
99
    //返回值的数据类型
100
    override def dataType: DataType = DoubleType
101
  }
102
103
  class MyMinus extends UserDefinedAggregateFunction {
104
    override def inputSchema: StructType = StructType(StructField("value", LongType) :: Nil)
105
106
    override def bufferSchema: StructType = StructType(StructField("minus", LongType) :: Nil)
107
108
    override def initialize(buffer: MutableAggregationBuffer): Unit = {
109
      buffer(0) = 0L //表示差值
110
    }
111
112
    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
113
      if (!input.isNullAt(0)) {
114
        //输入的后者减去前者
115
        buffer(0) = input.getLong(0) - buffer.getLong(0)
116
      }
117
    }
118
119
    //此函数是否始终在相同输入上返回相同的输出
120
    override def deterministic: Boolean = true
121
122
    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
123
      //分区合并,也是后者减前者
124
      buffer1(0) = buffer2.getLong(0) - buffer1.getLong(0)
125
    }
126
127
    //计算最终的结果
128
    override def evaluate(buffer: Row): Any = {
129
      buffer.getLong(0)
130
    }
131
132
    //返回值的数据类型
133
    override def dataType: DataType = LongType
134
  }
135
}
Author: VinxC
Link: https://vinxikk.github.io/2018/08/10/spark/spark-window-function/
Copyright Notice: All articles in this blog are licensed under CC BY-NC-SA 4.0 unless stating additionally.