Simplest way to get the top n elements of a Scala Iterable
Asked Answered
T

9

41

Is there a simple and efficient solution to determine the top n elements of a Scala Iterable? I mean something like

iter.toList.sortBy(_.myAttr).take(2)

but without having to sort all elements when only the top 2 are of interest. Ideally I'm looking for something like

iter.top(2, _.myAttr)

see also: Solution for the top element using an Ordering: In Scala, how to use Ordering[T] with List.min or List.max and keep code readable

Update:

Thank you all for your solutions. Finally, I took the original solution of user unknown and adopted it to use Iterable and the pimp-my-library pattern:

implicit def iterExt[A](iter: Iterable[A]) = new {
  def top[B](n: Int, f: A => B)(implicit ord: Ordering[B]): List[A] = {
    def updateSofar (sofar: List [A], el: A): List [A] = {
      //println (el + " - " + sofar)

      if (ord.compare(f(el), f(sofar.head)) > 0)
        (el :: sofar.tail).sortBy (f)
      else sofar
    }

    val (sofar, rest) = iter.splitAt(n)
    (sofar.toList.sortBy (f) /: rest) (updateSofar (_, _)).reverse
  }
}

case class A(s: String, i: Int)
val li = List (4, 3, 6, 7, 1, 2, 9, 5).map(i => A(i.toString(), i))
println(li.top(3, _.i))
Teatime answered 15/4, 2011 at 9:21 Comment(3)
I think iter.toList.sortBy(_.myAttr).take(n) is simple and efficient enough when the size of the iterator is close to n. Are you thinking that n will always be small and dealing with possibly large iterators?Colon
I also use this syntax for small lists, but now I have a case where an collection of arbitrary length (could get quite large) has to be efficiently processed and I usually need the top 5 or top 10 elements of this collection.Teatime
I would think using a list as an accumulator would be best for that exercise. I would recommend a Priority Queue (not requiring resorting everytime you need to insert into the accumulator). Thoughts ?Sunderland
D
24

My solution (bound to Int, but should be easily changed to Ordered (a few minutes please):

def top (n: Int, li: List [Int]) : List[Int] = {

  def updateSofar (sofar: List [Int], el: Int) : List [Int] = {
    // println (el + " - " + sofar)
    if (el < sofar.head) 
      (el :: sofar.tail).sortWith (_ > _) 
    else sofar
  }

  /* better readable:
    val sofar = li.take (n).sortWith (_ > _)
    val rest = li.drop (n)
    (sofar /: rest) (updateSofar (_, _)) */    
  (li.take (n). sortWith (_ > _) /: li.drop (n)) (updateSofar (_, _)) 
}

usage:

val li = List (4, 3, 6, 7, 1, 2, 9, 5)    
top (2, li)
  • For above list, take the first 2 (4, 3) as starting TopTen (TopTwo).
  • Sort them, such that the first element is the bigger one (if any).
  • repeatedly iterate through the rest of the list (li.drop(n)), and compare the current element with the maximum of the list of minimums; replace, if neccessary, and resort again.
  • Improvements:
    • Throw away Int, and use ordered.
    • Throw away (_ > _) and use a user-Ordering to allow BottomTen. (Harder: pick the middle 10 :) )
    • Throw away List, and use Iterable instead

update (abstraction):

def extremeN [T](n: Int, li: List [T])
  (comp1: ((T, T) => Boolean), comp2: ((T, T) => Boolean)):
     List[T] = {

  def updateSofar (sofar: List [T], el: T) : List [T] =
    if (comp1 (el, sofar.head)) 
      (el :: sofar.tail).sortWith (comp2 (_, _)) 
    else sofar

  (li.take (n) .sortWith (comp2 (_, _)) /: li.drop (n)) (updateSofar (_, _)) 
}

/*  still bound to Int:  
def top (n: Int, li: List [Int]) : List[Int] = {
  extremeN (n, li) ((_ < _), (_ > _))
}
def bottom (n: Int, li: List [Int]) : List[Int] = {
  extremeN (n, li) ((_ > _), (_ < _))
}
*/

