How to implement range search in KD-Tree
Asked Answered
M

2

7

I have built a d dimensional KD-Tree. I want to do range search on this tree. Wikipedia mentions range search in KD-Trees, but doesn't talk about implementation/algorithm in any way. Can someone please help me with this? If not for any arbitrary d, any help for at least for d = 2 and d = 3 would be great. Thanks!

Multifoil answered 20/12, 2019 at 13:42 Comment(3)
I found these notes <cs.utah.edu/~lifeifei/cis5930/kdtree.pdf>. I has pseudo-code for case d=1 and d=2.Abacist
@Abacist thank you so much. This algorithm can be generalised for different values of d, going cyclically over the nodes/axisMultifoil
@9mat, you should write an answer with the link and a brief description of its contents. I think you pointed towards the best resource online to solve this problem.Cannular
M
6

There are multiple variants of kd-tree. The one I used had the following specs:

  1. Each internal node has max two nodes.
  2. Each leaf node can have max maxCapacity points.
  3. No internal node stores any points.

Side note: there are also versions where each node (irrespective of whether its internal or leaf) stores exactly one point. The algorithm below can be tweaked for those too. Its mainly the buildTree where the key difference lies.

I wrote an algorithm for this some 2 years back, thanks to the resource pointed to by @9mat .

Suppose the task is to find the number of points which lie in a given hyper-rectangle ("d" dimensions). This task can also be to list all points OR all points which lie in given range and satisfy some other criteria etc, but that can be a straightforward change to my code.

Define a base node class as:

template <typename T> class kdNode{
    public: kdNode(){}
    virtual long rangeQuery(const T* q_min, const T* q_max) const{ return 0; }
};

Then, an internal node (non-leaf node) can look like this:

class internalNode:public kdNode<T>{
    const kdNode<T> *left = nullptr, *right = nullptr; // left and right sub trees
    int axis; // the axis on which split of points is being done
    T value; // the value based on which points are being split

    public: internalNode(){}

    void buildTree(...){
        // builds the tree recursively
    }

    // returns the number of points in this sub tree that lie inside the hyper rectangle formed by q_min and q_max
    int rangeQuery(const T* q_min, const T* q_max) const{
        // num of points that satisfy range query conditions
        int rangeCount = 0;

        // check for left node
        if(q_min[axis] <= value) {
            rangeCount += left->rangeQuery(q_min, q_max);
        }
        // check for right node
        if(q_max[axis] >= value) {
            rangeCount += right->rangeQuery(q_min, q_max);
        }

        return rangeCount;
    }
};

Finally, the leaf node would look like:

class leaf:public kdNode<T>{
    // maxCapacity is a hyper - param, the max num of points you allow a node to hold
    array<T, d> points[maxCapacity];
    int keyCount = 0; // this is the actual num of points in this leaf (keyCount <= maxCapacity)

    public: leaf(){}

    public: void addPoint(const T* p){
        // add a point p to the leaf node
    }

    // check if points[index] lies inside the hyper rectangle formed by q_min and q_max
    inline bool containsPoint(const int index, const T* q_min, const T* q_max) const{
        for (int i=0; i<d; i++) {
            if (points[index][i] > q_max[i] || points[index][i] < q_min[i]) {
                return false;
            }
        }
        return true;
    }

    // returns number of points in this leaf node that lie inside the hyper rectangle formed by q_min and q_max
    int rangeQuery(const T* q_min, const T* q_max) const{
        // num of points that satisfy range query conditions
        int rangeCount = 0;
        for(int i=0; i < this->keyCount; i++) {
            if(containsPoint(i, q_min, q_max)) {
                rangeCount++;
            }
        }
        return rangeCount;
    }
};

In the code for range query inside the leaf node, it is also possible to do a "binary search" inside of "linear search". Since the points will be sorted along on the axis axis, you can do a binary search do find l and r values using q_min and q_max, and then do a linear search from l to r instead of 0 to keyCount-1 (of course in the worst case it wont help, but practically, and especially if you have a capacity of pretty high values, this may help).

