DEV Community

Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito)

Posted on • Updated on

requires_grad=True with a tensor, backward() and retain_grad() in PyTorch

requires_grad(bool, optional-Default:False) with True can enable a tensor to compute and accumulate its gradient as shown below:

*Memos:

  • There are a leaf tensor and non-leaf tensor.
  • data must be float or complex type with requires_grad=True.
  • backward() can do backpropagation. *Backpropagation is to calculate a gradient using the mean(average) of the sum of the losses(differences) between the model's predictions and true values(train data), working from output layer to input layer.
  • A gradient is accumulated each time backward() is called.
  • To call backward():
    • requires_grad must be True.
    • data must be the scalar(only one element) of float type of the 0D or more D tensor.
  • grad can get a gradient.
  • is_leaf can check if it's a leaf tensor or non-leaf tensor.
  • To call retain_grad(), requires_grad must be True.
  • To enable a non-leaf tensor to get a gradient without a warning using grad, retain_grad() must be called before it
  • Using retain_graph=True with backward() prevents error.

1 tensor with backward():

import torch

my_tensor = torch.tensor(data=7., requires_grad=True) # Leaf tensor

my_tensor, my_tensor.grad, my_tensor.is_leaf
# (tensor(7., requires_grad=True), None, True)

my_tensor.backward()

my_tensor, my_tensor.grad, my_tensor.is_leaf
# (tensor(7., requires_grad=True), tensor(1.), True)

my_tensor.backward()

my_tensor, my_tensor.grad, my_tensor.is_leaf
# (tensor(7., requires_grad=True), tensor(2.), True)

my_tensor.backward()

my_tensor, my_tensor.grad, my_tensor.is_leaf
# (tensor(7., requires_grad=True), tensor(3.), True)
Enter fullscreen mode Exit fullscreen mode

3 tensors with backward(retain_graph=True) and retain_grad():

import torch

tensor1 = torch.tensor(data=7., requires_grad=True) # Leaf tensor

tensor1, tensor1.grad, tensor1.is_leaf
# (tensor(7., requires_grad=True), None, True)

tensor1.backward()

tensor1, tensor1.grad, tensor1.is_leaf
# (tensor(7., requires_grad=True), tensor(1.), True)

tensor2 = tensor1 * 4 # Non-leaf tensor

tensor2.retain_grad()

tensor1, tensor1.grad, tensor1.is_leaf
# (tensor(7., requires_grad=True), tensor(1.), True)

tensor2, tensor2.grad, tensor2.is_leaf
# (tensor(28., grad_fn=<MulBackward0>), None, False)

tensor2.backward(retain_graph=True) # Important

tensor1, tensor1.grad, tensor1.is_leaf
# (tensor(7., requires_grad=True), tensor(5.), True)

tensor2, tensor2.grad, tensor2.is_leaf
# (tensor(28., grad_fn=<MulBackward0>), tensor(1.), False)

tensor3 = tensor2 * 5 # Non-leaf tensor

tensor3.retain_grad()

tensor1, tensor1.grad, tensor1.is_leaf
# (tensor(7., requires_grad=True), tensor(5.), True)

tensor2, tensor2.grad, tensor2.is_leaf
# (tensor(28., grad_fn=<MulBackward0>), tensor(1.), False)

tensor3, tensor3.grad, tensor3.is_leaf
# (tensor(140., grad_fn=<MulBackward0>), None, False)

tensor3.backward()

tensor1, tensor1.grad, tensor1.is_leaf
# (tensor(7., requires_grad=True), tensor(25.), True)

tensor2, tensor2.grad, tensor2.is_leaf
# (tensor(28., grad_fn=<MulBackward0>), tensor(6.), False)

tensor3, tensor3.grad, tensor3.is_leaf
# (tensor(140., grad_fn=<MulBackward0>), tensor(1.), False)
Enter fullscreen mode Exit fullscreen mode

In addition, 3 tensors with detach_() and requires_grad_(requires_grad=True) which doesn't retain gradients:

