Speed up NumPy's where function
Asked Answered
C

1

5

I am trying to extract the indices of all values of a 1D array of numbers that exceed some threshold. The array is on the order of 1e9 long.

My approach is the following in NumPy:

idxs = where(data>threshold) 

This takes something upwards of 20 mins, which is unacceptable. How can I speed this function up? Or, are there faster alternatives?

(To be specific, it takes that long on a Mac OS X running 10.6.7, 1.86 GHz Intel, 4GB RAM doing nothing else.)

Curhan answered 9/2, 2013 at 21:20 Comment(5)
It takes 20 minutes to run the np.where or to deleted the values below the threshold?Psephology
It takes 20 mins to run np.whereCurhan
Does it matter that I am calling each variable from a dictionary? I.e. data is really data['timeseries'] and threshold is really data[threshold][spikes]. I am sure the second variable is a scalar.Curhan
remember when I said the threshold was definitely a scalar. It's really array(array([[ 99.48158966]]), dtype=object). It now takes about 2 mins.Curhan
Why do the singleton dimensions gum everything up?Curhan
P
7

Try a mask array. This creates a view of the same data.

So the syntax would be:

 b=a[a>threshold]

b is not a new array (unlike where) but a view of a where the elements meet the boolean in the index.

Example:

import numpy as np
import time

a=np.random.random_sample(int(1e9))

t1=time.time()
b=a[a>0.5]
print(time.time()-t1,'seconds')

On my machine, that prints 22.389815092086792 seconds


edit

I tried the same with np.where, and it is just as fast. I am suspicious: are you deleting these values from the array?

Psephology answered 9/2, 2013 at 21:41 Comment(1)
If I am doing so, it is unintentional. My syntax is the same as yours. How could I be deleting something? I agree it would explain the slower time.Curhan

© 2022 - 2024 — McMap. All rights reserved.