Multifoil answered 10/9, 2022 at 6:31 Comment(3)
You have a bug in your code. Let's say the root node has a value of 50 cause that is where the median split happened. And we query for the range 15 to 30. It would return 0 cause the value of the root node is not in the range, and therefor none of the childs are visited.Koo
To fix this, you could traverse to the node where the first split happens. Another thing is that there can be 3 options, fully outside, intersecting and fully contained. The nice thing about when a node is fully contained then the whole branch can be added without the need of checking if the values are within the range.Koo
I wish I could provide a more detailed answer, thing is I'm still in the process of wrapping my head around it all.Koo
K
0

This is my solution for a KD-tree, where each node stores points (so not just the leafs). (Note that adapting for where points are stored only in the leafs is really easy).

I leaf some of the optimizations out and will explain them at the end, this to reduce the complexity of the solution.

The get_range function has varargs at the end, and can be called like, x1, y1, x2, y2 or x1, y1, z1, x2, y2, z2 etc. Where first the low values of the range are given and then the high values. (You can use as many dimensions as you like).

static public <T> void get_range(K_D_Tree<T> tree, List<T> result, float... range) {
    if (tree.root == null) return;
   
    float[] node_region = new float[tree.DIMENSIONS * 2];
    for (int i = 0; i < tree.DIMENSIONS; i++) {
        node_region[i] = -Float.MAX_VALUE;
        node_region[i+tree.DIMENSIONS] = Float.MAX_VALUE;
    }
    
    _get_range(tree, result, tree.root, node_region, 0, range);
}

The node_region represents the region of the node, we start as large as possible. Cause for all we know this could be the region we are dealing with.

Here the recursive _get_range implementation:

static public <T> void _get_range(K_D_Tree<T> tree, List<T> result, K_D_Tree_Node<T> node, float[] node_region, int dimension, float[] target_region) {
    if (dimension == tree.DIMENSIONS) dimension = 0;

    if (_contains_region(tree, node_region, target_region)) {
        _add_whole_branch(node, result);
    }
    else {
        float value = _value(tree, dimension, node);

        if (node.left != null) {
            float[] node_region_left  = new float[tree.DIMENSIONS*2];
            System.arraycopy(node_region, 0, node_region_left, 0, node_region.length);
            node_region_left[dimension + tree.DIMENSIONS] = value;

            if (_intersects_region(tree, node_region_left, target_region)){
                _get_range(tree, result, node.left, node_region_left, dimension+1, target_region);
            }
        }

        if (node.right != null) {
            float[] node_region_right = new float[tree.DIMENSIONS*2];
            System.arraycopy(node_region, 0, node_region_right, 0, node_region.length);
            node_region_right[dimension] = value;

            if (_intersects_region(tree, node_region_right, target_region)){
                _get_range(tree, result, node.right, node_region_right, dimension+1, target_region);
            }
        }        

        if (_region_contains_node(tree, target_region, node)) {
            result.add(node.point);
        }
    }
}

One important thing that the other answer does not provide is this part:

if (_contains_region(tree, node_region, target_region)) {
    _add_whole_branch(node, result);
}

With a range search for a KD-Tree you have 3 options for a node's region, it's:

  1. fully outside
  2. it intersects
  3. it's fully contained

Once you know a region is fully contained, then you can add the whole branch without doing any dimension checks. To make it more clear, here is the _add_whole_branch:

static public <T> void _add_whole_branch(K_D_Tree_Node<T> node, List<T> result) {
    result.add(node.point);
    if (node.left != null)  _add_whole_branch(node.left, result);
    if (node.right != null) _add_whole_branch(node.right, result);
}

In this image, all the big white dots where added using _add_whole_branch and only for the red dots a check for all dimensions had to be done. enter image description here

Optimization

1) Instead of starting with the root node for the _get_range function, instead you can find the split node. This is the first node that has it's point within the query range. To find the split node you will still need to start at the root node, but the calculations are a bit cheaper (cause you go either left or right till).

2) Now I create the float[] node_region_left and float[] node_region_right, and since this happens in a recursive function it can lead to quite some arrays. However, you can reuse the one for the left for the right. I didn't do it in this example for clarity reasons. I can also imagine storing the region size in the node, but this takes quite some more memory and might lead to a lot of cache misses.

Koo answered 13/1, 2023 at 17:14 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.