How to efficiently retrieve the indices of maximum values in a Torch tensor?
Asked Answered
H

6

30

Assume to have a torch tensor, for example of the following shape:

x = torch.rand(20, 1, 120, 120)

What I would like now, is to get the indices of the maximum values of each 120x120 matrix. To simplify the problem I would first x.squeeze() to work with shape [20, 120, 120]. I would then like to get torch tensor which is a list of indices with shape [20, 2].

How can I do this fast?

Haymes answered 8/11, 2018 at 16:53 Comment(4)
Why do you need a [20, 2] matrix. Do you want maximum along the rows and maximum along the columns for each of the 120 * 120 matrix?Inpatient
Yes, or in other terms: For each of the 20 120 * 120 matrices I want the [x, y] coordinates of the cell with maximum valueHaymes
If you want to know the indices of top k elemets, use torch.topk().Chuffy
Does this answer your question? Extracting the top-k value-indices from a 1-D TensorKamp
S
10

If I get you correctly you don't want the values, but the indices. Unfortunately there is no out of the box solution. There exists an argmax() function, but I cannot see how to get it to do exactly what you want.

So here is a small workaround, the efficiency should also be okay since we're just dividing tensors:

n = torch.tensor(4)
d = torch.tensor(4)
x = torch.rand(n, 1, d, d)
m = x.view(n, -1).argmax(1)
# since argmax() does only return the index of the flattened
# matrix block we have to calculate the indices by ourself 
# by using / and % (// would also work, but as we are dealing with
# type torch.long / works as well
indices = torch.cat(((m / d).view(-1, 1), (m % d).view(-1, 1)), dim=1)
print(x)
print(indices)

n represents your first dimension, and d the last two dimensions. I take smaller numbers here to show the result. But of course this will also work for n=20 and d=120:

n = torch.tensor(20)
d = torch.tensor(120)
x = torch.rand(n, 1, d, d)
m = x.view(n, -1).argmax(1)
indices = torch.cat(((m / d).view(-1, 1), (m % d).view(-1, 1)), dim=1)
#print(x)
print(indices)

Here is the output for n=4 and d=4:

tensor([[[[0.3699, 0.3584, 0.4940, 0.8618],
          [0.6767, 0.7439, 0.5984, 0.5499],
          [0.8465, 0.7276, 0.3078, 0.3882],
          [0.1001, 0.0705, 0.2007, 0.4051]]],


        [[[0.7520, 0.4528, 0.0525, 0.9253],
          [0.6946, 0.0318, 0.5650, 0.7385],
          [0.0671, 0.6493, 0.3243, 0.2383],
          [0.6119, 0.7762, 0.9687, 0.0896]]],


        [[[0.3504, 0.7431, 0.8336, 0.0336],
          [0.8208, 0.9051, 0.1681, 0.8722],
          [0.5751, 0.7903, 0.0046, 0.1471],
          [0.4875, 0.1592, 0.2783, 0.6338]]],


        [[[0.9398, 0.7589, 0.6645, 0.8017],
          [0.9469, 0.2822, 0.9042, 0.2516],
          [0.2576, 0.3852, 0.7349, 0.2806],
          [0.7062, 0.1214, 0.0922, 0.1385]]]])
tensor([[0, 3],
        [3, 2],
        [1, 1],
        [1, 0]])

I hope this is what you wanted to get! :)

Edit:

Here is a slightly modified which might be minimally faster (not much I guess :), but it is a bit simpler and prettier:

Instead of this like before:

