Improvement of Segmentation Algorithms using Attention and Loss Combinations
Designed, trained, and evaluated LiDAR segmentation models (Cylinder3D and PointNet) with novel loss functions and attention-based enhancements on SemanticKITTI dataset using PyTorch and MMDetection3D.
Overview
This project focused on improving LiDAR-based point cloud segmentation using deep learning techniques. I implemented and compared multiple state-of-the-art segmentation models including Cylinder3D and PointNet, enhanced with attention mechanisms and custom loss functions. The entire pipeline was built and tested on the SemanticKITTI dataset, using both the MMDetection3D framework and custom PyTorch modules.
Models Explored
- Cylinder3D: Efficient LiDAR segmentation using cylindrical voxel representation. Implemented via MMDetection3D.
- PointNet: Lightweight point cloud processing model with TNet-based feature transformation. Fully implemented in PyTorch.
- Group-Free 3D (Transformer): Attention-based mechanism adapted into the backbone of Cylinder3D and PointNet.
- OpenMMDetection3D: Used as the base framework to configure and run Cylinder3D models with custom components.
Our Contribution
We extended baseline implementations by adding self-attention modules and experimenting with loss functions such as Focal Loss, Dice Loss, and Jaccard Loss. These were embedded into both Cylinder3D and PointNet to study their impact on segmentation quality and training efficiency.
Implementation Details
MMDetection3D Setup
I faced significant challenges installing and configuring MMDetection3D due to compatibility issues with CUDA, Torch, and supporting libraries. After extensive trials with different environments (Windows, WSL, Docker), I finalized a working configuration with CUDA 11.6 on WSL for a stable build.
Cylinder3D
Using MMDetection3D, I customized the config.py to define a light-weight Cylinder3D model and manually integrated the transformer-based attention mechanism from Group-Free 3D. I tested two approaches: using MMDetection3D’s native Group-Free class (unsuccessful) and directly modifying the sparse convolution blocks (successful).
PointNet
I fully implemented PointNet in PyTorch with modular TNet and self-attention layers. The architecture supports toggling attention on/off and uses a dynamic configuration for custom loss functions.
Attention & Loss Modules
- Attention implemented using single-head self-attention.
- Losses: Focal + Dice (baseline), and optional Jaccard loss toggle.
- Training metrics, weights, and checkpoints saved at every epoch to enable safe recovery during GPU crashes.
Dataset
SemanticKITTI was used as the primary dataset. It consists of large-scale LiDAR sequences with semantic labels. Initial training attempts failed due to label irregularities, which were resolved by remapping to 20 core classes using the official remapper.
I also explored the nuScenes dataset, though I couldn't fully integrate it with MMDetection3D's loader due to format issues.
Compute Resources
- Cylinder3D training: NVIDIA RTX 3080, 4 epochs in 22 hours
- PointNet training: NVIDIA RTX 4060, 30–40 epochs per model in ~10 hours
Models Trained
Cylinder3D
- Baseline: 91.10% Accuracy, mIoU: 0.5203
- With Attention: 80.16% Accuracy, mIoU: 0.3472 (slower and worse performance)
PointNet
- Focal + Dice Loss – good baseline performance
- Focal + Dice + Jaccard Loss – faster convergence, smoother training
- Above + Self-Attention – increased training time, no consistent visual improvement
Analysis
- Jaccard loss improved segmentation consistency and training smoothness.
- Attention did not produce performance gains; visualizations suggested potential overfitting or misalignment of feature weighting.
- Manual attention tuning was computationally expensive and difficult to optimize without high-end GPUs.
Lessons Learned
- Self-attention must be carefully integrated into existing models to avoid performance degradation.
- Lightweight segmentation models benefit significantly from the right combination of loss functions.
- Installing and customizing frameworks like MMDetection3D on limited hardware requires deep understanding of versioning and GPU architecture constraints.
- Working with raw LiDAR point clouds offered valuable hands-on experience in 3D data processing and 3D vision pipelines.