Efficient nearest neighbour search in Scala
Asked Answered
B

5

8

Let this coordinates class with the Euclidean distance,

case class coord(x: Double, y: Double) {
  def dist(c: coord) = Math.sqrt( Math.pow(x-c.x, 2) + Math.pow(y-c.y, 2) ) 
}

and let a grid of coordinates, for instance

val grid = (1 to 25).map {_ => coord(Math.random*5, Math.random*5) }

Then for any given coordinate

val x = coord(Math.random*5, Math.random*5) 

the nearest points to x are

val nearest = grid.sortWith( (p,q) => p.dist(x) < q.dist(x) )

so the first three closest are nearest.take(3).

Is there a way to make these calculations more time efficient especially for the case of a grid with one million points ?

Bowerbird answered 6/9, 2014 at 5:9 Comment(13)
Good question. A very obvious way is to find minimum instead of sort val nearest = grid.minBy( p => p.dist(x) ) and then remove that element for list and try again. Works if small number 3. This is not not worthy of an answer. I suspect bit wise operation somewhere to speed upPatric
en.wikipedia.org/wiki/Nearest_neighbor_searchArterial
thought of K-d trees as preprocessing for bisecting search space (grid), en.wikipedia.org/wiki/K-d_treeBowerbird
@DavidEisenstat, SO needs a "mark as duplicate of Wikipedia article" :)Oscaroscillate
@enzyme, K-d trees are mentioned in the article David refers you to, along with lots of other suggestions. It's a well studied problemOscaroscillate
Many Thanks to all for the good ideas; a well-studied problem it is, yet a simple and efficient solution in Scala for the very problem depicted above is much valued.Bowerbird
One straightforward optimisation -- you can use distance square: def distSquare(c: coord) = Math.pow(x-c.x, 2) + Math.pow(y-c.y, 2) as the measure. (which basically saves you calculating .sqrt each time)Mussulman
@arturgrzesiak excellent observationBowerbird
@enzyme, this looks relevant: #5675241Oscaroscillate
Also, Apache Spark has a top method that does what you want. Maybe you can repurpose the source for that? spark.apache.org/docs/0.8.1/api/core/org/apache/spark/rdd/…Oscaroscillate
And how many nearest neighbors will you need? Different algorithms would applyiif a few (3, in your example) or most of them (a million)Oscaroscillate
@Paul thanks a ton! checking spark with special interest :)Bowerbird
If you really are going to only need a few of the million coordinates, then rather than sorting them you'd be much better off pushing them all onto a PriorityQueue and taking the top few. See #13 on this page.Passionless
S
6

I'm not sure if this is helpful (or even stupid), but I thought of this:

You use a sort-function to sort ALL elements in the grid and then pick the first k elements. If you consider a sorting algorithm like recursive merge-sort, you have something like this:

  1. Split collection in half
  2. Recurse on both halves
  3. Merge both sorted halves

Maybe you could optimize such a function for your needs. The merging part normally merges all elements from both halves, but you are only interested in the first k that result from the merging. So you could only merge until you have k elements and ignore the rest.

So in the worst-case, where k >= n (n is the size of the grid) you would still only have the complexity of merge-sort. O(n log n) To be honest I'm not able to determine the complexity of this solution relative to k. (too tired for that at the moment)