def top [T] (n: Int, li: List [T]) 
  (implicit ord: Ordering[T]): Iterable[T] = {
  extremeN (n, li) (ord.lt (_, _), ord.gt (_, _))
}
def bottom [T] (n: Int, li: List [T])
  (implicit ord: Ordering[T]): Iterable[T] = {
  extremeN (n, li) (ord.gt (_, _), ord.lt (_, _))
}

top (3, li)
bottom (3, li)
val sl = List ("Haus", "Garten", "Boot", "Sumpf", "X", "y", "xkcd", "x11")
bottom (2, sl)

To replace List with Iterable seems to be a bit harder.

As Daniel C. Sobral pointed out in the comments, a high n in topN can lead to much sorting work, so that it could be useful, to do a manual insertion sort instead of repeatedly sorting the whole list of top-n elements:

def extremeN [T](n: Int, li: List [T])
  (comp1: ((T, T) => Boolean), comp2: ((T, T) => Boolean)):
     List[T] = {

  def sortedIns (el: T, list: List[T]): List[T] = 
    if (list.isEmpty) List (el) else 
    if (comp2 (el, list.head)) el :: list else 
      list.head :: sortedIns (el, list.tail)

  def updateSofar (sofar: List [T], el: T) : List [T] =
    if (comp1 (el, sofar.head)) 
      sortedIns (el, sofar.tail)
    else sofar

  (li.take (n) .sortWith (comp2 (_, _)) /: li.drop (n)) (updateSofar (_, _)) 
}

top/bottom method and usage as above. For small groups of top/bottom Elements, the sorting is rarely called, a few times in the beginning, and then less and less often over time. For example, 70 times with top (10) of 10 000, and 90 times with top (10) of 100 000.

Darceydarci answered 15/4, 2011 at 14:41 Comment(10)
Thanks for your solution. It's short and easy to understand. I adopted it to work with Iterables of arbitrary types. Unfortunately, stackoverflow does not allow me to put my code in this comment.Teatime
Adopted the original, or the second solution with extremN? You may append to your question (but mark it '# update:', to make the discussion understandable) or open an answer yourself.Darceydarci
I used your original version. I will append the code to the question.Teatime
Meanwhile I updated my lower solution a bit farther; top and bottom use Ordering too.Darceydarci
Instead of sorting, you should do sorted insertion in the tail. It will be much faster.Floatage
@Daniel: For a random List, equally distributed of 10 000 ints val li = rnd.shuffle ((1 to 10000).toList), and a call of top (10, list), I get 70 internal sorts. For a List of 100 000 ints, 90 sorts. With non-pathological input, I wouldn't expect a sorted insertion to be much faster. However, I looked out for a suitable collection and didn't find a SortedList or SortedSeq, nor a method 'sortedInsert'. What would you recommend? Write my own sorted insertion?Darceydarci
What is the "n" you used? For very small n, there isn't much point, indeed. But the algorithm is generic -- I'd expect more frequent sorts, and more costly, with something like n = 20 or n = 30. Yes, I recommend you do sorted insert yourself: it is trivial. Just write an auxiliary recursive function sortedIns(el: T, list: List[T]) = if (el < list.head) el :: list else list.head :: sortedIns(el, list.tail). Actually, that's it. Then replace (el :: sofar.tail).sortWith (comp2 (_, _)) with sortedIns(el, sofar.tail).Floatage
Yes, if you pick the top 5000 out of 10000, it would be wiser to sort the whole thing once, and cut it into 2 parts. The idea is, to pick only a few elements, compared to the whole collection, and well - I thought I might be missing a library method, some implicit thing. It's not too hard to do it myself. But it will make the code harder to understand. Nethertheless I will integrate it.Darceydarci
Look... you need to take n elements from a list of elements. When you see a discussion this size with so much code just to do that... you should start to think that something is very wrong. Don't take me wrong: I find Scala great and I use it 12+ hours a day... but so much code to just take n elements from a collection... honestly... it's time to break paradigms and deliver the product!Rudich
@RichardGomes: You have a better solution which fits the requirements (not sorting the whole thing)? If not, you should maybe comment to the question, not to my answer.Darceydarci
M
8

Here's another solution that is simple and has pretty good performance.

