nearest neighbour search kdTree
Asked Answered
T

4

6

To a list of N points [(x_1,y_1), (x_2,y_2), ... ] I am trying to find the nearest neighbours to each point based on distance. My dataset is too large to use a brute force approach so a KDtree seems best.

Rather than implement one from scratch I see that sklearn.neighbors.KDTree can find the nearest neighbours. Can this be used to find the nearest neighbours of each particle, i.e return a dim(N) list?

Teriann answered 6/1, 2018 at 11:16 Comment(2)
Do you want to find nearest neighbor of each particle with sklearn method only or by any method?Cerous
Any' kdtree' based method will do, but this seems an attractive library due to the wide range of customisation!Teriann
P
16

This question is very broad and missing details. It's unclear what you did try, how your data looks like and what a nearest-neighbor is (identity?).

Assuming you are not interested in the identity (with distance 0), you can query the two nearest-neighbors and drop the first column. This is probably the easiest approach here.

Code:

 import numpy as np
 from sklearn.neighbors import KDTree
 np.random.seed(0)
 X = np.random.random((5, 2))  # 5 points in 2 dimensions
 tree = KDTree(X)
 nearest_dist, nearest_ind = tree.query(X, k=2)  # k=2 nearest neighbors where k1 = identity
 print(X)
 print(nearest_dist[:, 1])    # drop id; assumes sorted -> see args!
 print(nearest_ind[:, 1])     # drop id 

Output

 [[ 0.5488135   0.71518937]
  [ 0.60276338  0.54488318]
  [ 0.4236548   0.64589411]
  [ 0.43758721  0.891773  ]
  [ 0.96366276  0.38344152]]
 [ 0.14306129  0.1786471   0.14306129  0.20869372  0.39536284]
 [2 0 0 0 1]
Pharisaic answered 6/1, 2018 at 11:57 Comment(3)
OP wants nearest neighbours based on distance, not k nearest neighbours.Preciousprecipice
@landogardner And that's where i'm not with you. Imho he asks for each nearest neighbor for each observation, based on distance! That's also one of the interpretations where his static output-dimension asked for works out.Pharisaic
It is now scipy.spatial.KDTreeLuane
P
7

You can use sklearn.neighbors.KDTree's query_radius() method, which returns a list of the indices of the nearest neighbours within some radius (as opposed to returning k nearest neighbours).

from sklearn.neighbors import KDTree

points = [(1, 1), (2, 2), (3, 3), (4, 4), (5, 5)]

tree = KDTree(points, leaf_size=2)
all_nn_indices = tree.query_radius(points, r=1.5)  # NNs within distance of 1.5 of point
all_nns = [[points[idx] for idx in nn_indices] for nn_indices in all_nn_indices]
for nns in all_nns:
    print(nns)

Outputs:

[(1, 1), (2, 2)]
[(1, 1), (2, 2), (3, 3)]
[(2, 2), (3, 3), (4, 4)]
[(3, 3), (4, 4), (5, 5)]
[(4, 4), (5, 5)]

Note that each point includes itself in its list of nearest neighbours within the given radius. If you want to remove these identity points, the line computing all_nns can be changed to:

all_nns = [
    [points[idx] for idx in nn_indices if idx != i]
    for i, nn_indices in enumerate(all_nn_indices)
]

Resulting in:

[(2, 2)]
[(1, 1), (3, 3)]
[(2, 2), (4, 4)]
[(3, 3), (5, 5)]
[(4, 4)]
Preciousprecipice answered 6/1, 2018 at 11:33 Comment(0)
O
3

Update 2023