Here is an example implementation of that solution (it's definitely not optimal and not generalized):

def minK(seq: IndexedSeq[coord], x: coord, k: Int) = {

  val dist = (c: coord) => c.dist(x)

  def sort(seq: IndexedSeq[coord]): IndexedSeq[coord] = seq.size match {
    case 0 | 1 => seq
    case size => {
      val (left, right) = seq.splitAt(size / 2)
      merge(sort(left), sort(right))
    }
  }

  def merge(left: IndexedSeq[coord], right: IndexedSeq[coord]) = {

    val leftF = left.lift
    val rightF = right.lift

    val builder = IndexedSeq.newBuilder[coord]

    @tailrec
    def loop(leftIndex: Int = 0, rightIndex: Int = 0): Unit = {
      if (leftIndex + rightIndex < k) {
        (leftF(leftIndex), rightF(rightIndex)) match {
          case (Some(leftCoord), Some(rightCoord)) => {
            if (dist(leftCoord) < dist(rightCoord)) {
              builder += leftCoord
              loop(leftIndex + 1, rightIndex)
            } else {
              builder += rightCoord
              loop(leftIndex, rightIndex + 1)
            }
          }
          case (Some(leftCoord), None) => {
            builder += leftCoord
            loop(leftIndex + 1, rightIndex)
          }
          case (None, Some(rightCoord)) => {
            builder += rightCoord
            loop(leftIndex, rightIndex + 1)
          }
          case _ =>
        }
      }
    }

    loop()

    builder.result
  }

  sort(seq)
}
Sciomancy answered 6/9, 2014 at 7:27 Comment(0)
R
4

Profile your code, to see what is costly.

Your way of sorting is already highly inefficient.

Do not recompute distances all the time. That isn't free - most likely your program spends 99% of the time with computing distances (use a profiler to find out!)

Finally, you can use index structures. For Euclidean distance you have probably the largest choice of indexes to accelerate finding the nearest neighbors. There is k-d-tree, but I found the R-tree to be often faster. If you want to play around with these, I recommend ELKI. It's a Java library for data mining (so it should be easy to use from Scala, too), and it has a huge choice of index structures.

Refutation answered 6/9, 2014 at 18:57 Comment(0)
O
2

This one was quite fun to do.

case class Coord(x: Double, y: Double) {
    def dist(c: Coord) = Math.sqrt(Math.pow(x - c.x, 2) + Math.pow(y - c.y, 2))
}
class CoordOrdering(x: Coord) extends Ordering[Coord] {
    def compare(a: Coord, b: Coord) = a.dist(x) compare b.dist(x)
}

def top[T](xs: Seq[T], n: Int)(implicit ord: Ordering[T]): Seq[T] = {
    // xs is an ordered sequence of n elements. insert returns xs with e inserted 
    // if it is less than anything currently in the sequence (and in that case, 
    // the last element is dropped) otherwise returns an unmodifed sequence

    def insert[T](xs: Seq[T], e: T)(implicit ord: Ordering[T]): Seq[T] = {
      val (l, r) = xs.span(x => ord.lt(x, e))
      (l ++ (e +: r)).take(n)
    }
    xs.drop(n).foldLeft(xs.take(n).sorted)(insert)
} 

Minimally tested. Call it like this:

val grid = (1 to 250000).map { _ => Coord(Math.random * 5, Math.random * 5) }
val x = Coord(Math.random * 5, Math.random * 5)
top(grid, 3)(new CoordOrdering(x)) 

EDIT: It's quite easy to extend this to (pre-)compute the distances just once

val zippedGrid = grid map {_.dist(x)} zip grid  

object ZippedCoordOrdering extends Ordering[(Double, Coord)] {
   def compare(a:(Double, Coord), b:(Double, Coord)) = a._1 compare b._1
}

top(zippedGrid,3)(ZippedCoordOrdering).unzip._2
Oscaroscillate answered 6/9, 2014 at 18:53 Comment(1)
Actually, insert can be made more efficient still - if (as is probably the usual case) e is greater than anything in the list, only one comparison is needed, but currently it does n. To do this, keep the current top n in reverse order (greatest first). I may make that modification later.Oscaroscillate
C
1

Here is an algorithm that makes use of an R-tree data structure. Not useful for the small data set described, but it scales well to a large number of objects.

Use an ordered list whose nodes represent either objects or R-tree bounding boxes. The order is closest first using whatever distance function you want. Maintain the order on insert.

Initialize the list by inserting the bounding boxes in the root node of the R-tree.

To get the next closest object:

(1) Remove the first element from the list.

(2) If it is an object, it is the closest one.

(3) If it is the bounding box of a non-leaf node of the R-tree, insert all the bounding boxes representing children of that node into the list in their proper places according to their distance.

(4) If it is the bounding box of an R-tree leaf node, insert the objects that are children of that node (the objects, not their bounding boxes) according to their distance.

(5) Go back to step (1).

The list will remain pretty short. At the front will be nearby objects that we are interested in, and later nodes in the list will be boxes representing collections of objects that are farther away.

Chromic answered 12/9, 2014 at 5:13 Comment(0)
R
0

It depends on whether exact or approximation.

As several benchmarks such as http://www.slideshare.net/erikbern/approximate-nearest-neighbor-methods-and-vector-models-nyc-ml-meetup show that approximation is a good solution in terms of efficient.

I wrote ann4s which is a Scala implementation of Annoy

Annoy (Approximate Nearest Neighbors Oh Yeah) is a C++ library with Python bindings to search for points in space that are close to a given query point. It also creates large read-only file-based data structures that are mmapped into memory so that many processes may share the same data.

Take a look at this repo.

Russianize answered 9/9, 2016 at 14:11 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.