def pickTopN[T](k: Int, iterable: Iterable[T])(implicit ord: Ordering[T]): Seq[T] = {
  val q = collection.mutable.PriorityQueue[T](iterable.toSeq:_*)
  val end = Math.min(k, q.size)
  (1 to end).map(_ => q.dequeue())
}

The Big O is O(n + k log n), where k <= n. So the performance is linear for small k and at worst n log n.

The solution can also be optimized to be O(k) for memory but O(n log k) for performance. The idea is to use a MinHeap to track only the top k items at all times. Here's the solution.

def pickTopN[A, B](n: Int, iterable: Iterable[A], f: A => B)(implicit ord: Ordering[B]): Seq[A] = {
  val seq = iterable.toSeq
  val q = collection.mutable.PriorityQueue[A](seq.take(n):_*)(ord.on(f).reverse) // initialize with first n

  // invariant: keep the top k scanned so far
  seq.drop(n).foreach(v => {
    q += v
    q.dequeue()
  })

  q.dequeueAll.reverse
}
Madelle answered 1/8, 2017 at 17:0 Comment(6)
You're pulling in 2 different implicits, ClassTag[T] and Ordering[T]. Why not use the same convenience format for both: [T:ClassTag:Ordering]? For that matter, why require ClassTag? It doesn't appear to be required.Younglove
Oops my bad. Yes ClassTag is not required. Thanks! I've made the edit.Madelle
Thx for writing algo complexity, it helps to pick the right solution!, do you think there is a linear solution if the final result does not need to be sorted ?Rafiq
@Michal: No, I don't think so. I think at minimum you need some number of comparisons in order to figure out which elements are the top k. So I think these two algorithms are the best that you can do. I cannot think of a proof though, so I may be wrong in this case. The last solution (as of today) by ponkin claims to be O(n), but I think the asymptotic analysis is incorrect.Madelle
@TomWang thx for quick answer. Anyway i tried your solution, seems to work just fine and reasonably fast!Rafiq
Thanks! Actually the solution that @ponkin provided below may be even faster. The average case is O(n) and worst case is O(n^2).Madelle
L
7

Yet another version:

val big = (1 to 100000)

def maxes[A](n:Int)(l:Traversable[A])(implicit o:Ordering[A]) =
    l.foldLeft(collection.immutable.SortedSet.empty[A]) { (xs,y) =>
      if (xs.size < n) xs + y
      else {
        import o._
        val first = xs.firstKey
        if (first < y) xs - first + y
        else xs
      }
    }

println(maxes(4)(big))
println(maxes(2)(List("a","ab","c","z")))

Using the Set force the list to have unique values:

def maxes2[A](n:Int)(l:Traversable[A])(implicit o:Ordering[A]) =
    l.foldLeft(List.empty[A]) { (xs,y) =>
      import o._
      if (xs.size < n) (y::xs).sort(lt _)
      else {
        val first = xs.head
        if (first < y) (y::(xs - first)).sort(lt _)
        else xs
      }
    }
Lester answered 15/4, 2011 at 15:49 Comment(0)
G
4

You don't need to sort the entire collection in order to determine the top N elements. However, I don't believe that this functionality is supplied by the raw library, so you would have to roll you own, possibly using the pimp-my-library pattern.

For example, you can get the nth element of a collection as follows:

  class Pimp[A, Repr <% TraversableLike[A, Repr]](self : Repr) {

    def nth(n : Int)(implicit ord : Ordering[A]) : A = {
      val trav : TraversableLike[A, Repr] = self
      var ltp : List[A] = Nil
      var etp : List[A] = Nil
      var mtp : List[A] = Nil
      trav.headOption match {
        case None      => error("Cannot get " + n + " element of empty collection")
        case Some(piv) =>
          trav.foreach { a =>
            val cf = ord.compare(piv, a)
            if (cf == 0) etp ::= a
            else if (cf > 0) ltp ::= a
            else mtp ::= a
          }
          if (n < ltp.length)
            new Pimp[A, List[A]](ltp.reverse).nth(n)(ord)
          else if (n < (ltp.length + etp.length))
            piv
          else
            new Pimp[A, List[A]](mtp.reverse).nth(n - ltp.length - etp.length)(ord)
      }
    }
  }

(This is not very functional; sorry)

