DEV Community

Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito)

Posted on • Updated on

Transformer in PyTorch

Buy Me a Coffee

*Memos:

Transformer() can get the 2D or 3D tensor of the one or more elements computed by Transformer from the 2D or 3D tensor of one or more elements as shown below:

*Memos:

  • The 1st argument for initialization is d_model(Optional-Default:512:Type:int): *Memos:
    • It must be 1 <= x.
    • It must be same as the number of the elements of the deepest dimension of src and tgt.
    • It must be divisible by nhead.
  • The 2nd argument for initialization is nhead(Optional-Default:8-Type:int). *It must be 1 <= x.
  • The 3rd argument for initialization is num_encoder_layers(Optional-Default:6-Type:int). *It must be 1 <= x.
  • The 4th argument for initialization is num_decoder_layers(Optional-Default:6-Type:int). *It must be 1 <= x.
  • The 5th argument for initialization is dim_feedforward(Optional-Default:2048-Type:int): *Memos:
    • It must be 0 <= x.
    • 0 does nothing.
  • The 6th argument for initialization is dropout(Optional-Default:0.1-Type:int or float). *It must be 0 <= x <= 1.
  • The 7th argument for initialization is activation(Optional-Default:'relu'-Type:str or activation function): *Memos: -'relu' or 'gelu' can be set for str.
  • The 8th argument for initialization is custom_encoder(Optional-Default:None-Type:transformer encoder). *TransformerEncoder() can be set.
  • The 9th argument for initialization is custom_decoder(Optional-Default:None-Type:transformer decoder). *TransformerDecoder() can be set.
  • The 10th argument for initialization is layer_norm_eps(Optional-Default:1e-05-Type:int or float).
  • The 11th argument for initialization is batch_first(Optional-Default:False-Type:bool).
  • The 12th argument for initialization is norm_first(Optional-Default:False-Type:bool).
  • The 13th argument for initialization is bias(Optional-Default:True-Type:bool). *My post explains bias argument.
  • The 14th argument for initialization is device(Optional-Default:None-Type:str, int or device()): *Memos:
  • The 15th argument for initialization is dtype(Optional-Default:None-Type:dtype): *Memos:
  • The 1st argument is src(Required-Type:tensor of float): *Memos:
    • It must be the 2D or 3D tensor of one or more elements.
    • Its D must be same as tgt's.
    • The number of the elements of the deepest dimension must be same as d_model and tgt's.
    • Its device and dtype must be same as tgt and Transformer()'s.
    • The tensor's requires_grad which is False by default is set to True by Transformer().
  • The 2nd argument is tgt(Required-Type:tensor of float): *Memos:
    • It must be the 2D or 3D tensor of one or more elements.
    • Its D must be same as src's.
    • The number of the elements of the deepest dimension must be same as d_model and src's.
    • Its device and dtype must be same as src and Transformer()'s.
    • The tensor's requires_grad which is False by default is set to True by Transformer().
  • The 3rd argument is src_mask(Optional-Default:None:Type:tensor of float or bool). *It must be the 2D or 3D tensor of one or more elements.
  • The 4th argument is tgt_mask(Optional-Default:None:Type:tensor of float or bool). *It must be the 2D or 3D tensor of one or more elements.
  • The 5th argument is memory_mask(Optional-Default:None:Type:tensor of float or bool). *It must be the 2D or 3D tensor of one or more elements.
  • The 6th argument is src_key_padding_mask(Optional-Default:None:Type:tensor of float or bool). *It must be the 1D tensor of one or more elements.
  • The 7th argument is tgt_key_padding_mask(Optional-Default:None:Type:tensor of float or bool). *It must be the 1D tensor of one or more elements.
  • The 8th argument is memory_key_padding_mask(Optional-Default:None:Type:tensor of float or bool). *It must be the 1D tensor of one or more elements.
  • The 9th argument is src_is_causal(Optional-Default:None:Type:bool).
  • The 10th argument is tgt_is_causal(Optional-Default:None:Type:bool).
  • The 11th argument is memory_is_causal(Optional-Default:False:Type:bool).
  • The device and dtype(float) of src_mask, tgt_mask, memory_mask, tgt_mask memory_mask, src_key_padding_mask, src_key_padding_mask, tgt_key_padding_mask and memory_key_padding_mask must be same as Transformer()'s, d_model's, src's and tgt's.
  • The dtype(bool) of src_mask, tgt_mask, memory_mask, tgt_mask memory_mask, src_key_padding_mask, src_key_padding_mask, tgt_key_padding_mask and memory_key_padding_mask must be the same.
  • tran1.device and tran1.dtype don't work.
import torch
from torch import nn