m = x.view(n, -1).argmax(1)
indices = torch.cat(((m // d).view(-1, 1), (m % d).view(-1, 1)), dim=1)

The necessary reshaping already done on the argmax values:

m = x.view(n, -1).argmax(1).view(-1, 1)
indices = torch.cat((m // d, m % d), dim=1)

But as mentioned in the comments. I don't think it is possible to get much more out of it.

One thing you could do, if it is really important for you to get the last possible bit of performance improvement out of it, is implementing this above function as a low-level extension (like in C++) for pytorch.

This would give you just one function you can call for it and would avoid slow python code.

https://pytorch.org/tutorials/advanced/cpp_extension.html

Suricate answered 8/11, 2018 at 20:28 Comment(5)
Yes, thats the output I want. I modified it to convert m with .float() and then use // in division by d. What you proposed is an unraveling, similar to numpy.unravel_indices(). If you can think of an even faster way it would be even better of course.Haymes
@Haymes I just made a short time test. Actually I think it is quite efficient, I guess there is no faster way currently: calling argmax() itself takes about 10 times so long as calculating the indices in the next line - on CPU, I can also check on GPU later. But the operations are really simple and strait-forward, so even this is a workaround it should be quite efficient also from a theoretical perspective.Suricate
no its not slow by any means, I needed about 5.5 ms on a Telsa Volta. I just need to max it out, but I agree, the argmax is a linear operation as tensors are unordered. probably thats the slowest component and not possible to speed up.Haymes
@Haymes I made small edit at the end, with a slightly nicer version. But I wouldn't expect really something in terms of performance, probably about the same - with a half nano second ahead. If it is really important for to get the most possible out of it you might want to go with a custom extension in C++. But probably the gain wouldn't be so much either considering the small snippet of code.Suricate
Thank you, works well. I also made a mistake in the evaluation and it seems it was just 0.5 ms instead of 5 ms.Haymes
C
20

torch.topk() is what you are looking for. From the docs,

torch.topk(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor)

Returns the k largest elements of the given input tensor along a given dimension.

  • If dim is not given, the last dimension of the input is chosen.

  • If largest is False then the k smallest elements are returned.

  • A namedtuple of (values, indices) is returned, where the indices are the indices of the elements in the original input tensor.

  • The boolean option sorted if True, will make sure that the returned k elements are themselves sorted

Chuffy answered 29/4, 2020 at 7:26 Comment(2)
Useful function to know, but it does not answer the original question. The OP wanted to obtain the indices, for each of the 20 120x120 matrices, of the maximum element in that matrix. That is, she wanted 20 2D coordinates, one of each matrix. topk returns the index of the maximum element in the maximized dimension only.Pavier
Note that topk's documentation is confusing regarding the meaning of the returned indices. It gives the impression that the function provides indices for the original tensor when in fact it returns the index in the maximized dimension only. See pytorch issue github.com/pytorch/pytorch/issues/50331#issue-782748956 that seeks to clarify it.Pavier
S
10

If I get you correctly you don't want the values, but the indices. Unfortunately there is no out of the box solution. There exists an argmax() function, but I cannot see how to get it to do exactly what you want.

So here is a small workaround, the efficiency should also be okay since we're just dividing tensors:

n = torch.tensor(4)
d = torch.tensor(4)
x = torch.rand(n, 1, d, d)
m = x.view(n, -1).argmax(1)
# since argmax() does only return the index of the flattened
# matrix block we have to calculate the indices by ourself 
# by using / and % (// would also work, but as we are dealing with
# type torch.long / works as well
indices = torch.cat(((m / d).view(-1, 1), (m % d).view(-1, 1)), dim=1)
print(x)
print(indices)

n represents your first dimension, and d the last two dimensions. I take smaller numbers here to show the result. But of course this will also work for n=20 and d=120:

n = torch.tensor(20)
d = torch.tensor(120)
x = torch.rand(n, 1, d, d)
m = x.view(n, -1).argmax(1)
indices = torch.cat(((m / d).view(-1, 1), (m % d).view(-1, 1)), dim=1)
#print(x)
print(indices)

Here is the output for n=4 and d=4:

tensor([[[[0.3699, 0.3584, 0.4940, 0.8618],
          [0.6767, 0.7439, 0.5984, 0.5499],
          [0.8465, 0.7276, 0.3078, 0.3882],
          [0.1001, 0.0705, 0.2007, 0.4051]]],


        [[[0.7520, 0.4528, 0.0525, 0.9253],
          [0.6946, 0.0318, 0.5650, 0.7385],
          [0.0671, 0.6493, 0.3243, 0.2383],
          [0.6119, 0.7762, 0.9687, 0.0896]]],


        [[[0.3504, 0.7431, 0.8336, 0.0336],
          [0.8208, 0.9051, 0.1681, 0.8722],
          [0.5751, 0.7903, 0.0046, 0.1471],
          [0.4875, 0.1592, 0.2783, 0.6338]]],


        [[[0.9398, 0.7589, 0.6645, 0.8017],
          [0.9469, 0.2822, 0.9042, 0.2516],
          [0.2576, 0.3852, 0.7349, 0.2806],
          [0.7062, 0.1214, 0.0922, 0.1385]]]])
tensor([[0, 3],
        [3, 2],
        [1, 1],
        [1, 0]])

I hope this is what you wanted to get! :)

Edit:

Here is a slightly modified which might be minimally faster (not much I guess :), but it is a bit simpler and prettier:

Instead of this like before:

m = x.view(n, -1).argmax(1)
indices = torch.cat(((m // d).view(-1, 1), (m % d).view(-1, 1)), dim=1)

The necessary reshaping already done on the argmax values:

m = x.view(n, -1).argmax(1).view(-1, 1)
indices = torch.cat((m // d, m % d), dim=1)

But as mentioned in the comments. I don't think it is possible to get much more out of it.

One thing you could do, if it is really important for you to get the last possible bit of performance improvement out of it, is implementing this above function as a low-level extension (like in C++) for pytorch.

This would give you just one function you can call for it and would avoid slow python code.

https://pytorch.org/tutorials/advanced/cpp_extension.html

Suricate answered 8/11, 2018 at 20:28 Comment(5)
Yes, thats the output I want. I modified it to convert m with .float() and then use // in division by d. What you proposed is an unraveling, similar to numpy.unravel_indices(). If you can think of an even faster way it would be even better of course.Haymes
@Haymes I just made a short time test. Actually I think it is quite efficient, I guess there is no faster way currently: calling argmax() itself takes about 10 times so long as calculating the indices in the next line - on CPU, I can also check on GPU later. But the operations are really simple and strait-forward, so even this is a workaround it should be quite efficient also from a theoretical perspective.Suricate
no its not slow by any means, I needed about 5.5 ms on a Telsa Volta. I just need to max it out, but I agree, the argmax is a linear operation as tensors are unordered. probably thats the slowest component and not possible to speed up.Haymes
@Haymes I made small edit at the end, with a slightly nicer version. But I wouldn't expect really something in terms of performance, probably about the same - with a half nano second ahead. If it is really important for to get the most possible out of it you might want to go with a custom extension in C++. But probably the gain wouldn't be so much either considering the small snippet of code.Suricate
Thank you, works well. I also made a mistake in the evaluation and it seems it was just 0.5 ms instead of 5 ms.Haymes
L
4

Here is an unravel_index implementation in torch:

def unravel_index(
    indices: torch.LongTensor,
    shape: Tuple[int, ...],
) -> torch.LongTensor:
    r"""Converts flat indices into unraveled coordinates in a target shape.

    This is a `torch` implementation of `numpy.unravel_index`.

    Args:
        indices: A tensor of (flat) indices, (*, N).
        shape: The targeted shape, (D,).

    Returns:
        The unraveled coordinates, (*, N, D).
    """

    coord = []

    for dim in reversed(shape):
        coord.append(indices % dim)
        indices = indices // dim

    coord = torch.stack(coord[::-1], dim=-1)

    return coord

Then, you can use the torch.argmax function to get the indices of the "flattened" tensor.

y = x.view(20, -1)
indices = torch.argmax(y)
indices.shape  # (20,)

And unravel the indices with the unravel_index function.

indices = unravel_index(indices, x.shape[-2:])
indices.shape  # (20, 2)
Lipoprotein answered 6/12, 2020 at 12:54 Comment(2)
This is the closest to a real, generic answer! To answer the original question more directly, which asks how to obtain the indices of the maximum values, you might want to edit to show how to use argmax to obtain the indices in the first place and then unravel them.Pavier
I ended up having to code the connection to argmax, so please check my answer. Feel free to incorporate what I did in yours.Pavier
P
0

The accepted answer only works for the given example.

The answer by tejasvi88 is interesting but does not help answering the original question (as explained in my comment there).

I believe Francois' answer is the closest because it deals with a more generic case (any number of dimensions). However, it does not connect with argmax and the shown example does not illustrate that function's capacity to deal with batches.

So I will build upon Francois' answer here and add code to connect to argmax. I write a new function, batch_argmax, that returns the indices of maximum values within a batch. The batch may be organized in multiple dimensions. I also include some test cases for illustration:

def batch_argmax(tensor, batch_dim=1):
    """
    Assumes that dimensions of tensor up to batch_dim are "batch dimensions"
    and returns the indices of the max element of each "batch row".
    More precisely, returns tensor `a` such that, for each index v of tensor.shape[:batch_dim], a[v] is
    the indices of the max element of tensor[v].
    """
    if batch_dim >= len(tensor.shape):
        raise NoArgMaxIndices()
    batch_shape = tensor.shape[:batch_dim]
    non_batch_shape = tensor.shape[batch_dim:]
    flat_non_batch_size = prod(non_batch_shape)
    tensor_with_flat_non_batch_portion = tensor.reshape(*batch_shape, flat_non_batch_size)

    dimension_of_indices = len(non_batch_shape)

    # We now have each batch row flattened in the last dimension of tensor_with_flat_non_batch_portion,
    # so we can invoke its argmax(dim=-1) method. However, that method throws an exception if the tensor
    # is empty. We cover that case first.
    if tensor_with_flat_non_batch_portion.numel() == 0:
        # If empty, either the batch dimensions or the non-batch dimensions are empty
        batch_size = prod(batch_shape)
        if batch_size == 0:  # if batch dimensions are empty
            # return empty tensor of appropriate shape
            batch_of_unraveled_indices = torch.ones(*batch_shape, dimension_of_indices).long()  # 'ones' is irrelevant as it will be empty
        else:  # non-batch dimensions are empty, so argmax indices are undefined
            raise NoArgMaxIndices()
    else:   # We actually have elements to maximize, so we search for them
        indices_of_non_batch_portion = tensor_with_flat_non_batch_portion.argmax(dim=-1)
        batch_of_unraveled_indices = unravel_indices(indices_of_non_batch_portion, non_batch_shape)

    if dimension_of_indices == 1:
        # above function makes each unraveled index of a n-D tensor a n-long tensor
        # however indices of 1D tensors are typically represented by scalars, so we squeeze them in this case.
        batch_of_unraveled_indices = batch_of_unraveled_indices.squeeze(dim=-1)
    return batch_of_unraveled_indices


class NoArgMaxIndices(BaseException):

    def __init__(self):
        super(NoArgMaxIndices, self).__init__(
            "no argmax indices: batch_argmax requires non-batch shape to be non-empty")

And here are the tests:

def test_basic():
    # a simple array
    tensor = torch.tensor([0, 1, 2, 3, 4])
    batch_dim = 0
    expected = torch.tensor(4)
    run_test(tensor, batch_dim, expected)

    # making batch_dim = 1 renders the non-batch portion empty and argmax indices undefined
    tensor = torch.tensor([0, 1, 2, 3, 4])
    batch_dim = 1
    check_that_exception_is_thrown(lambda: batch_argmax(tensor, batch_dim), NoArgMaxIndices)

    # now a batch of arrays
    tensor = torch.tensor([[1, 2, 3], [6, 5, 4]])
    batch_dim = 1
    expected = torch.tensor([2, 0])
    run_test(tensor, batch_dim, expected)

    # Now we have an empty batch with non-batch 3-dim arrays' shape (the arrays are actually non-existent)
    tensor = torch.ones(0, 3)  # 'ones' is irrelevant since this is empty
    batch_dim = 1
    # empty batch of the right shape: just the batch dimension 0,since indices of arrays are scalar (0D)
    expected = torch.ones(0)
    run_test(tensor, batch_dim, expected)

    # Now we have an empty batch with non-batch matrices' shape (the matrices are actually non-existent)
    tensor = torch.ones(0, 3, 2)  # 'ones' is irrelevant since this is empty
    batch_dim = 1
    # empty batch of the right shape: the batch and two dimension for the indices since we have 2D matrices
    expected = torch.ones(0, 2)
    run_test(tensor, batch_dim, expected)

    # a batch of 2D matrices:
    tensor = torch.tensor([[[1, 2, 3], [6, 5, 4]], [[2, 3, 1], [4, 5, 6]]])
    batch_dim = 1
    expected = torch.tensor([[1, 0], [1, 2]])  # coordinates of two 6's, one in each 2D matrix
    run_test(tensor, batch_dim, expected)

    # same as before, but testing that batch_dim supports negative values
    tensor = torch.tensor([[[1, 2, 3], [6, 5, 4]], [[2, 3, 1], [4, 5, 6]]])
    batch_dim = -2
    expected = torch.tensor([[1, 0], [1, 2]])
    run_test(tensor, batch_dim, expected)

    # Same data, but a 2-dimensional batch of 1D arrays!
    tensor = torch.tensor([[[1, 2, 3], [6, 5, 4]], [[2, 3, 1], [4, 5, 6]]])
    batch_dim = 2
    expected = torch.tensor([[2, 0], [1, 2]])  # coordinates of 3, 6, 3, and 6
    run_test(tensor, batch_dim, expected)

    # same as before, but testing that batch_dim supports negative values
    tensor = torch.tensor([[[1, 2, 3], [6, 5, 4]], [[2, 3, 1], [4, 5, 6]]])
    batch_dim = -1
    expected = torch.tensor([[2, 0], [1, 2]])
    run_test(tensor, batch_dim, expected)


def run_test(tensor, batch_dim, expected):
    actual = batch_argmax(tensor, batch_dim)
    print(f"batch_argmax of {tensor} with batch_dim {batch_dim} is\n{actual}\nExpected:\n{expected}")
    assert actual.shape == expected.shape
    assert actual.eq(expected).all()

def check_that_exception_is_thrown(thunk, exception_type):
    if isinstance(exception_type, BaseException):
        raise Exception(f"check_that_exception_is_thrown received an exception instance rather than an exception type: "
                        f"{exception_type}")
    try:
        thunk()
        raise AssertionError(f"Should have thrown {exception_type}")
    except exception_type:
        pass
    except Exception as e:
        raise AssertionError(f"Should have thrown {exception_type} but instead threw {e}")
Pavier answered 10/1, 2021 at 20:26 Comment(0)
G
0

I have a straightforward workaround, but not the optimal solution to batch-wise compute the 2D cordinates of the max values of each item. The simple workaround may be:

# suppose the tensor is of shape (3,2,2), 
>>> a = torch.randn(3, 2, 2)
>>> a
tensor([[[ 0.1450, -1.3480],
         [-0.3339, -0.5133]],

        [[ 0.6867, -0.2972],
         [ 0.8768,  0.0844]],

        [[-2.3115, -0.4549],
         [-1.5074, -0.8706]]])

# then perform batch-wise max
>>> torch.stack([(a[i]==torch.max(a[i])).nonzero() for i in range(a.size(0))], dim=0)

tensor([[[0, 0]],

        [[1, 0]],

        [[0, 1]]])
Giliana answered 6/5, 2023 at 8:18 Comment(0)
E
-2
ps=ps.numpy()
ps=ps.tolist()

mx=[max(l) for l in ps]
mx=max(mx)
for i in range(len(ps[0])):
  if mx==ps[0][i]:
    print("The digit is "+str(i))
    break

This worked for me quite fine

Eugene answered 29/7, 2020 at 17:28 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.