DEV Community

Cover image for Useful Tensor Manipulation Functions in PyTorch
Bala Priya C
Bala Priya C

Posted on • Edited on • Originally published at towardsai.net

Useful Tensor Manipulation Functions in PyTorch

PyTorch is a popular, open source, optimized tensor library widely used in deep learning and AI Research, developed by researchers at Facebook AI. The torch package contains data structures for multi-dimensional tensors and mathematical operations over these are defined.

In this blog post, you'll learn some useful functions that the torch package provides for manipulating tensors. Specifically, you'll take the help of examples to understand how the different functions work, including cases where the functions do not perform as expected and throw errors. We shall look at the following tensor manipulation functions.

  1. torch.cat: Concatenates the given sequence of tensors in the given dimension
  2. torch.unbind: Removes a tensor dimension
  3. torch.movedim: Moves the dimension(s) of input at the position(s) in source to the position(s) in destination
  4. torch.squeeze: Returns a tensor with all the dimensions of input of size 1 removed.
  5. torch.unsqueeze: Returns a new tensor with a dimension of size one inserted at the specified position.

Before we begin, let's import torch.

import torch
Enter fullscreen mode Exit fullscreen mode

1. torch.cat

torch.cat(tensors, dim=0, *, out=None)

  • Concatenates the given sequence of tensors in the given dimension.

  • All tensors must either have the same shape (except in the concatenating dimension) or be empty.

The argumenttensors denotes the sequence of tensors to be concatenated.

dim is an optional argument that specifies the dimension along which we want tensors to be concatenated. (default dim=0)

outis an optional keyword argument.

# Example 1
ip_tensor_1=torch.tensor([[1,2,3],[4,5,6]])
ip_tensor_2=torch.tensor([[7,8,9],[10,11,12]])

torch.cat((ip_tensor_1,ip_tensor_2),dim=0)

# Output
tensor([[ 1,  2,  3],
        [ 4,  5,  6],
        [ 7,  8,  9],
        [10, 11, 12]])
Enter fullscreen mode Exit fullscreen mode

As we specified dim=0, the input tensors have been concatenated along dimension 0. The input tensors each had shape (2,3) and as the tensors were concatenated along dimension 0, the output tensor is of shape (4,3)

# Example 2
ip_tensor_1=torch.tensor([[1,2,3],[4,5,6]])
ip_tensor_2=torch.tensor([[7,8,9,10],[11,12,13,14]])

torch.cat((ip_tensor_1,ip_tensor_2),dim=1)

# Output
tensor([[ 1,  2,  3,  7,  8,  9, 10],
        [ 4,  5,  6, 11, 12, 13, 14]])
Enter fullscreen mode Exit fullscreen mode

Well, this time, we chose to concatenate along the first dimension (dim=1).The ip_tensor_1 was of shape (2,3) and the ip_tensor_2 was of shape (2,4). As we chose to concatenate along the first dimension, the output tensor returned is of shape (2,7).
Now, let's see what happens when we try to concatenate the above two input tensors along dim=0.

# Example 3
ip_tensor_1=torch.tensor([[1,2,3],[4,5,6]])
ip_tensor_2=torch.tensor([[7,8,9,10],[11,12,13,14]])

torch.cat((ip_tensor_1,ip_tensor_2),dim=0)

# Output
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-23-cb649a6e60ac> in <module>()
      3 ip_tensor_2=torch.tensor([[7,8,9,10],[11,12,13,14]])
      4 
----> 5 torch.cat((ip_tensor_1,ip_tensor_2),dim=0)

RuntimeError: Sizes of tensors must match except in dimension 0. Got 3 and 4 in dimension 1 
(The offending index is 1)
Enter fullscreen mode Exit fullscreen mode

We see that an error is thrown when we try to concatenate along dim=0. This is precisely because the size of the tensors should agree in all dimensions other than the one that we're concatenating along.
Here,ip_tensor_1 has size 3 along dim=1 whereas ip_tensor_2has size 4 along dim=1 which is why we ran into an error.

Therefore, we can use the torch.cat function when we want to concatenate tensors along a valid dimension provided the tensors have the same size in all other dimensions.


2. torch.unbind

torch.unbind(input, dim=0)

This function removes the tensor dimension specified by the argument dim.(default dim=0)

Returns a tuple of slices of the tensor along the specified dim.

# Example 1
ip_tensor=torch.tensor([[1,2,3],[4,5,6]])
torch.unbind(ip_tensor,dim=0)

# Output
(tensor([1, 2, 3]), tensor([4, 5, 6]))
Enter fullscreen mode Exit fullscreen mode

The ip_tensor is of shape (2,3). As we specified dim=0 we can see that applying unbind along the dim=0 returns a tuple of slices of the ip_tensor along the zeroth dimension.

# Example 2
ip_tensor=torch.tensor([[1,2,3],[4,5,6]])
torch.unbind(ip_tensor,dim=1)

# Output
(tensor([1, 4]), tensor([2, 5]), tensor([3, 6]))
Enter fullscreen mode Exit fullscreen mode

