How to select indices according to another tensor in pytorch
Asked Answered
I

2

9

The task seems to be simple, but I cannot figure out how to do it.

So what I have are two tensors:

  • an indices tensor indices with shape (2, 5, 2), where the last dimensions corresponds to indices in x and y dimension
  • a "value tensor" value with shape (2, 5, 2, 16, 16), where I want the last two dimensions to be selected with x and y indices

To be more concrete, the indices are between 0 and 15 and I want to get an output:

out = value[:, :, :, x_indices, y_indices]

The shape of the output should therefore be of (2, 5, 2). Can anybody help me here? Thanks a lot!

Edit:

I tried the suggestion with gather, but unfortunately it does not seem to work (I changed the dimensions, but it doesn't matter):

First I generate a coordinate grid:

y_t = torch.linspace(-1., 1., 16, device='cpu').reshape(16, 1).repeat(1, 16).unsqueeze(-1)
x_t = torch.linspace(-1., 1., 16, device='cpu').reshape(1, 16).repeat(16, 1).unsqueeze(-1)
grid = torch.cat((y_t, x_t), dim=-1).permute(2, 0, 1).unsqueeze(0)
grid = grid.unsqueeze(1).repeat(1, 3, 1, 1, 1)

In the next step, I am creating some indices. In this case, I always take index 1:

indices = torch.ones([1, 3, 2], dtype=torch.int64)

Next, I am using your method:

indices = indices.unsqueeze(-1).unsqueeze(-1)
new_coords = torch.gather(grid, -1, indices).squeeze(-1).squeeze(-1)

Finally, I manually select index 1 for x and y coordinate:

new_coords_manual = grid[:, :, :, 1, 1]

This outputs the following new coordinates:

new_coords
tensor([[[-1.0000, -0.8667],
         [-1.0000, -0.8667],
         [-1.0000, -0.8667]]])

new_coords_manual
tensor([[[-0.8667, -0.8667],
         [-0.8667, -0.8667],
         [-0.8667, -0.8667]]])

As you can see, it only works for one dimension. Do you have an idea how to fix that?

Ism answered 20/1, 2021 at 18:42 Comment(2)
Could you show a minimal example indices and value and the desired output?Vizzone
The desired output is achieved when producing new_coords_manualIsm
V
3

What you could do is flatten the first three axes together and apply torch.gather:

>>> grid.flatten(start_dim=0, end_dim=2).shape
torch.Size([6, 16, 16])

>>> torch.gather(grid.flatten(0, 2), axis=1, indices)
tensor([[[-0.8667, -0.8667],
         [-0.8667, -0.8667],
         [-0.8667, -0.8667]]])

As explained on the documentation page, this will perform:

out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1

Vizzone answered 20/1, 2021 at 20:46 Comment(2)
Thanks for your help! This indeed works for batch size of 1, but it seems to face the same problems with batch size of > 1 :/ I also tried to split the problem in x and y coordinate and apply indices_y = indices[:, :, 0].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) followed by new_y = torch.gather(grid, 3, indices_y).squeeze(-1).squeeze(-1). After doing the same with the x values and gather on dim=4 I concatenated the tensors. But I receive exactly the result from your first suggestion.Ism
Why flatten first?Assimilation
I
3

I figured it out, thanks again @Ivan for your help! :)

The problem was, that i unsqueezed on the last dimension, while I should have unsqueezed in the middle dimensions, so that the indices are at the end:

y_t = torch.linspace(-1., 1., 16, device='cpu').reshape(16, 1).repeat(1, 16).unsqueeze(-1)
x_t = torch.linspace(-1., 1., 16, device='cpu').reshape(1, 16).repeat(16, 1).unsqueeze(-1)
grid = torch.cat((y_t, x_t), dim=-1).permute(2, 0, 1).unsqueeze(0)
grid = grid.unsqueeze(1).repeat(2, 3, 1, 1, 1)

indices = torch.ones([2, 3, 2], dtype=torch.int64).unsqueeze(-2).unsqueeze(-2)
new_coords = torch.gather(grid, 3, indices).squeeze(-2).squeeze(-2)

new_coords_manual = grid[:, :, :, 1, 1]

Now new_coords equals new_coords_manual.

Ism answered 20/1, 2021 at 21:59 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.