import torch

tensor1 = torch.tensor(data=7., requires_grad=True) # Leaf tensor

tensor1, tensor1.grad, tensor1.is_leaf
# (tensor(7., requires_grad=True), None, True)

tensor1.backward()

tensor1, tensor1.grad, tensor1.is_leaf
# (tensor(7., requires_grad=True), tensor(1.), True)

tensor2 = tensor1 * 4 # Non-leaf tensor

tensor2.retain_grad()

tensor1, tensor1.grad, tensor1.is_leaf
# (tensor(7., requires_grad=True), tensor(1.), True)

tensor2, tensor2.grad, tensor2.is_leaf
# (tensor(28., grad_fn=<MulBackward0>), None, False)

tensor2.backward()

tensor1, tensor1.grad, tensor1.is_leaf
# (tensor(7., requires_grad=True), tensor(5.), True)

tensor2, tensor2.grad, tensor2.is_leaf
# (tensor(28., grad_fn=<MulBackward0>), tensor(1.), False)

tensor3 = tensor2 * 5 # Non-leaf tensor
tensor3 = tensor3.detach_().requires_grad_(requires_grad=True) # Leaf tensor
                 # Important
tensor3.retain_grad()

tensor1, tensor1.grad, tensor1.is_leaf
# (tensor(7., requires_grad=True), tensor(5.), True)

tensor2, tensor2.grad, tensor2.is_leaf
# (tensor(28., grad_fn=<MulBackward0>), tensor(1.), False)

tensor3, tensor3.grad, tensor3.is_leaf
# (tensor(140., requires_grad=True), None, True)

tensor3.backward()

tensor1, tensor1.grad, tensor1.is_leaf
# (tensor(7., requires_grad=True), tensor(5.), True)

tensor2, tensor2.grad, tensor2.is_leaf
# (tensor(28., grad_fn=<MulBackward0>), tensor(1.), False)

tensor3, tensor3.grad, tensor3.is_leaf
# (tensor(140., requires_grad=True), tensor(1.), True)
Enter fullscreen mode Exit fullscreen mode

In addtion, you can manually set a gradient to a tensor whether requires_grad is True or False as shown below:
*Memos:

  • A gradient must be:
    • a tensor.
    • the same type and size as its tensor.

float:

import torch

my_tensor = torch.tensor(data=7., requires_grad=True)

my_tensor, my_tensor.grad, my_tensor.is_leaf
# (tensor(7., requires_grad=True), None, True)

my_tensor.grad = torch.tensor(data=4.)

my_tensor, my_tensor.grad, my_tensor.is_leaf
# (tensor(7., requires_grad=True), tensor(4.), True)

my_tensor = torch.tensor(data=7., requires_grad=False)

my_tensor, my_tensor.grad, my_tensor.is_leaf
# (tensor(7.), None, True)

my_tensor.grad = torch.tensor(data=4.)

my_tensor, my_tensor.grad, my_tensor.is_leaf
# (tensor(7.), tensor(4.), True)
Enter fullscreen mode Exit fullscreen mode

complex:

import torch

my_tensor = torch.tensor(data=7.+0.j, requires_grad=True)

my_tensor, my_tensor.grad, my_tensor.is_leaf
# (tensor(7.+0.j, requires_grad=True), None, True)

my_tensor.grad = torch.tensor(data=4.+0.j)

my_tensor, my_tensor.grad, my_tensor.is_leaf
# (tensor(7.+0.j, requires_grad=True), tensor(4.+0.j), True)

my_tensor = torch.tensor(data=7.+0.j, requires_grad=False)

my_tensor, my_tensor.grad, my_tensor.is_leaf
# (tensor(7.+0.j), None, True)

my_tensor.grad = torch.tensor(data=4.+0.j)

my_tensor, my_tensor.grad, my_tensor.is_leaf
# (tensor(7.+0.j), tensor(4.+0.j), True)
Enter fullscreen mode Exit fullscreen mode

Top comments (0)