In the above example, we see that when we choose to unbind along dim=1, we get a tuple containing three slices of the input tensor along the first dimension.

# Example 3
ip_tensor=torch.randn(10,10)
torch.unbind(ip_tensor,dim=2)

# Output
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-19-0159bee72931> in <module>()
      2 ip_tensor=torch.randn(10,10)
      3 print(ip_tensor)
----> 4 torch.unbind(ip_tensor,dim=2)

IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)
Enter fullscreen mode Exit fullscreen mode

As expected, we see that the input tensor is of shape (10,10) and when we choose to unbind along a dimension that is not valid, we run into an error.

A clear understanding of dimensions and size along a specific dimension is necessary; Even though our input tensor has 100 elements and has size 10 in each of the dimensions 0 and 1 it does not have a third dimension of index 2; hence, it's important to pass in a valid dimension for the tensor manipulation operations.

The unbind function can be useful when we would like to examine slices of a tensor along a specified input dimension.


3. torch.movedim

torch.movedim(input, source, destination)

This function moves the dimensions of input at the positions in source to the positions specified in destination.

source and destination can be either int (single dimension) or tuple of dimensions to be moved.

Other dimensions of input that are not explicitly moved remain in their original order and appear at the positions not specified in destination.

# Example 1 
ip_tensor= torch.randn(4,3,2)
print(f"Input Tensor shape:{ip_tensor.shape}\n")
op_tensor=torch.movedim(ip_tensor,1,2)
print(f"Output Tensor shape:{op_tensor.shape}\n")
# Output
Input Tensor shape:torch.Size([4, 3, 2])

Output Tensor shape:torch.Size([4, 2, 3])
Enter fullscreen mode Exit fullscreen mode

In this example, we wanted to move the dimension 1 in the input tensor to dimension 2 in the output tensor & we've done just that using the movedim function.

ip_tensor is of shape (4,3,2) whereas op_tensor is of shape (4,2,3) that is, dim1 in input tensor has moved to dim2 in the output tensor.

# Example 2 
ip_tensor= torch.randn(4,3,2)
print(f"Input Tensor shape:{ip_tensor.shape}\n")
op_tensor=torch.movedim(ip_tensor,(1,0),(2,1))
print(f"Output Tensor shape:{op_tensor.shape}\n")

# Output
Input Tensor shape:torch.Size([4, 3, 2])

Output Tensor shape:torch.Size([2, 4, 3])
Enter fullscreen mode Exit fullscreen mode

In this example,we want to move dimensions 1 and 0 in input tensor to dimensions 2 and 1 in the output tensor. And we see that this change has been reflected by checking the shape of the respective tensors.

# Example 3
ip_tensor= torch.randn(4,3,2)
print(f"Input Tensor shape:{ip_tensor.shape}\n")
op_tensor=torch.movedim(ip_tensor,(1,0),(1,1))
print(f"Output Tensor shape:{op_tensor.shape}\n")

# Output
Input Tensor shape:torch.Size([4, 3, 2])

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-14-f3275d938a7d> in <module>()
      2 ip_tensor= torch.randn(4,3,2)
      3 print(f"Input Tensor shape:{ip_tensor.shape}\n")
----> 4 op_tensor=torch.movedim(ip_tensor,(1,0),(1,1))
      5 print(f"Output Tensor shape:{op_tensor.shape}\n")

RuntimeError: movedim: repeated dim in `destination` ([1, 1])
Enter fullscreen mode Exit fullscreen mode

In this example, we get an error as we've repeated dimension 1 in the destination tuple. The entries in source and destination tuples should all be unique.

Thus, the function movedim helps in going from a tensor of specified shape to another while retaining the same underlying elements.


4. torch.squeeze

torch.squeeze(input, dim=None, *, out=None)

This operation returns a tensor with all the dimensions of input of size 1 removed.

When dim is specified, then squeeze operation is done only along that dimension.

# Example 1
ip_tensor=torch.randn(2,1,3)
print(ip_tensor)
print(f"Input Tensor shape:{ip_tensor.shape}\n")
op_tensor=torch.squeeze(ip_tensor)
print(op_tensor)
print(f"Output Tensor shape:{op_tensor.shape}\n")

# Output

tensor([[[ 0.2819,  0.3406, -1.8031]],

        [[-0.9314,  1.0048, -0.3198]]])
Input Tensor shape:torch.Size([2, 1, 3])

tensor([[ 0.2819,  0.3406, -1.8031],
        [-0.9314,  1.0048, -0.3198]])
Output Tensor shape:torch.Size([2, 3])
Enter fullscreen mode Exit fullscreen mode

We had ip_tensor of shape (2,1,3). In the op_tensor after squeezing operation, we have shape (2,3). In the input tensor ip_tensor the second dimension of size 1 has been dropped.

