This question is very similar to filtering np.nan
values from pytorch in a -Dimensional tensor. The difference is that I want to apply the same concept to tensors of 2 or higher dimensions.
I have a tensor that looks like this:
import torch
tensor = torch.Tensor(
[[1, 1, 1, 1, 1],
[float('nan'), float('nan'), float('nan'), float('nan'), float('nan')],
[2, 2, 2, 2, 2]]
)
>>> tensor.shape
>>> [3, 5]
I would like to find the most pythonic / PyTorch way of to filter out (remove) the rows of the tensor which are nan
. By filtering this tensor
along the first (0
th axis) I want to obtain a filtered_tensor
which looks like this:
>>> print(filtered_tensor)
>>> torch.Tensor(
[[1, 1, 1, 1, 1],
[2, 2, 2, 2, 2]]
)
>>> filtered_tensor.shape
>>> [2, 5]
torch.any()
had adim
parameter. Extra points for showing how to do the same thing while flattening! I will accept this as the answer, but you have a tiny mistake.Tensor
s don't have at.isnan()
function, it's just a top leveltorch.isnan(t)
function. If you don't mind, please change that and I'll accept your answer :D. – Winola