tensor1 = torch.tensor([[8., -3., 0., 1.]])
tensor2 = torch.tensor([[5., 9., -4., 8.],
                        [-2., 7., 3., 6.]])
tensor1.requires_grad
tensor2.requires_grad
# False

torch.manual_seed(42)

tran1 = nn.Transformer(d_model=4, nhead=2)

tensor3 = tran1(src=tensor1, tgt=tensor2)
tensor3
# tensor([[1.5608, 0.1450, -0.6434, -1.0624],
#         [0.8815, 1.0994, -1.1523, -0.8286]],
#        grad_fn=<NativeLayerNormBackward0>)

tensor3.requires_grad
# True

tran1
# Transformer(
#   (encoder): TransformerEncoder(
#     (layers): ModuleList(
#       (0-5): 6 x TransformerEncoderLayer(
#         (self_attn): MultiheadAttention(
#           (out_proj): NonDynamicallyQuantizableLinear(
#                         in_features=4, out_features=4, bias=True
#                       )
#         )
#         (linear1): Linear(in_features=6, out_features=2048, bias=True)
#         (dropout): Dropout(p=0.1, inplace=False)
#         (linear2): Linear(in_features=2048, out_features=6, bias=True)
#         (norm1): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
#         (norm2): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
#         (dropout1): Dropout(p=0.1, inplace=False)
#         (dropout2): Dropout(p=0.1, inplace=False)
#       )
#     )
#     (norm): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
#   )
#   (decoder): TransformerDecoder(
#     (layers): ModuleList(
#       (0-5): 6 x TransformerDecoderLayer(
#         (self_attn): MultiheadAttention(
#           (out_proj): NonDynamicallyQuantizableLinear(
#                         in_features=4, out_features=4, bias=True
#                       )
#         )
#         (multihead_attn): MultiheadAttention(
#           (out_proj): NonDynamicallyQuantizableLinear(
#                         in_features=4, out_features=4, bias=True
#                       )
#         )
#         (linear1): Linear(in_features=4, out_features=2048, bias=True)
#         (dropout): Dropout(p=0.1, inplace=False)
#         (linear2): Linear(in_features=2048, out_features=4, bias=True)
#         (norm1): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
#         (norm2): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
#         (norm3): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
#         (dropout1): Dropout(p=0.1, inplace=False)
#         (dropout2): Dropout(p=0.1, inplace=False)
#         (dropout3): Dropout(p=0.1, inplace=False)
#       )
#     )
#     (norm): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
#   )
# )

tran1.encoder
# TransformerEncoder(
#   (layers): ModuleList(
#     (0-5): 6 x TransformerEncoderLayer(
#       (self_attn): MultiheadAttention(
#         (out_proj): NonDynamicallyQuantizableLinear(
#                       in_features=4, out_features=4, bias=True
#                     )
#       )
#       (linear1): Linear(in_features=4, out_features=2048, bias=True)
#       (dropout): Dropout(p=0.1, inplace=False)
#       (linear2): Linear(in_features=2048, out_features=6, bias=True)
#       (norm1): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
#       (norm2): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
#       (dropout1): Dropout(p=0.1, inplace=False)
#       (dropout2): Dropout(p=0.1, inplace=False)
#     )
#   )
#   (norm): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
# )

tran1.decoder
# TransformerDecoder(
#   (layers): ModuleList(
#     (0-5): 6 x TransformerDecoderLayer(
#       (self_attn): MultiheadAttention(
#         (out_proj): NonDynamicallyQuantizableLinear(
#                       in_features=4, out_features=4, bias=True
#                     )
#       )
#       (multihead_attn): MultiheadAttention(
#         (out_proj): NonDynamicallyQuantizableLinear(
#                       in_features=4, out_features=4, bias=True
#                     )
#       )
#       (linear1): Linear(in_features=4, out_features=2048, bias=True)
#       (dropout): Dropout(p=0.1, inplace=False)
#       (linear2): Linear(in_features=2048, out_features=6, bias=True)
#       (norm1): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
#       (norm2): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
#       (norm3): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
#       (dropout1): Dropout(p=0.1, inplace=False)
#       (dropout2): Dropout(p=0.1, inplace=False)
#       (dropout3): Dropout(p=0.1, inplace=False)
#     )
#   )
#   (norm): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
# )

tran1.d_model 
# 4

tran1.nhead 
# 2

tran1.batch_first
# False

torch.manual_seed(42)

tran2 = nn.Transformer(d_model=4, nhead=2)

tran1(src=tensor2, tgt=tensor3)
# tensor([[-0.8631, 1.6747, -0.6517, -0.1599],
#         [-0.0919, 1.6377, -0.5336, -1.0122]],
#        grad_fn=<NativeLayerNormBackward0>)

