How to Use Class Weights with Focal Loss in PyTorch for Imbalanced MultiClass Classification
As a data scientist or software engineer, you may come across a common problem in classification tasks where the dataset is imbalanced. In such cases, the majority class dominates the training process, leading to poor performance on the minority class. One way to deal with this issue is to use class weights to balance the contribution of each class during training. In this blog post, we will discuss how to use class weights with focal loss in PyTorch for imbalanced multiclass classification.
Table of Contents
- Introduction
- What is Imbalanced Multiclass Classification?
- Why Use Class Weights?
- What is Focal Loss?
- How to Use Class Weights with Focal Loss in PyTorch
- Conclusion
What is Imbalanced Multiclass Classification?
Multiclass classification is a common task in machine learning where the goal is to predict the class of a given input among multiple classes. In an imbalanced multiclass classification problem, the dataset has an unequal distribution of examples across different classes. For instance, consider a medical diagnosis problem where the task is to classify a patient’s disease into one of the ten possible classes. If one of the diseases is rare, then the dataset may have fewer examples for that class, making it an imbalanced dataset.
Why Use Class Weights?
In an imbalanced dataset, the model trained on such data often becomes biased towards the majority class. As a result, the model may struggle to predict minority classes accurately. One way to address this issue is to use class weights. Class weights assign a weight to each class based on its frequency in the dataset. The idea is to give more weight to the minority class and less weight to the majority class during training. This helps the model to learn equally from all classes and improves the performance on the minority class.
What is Focal Loss?
Focal loss is a loss function designed to handle imbalanced datasets. It was introduced in the paper “Focal Loss for Dense Object Detection” by Lin et al. in 2017. Focal loss works by down-weighting easy examples and focusing on hard examples. The idea is to assign a higher weight to misclassified examples and a lower weight to correctly classified examples. This helps the model to focus on hard examples and improves its performance on the minority class.
How to Use Class Weights with Focal Loss in PyTorch
PyTorch is a popular deep learning framework that provides a flexible and efficient way to build and train deep learning models. In PyTorch, we can use class weights with focal loss to handle imbalanced datasets.
Step 1: Load the Dataset
The first step is to load the dataset. In this example, we create an imbalanced dataset.
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
# Create a hypothetical imbalanced dataset
X, y = make_classification(
n_samples=1000, # Total number of examples
n_features=20, # Number of features
n_informative=10, # Number of informative features
n_redundant=5, # Number of redundant features
weights=[0.8, 0.2], # Class distribution (80% in Class 0, 20% in Class 1)
random_state=42
)
# Split the dataset into training and testing sets
x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.2, shuffle=True, random_state=42)
# Visualize the class distribution in the training set
import matplotlib.pyplot as plt
import numpy as np
class_labels, class_counts = np.unique(y_train, return_counts=True)
plt.bar(class_labels, class_counts, color='blue')
plt.xlabel('Class Label')
plt.ylabel('Number of Examples')
plt.title('Distribution of Examples Across Classes (Training Data)')
plt.show()
Step 2: Calculate Class Weights
Once we have loaded the dataset, we need to calculate the class weights. We can do this by counting the number of examples in each class and dividing it by the total number of examples. We can then take the inverse of the class frequency to get the class weight.
import numpy as np
class_counts = np.bincount(y_train)
num_classes = len(class_counts)
total_samples = len(y_train)
class_weights = []
for count in class_counts:
weight = 1 / (count / total_samples)
class_weights.append(weight)
Step 3: Define the Focal Loss Function
Next, we need to define the focal loss function. We can do this by subclassing the PyTorch nn.Module
class and implementing the forward method.
import torch
import torch.nn as nn
import torch.nn.functional as F
class FocalLoss(nn.Module):
def __init__(self, alpha=None, gamma=2):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
def forward(self, inputs, targets):
ce_loss = F.cross_entropy(inputs, targets, reduction='none')
pt = torch.exp(-ce_loss)
loss = (self.alpha[targets] * (1 - pt) ** self.gamma * ce_loss).mean()
return loss
Step 4: Train the Model
Finally, we can train the model using the class weights and focal loss function.
import torch.optim as optim
class_weights = torch.FloatTensor(class_weights)
model = nn.Sequential(
nn.Linear(20, 64),
nn.ReLU(),
nn.Linear(64, num_classes)
)
optimizer = optim.Adam(model.parameters())
criterion = FocalLoss(alpha=class_weights, gamma=2)
for epoch in range(100):
optimizer.zero_grad()
inputs = torch.FloatTensor(x_train)
targets = torch.LongTensor(y_train)
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
if epoch % 10 == 0:
print(f"Epoch: {epoch}, Loss: {loss.item()}")
Output:
Epoch: 0, Loss: 0.444484144449234
Epoch: 10, Loss: 0.26395726203918457
Epoch: 20, Loss: 0.2063663899898529
Epoch: 30, Loss: 0.1757550984621048
Epoch: 40, Loss: 0.15498967468738556
Epoch: 50, Loss: 0.13924439251422882
Epoch: 60, Loss: 0.12659738957881927
Epoch: 70, Loss: 0.11594895273447037
Epoch: 80, Loss: 0.1070123016834259
Epoch: 90, Loss: 0.09911686927080154
After 100 epochs of training, we observe a consistent decrease in the loss values, indicating that the model is learning and adapting to the training data. The initial loss of 0.44 has progressively reduced to 0.09, demonstrating the effectiveness of utilizing class weights with focal loss in addressing the challenges posed by imbalanced datasets. The diminishing loss values signify that the model is successfully mitigating the bias towards the majority class, thereby improving its ability to make accurate predictions across all classes. This training process demonstrates the practical application of class weights and focal loss in PyTorch for achieving better performance on imbalanced multiclass classification tasks.
Conclusion
In this blog post, we have discussed how to use class weights with focal loss in PyTorch for imbalanced multiclass classification. We started by introducing the problem of imbalanced datasets and why class weights are necessary. We then discussed focal loss and how it works. Finally, we showed how to implement class weights with focal loss in PyTorch using the famous Iris dataset. By following the steps outlined in this blog post, you can improve the performance of your model on imbalanced datasets and achieve better results on minority classes.
About Saturn Cloud
Saturn Cloud is your all-in-one solution for data science & ML development, deployment, and data pipelines in the cloud. Spin up a notebook with 4TB of RAM, add a GPU, connect to a distributed cluster of workers, and more. Request a demo today to learn more.
Saturn Cloud provides customizable, ready-to-use cloud environments for collaborative data teams.
Try Saturn Cloud and join thousands of users moving to the cloud without
having to switch tools.