Types in Python are optional, but for any PyTorch project exceeding 1,000 lines of code, implementing them can be highly beneficial.
You've probably encountered a scenario where, after coding a model with several networks and initiating training, you realize there's a mismatch in data dimensions or type. Types help clarify the information flow. Moreover, they enable code completion tools to suggest more appropriate options.
While types won't eliminate debugging, they do make the process more focused.
Python Types
First, check if your editor or IDE supports static type checking. If not, a quick online search will reveal popular options like Mypy
or Pyright
.
Python types are akin to a simplified version of the typing systems in Java or C++.
Here's how the syntax looks:
variable: int = 10
For functions, specify the types of parameters and the return type:
def function(x: float, y: float) -> float:
return x ** y
Below is a list of the most fundamental Python types you'll encounter:
int # Whole number integers
float # Floating point numbers
bool # Boolean values (True or False)
str # Strings
list # Ordered mutable lists
tuple # Ordered immutable lists
dict # Dictionaries of key-value pairs
For the latter three, use the Python typing package for more precise type specifications:
from typing import Dict, List, Tuple
Example usages:
a_dict: Dict[str, float] = {'a': 1.0, ...}
a_list: List[int] = [1, 2, 3]
a_tuple: Tuple[int, bool, float] = (1, True, 3.141)
def forward(input: List[float]) -> Dict[str, float]:
...
For dealing with uncertain types, Any
is a versatile type that accepts any value. Use it sparingly to maintain the effectiveness of your type checker.
The Self
type represents the current class:
import numpy as np
from typing import Self, List, Tuple
@dataclass
class PointCloud:
points: List[Tuple[float, float, float]]
def from_numpy(arr: np.ndarray) -> Self:
...
Frequently recurring patterns in your codebase may warrant the creation of custom types from basic types:
from typing import Dict, List
SomeType = Dict[str, List[int]]
var: SomeType = ...
However, using expressive classes as types is often more sensible:
pcd: PointCloud = PointCloud([(1.0, 1.0, 1.0), ...]) # example from above
Be aware that your static type checker can often automatically identify the correct class upon assignment. However, use explicit types in scenarios where automatic deduction is not possible.
Types for PyTorch
A clean architecture combined with proper types elevates the professionalism of your PyTorch codebases significantly.
Key base classes in PyTorch you should be aware of:
torch.Tensor
torch.optim.Optimizer
torch.utils.data.Dataset
torch.nn.Module
For more details, refer to the official PyTorch documentation or explore the PyTorch codebase using your code editor's symbol search feature.
Dataclasses
are extremely useful for defining types. They allow for quick class definitions intended for data holding and validation, usually without an __init__
method. Validation can be implemented using the __post_init__
method.
Example:
import torch
from dataclasses import dataclass
@dataclass
class PointCloud:
points: torch.Tensor
colors: torch.Tensor
def __post_init__(self):
assert self.points.shape == torch.Size([...])
...
Employing these types for network inputs and outputs significantly reduced my debugging time. I hope they will prove just as beneficial for you.
Common Troubleshooting for DataLoaders
If you adopt a strict dataclass
typing strategy in PyTorch, you might face a unique challenge. The PyTorch DataLoader
typically expects data to consist of Tensors
or iterable collections of Tensors
. To resolve this, reimplement the default collate_fn
and provide it to the DataLoader
. Here's an example:
@dataclass
class PointCloud:
position: torch.Tensor
color: torch.Tensor
def custom_collate_fn(batch: List[PointCloud]):
positions = torch.stack([pcd.position for pcd in batch])
colors = torch.stack([pcd.color for pcd in batch])
return PointCloud(positions, colors)
...
loader = DataLoader(..., collate_fn=custom_collate_fn)
For distinguishing between batched and unbatched types, I usually rely on simple shape assertions.
With these minor adjustments, you'll find that your types work harmoniously with the PyTorch framework.
Let me know if you use types in your code and how they've benefited your projects.
Top comments (0)