|
--- |
|
license: mit |
|
library_name: pytorch |
|
tags: |
|
- computer-vision |
|
- image-fusion |
|
- multi-focus |
|
- transformer |
|
- pytorch |
|
- focal-transformer |
|
- crossvit |
|
- attention |
|
- vision-transformer |
|
- image-processing |
|
pipeline_tag: image-to-image |
|
datasets: |
|
- lytro-multi-focus |
|
metrics: |
|
- psnr |
|
- ssim |
|
- qabf |
|
- structural-fitness |
|
- vif |
|
- mutual-information |
|
language: |
|
- en |
|
base_model: [] |
|
model-index: |
|
- name: HybridTransformer-MFIF |
|
results: |
|
- task: |
|
type: image-fusion |
|
name: Multi-Focus Image Fusion |
|
dataset: |
|
type: lytro-multi-focus |
|
name: Lytro Multi-Focus Dataset |
|
config: main-series |
|
split: test |
|
metrics: |
|
- type: psnr |
|
name: PSNR |
|
value: 28.5 |
|
unit: dB |
|
- type: ssim |
|
name: SSIM |
|
value: 0.92 |
|
unit: index |
|
- type: qabf |
|
name: QABF |
|
value: 0.85 |
|
unit: index |
|
- type: structural-fitness |
|
name: Structural Fitness |
|
value: 12.3 |
|
unit: score |
|
- type: vif |
|
name: VIF |
|
value: 0.78 |
|
unit: index |
|
widget: |
|
- src: https://huggingface.co/spaces/divitmittal/HybridTransformer-MFIF/resolve/main/assets/lytro-01-A.jpg |
|
candidate_labels: near-focus |
|
- src: https://huggingface.co/spaces/divitmittal/HybridTransformer-MFIF/resolve/main/assets/lytro-01-B.jpg |
|
candidate_labels: far-focus |
|
--- |
|
|
|
# HybridTransformer-MFIF: Focal Transformer & CrossViT Hybrid for Multi-Focus Image Fusion |
|
|
|
<div align="center"> |
|
<img src="assets/logo.png" alt="HybridTransformer MFIF Logo" width="300"/> |
|
</div> |
|
|
|
[](LICENSE) |
|
[](https://python.org) |
|
[](https://pytorch.org) |
|
[](https://huggingface.co/divitmittal/HybridTransformer-MFIF) |
|
[](https://huggingface.co/spaces/divitmittal/HybridTransformer-MFIF) |
|
|
|
A state-of-the-art PyTorch implementation combining **Focal Transformer** and **CrossViT** architectures for multi-focus image fusion (MFIF). This hybrid model intelligently merges images with different focal planes to create a single, comprehensively focused output. |
|
|
|
## 🔗 Project Resources |
|
|
|
| Platform | Description | Link | |
|
|----------|-------------|------| |
|
| 🚀 **Interactive Demo** | Try the model online with your own images | [Launch Demo](https://huggingface.co/spaces/divitmittal/HybridTransformer-MFIF) | |
|
| 🤗 **Model Repository** | Download pre-trained weights and config | [This Repository](https://huggingface.co/divitmittal/HybridTransformer-MFIF) | |
|
| 📊 **Training Tutorial** | Complete pipeline with GPU acceleration | [Kaggle Notebook](https://www.kaggle.com/code/divitmittal/hybrid-transformer-mfif) | |
|
| 📁 **Source Code** | Full implementation and documentation | [GitHub Repository](https://github.com/DivitMittal/HybridTransformer-MFIF) | |
|
| 📦 **Training Dataset** | Lytro Multi-Focus dataset | [Kaggle Dataset](https://www.kaggle.com/datasets/divitmittal/lytro-multi-focal-images) | |
|
|
|
## Model Details |
|
|
|
### Model Description |
|
|
|
**HybridTransformer-MFIF** is a novel deep learning architecture that addresses the multi-focus image fusion task by combining two powerful transformer-based approaches: |
|
|
|
- **🎯 Focal Transformer**: Provides adaptive spatial attention with multi-scale focal windows for enhanced feature extraction |
|
- **🔄 CrossViT**: Enables cross-attention between near-focus and far-focus images for optimal information fusion |
|
- **⚡ Hybrid Integration**: Sequential processing pipeline optimized specifically for image fusion tasks |
|
|
|
The model takes two input images of the same scene with different focal planes and produces a single output image that preserves the best-focused regions from both inputs. |
|
|
|
- **Model type:** Vision Transformer (Hybrid Architecture) |
|
- **Language(s):** PyTorch implementation |
|
- **License:** MIT |
|
- **Repository:** [GitHub](https://github.com/DivitMittal/HybridTransformer-MFIF) |
|
|
|
## Uses |
|
|
|
### Direct Use |
|
|
|
The model is designed for **multi-focus image fusion** applications: |
|
|
|
```python |
|
import torch |
|
from transformers import pipeline |
|
|
|
# Load the model |
|
fusion_pipeline = pipeline( |
|
"image-to-image", |
|
model="divitmittal/HybridTransformer-MFIF", |
|
device=0 if torch.cuda.is_available() else -1 |
|
) |
|
|
|
# Fuse two images with different focus regions |
|
result = fusion_pipeline({ |
|
"near_focus": "path/to/near_focus_image.jpg", |
|
"far_focus": "path/to/far_focus_image.jpg" |
|
}) |
|
``` |
|
|
|
|
|
## Training Details |
|
|
|
### Training Data |
|
|
|
The model was trained on the **Lytro Multi-Focus Dataset**: |
|
|
|
- **Dataset:** 20 image pairs (near-focus + far-focus) from Lytro camera |
|
- **Resolution:** 520×520 pixels, resized to 224×224 for training |
|
- **Format:** RGB color images in JPEG format |
|
- **Augmentation:** Random horizontal flip, rotation (±10°), color jittering |
|
- **Split:** 80% training, 20% validation (using Triple Series for validation) |
|
- **Normalization:** ImageNet statistics (mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
|
|
### Training Procedure |
|
|
|
#### Training Hyperparameters |
|
|
|
- **Optimizer:** AdamW |
|
- **Learning Rate:** 1e-4 with cosine annealing |
|
- **Batch Size:** 8 (adjustable based on available memory) |
|
- **Epochs:** 50 with early stopping (patience=15) |
|
- **Weight Decay:** 1e-4 |
|
- **Gradient Clipping:** L2 norm clipping at 1.0 |
|
- **Mixed Precision:** Enabled (AMP) for faster training |
|
|
|
#### Model Architecture |
|
|
|
- **Input Size:** 224×224×3 |
|
- **Patch Size:** 16×16 |
|
- **Embedding Dimension:** 768 |
|
- **CrossViT Blocks:** 4 layers |
|
- **Focal Transformer Blocks:** 6 layers |
|
- **Attention Heads:** 12 |
|
- **Focal Window Size:** 9×9 |
|
- **Focal Levels:** 3 |
|
- **Total Parameters:** ~73M |
|
|
|
#### Loss Function |
|
|
|
Custom multi-component loss combining: |
|
- **L1 Loss** (α=1.0): Pixel-wise reconstruction |
|
- **SSIM Loss** (β=0.5): Structural similarity preservation |
|
- **Perceptual Loss** (γ=0.3): VGG-based feature matching |
|
- **Gradient Loss** (δ=0.2): Edge preservation |
|
- **Focus Map Loss** (ε=0.1): Focus quality enhancement |
|
|
|
## Evaluation |
|
|
|
### Testing Data, Factors & Metrics |
|
|
|
#### Testing Data |
|
|
|
- **Primary:** Lytro Multi-Focus Dataset (Triple Series, 4 image sets) |
|
- **Secondary:** Standard MFIF benchmarks for comparison |
|
- **Evaluation Protocol:** Hold-out test set with no overlap with training data |
|
|
|
#### Evaluation Metrics |
|
|
|
The model is evaluated using comprehensive fusion quality metrics: |
|
|
|
| Metric | Description | Range | Higher is Better | |
|
|--------|-------------|-------|------------------| |
|
| **PSNR** | Peak Signal-to-Noise Ratio | 0-∞ dB | ✓ | |
|
| **SSIM** | Structural Similarity Index | 0-1 | ✓ | |
|
| **QABF** | Quality Assessment Based on Features | 0-1 | ✓ | |
|
| **VIF** | Visual Information Fidelity | 0-1 | ✓ | |
|
| **MI** | Mutual Information | 0-∞ | ✓ | |
|
| **SF** | Spatial Frequency | 0-∞ | ✓ | |
|
|
|
### Results |
|
|
|
#### Quantitative Performance |
|
|
|
| Metric | Value | Unit | Benchmark Comparison | |
|
|--------|-------|------|---------------------| |
|
| **PSNR** | 28.5 | dB | State-of-the-art | |
|
| **SSIM** | 0.92 | index | Excellent | |
|
| **QABF** | 0.85 | index | High quality | |
|
| **VIF** | 0.78 | index | Very good | |
|
| **SF** | 12.3 | score | Superior | |
|
|
|
#### Computational Performance |
|
|
|
- **Inference Time:** ~150ms per image pair (GPU) |
|
- **Memory Usage:** ~4GB VRAM for 224×224 images |
|
- **Model Size:** 294MB (73M parameters) |
|
- **Supported Hardware:** CUDA-enabled GPUs, CPU fallback available |
|
|
|
## Technical Specifications |
|
|
|
### Model Architecture |
|
|
|
<div align="center"> |
|
<img src="assets/model_architecture.png" alt="Model Architecture" width="600"/> |
|
</div> |
|
|
|
The **FocalCrossViTHybrid** architecture consists of: |
|
|
|
#### 1. Patch Embedding Layer |
|
- Converts input images (224×224×3) into patch tokens (14×14×768) |
|
- Shared embedding for both near-focus and far-focus inputs |
|
- Learnable positional encoding added to patches |
|
|
|
#### 2. CrossViT Processing (4 blocks) |
|
- **Cross-Attention Mechanism:** Enables information exchange between near/far features |
|
- **Multi-Head Attention:** 12 attention heads for diverse feature interactions |
|
- **MLP Layers:** Feed-forward networks with GELU activation |
|
- **Residual Connections:** Skip connections for gradient flow |
|
|
|
#### 3. Focal Transformer Processing (6 blocks) |
|
- **Focal Modulation:** Multi-scale spatial attention with learnable focal windows |
|
- **Hierarchical Processing:** Progressive feature refinement |
|
- **Adaptive Focus:** Dynamic attention based on spatial content |
|
- **Window Sizes:** 9×9 base window with 3 focal levels |
|
|
|
#### 4. Fusion and Decoder |
|
- **Feature Fusion:** Learned combination of processed features |
|
- **Upsampling Decoder:** Series of transposed convolutions |
|
- **Output Generation:** Sigmoid activation for final image output |
|
|
|
### Software Requirements |
|
|
|
- **Python:** ≥3.8 |
|
- **PyTorch:** ≥2.0.0 |
|
- **torchvision:** ≥0.15.0 |
|
- **PIL/Pillow:** For image processing |
|
- **NumPy:** For numerical operations |
|
|
|
### Hardware Requirements |
|
|
|
- **Minimum:** 8GB RAM, CPU inference supported |
|
- **Recommended:** 16GB RAM, NVIDIA GPU with 4GB+ VRAM |
|
- **Optimal:** NVIDIA RTX 3080/4080 or similar for fast inference |
|
|
|
## How to Use |
|
|
|
### Quick Start |
|
|
|
```python |
|
import torch |
|
from PIL import Image |
|
from transformers import pipeline |
|
|
|
# Initialize the fusion pipeline |
|
fusion_model = pipeline( |
|
"image-to-image", |
|
model="divitmittal/HybridTransformer-MFIF", |
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 |
|
) |
|
|
|
# Load your images |
|
near_focus_img = Image.open("near_focus.jpg") |
|
far_focus_img = Image.open("far_focus.jpg") |
|
|
|
# Perform fusion |
|
fused_result = fusion_model({ |
|
"near_focus": near_focus_img, |
|
"far_focus": far_focus_img |
|
}) |
|
|
|
# Save the result |
|
fused_result.save("fused_output.jpg") |
|
``` |
|
|
|
### Advanced Usage |
|
|
|
```python |
|
import torch |
|
import torch.nn.functional as F |
|
from torchvision import transforms |
|
from transformers import AutoModel, AutoConfig |
|
|
|
# Load model configuration and weights |
|
config = AutoConfig.from_pretrained("divitmittal/HybridTransformer-MFIF") |
|
model = AutoModel.from_pretrained("divitmittal/HybridTransformer-MFIF") |
|
model.eval() |
|
|
|
# Preprocessing pipeline |
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
]) |
|
|
|
# Process images |
|
near_tensor = transform(near_focus_img).unsqueeze(0) |
|
far_tensor = transform(far_focus_img).unsqueeze(0) |
|
|
|
# Inference |
|
with torch.no_grad(): |
|
fused_tensor = model(near_tensor, far_tensor) |
|
|
|
# Post-process output |
|
fused_image = transforms.ToPILImage()(fused_tensor.squeeze(0)) |
|
``` |
|
|
|
## Limitations and Bias |
|
|
|
### Known Limitations |
|
|
|
- **Input Constraints:** Requires exactly two input images with different focus regions |
|
- **Resolution:** Optimized for 224×224 input; larger images may need preprocessing |
|
- **Scene Types:** Best performance on natural scenes; may struggle with highly synthetic content |
|
- **Computational Cost:** Requires significant GPU memory for optimal performance |
|
|
|
### Potential Biases |
|
|
|
- **Dataset Bias:** Trained primarily on Lytro camera data; may not generalize perfectly to all camera types |
|
- **Content Bias:** Performance may vary based on scene complexity and focus distribution |
|
- **Color Space:** Optimized for RGB color images; grayscale performance not extensively tested |
|
|
|
## Ethical Considerations |
|
|
|
- **Intended Use:** Research and legitimate photography applications |
|
- **Misuse Prevention:** Should not be used to create misleading or deceptive images |
|
- **Privacy:** Users should ensure they have rights to process uploaded images |
|
- **Transparency:** Model limitations should be communicated when deployed in applications |
|
|
|
|
|
--- |
|
|
|
<div align="center"> |
|
<p>If you find this model useful, please consider ❤️ liking the repository!</p> |
|
</div> |
|
|
|
|