Answer
Spark's limit doesn't run in parallel.
Reason
There is some physical operator in Spark for limit logic:
- CollectLimitExec:
collect data to a single partition, not work in parallel, but "perform limit incrementally".
only be used when a logical Limit operation is the final operator in an logical plan, which happens when the user is collecting results back to the driver. eg: spark.sql("select * from x limit 100").collect()
- LocalLimitExec & GlobalLimitExec
they work together
LocalLimitExec: Take the first limit elements of each child partition, but do not collect or shuffle them.
GlobalLimitExec: Take the first limit elements of the child’s single output partition.
There is an Exchange(shuffle) between them.
the global limit step, work in single partition, not parallel
- other Exec: CollectTailExec / TakeOrderedAndProjectExec
Solution
random
SELECT * FROM test TABLESAMPLE (50 PERCENT)
select * from x where rand() < 0.01
df.sample(0.01) or rdd.sample(0.01)
mapPartitions and take
take directly
df.mapPartitions((a)=>a.take(2853557))
take after countByPartitions
// ① 统计每个分区内行数 / countByPartitions
val x = df5.mapPartitions((a) => {
val pid = TaskContext.getPartitionId()
Iterator((pid, a.size))
})
val countByPart = x.collectAsList()
print(countByPart) //[(0,400), (1,400), (2,400), (3,400), (4,400)]
// ② 分配各分区应该take的数量. / allocate
var limit = 900
val takeByPart = new Array[Int](countByPart.size)
for (a <- 0 until countByPart.size) {
val take = if (limit > 0) {
Math.min(limit, countByPart.get(a)._2)
} else {
0
}
limit = limit - take
takeByPart(a) = take
}
print(takeByPart.mkString("(", ", ", ")")) //(400, 400, 100, 0, 0)
val takeByPartBC = spark.sparkContext.broadcast(takeByPart)
// ③ 分区take结果
val result = df5.mapPartitions((a) => {
val pid = TaskContext.getPartitionId()
val take = takeByPartBC.value(pid)
a.take(take)
})
assert(result.count() == 900)
go to my blog for further reading