# Example 2 a
ip_tensor=torch.randn(2,1,3,1)
print(ip_tensor)
print(f"Input Tensor shape:{ip_tensor.shape}\n")
op_tensor=torch.squeeze(ip_tensor,dim=0)
print(op_tensor)
print(f"Output Tensor shape:{op_tensor.shape}\n")

# Output

tensor([[[[ 0.4133],
          [-0.6541],
          [-0.5506]]],


        [[[-1.1734],
          [-0.3823],
          [-0.8710]]]])
Input Tensor shape:torch.Size([2, 1, 3, 1])

tensor([[[[ 0.4133],
          [-0.6541],
          [-0.5506]]],


        [[[-1.1734],
          [-0.3823],
          [-0.8710]]]])
Output Tensor shape:torch.Size([2, 1, 3, 1])
Enter fullscreen mode Exit fullscreen mode

In the above example, we set the dimension argument dim=0. The input tensor ip_tensor has size=2 along dim=0. As we did not have size=1 along dim=0, there's no effect of squeezing operation on the tensor and the output tensor is identical to the input tensor.

# Example 2 b
ip_tensor=torch.randn(2,1,3,1)
print(ip_tensor)
print(f"Input Tensor shape:{ip_tensor.shape}\n")
op_tensor=torch.squeeze(ip_tensor,dim=1)
print(op_tensor)
print(f"Output Tensor shape:{op_tensor.shape}\n")

# Output

tensor([[[[-1.7004],
          [-0.1863],
          [ 1.1550]]],


        [[[-1.1890],
          [-0.4821],
          [-0.3731]]]])
Input Tensor shape:torch.Size([2, 1, 3, 1])

tensor([[[-1.7004],
         [-0.1863],
         [ 1.1550]],

        [[-1.1890],
         [-0.4821],
         [-0.3731]]])
Output Tensor shape:torch.Size([2, 3, 1])
Enter fullscreen mode Exit fullscreen mode

In the above example, we set the dimension argument dim=1.

As the tensor had size=1 along the first dimension, in the output tensor, that dimension was removed and the output tensor is of shape (2,3,1).

# Example 3
ip_tensor=torch.randn(2,3)
print(ip_tensor)
print(f"Input Tensor shape:{ip_tensor.size()}\n")
op_tensor=torch.squeeze(ip_tensor,dim=2)

# Output
tensor([[-0.0688, -0.7170, -1.5563],
        [-0.2138, -0.5387, -1.0245]])
Input Tensor shape:torch.Size([2, 3])
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-8-d312de22f1d3> in <module>()
      3 print(ip_tensor)
      4 print(f"Input Tensor shape:{ip_tensor.size()}\n")
----> 5 op_tensor=torch.squeeze(ip_tensor,dim=2)

IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)
Enter fullscreen mode Exit fullscreen mode

In the above example, we see that ip_tensor has shape (2,3) and is a 2-D tensor with dim=0,1 defined. As we tried to squeeze along dim=2 which does not exist in the original tensor, we get an IndexError.

In essence, squeeze function helps remove all dimensions of size 1 or along a specific dimension.


5. torch.unsquueze

torch.unsqueeze(input, dim)

Here, dim denotes the index at which we want the dimension of size 1 to be inserted.

This function returns a new tensor with a dimension of size one inserted at the specified position. The returned tensor shares the same underlying data with this tensor.

# Example 1
ip_tensor = torch.tensor([1, 2, 3, 4])
torch.unsqueeze(ip_tensor, 0)

# Output
tensor([[1, 2, 3, 4]])
Enter fullscreen mode Exit fullscreen mode

In this simple example shown above, unsqueeze inserts a singleton dimension at the specified index 0.

# Example 2
ip_tensor=torch.rand(2,3)
torch.unsqueeze(ip_tensor,2)

# Output
tensor([[[0.3670],
         [0.1786],
         [0.7115]],

        [[0.4241],
         [0.0422],
         [0.2277]]])
Enter fullscreen mode Exit fullscreen mode

In this simple example shown above, unsqueeze inserts a singleton dimension at the specified index 2 (the input is of dimension 2 (0,1) and we have inserted a new dimension of size 1 along dim=2).

# Example 3 
ip_tensor=torch.rand(2,3)
torch.unsqueeze(ip_tensor,3)

# Output
--------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-28-e1fcd3f3f58a> in <module>()
      1 # Example 3
      2 ip_tensor=torch.rand(2,3)
----> 3 torch.unsqueeze(ip_tensor,3)

IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)
Enter fullscreen mode Exit fullscreen mode

We get an index error as expected; This is because the argument dim can only take values upto input_dim+1. In this case, dim can take a maximum value of 2.

Thus unsqueeze function lets us insert dimension of size 1 at the required index.


In this post, we've tried to cover some useful functions that can be used for manipulating tensors. Hope you found it useful. Happy Learning!


References

[1] Official documentation for tensor operations
[2] A useful blog on basics of tensors: https://www.kdnuggets.com/2018/05/pytorch-tensor-basics.html

Cover Image: Photo by Annie Spratt on Unsplash

Top comments (0)