I had to revisit this and found that my implementation though very fast is not accurate. The sklearn should be the best. I wrote the below some time back ,where I needed custom distance. (sklearn does not support custom distance function for KDTree, but supports for BallTree.

Here is a Jupyter notebook with timing with sklearn BallTree and KDTree and my custom code. https://colab.research.google.com/drive/1ymx2r3J7oUMAuPlsZxDnESM7aTfwbLWV?usp=sharing

BallTree is accurate, but pretty slow, KDTree gets the top 5 but in a different order, not that accurate, my code is missing one node from the top 5, and hence is not accurate.

I believe it is due to projecting lat,long which are spherical coordinate projections to the rectangle system (KDTree switches on the x,y.. axis) and then using the custom distance function. See also a similar discussion. I tried to cast this to cartesian co-ordinates with not much better result. Leaving the original post as is if it is useful


Adapted from my gist for 2D https://gist.github.com/alexcpn/1f187f2114976e748f4d3ad38dea17e8

# From https://gist.github.com/alexcpn/1f187f2114976e748f4d3ad38dea17e8
# Author alex punnen
from collections import namedtuple
from operator import itemgetter
import numpy as np
    
def find_nearest_neighbour(node,point,distance_fn,current_axis):
    # Algorith to find nearest neighbour in a KD Tree;the KD tree has done a spatial sort
    # of the given co-ordinates, such that to the left of the root lies co-ordinates nearest to the x-axis
    # and to the right of the root ,lies the co-ordinates farthest from the x axis
    # On the y axis split on the left of the parent/root node lies co-ordinates nearest to the y-axis and to
    # the right of the root, lies the co-ordinates farthest from the y axis
    # to find the nearest neightbour, from the root, you first check left and right node; if distance is closer
    # to the right node,then the entire left node can be discarded from search, because of the spatial split
    # and that node becomes the root node. This process is continued recursively till the nearest is found
    # param:node: The current node
    # param: point: The point to which the nearest neighbour is to be found
    # param: distance_fn: to calculate the nearest neighbour
    # param: current_axis: here assuming only two dimenstion and current axis will be either x or y , 0 or 1
    
    if node is None:
        return None,None
    current_closest_node = node
    closest_known_distance = distance_fn(node.cell[0],node.cell[1],point[0],point[1])
    print closest_known_distance,node.cell
    
    x = (node.cell[0],node.cell[1])
    y = point
    
    new_node = None
    new_closest_distance = None
    if x[current_axis] > y[current_axis]:
        new_node,new_closest_distance= find_nearest_neighbour(node.left_branch,point,distance_fn,
                                                          (current_axis+1) %2)
    else:
        new_node,new_closest_distance = find_nearest_neighbour(node.right_branch,point,distance_fn,
                                                           (current_axis+1) %2) 
    
    if  new_closest_distance and new_closest_distance < closest_known_distance:
        print 'Reset closest node to ',new_node.cell
        closest_known_distance = new_closest_distance
        current_closest_node = new_node
        
    return current_closest_node,closest_known_distance
    
    
class Node(namedtuple('Node','cell, left_branch, right_branch')):
    # This Class is taken from wikipedia code snippet for  KD tree
    pass
    
def create_kdtree(cell_list,current_axis,no_of_axis):
    # Creates a KD Tree recursively following the snippet from wikipedia for KD tree
    # but making it generic for any number of axis and changes in data strucure
    if not cell_list:
        return
    # get the cell as a tuple list this is for 2 dimensions
    k= [(cell[0],cell[1])  for cell  in cell_list]
    # say for three dimension
    # k= [(cell[0],cell[1],cell[2])  for cell  in cell_list]
    k.sort(key=itemgetter(current_axis)) # sort on the current axis
    median = len(k) // 2 # get the median of the list
    axis = (current_axis + 1) % no_of_axis # cycle the axis
    return Node(k[median], # recurse 
                create_kdtree(k[:median],axis,no_of_axis),
                create_kdtree(k[median+1:],axis,no_of_axis))

def eucleaden_dist(x1,y1,x2,y2):
    a= np.array([x1,y1])
    b= np.array([x2,y2])
    dist = np.linalg.norm(a-b)
    return dist


np.random.seed(0)
#cell_list = np.random.random((2, 2))
#cell_list = cell_list.tolist()
cell_list = [[2,2],[4,8],[10,2]]
print(cell_list)
tree = create_kdtree(cell_list,0,2)

node,distance = find_nearest_neighbour(tree,(1, 1),eucleaden_dist,0)
print 'Nearest Neighbour=',node.cell,distance

node,distance = find_nearest_neighbour(tree,(8, 1),eucleaden_dist,0)
print 'Nearest Neighbour=',node.cell,distance
Ostium answered 5/2, 2019 at 14:35 Comment(2)
Plus one from me. This is cool, but it does not give very accurate result.Baynebridge
updated the answer with jupyter notebook illustrating the incorrectnessOstium
A
1

I implemented the solution to this problem and i think it might be helpful.

from collections import namedtuple
from operator import itemgetter
from pprint import pformat
from math import inf


def nested_getter(idx1, idx2):
    def g(obj):
        return obj[idx1][idx2]
    return g


class Node(namedtuple('Node', 'location left_child right_child')):
    def __repr__(self):
        return pformat(tuple(self))


def kdtree(point_list, depth: int = 0):
    if not point_list:
        return None

    k = len(point_list[0])  # assumes all points have the same dimension
    # Select axis based on depth so that axis cycles through all valid values
    axis = depth % k

    # Sort point list by axis and choose median as pivot element
    point_list.sort(key=nested_getter(1, axis))
    median = len(point_list) // 2

    # Create node and construct subtrees
    return Node(
        location=point_list[median],
        left_child=kdtree(point_list[:median], depth + 1),
        right_child=kdtree(point_list[median + 1:], depth + 1)
    )


def nns(q, n, p, w, depth: int = 0):
    """
    NNS = Nearest Neighbor Search
    :param depth:
    :param q: point
    :param n: node
    :param p: ref point
    :param w: ref distance
    :return: new_p, new_w
    """

    new_w = distance(q[1], n.location[1])
    # below we test if new_w > 0 because we don't want to allow p = q
    if (new_w > 0) and new_w < w:
        p, w = n.location, new_w

    k = len(p)
    axis = depth % k
    n_value = n.location[1][axis]
    search_left_first = (q[1][axis] <= n_value)
    if search_left_first:
        if n.left_child and (q[1][axis] - w <= n_value):
            new_p, new_w = nns(q, n.left_child, p, w, depth + 1)
            if new_w < w:
                p, w = new_p, new_w
        if n.right_child and (q[1][axis] + w >= n_value):
            new_p, new_w = nns(q, n.right_child, p, w, depth + 1)
            if new_w < w:
                p, w = new_p, new_w
    else:
        if n.right_child and (q[1][axis] + w >= n_value):
            new_p, new_w = nns(q, n.right_child, p, w, depth + 1)
            if new_w < w:
                p, w = new_p, new_w
        if n.left_child and (q[1][axis] - w <= n_value):
            new_p, new_w = nns(q, n.left_child, p, w, depth + 1)
            if new_w < w:
                p, w = new_p, new_w
    return p, w


def main():
    """Example usage of kdtree"""
    point_list = [(7, 2), (5, 4), (9, 6), (4, 7), (8, 1), (2, 3)]
    tree = kdtree(point_list)
    print(tree)


def city_houses():
    """
    Here we compute the distance to the nearest city from a list of N cities.
    The first line of input contains N, the number of cities.
    Each of the next N lines contain two integers x and y, which locate the city in (x,y),
    separated by a single whitespace.
    It's guaranteed that a spot (x,y) does not contain more than one city.
    The output contains N lines, the line i with a number representing the distance
    for the nearest city from the i-th city of the input.
    """
    n = int(input())
    cities = []
    for i in range(n):
        city = i, tuple(map(int, input().split(' ')))
        cities.append(city)
    # print(cities)
    tree = kdtree(cities)
    # print(tree)
    ans = [(target[0], nns(target, tree, tree.location, inf)[1]) for target in cities]
    ans.sort(key=itemgetter(0))
    ans = [item[1] for item in ans]
    print('\n'.join(map(str, ans)))


def distance(a, b):
    # Taxicab distance is used below. You can use squared euclidean distance if you prefer
    k = len(b)
    total = 0
    for i in range(k):
        total += abs(b[i] - a[i])
    return total


if __name__ == '__main__':
    city_houses()
Ashla answered 20/9, 2020 at 1:45 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.