窗口函数引入
使用Spark SQL进行复杂的离线统计任务,有时需要计算一些排序特征、窗口特征等,显然不能简单地通过groupBy来完成,这时就需要了解spark中的窗口函数。
比如下面的统计需求:
- 统计订单表,每个店铺每个订单和前一单的价格和。此时如果通过groupBy来完成特别费劲。
- 统计订单表,每个店铺每个订单与前一单的差值。此时需要自定义聚合函数。
- 还有计算前几单的平均值、计算环比之类的,都要用到窗口函数。
窗口函数的使用
下面以订单表推演spark中窗口函数的使用。
订单表字段:订单id,店铺id,支付时间,支付金额。
1 | object WindowFuncApp extends App { |
2 | private val spark: SparkSession = SparkSession.builder() |
3 | .appName("WindowFuncApp") |
4 | .master("local") |
5 | .getOrCreate() |
6 | |
7 | import spark.implicits._ |
8 | |
9 | private val ordersDF: DataFrame = Seq( |
10 | ("o1", "s1", "2017-05-01", 100), |
11 | ("o2", "s1", "2017-05-02", 200), |
12 | ("o3", "s2", "2017-05-01", 200), |
13 | ("o4", "s1", "2017-05-03", 200), |
14 | ("o5", "s2", "2017-05-02", 100), |
15 | ("o6", "s1", "2017-05-04", 300) |
16 | ).toDF("order_id", "seller_id", "pay_time", "price") |
17 | ordersDF.printSchema() |
18 | |
19 | //1.店铺订单顺序 |
20 | import org.apache.spark.sql.functions._ |
21 | private val rankSpec: WindowSpec = Window.partitionBy("seller_id").orderBy("pay_time") |
22 | ordersDF.withColumn("rank", dense_rank.over(rankSpec)).show() |
23 | |
24 | private val rankSpec2: WindowSpec = Window.partitionBy("seller_id").orderBy("price") |
25 | ordersDF.withColumn("rank2", rank.over(rankSpec2)).show() //1,2,2,4 |
26 | ordersDF.withColumn("dense_rank2", dense_rank.over(rankSpec2)).show() //1,2,2,3 |
27 | |
28 | println("****************************************") |
29 | |
30 | //定义前一单和本单的窗口 |
31 | private val winSpec: WindowSpec = Window.partitionBy("seller_id").orderBy("pay_time").rowsBetween(-1, 0) |
32 | //本单及前一单的价格和 |
33 | ordersDF.withColumn("sum_pay", sum("price").over(winSpec)).show() |
34 | |
35 | //本单与前一单的平均值,用UDAF |
36 | def getAvgUdaf: UserDefinedAggregateFunction = new MyAverage |
37 | ordersDF.withColumn("avg", getAvgUdaf($"price").over(winSpec)).show() |
38 | ordersDF.withColumn("avg2", avg("price").over(winSpec)).show() |
39 | |
40 | println("****************************************") |
41 | |
42 | //每个店铺当前订单与前一单的差值,需要自定义聚合函数,或者lag函数 |
43 | def getMinusUdaf: UserDefinedAggregateFunction = new MyMinus |
44 | ordersDF.withColumn("rank", dense_rank.over(rankSpec)) |
45 | .withColumn("prePrice", lag("price", 1).over(rankSpec)) //前一行的值 |
46 | .withColumn("minus", getMinusUdaf($"price").over(winSpec)) //在前面的基础上用UDF也行 |
47 | .show() |
48 | /* |
49 | * lag(field, n): 就是取从当前字段往前第n个值,这里是取前一行的值 |
50 | * first/last(): 提取这个分组特定排序的第一个最后一个,在获取用户退出的时候,你可能会用到 |
51 | * lag/lead(field, n): lead就是lag相反的操作,这个用于做数据回测特别有用,结果回推条件 |
52 | * */ |
53 | |
54 | spark.stop() |
55 | |
56 | |
57 | /** |
58 | * 自定义聚合函数UDAF |
59 | */ |
60 | class MyAverage extends UserDefinedAggregateFunction { |
61 | //继承抽象函数必须实现以下方法 |
62 | |
63 | //输入参数的数据类型 |
64 | override def inputSchema: StructType = StructType(StructField("value", LongType) :: Nil) |
65 | |
66 | //缓冲区中进行聚合时,所处理的数据的类型 |
67 | override def bufferSchema: StructType = StructType(StructField("count", LongType) :: StructField("sum", DoubleType) :: Nil) |
68 | |
69 | //初始化给定的聚合缓冲区,即聚合缓冲区的零值(缓冲区内的数组和映射仍然是不可变的) |
70 | override def initialize(buffer: MutableAggregationBuffer): Unit = { |
71 | buffer(0) = 0L //表示次数 |
72 | buffer(1) = 0.0D //表示总和 |
73 | } |
74 | |
75 | //使用来自input的新输入数据更新给定的聚合缓冲区buffer,每个输入行调用一次 |
76 | override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { |
77 | if (!input.isNullAt(0)) { |
78 | buffer(0) = buffer.getLong(0) + 1L //次数加1 |
79 | buffer(1) = buffer.getDouble(1) + input.getAs[Long](0).toDouble //求和 |
80 | } |
81 | } |
82 | |
83 | //此函数是否始终在相同输入上返回相同的输出 |
84 | override def deterministic: Boolean = true |
85 | |
86 | //合并两个聚合缓冲区并将更新的缓冲区值存储回buffer1 |
87 | //当我们将两个部分聚合的数据合并在一起时调用此方法 |
88 | //Spark是分布式的,所以不同的区需要进行合并 |
89 | override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { |
90 | buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0) //求次数 |
91 | buffer1(1) = buffer1.getDouble(1) + buffer2.getDouble(1) //求和 |
92 | } |
93 | |
94 | //计算最终的结果 |
95 | override def evaluate(buffer: Row): Any = { |
96 | buffer.getDouble(1) / buffer.getLong(0).toDouble |
97 | } |
98 | |
99 | //返回值的数据类型 |
100 | override def dataType: DataType = DoubleType |
101 | } |
102 | |
103 | class MyMinus extends UserDefinedAggregateFunction { |
104 | override def inputSchema: StructType = StructType(StructField("value", LongType) :: Nil) |
105 | |
106 | override def bufferSchema: StructType = StructType(StructField("minus", LongType) :: Nil) |
107 | |
108 | override def initialize(buffer: MutableAggregationBuffer): Unit = { |
109 | buffer(0) = 0L //表示差值 |
110 | } |
111 | |
112 | override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { |
113 | if (!input.isNullAt(0)) { |
114 | //输入的后者减去前者 |
115 | buffer(0) = input.getLong(0) - buffer.getLong(0) |
116 | } |
117 | } |
118 | |
119 | //此函数是否始终在相同输入上返回相同的输出 |
120 | override def deterministic: Boolean = true |
121 | |
122 | override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { |
123 | //分区合并,也是后者减前者 |
124 | buffer1(0) = buffer2.getLong(0) - buffer1.getLong(0) |
125 | } |
126 | |
127 | //计算最终的结果 |
128 | override def evaluate(buffer: Row): Any = { |
129 | buffer.getLong(0) |
130 | } |
131 | |
132 | //返回值的数据类型 |
133 | override def dataType: DataType = LongType |
134 | } |
135 | } |