DEV Community

Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito)

Posted on • Updated on

argmin() and argmax() in PyTorch

*Memos:

argmin() can get the 0D or more D tensor of the zero or more indices of the 1st minimum elements from the 0D or more D tensor of zero or more elements as shown below:

*Memos:

  • argmin() can be used with torch or a tensor.
  • The 1st argument with torch or using a tensor is input(Required-Type:tensor of int or float).
  • The 2nd argument with torch or the 1st argument is dim(Optional-Type:int). *Setting dim can get the zero or more indices of the 1st minimum elements.
  • The 3rd argument with torch or the 2nd argument is keepdim(Optional-Type:bool). *My post explains keepdim argument.
  • The 1D or more D tensor of one complex number or boolean value with dim works.
  • Empty 2D or more D input tensor without other tensor doesn't work if not setting dim.
  • Empty 1D input tesnor without other tensor doesn't work even if setting dim.
import torch

my_tensor = torch.tensor([[5, 4, 7, 7],
                          [6, 5, 3, 5],
                          [3, 8, 9, 3]])
torch.argmin(input=my_tensor)
my_tensor.argmin()
# tensor(6)

torch.argmin(input=my_tensor, dim=0)
torch.argmin(input=my_tensor, dim=-2)
# tensor([2, 0, 1, 2])

torch.argmin(input=my_tensor, dim=1)
torch.argmin(input=my_tensor, dim=-1)
# tensor([1, 2, 0])

my_tensor = torch.tensor([[5., 4., 7., 7.],
                          [6., 5., 3., 5.],
                          [3., 8., 9., 3.]])
torch.argmin(input=my_tensor)
# tensor(6)

my_tensor = torch.tensor([5.+7.j])

torch.argmin(input=my_tensor, dim=0)
# tensor(0)

my_tensor = torch.tensor([[True]])

torch.argmin(input=my_tensor, dim=0)
# tensor([0])

my_tensor = torch.tensor([])
my_tensor = torch.tensor([[]])
my_tensor = torch.tensor([[[]]])

torch.argmin(input=my_tensor) # Error

my_tensor = torch.tensor([])

torch.argmin(input=my_tensor, dim=0) # Error

my_tensor = torch.tensor([[]])

torch.argmin(input=my_tensor, dim=0)
# tensor([], dtype=torch.int64)

my_tensor = torch.tensor([[[]]])

torch.argmin(input=my_tensor, dim=0)
# tensor([], size=(1, 0), dtype=torch.int64)
Enter fullscreen mode Exit fullscreen mode

argmax() can get the 0D or more D tensor of the zero or more indices of the 1st maximum elements from the 0D or more D tensor of zero or more elements as shown below:

*Memos:

  • argmax() can be used with torch or a tensor.
  • The 1st argument with torch or using a tensor is input(Required-Type:tensor of int or float).
  • The 2nd argument with torch or the 1st argument is dim(Optional-Type:int). *Setting dim can get the zero or more indices of the 1st maximum elements.
  • The 3rd argument with torch or the 2nd argument is keepdim(Optional-Type:bool). *My post explains keepdim argument.
  • The 1D or more D tensor of one complex number or boolean value with dim works.
  • Empty 2D or more D input tensor without other tensor doesn't work if not setting dim.
  • Empty 1D input tesnor without other tensor doesn't work even if setting dim.
import torch

my_tensor = torch.tensor([[5, 4, 7, 7],
                          [6, 5, 3, 5],
                          [3, 8, 9, 3]])
torch.argmax(input=my_tensor)
my_tensor.argmax()
# tensor(10)

torch.argmax(input=my_tensor, dim=0)
torch.argmax(input=my_tensor, dim=-2)
# tensor([1, 2, 2, 0])

torch.argmax(input=my_tensor, dim=1)
torch.argmax(input=my_tensor, dim=-1)
# tensor([2, 0, 2])

my_tensor = torch.tensor([[5., 4., 7., 7.],
                          [6., 5., 3., 5.],
                          [3., 8., 9., 3.]])
torch.argmax(input=my_tensor)
# tensor(10)

my_tensor = torch.tensor([5.+7.j])

torch.argmax(input=my_tensor, dim=0)
# tensor(0)

my_tensor = torch.tensor([[True]])

torch.argmax(input=my_tensor, dim=0)
# tensor([0])

my_tensor = torch.tensor([])
my_tensor = torch.tensor([[]])
my_tensor = torch.tensor([[[]]])

torch.argmax(input=my_tensor) # Error

my_tensor = torch.tensor([])

torch.argmax(input=my_tensor, dim=0) # Error

my_tensor = torch.tensor([[]])

torch.argmax(input=my_tensor, dim=0)
# tensor([], dtype=torch.int64)

my_tensor = torch.tensor([[[]]])

torch.argmax(input=my_tensor, dim=0)
# tensor([], size=(1, 0), dtype=torch.int64)
Enter fullscreen mode Exit fullscreen mode

Top comments (0)