May be it will be interest for someone. Please see my nearest() (and KD Tree class) implementation for 2D tree in java:
import edu.princeton.cs.algs4.Point2D;
import edu.princeton.cs.algs4.RectHV;
import edu.princeton.cs.algs4.StdDraw;
import java.util.ArrayList;
import java.util.List;
public class KdTree {
private Node root;
private int size;
private static class Node {
private Point2D p; // the point
private RectHV rect; // the axis-aligned rectangle corresponding to this node
private Node lb; // the left/bottom subtree
private Node rt; // the right/top subtree
public Node(Point2D p, RectHV rect) {
this.p = p;
this.rect = rect;
}
}
public KdTree() {
}
public boolean isEmpty() {
return size == 0;
}
public int size() {
return size;
}
public boolean contains(Point2D p) {
if (p == null) throw new IllegalArgumentException("argument to contains() is null");
return contains(root, p, 1);
}
private boolean contains(Node node, Point2D p, int level) {
if (node == null) return false; // a base case for recursive call
if (node.p.equals(p)) return true;
if (level % 2 == 0) { // search by y coordinate (node with horizontal partition line)
if (p.y() < node.p.y())
return contains(node.lb, p, level + 1);
else
return contains(node.rt, p, level + 1);
}
else { // search by x coordinate (node with vertical partition line)
if (p.x() < node.p.x())
return contains(node.lb, p, level + 1);
else
return contains(node.rt, p, level + 1);
}
}
public void insert(Point2D p) {
if (p == null) throw new IllegalArgumentException("calls insert() with a null point");
root = insert(root, p, 1);
}
private Node insert(Node x, Point2D p, int level) {
if (x == null) {
size++;
return new Node(p, new RectHV(0, 0, 1, 1));
}
if (x.p.equals(p)) return x; // if we try to insert existed point just return its node
if (level % 2 == 0) { // search by y coordinate (node with horizontal partition line)
if (p.y() < x.p.y()) {
x.lb = insert(x.lb, p, level + 1);
if (x.lb.rect.equals(root.rect))
x.lb.rect = new RectHV(x.rect.xmin(), x.rect.ymin(), x.rect.xmax(), x.p.y());
}
else {
x.rt = insert(x.rt, p, level + 1);
if (x.rt.rect.equals(root.rect))
x.rt.rect = new RectHV(x.rect.xmin(), x.p.y(), x.rect.xmax(), x.rect.ymax());
}
}
else { // search by x coordinate (node with vertical partition line)
if (p.x() < x.p.x()) {
x.lb = insert(x.lb, p, level + 1);
if (x.lb.rect.equals(root.rect))
x.lb.rect = new RectHV(x.rect.xmin(), x.rect.ymin(), x.p.x(), x.rect.ymax());
}
else {
x.rt = insert(x.rt, p, level + 1);
if (x.rt.rect.equals(root.rect))
x.rt.rect = new RectHV(x.p.x(), x.rect.ymin(), x.rect.xmax(), x.rect.ymax());
}
}
return x;
}
public void draw() {
draw(root, 1);
}
private void draw(Node node, int level) {
if (node == null) return;
StdDraw.setPenColor(StdDraw.BLACK);
StdDraw.setPenRadius(0.01);
node.p.draw();
StdDraw.setPenRadius();
if (level % 2 == 0) {
StdDraw.setPenColor(StdDraw.BLUE);
StdDraw.line(node.rect.xmin(), node.p.y(), node.rect.xmax(), node.p.y());
}
else {
StdDraw.setPenColor(StdDraw.RED);
StdDraw.line(node.p.x(), node.rect.ymin(), node.p.x(), node.rect.ymax());
}
draw(node.lb, level + 1);
draw(node.rt, level + 1);
}
public Iterable<Point2D> range(RectHV rect) {
if (rect == null) throw new IllegalArgumentException("calls range() with a null rect");
List<Point2D> points = new ArrayList<>(); // create an Iterable object with all points we found
range(root, rect, points); // call helper method with rects intersects comparing
return points; // return an Iterable object (It could be any type - Queue, LinkedList etc)
}
private void range(Node node, RectHV rect, List<Point2D> points) {
if (node == null || !node.rect.intersects(rect)) return; // a base case for recursive call
if (rect.contains(node.p))
points.add(node.p);
range(node.lb, rect, points);
range(node.rt, rect, points);
}
public Point2D nearest(Point2D query) {
if (isEmpty()) return null;
if (query == null) throw new IllegalArgumentException("calls nearest() with a null point");
// set the start distance from root to query point
double best = root.p.distanceSquaredTo(query);
// StdDraw.setPenColor(StdDraw.BLACK); // just for debugging
// StdDraw.setPenRadius(0.01);
// query.draw();
return nearest(root, query, root.p, best, 1); // call a helper method
}
private Point2D nearest(Node node, Point2D query, Point2D champ, double best, int level) {
// a base case for the recursive call
if (node == null || best < node.rect.distanceSquaredTo(query)) return champ;
// we'll need to set an actual best distance when we recur
best = champ.distanceSquaredTo(query);
// check whether a distance from query point to the traversed node less than
// distance from current champion to query point
double temp = node.p.distanceSquaredTo(query);
if (temp < best) {
best = temp;
champ = node.p;
}
if (level % 2 == 0) { // search by y coordinate (node with horizontal partition line)
// we compare y coordinate and decide go up or down
if (node.p.y() < query.y()) { // if true go up
champ = nearest(node.rt, query, champ, best, level + 1);
// important case - when we traverse node and go back up through the tree
// we need to decide whether we need to go down(left) in this node or not
// we just check our bottom (left) node on null && compare distance
// from query point to the nearest point of the node's rectangle and
// the distance from current champ point to thr query point
if (node.lb != null && node.lb.rect.distanceSquaredTo(query) < champ.distanceSquaredTo(query)) {
champ = nearest(node.lb, query, champ, best, level + 1);
}
}
else { // if false go down
champ = nearest(node.lb, query, champ, best, level + 1);
if (node.rt != null && node.rt.rect.distanceSquaredTo(query) < champ.distanceSquaredTo(query))
// when we traverse node and go back up through the tree
// we need to decide whether we need to go up(right) in this node or not
// we just check our top (right) node on null && compare distance
// from query point to the nearest point of the node's rectangle and
// the distance from current champ point to thr query point
champ = nearest(node.rt, query, champ, best, level + 1);
}
}
else {
// search by x coordinate (node with vertical partition line)
if (node.p.x() < query.x()) { // if true go right
champ = nearest(node.rt, query, champ, best, level + 1);
// the same check as mentioned above when we search by y coordinate
if (node.lb != null && node.lb.rect.distanceSquaredTo(query) < champ.distanceSquaredTo(query))
champ = nearest(node.lb, query, champ, best, level + 1);
}
else { // if false go left
champ = nearest(node.lb, query, champ, best, level + 1);
if (node.rt != null && node.rt.rect.distanceSquaredTo(query) < champ.distanceSquaredTo(query))
champ = nearest(node.rt, query, champ, best, level + 1);
}
}
return champ;
}
public static void main(String[] args) {
// unit tests
KdTree kd = new KdTree();
Point2D p1 = new Point2D(0.7, 0.2);
Point2D p2 = new Point2D(0.5, 0.4);
Point2D p3 = new Point2D(0.2, 0.3);
Point2D p4 = new Point2D(0.4, 0.7);
Point2D p5 = new Point2D(0.9, 0.6);
// Point2D query = new Point2D(0.676, 0.736);
Point2D query1 = new Point2D(0.972, 0.887);
// RectHV test = new RectHV(0, 0, 0.7, 0.4);
// Point2D query = new Point2D(0.331, 0.762);
// Point2D p6 = new Point2D(0.4, 0.4);
// Point2D p7 = new Point2D(0.1, 0.6);
// RectHV rect = new RectHV(0.05, 0.1, 0.15, 0.6);
kd.insert(p1);
kd.insert(p2);
kd.insert(p3);
kd.insert(p4);
kd.insert(p5);
System.out.println(kd.nearest(query1));
// System.out.println("Dist query to 0.4,0.7= " + query.distanceSquaredTo(p4));
// System.out.println("Dist query to RectHV 0.2,0,3= " + test.distanceSquaredTo(p4));
// kd.insert(p6);
// kd.insert(p7);
// System.out.println(kd.size);
// System.out.println(kd.contains(p3));
// // System.out.println(kd.range(rect));
kd.draw();
}
}