Mask2Former for Semantic Segmentation

This repository contains the Mask2Former model fine-tuned for semantic segmentation tasks. The model can be used to predict segmentation masks on input images and is based on the facebook/mask2former-swin-large-cityscapes-semantic pre-trained model.

Model Overview

Mask2Former is a general-purpose framework for mask prediction tasks, including:

  • Semantic Segmentation
  • Instance Segmentation
  • Panoptic Segmentation

This version has been fine-tuned and optimized for semantic segmentation tasks. You can use it for tasks such as road scene understanding, autonomous driving, and other segmentation-related applications.


How to Use the Model

You can use this model with the transformers library from Hugging Face. Below is an example to load the model, process an image, and visualize the output.

Installation

First, ensure you have the required libraries installed:

pip install transformers torch torchvision pillow matplotlib

How to use

Here is how to use this model:

from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
from PIL import Image
import torch
import matplotlib.pyplot as plt

# Load the processor and model
model_name = "saninmohammedn/mask2former-deployment"
processor = AutoImageProcessor.from_pretrained(model_name)
model = Mask2FormerForUniversalSegmentation.from_pretrained(model_name)

# Load an input image
image_path = "your_image.jpg"  # Replace with your image path
image = Image.open(image_path).convert("RGB")

# Prepare the image for the model
inputs = processor(images=image, return_tensors="pt")

# Perform inference
with torch.no_grad():
    outputs = model(**inputs)

# Post-process the predicted segmentation map
predicted_map = processor.post_process_semantic_segmentation(
    outputs, target_sizes=[image.size[::-1]]
)[0].cpu().numpy()

# Visualize the input and predicted segmentation map
plt.figure(figsize=(10, 5))

# Display original image
plt.subplot(1, 2, 1)
plt.imshow(image)
plt.title("Original Image")
plt.axis("off")

# Display predicted segmentation map
plt.subplot(1, 2, 2)
plt.imshow(predicted_map, cmap="jet")
plt.title("Predicted Segmentation Map")
plt.axis("off")

plt.tight_layout()
plt.show()
Downloads last month
29
Safetensors
Model size
216M params
Tensor type
I64
·
F32
·
Inference Examples
Unable to determine this model's library. Check the docs .