Spark:查找前 n 个值的高性能方法

2024-05-18

我有一个很大的数据集,我想找到具有 n 个最高值的行。

id, count
id1, 10
id2, 15
id3, 5
...

我能想到的唯一方法是使用row_number没有分区就像

val window = Window.orderBy(desc("count"))

df.withColumn("row_number", row_number over window).filter(col("row_number") <= n)

但是,当数据包含数百万或数十亿行时,这绝不是高性能的,因为它将数据推送到一个分区中,并且出现 OOM。

有没有人想出一个高性能的解决方案?


我看到两种提高算法性能的方法。首先是使用sort https://spark.apache.org/docs/latest/api/scala/org/apache/spark/sql/Dataset.html#sort(sortExprs:org.apache.spark.sql.Column*):org.apache.spark.sql.Dataset%5BT%5D and limit https://spark.apache.org/docs/latest/api/scala/org/apache/spark/sql/Dataset.html#limit(n:Int):org.apache.spark.sql.Dataset%5BT%5D检索前 n 行。第二个是开发您的定制Aggregator https://spark.apache.org/docs/latest/sql-ref-functions-udf-aggregate.html#aggregator-in-buf-out.

排序和限制方法

您对数据框进行排序,然后获取第一个n rows:

val n: Int = ???

import org.apache.spark.functions.sql.desc

df.orderBy(desc("count")).limit(n)

Spark 通过首先对每个分区执行排序来优化这种转换序列,首先n每个分区上的行,在最终分区上检索它并重新执行排序并首先获取最终分区n行。您可以通过执行来检查这一点explain()关于转变。您将得到以下执行计划:

== Physical Plan ==
TakeOrderedAndProject(limit=3, orderBy=[count#8 DESC NULLS LAST], output=[id#7,count#8])
+- LocalTableScan [id#7, count#8]

并通过观察如何TakeOrderedAndProject步骤执行于限制.scala https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala在Spark的源代码中(案例类TakeOrderedAndProjectExec, 方法doExecute).

自定义聚合器方法

对于自定义聚合器,您创建一个Aggregator这将填充并更新顶部的有序数组n rows.

import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.Encoder

import scala.collection.mutable.ArrayBuffer

case class Record(id: String, count: Int)

case class TopRecords(limit: Int) extends Aggregator[Record, ArrayBuffer[Record], Seq[Record]] {

  def zero: ArrayBuffer[Record] = ArrayBuffer.empty[Record]

  def reduce(topRecords: ArrayBuffer[Record], currentRecord: Record): ArrayBuffer[Record] = {
    val insertIndex = topRecords.lastIndexWhere(p => p.count > currentRecord.count)
    if (topRecords.length < limit) {
      topRecords.insert(insertIndex + 1, currentRecord)
    } else if (insertIndex < limit - 1) {
      topRecords.insert(insertIndex + 1, currentRecord)
      topRecords.remove(topRecords.length - 1)
    }
    topRecords
  }

  def merge(topRecords1: ArrayBuffer[Record], topRecords2: ArrayBuffer[Record]): ArrayBuffer[Record] = {
    val merged = ArrayBuffer.empty[Record]
    while (merged.length < limit && (topRecords1.nonEmpty || topRecords2.nonEmpty)) {
      if (topRecords1.isEmpty) {
        merged.append(topRecords2.remove(0))
      } else if (topRecords2.isEmpty) {
        merged.append(topRecords1.remove(0))
      } else if (topRecords2.head.count < topRecords1.head.count) {
        merged.append(topRecords1.remove(0))
      } else {
        merged.append(topRecords2.remove(0))
      }
    }
    merged
  }

  def finish(reduction: ArrayBuffer[Record]): Seq[Record] = reduction

  def bufferEncoder: Encoder[ArrayBuffer[Record]] = ExpressionEncoder[ArrayBuffer[Record]]

  def outputEncoder: Encoder[Seq[Record]] = ExpressionEncoder[Seq[Record]]

}

然后将此聚合器应用到数据帧上,并展平聚合结果:

val n: Int = ???

import sparkSession.implicits._

df.as[Record].select(TopRecords(n).toColumn).flatMap(record => record)

方法比较

为了比较这两种方法,假设我们想要采用顶部n分布在的数据帧的行p分区,每个分区有大约k记录。所以数据框有大小p·k。这给出了以下复杂性(可能会出现错误):

method total number of operations memory consumption
(on executor)
memory consumption
(on final executor)
Current code O(p·k·log(p·k)) -- O(p·k)
Sort and Limit O(p·k·log(k) + p·n·log(p·n)) O(k) O(p·n)
Custom Aggregator O(p·k) O(k) + O(n) O(p·n)

所以关于操作次数,定制聚合器是性能最好的。然而,这种方法是迄今为止最复杂的,并且意味着大量的序列化/反序列化,因此它的性能可能低于排序和限制在某些情况下。

结论

您有两种方法可以有效地占据顶部n rows, 排序和限制 and 定制聚合器。要选择使用哪一种,您应该使用真实的数据帧对这两种方法进行基准测试。如果在基准测试之后排序和限制比 慢一点定制聚合器,我会选择排序和限制因为它的代码更容易维护。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Spark:查找前 n 个值的高性能方法 的相关文章

随机推荐