DEV Community

Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito)

Posted on • Updated on

argwhere() and nonzero() in PyTorch

*My post explains where() and count_nonzero().

argwhere() can get the 2D tensor of the zero or more indices of non-zero elements by a 0D or more D tensor as shown below:

*Memos:

  • argwhere() can be used with torch or a tensor.
  • The 1st argument(tensor of int, float, complex or bool) with torch or using a tensor(tensor of int, float, complex or bool) is input(Required).
import torch

my_tensor = torch.tensor(5)

torch.argwhere(input=my_tensor)
my_tensor.argwhere()
# tensor([], size=(1, 0), dtype=torch.int64)

my_tensor = torch.tensor([5, 0, 4, 0, 3, 1])

torch.argwhere(input=my_tensor)
# tensor([[0], [2], [4], [5]])

my_tensor = torch.tensor([5., 0., 4., 0., 3., 1])

torch.argwhere(input=my_tensor)
# tensor([[0], [2], [4], [5]])

my_tensor = torch.tensor([5.+0.j, 0.+0.j, 4.+0.j, 0.+0.j, 3.+0.j, 1.+0.j])

torch.argwhere(input=my_tensor)
# tensor([[0], [2], [4], [5]])

my_tensor = torch.tensor([True, False, True, False, True, False])

torch.argwhere(input=my_tensor)
# tensor([[0], [2], [4]])

my_tensor = torch.tensor([[5, 0, 4],
                          [0, 3, 1]])
torch.argwhere(input=my_tensor)
# tensor([[0, 0], [0, 2], [1, 1], [1, 2]])

my_tensor = torch.tensor([[[5, 0, 4], [0, 3, 1]],
                          [[0, 7, 0], [0, 6, 8]]])
torch.argwhere(input=my_tensor)
# tensor([[0, 0, 0],
#         [0, 0, 2],
#         [0, 1, 1],
#         [0, 1, 2],
#         [1, 0, 1],
#         [1, 1, 1],
#         [1, 1, 2]])
Enter fullscreen mode Exit fullscreen mode

nonzero() can get the 2D tensor(Default) of the zero or more indices of non-zero elements by a 0D or more D tensor as shown below:

*Memos:

  • nonzero() can be used with torch or a tensor.
  • The 1st argument(tensor of int, float, complex or bool) with torch or using a tensor(tensor of int, float, complex or bool) is input(Required).
  • There is as_tuple argument(bool) (Optional-Default:False) with torch or a tensor. *Memos:
    • If as_tuple is True, the tuple of zero or more 1D tensors is returned.
import torch

my_tensor = torch.tensor(5)

torch.nonzero(input=my_tensor)
my_tensor.nonzero()
# tensor([], size=(1, 0), dtype=torch.int64)

torch.nonzero(input=my_tensor, as_tuple=True)
# (tensor([0]),)

my_tensor = torch.tensor([5, 0, 4, 0, 3, 1])

torch.nonzero(input=my_tensor)
my_tensor.nonzero()
# tensor([[0], [2], [4], [5]])

torch.nonzero(input=my_tensor, as_tuple=True)
# (tensor([0, 2, 4, 5]),)

my_tensor = torch.tensor([5., 0., 4., 0., 3., 1])

torch.nonzero(input=my_tensor)
# tensor([[0], [2], [4], [5]])

my_tensor = torch.tensor([5.+0.j, 0.+0.j, 4.+0.j, 0.+0.j, 3.+0.j, 1.+0.j])

torch.nonzero(input=my_tensor)
# tensor([[0], [2], [4], [5]])

my_tensor = torch.tensor([True, False, True, False, True, False])

torch.nonzero(input=my_tensor)
# tensor([[0], [2], [4]])

my_tensor = torch.tensor([[5, 0, 4],
                          [0, 3, 1]])
torch.nonzero(input=my_tensor)
# tensor([[0, 0], [0, 2], [1, 1], [1, 2]])

torch.nonzero(input=my_tensor, as_tuple=True)
# (tensor([0, 0, 1, 1]),
#  tensor([0, 2, 1, 2]))

my_tensor = torch.tensor([[[5, 0, 4], [0, 3, 1]],
                          [[0, 7, 0], [0, 6, 8]]])
torch.nonzero(input=my_tensor)
# tensor([[0, 0, 0],
#         [0, 0, 2],
#         [0, 1, 1],
#         [0, 1, 2],
#         [1, 0, 1],
#         [1, 1, 1],
#         [1, 1, 2]])

torch.nonzero(input=my_tensor, as_tuple=True)
# (tensor([0, 0, 0, 0, 1, 1, 1]),
#  tensor([0, 0, 1, 1, 0, 1, 1]),
#  tensor([0, 2, 1, 2, 1, 1, 2]))
Enter fullscreen mode Exit fullscreen mode

Top comments (0)