Update 2023
I had to revisit this and found that my implementation though very fast is not accurate. The sklearn should be the best. I wrote the below some time back ,where I needed custom distance. (sklearn does not support custom distance function for KDTree, but supports for BallTree.
Here is a Jupyter notebook with timing with sklearn BallTree and KDTree and my custom code.
https://colab.research.google.com/drive/1ymx2r3J7oUMAuPlsZxDnESM7aTfwbLWV?usp=sharing
BallTree
is accurate, but pretty slow, KDTree
gets the top 5 but in a different order, not that accurate, my code is missing one node from the top 5, and hence is not accurate.
I believe it is due to projecting lat,long which are spherical coordinate projections to the rectangle system (KDTree switches on the x,y.. axis) and then using the custom distance function. See also a similar discussion.
I tried to cast this to cartesian co-ordinates with not much better result. Leaving the original post as is if it is useful
Adapted from my gist for 2D https://gist.github.com/alexcpn/1f187f2114976e748f4d3ad38dea17e8
# From https://gist.github.com/alexcpn/1f187f2114976e748f4d3ad38dea17e8
# Author alex punnen
from collections import namedtuple
from operator import itemgetter
import numpy as np
def find_nearest_neighbour(node,point,distance_fn,current_axis):
# Algorith to find nearest neighbour in a KD Tree;the KD tree has done a spatial sort
# of the given co-ordinates, such that to the left of the root lies co-ordinates nearest to the x-axis
# and to the right of the root ,lies the co-ordinates farthest from the x axis
# On the y axis split on the left of the parent/root node lies co-ordinates nearest to the y-axis and to
# the right of the root, lies the co-ordinates farthest from the y axis
# to find the nearest neightbour, from the root, you first check left and right node; if distance is closer
# to the right node,then the entire left node can be discarded from search, because of the spatial split
# and that node becomes the root node. This process is continued recursively till the nearest is found
# param:node: The current node
# param: point: The point to which the nearest neighbour is to be found
# param: distance_fn: to calculate the nearest neighbour
# param: current_axis: here assuming only two dimenstion and current axis will be either x or y , 0 or 1
if node is None:
return None,None
current_closest_node = node
closest_known_distance = distance_fn(node.cell[0],node.cell[1],point[0],point[1])
print closest_known_distance,node.cell
x = (node.cell[0],node.cell[1])
y = point
new_node = None
new_closest_distance = None
if x[current_axis] > y[current_axis]:
new_node,new_closest_distance= find_nearest_neighbour(node.left_branch,point,distance_fn,
(current_axis+1) %2)
else:
new_node,new_closest_distance = find_nearest_neighbour(node.right_branch,point,distance_fn,
(current_axis+1) %2)
if new_closest_distance and new_closest_distance < closest_known_distance:
print 'Reset closest node to ',new_node.cell
closest_known_distance = new_closest_distance
current_closest_node = new_node
return current_closest_node,closest_known_distance
class Node(namedtuple('Node','cell, left_branch, right_branch')):
# This Class is taken from wikipedia code snippet for KD tree
pass
def create_kdtree(cell_list,current_axis,no_of_axis):
# Creates a KD Tree recursively following the snippet from wikipedia for KD tree
# but making it generic for any number of axis and changes in data strucure
if not cell_list:
return
# get the cell as a tuple list this is for 2 dimensions
k= [(cell[0],cell[1]) for cell in cell_list]
# say for three dimension
# k= [(cell[0],cell[1],cell[2]) for cell in cell_list]
k.sort(key=itemgetter(current_axis)) # sort on the current axis
median = len(k) // 2 # get the median of the list
axis = (current_axis + 1) % no_of_axis # cycle the axis
return Node(k[median], # recurse
create_kdtree(k[:median],axis,no_of_axis),
create_kdtree(k[median+1:],axis,no_of_axis))
def eucleaden_dist(x1,y1,x2,y2):
a= np.array([x1,y1])
b= np.array([x2,y2])
dist = np.linalg.norm(a-b)
return dist
np.random.seed(0)
#cell_list = np.random.random((2, 2))
#cell_list = cell_list.tolist()
cell_list = [[2,2],[4,8],[10,2]]
print(cell_list)
tree = create_kdtree(cell_list,0,2)
node,distance = find_nearest_neighbour(tree,(1, 1),eucleaden_dist,0)
print 'Nearest Neighbour=',node.cell,distance
node,distance = find_nearest_neighbour(tree,(8, 1),eucleaden_dist,0)
print 'Nearest Neighbour=',node.cell,distance