If I get you correctly you don't want the values, but the indices. Unfortunately there is no out of the box solution. There exists an argmax()
function, but I cannot see how to get it to do exactly what you want.
So here is a small workaround, the efficiency should also be okay since we're just dividing tensors:
n = torch.tensor(4)
d = torch.tensor(4)
x = torch.rand(n, 1, d, d)
m = x.view(n, -1).argmax(1)
# since argmax() does only return the index of the flattened
# matrix block we have to calculate the indices by ourself
# by using / and % (// would also work, but as we are dealing with
# type torch.long / works as well
indices = torch.cat(((m / d).view(-1, 1), (m % d).view(-1, 1)), dim=1)
print(x)
print(indices)
n
represents your first dimension, and d
the last two dimensions. I take smaller numbers here to show the result. But of course this will also work for n=20
and d=120
:
n = torch.tensor(20)
d = torch.tensor(120)
x = torch.rand(n, 1, d, d)
m = x.view(n, -1).argmax(1)
indices = torch.cat(((m / d).view(-1, 1), (m % d).view(-1, 1)), dim=1)
#print(x)
print(indices)
Here is the output for n=4
and d=4
:
tensor([[[[0.3699, 0.3584, 0.4940, 0.8618],
[0.6767, 0.7439, 0.5984, 0.5499],
[0.8465, 0.7276, 0.3078, 0.3882],
[0.1001, 0.0705, 0.2007, 0.4051]]],
[[[0.7520, 0.4528, 0.0525, 0.9253],
[0.6946, 0.0318, 0.5650, 0.7385],
[0.0671, 0.6493, 0.3243, 0.2383],
[0.6119, 0.7762, 0.9687, 0.0896]]],
[[[0.3504, 0.7431, 0.8336, 0.0336],
[0.8208, 0.9051, 0.1681, 0.8722],
[0.5751, 0.7903, 0.0046, 0.1471],
[0.4875, 0.1592, 0.2783, 0.6338]]],
[[[0.9398, 0.7589, 0.6645, 0.8017],
[0.9469, 0.2822, 0.9042, 0.2516],
[0.2576, 0.3852, 0.7349, 0.2806],
[0.7062, 0.1214, 0.0922, 0.1385]]]])
tensor([[0, 3],
[3, 2],
[1, 1],
[1, 0]])
I hope this is what you wanted to get! :)
Edit:
Here is a slightly modified which might be minimally faster (not much I guess :), but it is a bit simpler and prettier:
Instead of this like before:
m = x.view(n, -1).argmax(1)
indices = torch.cat(((m // d).view(-1, 1), (m % d).view(-1, 1)), dim=1)
The necessary reshaping already done on the argmax
values:
m = x.view(n, -1).argmax(1).view(-1, 1)
indices = torch.cat((m // d, m % d), dim=1)
But as mentioned in the comments. I don't think it is possible to get much more out of it.
One thing you could do, if it is really important for you to get the last possible bit of performance improvement out of it, is implementing this above function as a low-level extension (like in C++) for pytorch.
This would give you just one function you can call for it and would avoid slow python code.
https://pytorch.org/tutorials/advanced/cpp_extension.html
[20, 2]
matrix. Do you want maximum along the rows and maximum along the columns for each of the120 * 120
matrix? – Inpatient120 * 120
matrices I want the[x, y]
coordinates of the cell with maximum value – Haymesk
elemets, use torch.topk(). – Chuffy