DEV Community

Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito)

Posted on • Updated on

unbind in PyTorch

Buy Me a Coffee

*Memos:

unbind() can get the one or more 0D or more D splitted view tensors of zero or more elements by the removal of only one specified dimension from the 1D or more D tensor of zero or more elements as shown below:

*Memos:

  • unbind() can be used with torch or a tensor.
  • The 1st argument(input) with torch or using a tensor(Required-Type:tensor of int, float, complex or bool).
  • The 2nd argument with torch or the 1st argument is dim(Optional-Default:0-Type:int).
  • The total number of the zero or more elements of one or more returned tensors doesn't change.
  • One or more returned tensors don't keep the dimension of the original tensor.
import torch

my_tensor = torch.tensor([0, 1, 2, 3])

torch.unbind(input=my_tensor)
my_tensor.unbind()
torch.unbind(input=my_tensor, dim=0)
torch.unbind(input=my_tensor, dim=-1)
# (tensor(0),
#  tensor(1),
#  tensor(2),
#  tensor(3))

my_tensor = torch.tensor([[0, 1, 2, 3],
                          [4, 5, 6, 7],
                          [8, 9, 10, 11]])
torch.unbind(input=my_tensor)
torch.unbind(input=my_tensor, dim=0)
torch.unbind(input=my_tensor, dim=-2)
# (tensor([0, 1, 2, 3]),
#  tensor([4, 5, 6, 7]),
#  tensor([8, 9, 10, 11]))

torch.unbind(input=my_tensor, dim=1)
torch.unbind(input=my_tensor, dim=-1)
# (tensor([0, 4, 8]),
#  tensor([1, 5, 9]),
#  tensor([2, 6, 10]),
#  tensor([3, 7, 11]))

my_tensor = torch.tensor([[[0, 1, 2, 3],
                           [4, 5, 6, 7],
                           [8, 9, 10, 11]]])
torch.unbind(input=my_tensor)
torch.unbind(input=my_tensor, dim=0)
torch.unbind(input=my_tensor, dim=-3)
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),)

torch.unbind(input=my_tensor, dim=1)
torch.unbind(input=my_tensor, dim=-2)
# (tensor([[0, 1, 2, 3]]),
#  tensor([[4, 5, 6, 7]]),
#  tensor([[8, 9, 10, 11]]))

torch.unbind(input=my_tensor, dim=2)
torch.unbind(input=my_tensor, dim=-1)
# (tensor([[0, 4, 8]]),
#  tensor([[1, 5, 9]]),
#  tensor([[2, 6, 10]]),
#  tensor([[3, 7, 11]]))

my_tensor = torch.tensor([[[0., 1., 2., 3.],
                           [4., 5., 6., 7.],
                           [8., 9., 10., 11.]]])
torch.unbind(input=my_tensor)
# (tensor([[0., 1., 2., 3.],
#          [4., 5., 6., 7.],
#          [8., 9., 10., 11.]]),)

my_tensor = torch.tensor([[[0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j],
                           [4.+0.j, 5.+0.j, 6.+0.j, 7.+0.j],
                           [8.+0.j, 9.+0.j, 10.+0.j, 11.+0.j]]])
torch.unbind(input=my_tensor)
# (tensor([[0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j],
#          [4.+0.j, 5.+0.j, 6.+0.j, 7.+0.j],
#          [8.+0.j, 9.+0.j, 10.+0.j, 11.+0.j]]),)

my_tensor = torch.tensor([[[True, False, True, False],
                           [False, True, False, True],
                           [True, False, True, False]]])
torch.unbind(input=my_tensor)
# (tensor([[True, False, True, False],
#          [False, True, False, True],
#          [True, False, True, False]]),)
Enter fullscreen mode Exit fullscreen mode

Top comments (0)