Filter out NaN values from a PyTorch N-Dimensional tensor
Asked Answered
W

1

7

This question is very similar to filtering np.nan values from pytorch in a -Dimensional tensor. The difference is that I want to apply the same concept to tensors of 2 or higher dimensions.

I have a tensor that looks like this:

import torch

tensor = torch.Tensor(
[[1, 1, 1, 1, 1],
 [float('nan'), float('nan'), float('nan'), float('nan'), float('nan')],
 [2, 2, 2, 2, 2]]
)
>>> tensor.shape
>>> [3, 5]

I would like to find the most pythonic / PyTorch way of to filter out (remove) the rows of the tensor which are nan. By filtering this tensor along the first (0th axis) I want to obtain a filtered_tensor which looks like this:

>>> print(filtered_tensor)
>>> torch.Tensor(
[[1, 1, 1, 1, 1],
 [2, 2, 2, 2, 2]]
)
>>> filtered_tensor.shape
>>> [2, 5]
Winola answered 29/10, 2020 at 15:42 Comment(0)
E
11

Use PyTorch's isnan() together with any() to slice tensor's rows using the obtained boolean mask as follows:

filtered_tensor = tensor[~torch.any(tensor.isnan(),dim=1)]

Note that this will drop any row that has a nan value in it. If you want to drop only rows where all values are nan replace torch.any with torch.all.

For an N-dimensional tensor you could just flatten all the dims apart from the first dim and apply the same procedure as above:

#Flatten:
shape = tensor.shape
tensor_reshaped = tensor.reshape(shape[0],-1)
#Drop all rows containing any nan:
tensor_reshaped = tensor_reshaped[~torch.any(tensor_reshaped.isnan(),dim=1)]
#Reshape back:
tensor = tensor_reshaped.reshape(tensor_reshaped.shape[0],*shape[1:])
Estovers answered 29/10, 2020 at 16:9 Comment(4)
Beauuutiful. This is precisely what I was looking for. I should have checked that torch.any() had a dim parameter. Extra points for showing how to do the same thing while flattening! I will accept this as the answer, but you have a tiny mistake. Tensors don't have a t.isnan() function, it's just a top level torch.isnan(t) function. If you don't mind, please change that and I'll accept your answer :D.Winola
A Tensor does have an isnan method, check out at pytorch.org/docs/stable/tensors.html . That's why the code works perfectly. You are right that it does call torch.isnan in background.Estovers
You are right! I was using an older version of PyTorch. Thanks again! I've accepted the answerWinola
The N-dimensional code didn't work for me; start with tensor = torch.arange(600, dtype=torch.float32).reshape(1, 3, 20, 10) then tensor[0, 2, 0, 0] = float('nan') and the result has shape (0, 3, 20, 10)Lucila

© 2022 - 2024 — McMap. All rights reserved.