It's then trivial to get the top n elements:

def topN(n : Int)(implicit ord : Ordering[A], bf : CanBuildFrom[Repr, A, Repr]) ={
  val b = bf()
  val elem = new Pimp[A, Repr](self).nth(n)(ord)
  import util.control.Breaks._
  breakable {
    var soFar = 0
    self.foreach { tt =>
      if (ord.compare(tt, elem) < 0) {
         b += tt
         soFar += 1
      }
    }
    assert (soFar <= n)
    if (soFar < n) {
      self.foreach { tt =>
        if (ord.compare(tt, elem) == 0) {
          b += tt
          soFar += 1
        }
        if (soFar == n) break
      }
    }

  }
  b.result()
}

Unfortunately I'm having trouble getting this pimp to be discovered via this implicit:

implicit def t2n[A, Repr <% TraversableLike[A, Repr]](t : Repr) : Pimp[A, Repr] 
  = new Pimp[A, Repr](t)

I get this:

scala> List(4, 3, 6, 7, 1, 2, 8, 5).topN(4)
<console>:9: error: could not find implicit value for evidence parameter of type (List[Int]) => scala.collection.TraversableLike[A,List[Int]]
   List(4, 3, 6, 7, 1, 2, 8, 5).topN(4)
       ^

However, the code actually works OK:

scala> new Pimp(List(4, 3, 6, 7, 1, 2, 8, 5)).topN(4)
res3: List[Int] = List(3, 1, 2, 4)

And

scala> new Pimp("ioanusdhpisjdmpsdsvfgewqw").topN(6)
res2: java.lang.String = adddfe
Gardie answered 15/4, 2011 at 9:25 Comment(0)
P
2

If the goal is to not sort the whole list then you could do something like this (of course it could be optimized a tad so that we don't change the list when the number clearly shouldn't be there):

List(1,6,3,7,3,2).foldLeft(List[Int]()){(l, n) => (n :: l).sorted.take(2)}
Pleasantry answered 15/4, 2011 at 10:7 Comment(3)
sorted.reverse.take(2) to get the top 2. It's simple but I'm not sure of the efficiency as sorted is built on top of java.util.Arrays.sort, so this may create a lot of temp arrays.Colon
And also it fails if the goal was to avoid sorting the complete list.Pleasantry
Just to be sure we are on the same page, your solution returns the bottom 2 and this is why I mentioned reverse to make your solution work.Colon
B
1

I implemented such an ranking algorithm recently in the Rank class of Apache Jackrabbit (in Java though). See the take method for the gist of it. The basic idea is to quicksort but terminate prematurely as soon as the top n elements have been found.

Berga answered 15/4, 2011 at 11:15 Comment(1)
Oh, that's good to know. It's probably the most efficient solution. The only drawback I see is that it generates an array with all elements of the Iterable. Thus, in extreme cases it may not be applicable because of restricted memory.Teatime
O
1

Here is asymptotically O(n) solution.

def top[T](data: List[T], n: Int)(implicit ord: Ordering[T]): List[T] = {
    require( n < data.size)

    def partition_inner(shuffledData: List[T], pivot: T): List[T] = 
      shuffledData.partition( e => ord.compare(e, pivot) > 0 ) match {
          case (left, right) if left.size == n => left
          case (left, x :: rest) if left.size < n => 
            partition_inner(util.Random.shuffle(data), x)
          case (left @ y :: rest, right) if left.size > n => 
            partition_inner(util.Random.shuffle(data), y)
      }

     val shuffled = util.Random.shuffle(data)
     partition_inner(shuffled, shuffled.head)
}

scala> top(List.range(1,10000000), 5)

Due to recursion, this solution will take longer than some non-linear solutions above and can cause java.lang.OutOfMemoryError: GC overhead limit exceeded. But slightly more readable IMHO and functional style. Just for job interview ;).

What is more important, that this solution can be easily parallelized.

