Introduction
In this article, we will walk through the steps of training a simple neural network on the MNIST dataset using PyTorch and then deploying it with Gradio for interactive predictions. The MNIST dataset is a popular dataset in the field of machine learning that consists of 70,000 28x28 grayscale images of handwritten digits.
Training a Neural Network with PyTorch
PyTorch is an open-source deep learning framework developed by Facebook's artificial intelligence research group. It provides a wide range of functionalities for building and training neural networks.
Step 1: Import necessary libraries
First, we need to import PyTorch, torchvision (a package with popular datasets, model architectures, and common image transformations), and some other necessary libraries:
import torch
import torchvision
import torchvision.transforms as transforms
from torch import nn, optim
Step 2: Load the dataset
Next, we load the MNIST dataset. We'll use torchvision's built-in functionality to do this. We also apply transformations to normalize the data:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)
Step 3: Define the network
We'll define a simple feed-forward neural network with one hidden layer:
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(28 * 28, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = x.view(-1, 28 * 28)
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
net = Net()
Step 4: Define the loss function and optimizer
We'll use CrossEntropyLoss for our loss function and SGD for our optimizer:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
Step 5: Train the network
Now we're ready to train our network:
for epoch in range(10):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f'Epoch {epoch+1}, loss: {running_loss/len(trainloader)}')
print('Finished Training')
Deploying with Gradio
Gradio is an open-source library for creating customizable UI components around your ML models. It allows us to demonstrate a model’s functionality in an intuitive manner.
Step 1: Install Gradio
!pip install gradio
Step 2: Import Gradio and define the prediction function
import gradio as gr
def predict(image):
image = image.reshape(1, 1, 28, 28)
image = torch.from_numpy(image).float()
output = net(image)
_, predicted = torch.max(output.data, 1)
return predicted.item()
In the predict function, we take the input image, reshape it to match our model's expected input shape, convert it to a torch tensor, pass it through our model to get the output, and then return the predicted digit.
Step 3: Define the Gradio interface
Now, we define the interface for our model. We'll use an 'Image' input interface and a 'Label' output interface:
iface = gr.Interface(
fn=predict,
inputs=gr.inputs.Image(shape=(28, 28), invert_colors=True, source="canvas"),
outputs="label",
interpretation="default"
)
The 'Image' input interface lets users draw an image with their mouse. We set invert_colors=True
because the MNIST dataset consists of white digits on a black background, and by default, the Gradio image interface has a white background.
Step 4: Launch the interface
Finally, we launch the interface:
iface.launch()
With this, you should see an interactive interface where you can draw a digit and see the prediction from your PyTorch model.
Conclusion
In this article, we saw how to train a simple neural network using PyTorch and then deploy it with Gradio for interactive predictions. This combination allows us to leverage the power of deep learning models in an easy-to-use and interpret manner.
Top comments (0)