ndindex()
is NOT the ND equivalent of range()
(despite some of the other answers here).
It works for your simple example, but it doesn't permit arbitrary start
, stop
, and step
arguments. It only accepts stop
, and it hard-codes start
to (0,0,...)
and hard-codes step
to (1,1,...)
.
Here's an implementation that acts more like the built-in range()
function. That is, it permits arbitrary start
/stop
/step
arguments, yet it works on tuples instead of mere integers. Like built-in range()
, it returns an iterable.
from itertools import product
def ndrange(start, stop=None, step=None):
if stop is None:
stop = start
start = (0,) * len(stop)
if step is None:
step = (1,) * len(start)
assert len(start) == len(stop) == len(step)
for index in product(*map(range, start, stop, step)):
yield index
Example:
In [7]: for index in ndrange((1,2,3), (10,20,30), step=(5,10,15)):
...: print(index)
...:
(1, 2, 3)
(1, 2, 18)
(1, 12, 3)
(1, 12, 18)
(6, 2, 3)
(6, 2, 18)
(6, 12, 3)
(6, 12, 18)
For numpy users
If your code is numpy-based, then it's probably more convenient to work directly with ndarray
objects, rather than an iterable of tuples. It might also be faster. The following implementation is faster than using the above if you planned to convert the result to ndarray.
def ndrange_array(start, stop=None, step=1):
"""
Like np.ndindex, but accepts start/stop/step instead of
assuming that start is always (0,0,0) and step is (1,1,1),
and returns an array instead of an iterator.
"""
start = np.asarray(start)
if stop is None:
stop = start
start = (0,) * len(stop)
def ndindex(shape):
"""Like np.ndindex, but returns ndarray"""
return np.indices(shape).reshape(len(shape), -1).transpose()
shape = (stop - start + step - 1) // step
return start + step * ndindex(shape)