take() can get the 0D or more D tensor of zero or more elements using the 0D or more D tensor of zero or more indices from the 0D or more D tensor of zero or more elements as shown below:
*Memos:
-
take()
can be used with torch or a tensor. - The 1st argument(
input
) withtorch
or using a tensor(Required-Type:tensor
ofint
,float
,complex
orbool
). - The 2nd argument with
torch
or the 1st argument with a tensor isindex
(Required-Type:tensor
ofint
). *It decides the size of a returned tensor.
import torch
my_tensor = torch.tensor([[9, 5, 0, 6, 2], [7, 1, 3, 4, 8]])
torch.take(input=my_tensor, index=torch.tensor(3))
my_tensor.take(index=torch.tensor(3))
torch.take(input=my_tensor, index=torch.tensor(-7))
# tensor(6)
torch.take(input=my_tensor, index=torch.tensor([3, 0, 7, 4]))
torch.take(input=my_tensor, index=torch.tensor([-7, -10, -3, -6]))
# tensor([6, 9, 3, 2])
torch.take(input=my_tensor, index=torch.tensor([[3, 0], [7, 4]]))
torch.take(input=my_tensor, index=torch.tensor([[-7, -10], [-3, -6]]))
# tensor([[6, 9], [3, 2]])
torch.take(input=my_tensor, index=torch.tensor([[[3, 0], [7, 4]],
[[8, 2], [3, 5]]]))
torch.take(input=my_tensor, index=torch.tensor([[[-7, -10], [-3, -6]],
[[-2, -8], [-7, -5]]]))
# tensor([[[6, 9], [3, 2]], [[4, 0], [6, 7]]])
my_tensor = torch.tensor([[9., 5., 0., 6., 2.], [7., 1., 3., 4., 8.]])
torch.take(input=my_tensor, index=torch.tensor([[[3, 0], [7, 4]],
[[8, 2], [3, 5]]]))
# tensor([[[6., 9.], [3., 2.]], [[4., 0.], [6., 7.]]])
my_tensor = torch.tensor([[9.+0.j, 5.+0.j, 0.+0.j, 6.+0.j, 2.+0.j],
[7.+0.j, 1.+0.j, 3.+0.j, 4.+0.j, 8.+0.j]])
torch.take(input=my_tensor, index=torch.tensor([[[3, 0], [7, 4]],
[[8, 2], [3, 5]]]))
# tensor([[[6.+0.j, 9.+0.j], [3.+0.j, 2.+0.j]],
# [[4.+0.j, 0.+0.j], [6.+0.j, 7.+0.j]]])
my_tensor = torch.tensor([[True, False, True, False, True],
[False, True, False, True, False]])
torch.take(input=my_tensor, index=torch.tensor([[[3, 0], [7, 4]],
[[8, 2], [3, 5]]]))
# tensor([[[False, True], [False, True]],
# [[True, True], [False, False]]])
take_along_dim() can get the 1D or more D tensor of zero or more elements using the 0D or more D tensor of zero or more indices from the 0D or more D tensor of zero or more elements as shown below:
*Memos:
-
take_along_dim()
can be used withtorch
or a tensor. - The 1st argument(
input
) withtorch
or using a tensor(Required-Type:tensor
ofint
,float
,complex
orbool
). - The 2nd argument with
torch
or the 1st argument with a tensor isindices
(Required-Type:tensor
ofint
). - The 3rd argument with
torch
or the 2nd argument with a tensor isdim
(Optional-Type:int
): *Memos:- Not setting
dim
returns a 1D tensor. - If
dim
is set, both tensors must be the same D and the returned tensor is its D.
- Not setting
- There is
out
argument withtorch
(Optional-Default:None
-Type:tensor
): *Memos:-
out=
must be used. -
My post explains
out
argument.
-
import torch
my_tensor = torch.tensor([[9, 5, 0, 6, 2], [7, 1, 3, 4, 8]])
torch.take_along_dim(input=my_tensor, indices=torch.tensor(3))
my_tensor.take_along_dim(indices=torch.tensor(3))
torch.gather(input=my_tensor, indices=torch.tensor(3))
# tensor([6])
torch.take_along_dim(input=my_tensor, indices=torch.tensor([3, 0, 7, 4]))
torch.take_along_dim(input=my_tensor, indices=torch.tensor([[3, 0], [7, 4]]))
# tensor([6, 9, 3, 2])
torch.take_along_dim(input=my_tensor,
indices=torch.tensor([[[3, 0], [7, 4]],
[[8, 2], [3, 5]]]))
# tensor([6, 9, 3, 2, 4, 0, 6, 7])
torch.take_along_dim(input=my_tensor,
indices=torch.tensor([[0], [1], [0], [1]]),
dim=0)
torch.take_along_dim(input=my_tensor,
indices=torch.tensor([[0], [1], [0], [1]]),
dim=-2)
# tensor([[9, 5, 0, 6, 2],
# [7, 1, 3, 4, 8],
# [9, 5, 0, 6, 2],
# [7, 1, 3, 4, 8]])
torch.take_along_dim(input=my_tensor,
indices=torch.tensor([[0, 0, 0, 0, 0],
[1, 1, 1, 1, 1],
[0, 1, 0, 1, 0],
[1, 0, 1, 0, 1]]),
dim=0)
torch.take_along_dim(input=my_tensor,
indices=torch.tensor([[0, 0, 0, 0, 0],
[1, 1, 1, 1, 1],
[0, 1, 0, 1, 0],
[1, 0, 1, 0, 1]]),
dim=-2)
# tensor([[9, 5, 0, 6, 2],
# [7, 1, 3, 4, 8],
# [9, 1, 0, 4, 2],
# [7, 5, 3, 6, 8]])
torch.take_along_dim(input=my_tensor,
indices=torch.tensor([[3, 0, 4]]),
dim=1)
torch.take_along_dim(input=my_tensor,
indices=torch.tensor([[3, 0, 4]]),
dim=-1)
# tensor([[6, 9, 2], [4, 7, 8]])
torch.take_along_dim(input=my_tensor,
indices=torch.tensor([[3, 0, 4], [4, 1, 2]]),
dim=1)
torch.take_along_dim(input=my_tensor,
indices=torch.tensor([[3, 0, 4], [4, 1, 2]]),
dim=-1)
# tensor([[6, 9, 2], [8, 1, 3]])
my_tensor = torch.tensor([[9., 5., 0., 6., 2.], [7., 1., 3., 4., 8.]])
torch.take_along_dim(input=my_tensor,
indices=torch.tensor([[3, 0, 4], [4, 1, 2]]),
dim=1)
# tensor([[6., 9., 2.], [8., 1., 3.]])
my_tensor = torch.tensor([[9.+0.j, 5.+0.j, 0.+0.j, 6.+0.j, 2.+0.j],
[7.+0.j, 1.+0.j, 3.+0.j, 4.+0.j, 8.+0.j]])
torch.take_along_dim(input=my_tensor,
indices=torch.tensor([[3, 0, 4], [4, 1, 2]]),
dim=1)
# tensor([[6.+0.j, 9.+0.j, 2.+0.j],
# [8.+0.j, 1.+0.j, 3.+0.j]])
my_tensor = torch.tensor([[True, False, True, False, True],
[False, True, False, True, False]])
torch.take_along_dim(input=my_tensor,
indices=torch.tensor([[3, 0, 4], [4, 1, 2]]),
dim=1)
# tensor([[False, True, True],
# [False, True, False]])
Top comments (0)