torch.manual_seed(42)

tran = nn.Transformer(d_model=4, nhead=2, num_encoder_layers=6,
          num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, 
          activation='relu', custom_encoder=None, custom_decoder=None, 
          layer_norm_eps=1e-05, batch_first=False, norm_first=False, 
          bias=True, device=None, dtype=None)
tran(src=tensor1, tgt=tensor2, src_mask=None, tgt_mask=None, 
     memory_mask=None, src_key_padding_mask=None,
     tgt_key_padding_mask=None, memory_key_padding_mask=None,
     src_is_causal=None, tgt_is_causal=None, memory_is_causal=False)
# tensor([[1.5608, 0.1450, -0.6434, -1.0624],
#         [0.8815, 1.0994, -1.1523, -0.8286]],
#        grad_fn=<NativeLayerNormBackward0>)

tensor1 = torch.tensor([[8., -3.], [0., 1.]])
tensor2 = torch.tensor([[5., 9.], [-4., 8.],
                        [-2., 7.], [3., 6.]])
torch.manual_seed(42)

tran = nn.Transformer(d_model=2, nhead=2)
tran(src=tensor1, tgt=tensor2)
# tensor([[1.0000, -1.0000],
#         [-1.0000, 1.0000],
#         [-1.0000, 1.0000],
#         [-1.0000, 1.0000]], grad_fn=<NativeLayerNormBackward0>)

tensor1 = torch.tensor([[8.], [-3.], [0.], [1.]])
tensor2 = torch.tensor([[5.], [9.], [-4.], [8.],
                        [-2.], [7.], [3.], [6.]])
torch.manual_seed(42)

tran = nn.Transformer(d_model=1, nhead=1)
tran(src=tensor1, tgt=tensor2)
# tensor([[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]],
#        grad_fn=<NativeLayerNormBackward0>)

tensor1 = torch.tensor([[[8.], [-3.], [0.], [1.]]])
tensor2 = torch.tensor([[[5.], [9.], [-4.], [8.]],
                        [[-2.], [7.], [3.], [6.]]])
torch.manual_seed(42)

tran = nn.Transformer(d_model=1, nhead=1)
tran(src=tensor1, tgt=tensor2)
# tensor([[[0.], [0.], [0.], [0.]],
#         [[0.], [0.], [0.], [0.]]], grad_fn=<NativeLayerNormBackward0>)
Enter fullscreen mode Exit fullscreen mode

Transformer().generate_square_subsequent_mask() can get the 2D tensor of the zero or more 0.(Default), 0.+0.j or False and -inf(Default), -inf+0.j or True as shown below:

*Memos:

  • The 1st argument is sz(Required-Type:int). *It must be 0 <= x.
  • The 2nd argument for initialization is device(Optional-Default:None-Type:str, int or device()): *Memos:
    • If it's None, cpu is set.
    • device= can be omitted.
    • My post explains device argument.
  • The 3rd argument for initialization is dtype(Optional-Default:None-Type:dtype): *Memos:
    • If it's None, float32 is set.
    • dtype= can be omitted.
    • My post explains dtype argument.
import torch
from torch import nn

tran = nn.Transformer()

tran.generate_square_subsequent_mask(sz=3)
tran.generate_square_subsequent_mask(sz=3, device=None, dtype=None)
# tensor([[0., -inf, -inf],
#         [0., 0., -inf],
#         [0., 0., 0.]])

tran1.generate_square_subsequent_mask(sz=5)
# tensor([[0., -inf, -inf, -inf, -inf],
#         [0., 0., -inf, -inf, -inf],
#         [0., 0., 0., -inf, -inf],
#         [0., 0., 0., 0., -inf],
#         [0., 0., 0., 0., 0.]])

tran1.generate_square_subsequent_mask(sz=5, dtype=torch.complex64)
# tensor([[0.+0.j, -inf+0.j, -inf+0.j, -inf+0.j, -inf+0.j],
#         [0.+0.j, 0.+0.j, -inf+0.j, -inf+0.j, -inf+0.j],
#         [0.+0.j, 0.+0.j, 0.+0.j, -inf+0.j, -inf+0.j],
#         [0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, -inf+0.j],
#         [0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j]])

tran1.generate_square_subsequent_mask(sz=5, dtype=torch.bool)
# tensor([[False, True, True, True, True],
#         [False, False, True, True, True],
#         [False, False, False, True, True],
#         [False, False, False, False, True],
#         [False, False, False, False, False]])
Enter fullscreen mode Exit fullscreen mode

Top comments (0)