我看到两种提高算法性能的方法。首先是使用sort https://spark.apache.org/docs/latest/api/scala/org/apache/spark/sql/Dataset.html#sort(sortExprs:org.apache.spark.sql.Column*):org.apache.spark.sql.Dataset%5BT%5D and limit https://spark.apache.org/docs/latest/api/scala/org/apache/spark/sql/Dataset.html#limit(n:Int):org.apache.spark.sql.Dataset%5BT%5D检索前 n 行。第二个是开发您的定制Aggregator https://spark.apache.org/docs/latest/sql-ref-functions-udf-aggregate.html#aggregator-in-buf-out.
排序和限制方法
您对数据框进行排序,然后获取第一个n
rows:
val n: Int = ???
import org.apache.spark.functions.sql.desc
df.orderBy(desc("count")).limit(n)
Spark 通过首先对每个分区执行排序来优化这种转换序列,首先n
每个分区上的行,在最终分区上检索它并重新执行排序并首先获取最终分区n
行。您可以通过执行来检查这一点explain()
关于转变。您将得到以下执行计划:
== Physical Plan ==
TakeOrderedAndProject(limit=3, orderBy=[count#8 DESC NULLS LAST], output=[id#7,count#8])
+- LocalTableScan [id#7, count#8]
并通过观察如何TakeOrderedAndProject
步骤执行于限制.scala https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala在Spark的源代码中(案例类TakeOrderedAndProjectExec
, 方法doExecute
).
自定义聚合器方法
对于自定义聚合器,您创建一个Aggregator
这将填充并更新顶部的有序数组n
rows.
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.Encoder
import scala.collection.mutable.ArrayBuffer
case class Record(id: String, count: Int)
case class TopRecords(limit: Int) extends Aggregator[Record, ArrayBuffer[Record], Seq[Record]] {
def zero: ArrayBuffer[Record] = ArrayBuffer.empty[Record]
def reduce(topRecords: ArrayBuffer[Record], currentRecord: Record): ArrayBuffer[Record] = {
val insertIndex = topRecords.lastIndexWhere(p => p.count > currentRecord.count)
if (topRecords.length < limit) {
topRecords.insert(insertIndex + 1, currentRecord)
} else if (insertIndex < limit - 1) {
topRecords.insert(insertIndex + 1, currentRecord)
topRecords.remove(topRecords.length - 1)
}
topRecords
}
def merge(topRecords1: ArrayBuffer[Record], topRecords2: ArrayBuffer[Record]): ArrayBuffer[Record] = {
val merged = ArrayBuffer.empty[Record]
while (merged.length < limit && (topRecords1.nonEmpty || topRecords2.nonEmpty)) {
if (topRecords1.isEmpty) {
merged.append(topRecords2.remove(0))
} else if (topRecords2.isEmpty) {
merged.append(topRecords1.remove(0))
} else if (topRecords2.head.count < topRecords1.head.count) {
merged.append(topRecords1.remove(0))
} else {
merged.append(topRecords2.remove(0))
}
}
merged
}
def finish(reduction: ArrayBuffer[Record]): Seq[Record] = reduction
def bufferEncoder: Encoder[ArrayBuffer[Record]] = ExpressionEncoder[ArrayBuffer[Record]]
def outputEncoder: Encoder[Seq[Record]] = ExpressionEncoder[Seq[Record]]
}
然后将此聚合器应用到数据帧上,并展平聚合结果:
val n: Int = ???
import sparkSession.implicits._
df.as[Record].select(TopRecords(n).toColumn).flatMap(record => record)
方法比较
为了比较这两种方法,假设我们想要采用顶部n
分布在的数据帧的行p
分区,每个分区有大约k
记录。所以数据框有大小p·k
。这给出了以下复杂性(可能会出现错误):
method |
total number of operations |
memory consumption (on executor) |
memory consumption (on final executor) |
Current code |
O(p·k·log(p·k)) |
-- |
O(p·k) |
Sort and Limit |
O(p·k·log(k) + p·n·log(p·n)) |
O(k) |
O(p·n) |
Custom Aggregator |
O(p·k) |
O(k) + O(n) |
O(p·n) |
所以关于操作次数,定制聚合器是性能最好的。然而,这种方法是迄今为止最复杂的,并且意味着大量的序列化/反序列化,因此它的性能可能低于排序和限制在某些情况下。
结论
您有两种方法可以有效地占据顶部n
rows, 排序和限制 and 定制聚合器。要选择使用哪一种,您应该使用真实的数据帧对这两种方法进行基准测试。如果在基准测试之后排序和限制比 慢一点定制聚合器,我会选择排序和限制因为它的代码更容易维护。