Fine-tuning is a powerful technique that allows you to adapt a pre-trained model to a new task, saving time and resources. This tutorial will guide you through fine-tuning a ResNet18 model for digit classification using PyTorch.
Step 1: Setting Up the Environment and Model
First, let's import the required libraries and set up the pre-trained model.
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision import models
# Load a pre-trained ResNet18 model
model = models.resnet18(pretrained=True)
# Modify the last layer to match MNIST classes (10 classes)
model.fc = nn.Linear(model.fc.in_features, 10)
# Set the model to training mode and use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
Explanation: We are using ResNet18, which is pre-trained on ImageNet. The last layer is modified to fit our new task (10 classes for digits 0-9).
Step 2: Preparing the Dataset
We need to resize the MNIST images to 224x224 because ResNet18 expects this size.
# Transform images to 224x224 and normalize
transform = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# Load the MNIST dataset
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)
Explanation: We resize images and normalize them to fit what the model expects. We also create data loaders for training and testing.
Step 3: Setting Up the Loss Function and Optimizer
Define the loss function and optimizer. We use Adam, a popular choice for fine-tuning.
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Learning rate scheduler to adjust the learning rate
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
Explanation: CrossEntropyLoss is used for classification tasks, and Adam is chosen for optimization. The scheduler helps adjust the learning rate during training.
Step 4: Fine-Tuning the Model
Train the model for a few epochs to fine-tune it on the new task.
# Fine-tune the model
num_epochs = 5
#set num_epochs to a smaller number like 1 and use T4 GPU, wait for 4 minute may be, if training is taking too long in colab.
for epoch in range(num_epochs):
running_loss = 0.0
for images, labels in trainloader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
# Step the scheduler after each epoch
scheduler.step()
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(trainloader):.4f}")
print('Fine-tuning complete!')
Explanation: We loop through the dataset, perform forward and backward passes, and update the model's weights. This step adapts the pre-trained model to our specific task.
Step 5: Saving the Fine-Tuned Model
Save the fine-tuned model for later use.
# Save the fine-tuned model
torch.save(model.state_dict(), 'finetuned_resnet18_mnist.pth')
print('Model saved!')
Explanation: We save the model's state dictionary (parameters) to a file. This allows us to load it later without retraining.
Step 6: Evaluating the Model
Check how well the model performs on unseen data.
# Set the model to evaluation mode
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in testloader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f'Accuracy of the fine-tuned model on the test images: {100 * correct / total:.2f}%')
Explanation: We set the model to evaluation mode and calculate accuracy on the test set to see how well the model generalizes to new data.
Step 7: Making Predictions
Finally, use the fine-tuned model to make predictions on a single image.
# Load the model for inference
model = models.resnet18()
model.fc = nn.Linear(model.fc.in_features, 10)
model.load_state_dict(torch.load('finetuned_resnet18_mnist.pth'))
model.eval()
model = model.to(device)
# Make a prediction on a single image from the test set
test_image, _ = testset[0] # Get the first image from the test set
test_image = test_image.unsqueeze(0).to(device) # Add a batch dimension and move to device
output = model(test_image)
_, predicted = torch.max(output, 1)
print('Predicted label:', predicted.item())
Explanation: We load the saved model, set it to evaluation mode, and use it to predict the class of a new image.
Conclusion
Fine-tuning is an efficient way to adapt powerful pre-trained models to new tasks with minimal effort. In this guide, you learned how to:
1.Set up a pre-trained model and modify it for a new task.
2.Prepare and load your dataset.
3.Train and fine-tune the model with a new dataset.
4.Save and load the model for future use.
5.Make predictions using the fine-tuned model.
Hope you found this post helpful and enjoyable.
Thank you!
Top comments (0)