How to speed up an N dimensional interval tree in python?
Asked Answered
A

0

3

Consider the following problem: Given a set of n intervals and a set of m floating-point numbers, determine, for each floating-point number, the subset of intervals that contain the floating-point number.

This problem has been addressed by constructing an interval tree (or called range tree or segment tree). Implementations have been done for the one-dimensional case, e.g. python's intervaltree package. Usually, these implementations consider one or few floating-point numbers, namely a small "m" above.

In my problem setting, both n and m are extremely large numbers (from solving an image processing problem). Further, I need to consider the N-dimensional intervals (called cuboid when N=3, because I was modeling human brains with the Finite Element Method). I have implemented a simple N-dimensional interval tree in python, but it run in a loop and can only take one floating-point number at a time. Can anyone help improve the implementation in terms of efficiency? You can change data structure freely.

import sys
import time
import numpy as np

# find the index of a satisfying x > a in one dimension
def find_index_smaller(a, x):
    idx = np.argsort(a)
    ss = np.searchsorted(a, x, sorter=idx)
    res = idx[0:ss]
    return res

# find the index of a satisfying x < a in one dimension
def find_index_larger(a, x):
    return find_index_smaller(-a, -x)

# find the index of a satisfing amin < x < amax in one dimension
def find_intv_at(amin, amax, x):
    idx = find_index_smaller(amin, x)
    idx2 = find_index_larger(amax[idx], x)
    res = idx[idx2]
    return res

# find the index of a satisfying amin < x < amax in N dimensions
def find_intv_at_nd(amin, amax, x):
    dim = amin.shape[0]
    res = np.arange(amin.shape[-1])
    for i in range(dim):
        idx = find_intv_at(amin[i, res], amax[i, res], x[i])
        res = res[idx]
    return res

I also have two test examples for sanity check and performance testing:

def demo1():
    print ("By default, we do a correctness test")
    n_intv = 2
    n_point = 2
    # generate the test data
    point = np.random.rand(3, n_point)
    intv_min = np.random.rand(3, n_intv)
    intv_max = intv_min + np.random.rand(3, n_intv)*8
    print ("point ")
    print (point)
    print ("intv_min")
    print (intv_min)
    print ("intv_max")
    print (intv_max)
    print ("===Indexes of intervals that contain the point===")
    for i in range(n_point):
        print (find_intv_at_nd(intv_min,intv_max, point[:, i]))

def demo2():
    print ("Performance:")
    n_points=100
    n_intv = 1000000

    # generate the test data
    points = np.random.rand(n_points, 3)*512
    intv_min = np.random.rand(3, n_intv)*512
    intv_max = intv_min + np.random.rand(3, n_intv)*8
    print ("point.shape = "+str(points.shape))
    print ("intv_min.shape = "+str(intv_min.shape))
    print ("intv_max.shape = "+str(intv_max.shape))

    starttime = time.time()
    for point in points:
        tmp = find_intv_at_nd(intv_min, intv_max, point)
    print("it took this long to run {} points, with {} interva: {}".format(n_points, n_intv, time.time()-starttime))

My idea would be:

  1. Remove np.argsort() from the algo, because the interval tree does not change, so sorting could have been done in pre-processing.
  2. Vectorize x. The algo runs a loop for each x. It would be nice if we can get rid of the loop over x.

Any contribution would be appreciated.

Almanac answered 24/9, 2020 at 9:31 Comment(5)
have you tried using oct-trees or kd-trees?Weitzel
@Weitzel I have not found N-D implementation of oct-trees. And for scipy kd-trees, it is to find the nearest neighbor.Almanac
you need it only for N=3 though? You can also query the kd-tree the other way around, i.e. look for points contained in a given cuboid.Weitzel
you could also maybe cheat with scikit-learn.org/stable/modules/generated/… using a chebyshev metric if you don't want to roll your own solution.Weitzel
@Weitzel I do not think that we can use kd-tree here, because I need to find all intervals that contain the specific point. By the way, kd-tree does not always give an accurate solution, so I prefer to not use it.Almanac

© 2022 - 2024 — McMap. All rights reserved.