def top[T](data: List[T], n: Int)(implicit ord: Ordering[T]): List[T] = {
    require( n < data.size)

    @tailrec
    def partition_inner(shuffledData: List[T], pivot: T): List[T] = 
      shuffledData.par.partition( e => ord.compare(e, pivot) > 0 ) match {
          case (left, right) if left.size == n => left.toList
          case (left, right) if left.size < n => 
            partition_inner(util.Random.shuffle(data), right.head)
          case (left, right) if left.size > n => 
            partition_inner(util.Random.shuffle(data), left.head)
      }

     val shuffled = util.Random.shuffle(data)
     partition_inner(shuffled, shuffled.head)
}
Ophthalmitis answered 4/8, 2016 at 7:17 Comment(4)
I don't think this is O(n). This is essentially a modified quick sort. Even without counting the amount of time to do the shuffle, which I believe is meant to be an optimization to make worst case quadratic solution extremely unlikely, your algorithm would still take around O(n log n). Say we only want to pick the largest element. Each recursive call is expected to be about 1/2 the size as before. So you would still expect to take O(n log n) just to pick out the largest element.Madelle
@TomWang Not exactly. this is modification of quickselect algorithm, which has average complexity O(n). But in our case, we apply it n times, which gives us O(1) * O(n). Shuffling can be implemented in linear time, but we need to do it only once.Ophthalmitis
Sounds good. I see that the average case is O(n) with worst case of O(n^2). I missed that the search is happening only on one side of the pivot, hence reducing the runtime from quick sort. I think you may want to write an algorithm that reduces the number of shuffles to reduce the amount of time spent shuffling.Madelle
Unless I'm missing something, this can get into a state where it never terminates if there are multiple elements that are equal to the nth top value. The partition will always over or undershoot, and never hit the terminating case.Stirps
C
0

For small values of n and large lists, getting the top n elements can be implemented by picking out the max element n times:

def top[T](n:Int, iter:Iterable[T])(implicit ord: Ordering[T]): Iterable[T] = {
  def partitionMax(acc: Iterable[T], it: Iterable[T]): Iterable[T]  = {
    val max = it.max(ord)
    val (nextElems, rest) = it.partition(ord.gteq(_, max))
    val maxElems = acc ++ nextElems
    if (maxElems.size >= n || rest.isEmpty) maxElems.take(n)
    else partitionMax(maxElems, rest)
  }
  if (iter.isEmpty) iter.take(0)
  else partitionMax(iter.take(0), iter)
}

This does not sort the entire list and takes an Ordering. I believe every method I call in partitionMax is O(list size) and I only expect to call it n times at most, so the overall efficiency for small n will be proportional to the size of the iterator.

scala> top(5, List.range(1,1000000))
res13: Iterable[Int] = List(999999, 999998, 999997, 999996, 999995)

scala> top(5, List.range(1,1000000))(Ordering[Int].on(- _))
res14: Iterable[Int] = List(1, 2, 3, 4, 5)

You could also add a branch for when n gets close to size of the iterable, and switch to iter.toList.sortBy(_.myAttr).take(n).

It does not return the type of collection provided, but you can look at How do I apply the enrich-my-library pattern to Scala collections? if this is a requirement.

Colon answered 15/4, 2011 at 13:42 Comment(0)
F
0

An optimised solution using PriorityQueue with Time Complexity of O(nlogk). In the approach given in the update, you are sorting the sofar list every time which is not needed and below it is optimised by using PriorityQueue.

import scala.language.implicitConversions
import scala.language.reflectiveCalls
import collection.mutable.PriorityQueue
implicit def iterExt[A](iter: Iterable[A]) = new {
    def top[B](n: Int, f: A => B)(implicit ord: Ordering[B]) : List[A] = {
        def updateSofar (sofar: PriorityQueue[A], el: A): PriorityQueue[A] = {
            if (ord.compare(f(el), f(sofar.head)) < 0){
                sofar.dequeue
                sofar.enqueue(el)
            }
            sofar
        }

        val (sofar, rest) = iter.splitAt(n)
        (PriorityQueue(sofar.toSeq:_*)( Ordering.by( (x :A) => f(x) ) ) /: rest) (updateSofar (_, _)).dequeueAll.toList.reverse
    }
}

case class A(s: String, i: Int)
val li = List (4, 3, 6, 7, 1, 2, 9, 5).map(i => A(i.toString(), i))
println(li.top(3, -_.i))
Fineberg answered 19/2, 2018 at 13:4 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.