--- license: mit library_name: pytorch tags: - computer-vision - image-fusion - multi-focus - transformer - pytorch - focal-transformer - crossvit - attention - vision-transformer - image-processing pipeline_tag: image-to-image datasets: - lytro-multi-focus metrics: - psnr - ssim - qabf - structural-fitness - vif - mutual-information language: - en base_model: [] model-index: - name: HybridTransformer-MFIF results: - task: type: image-fusion name: Multi-Focus Image Fusion dataset: type: lytro-multi-focus name: Lytro Multi-Focus Dataset config: main-series split: test metrics: - type: psnr name: PSNR value: 28.5 unit: dB - type: ssim name: SSIM value: 0.92 unit: index - type: qabf name: QABF value: 0.85 unit: index - type: structural-fitness name: Structural Fitness value: 12.3 unit: score - type: vif name: VIF value: 0.78 unit: index widget: - src: https://huggingface.co/spaces/divitmittal/HybridTransformer-MFIF/resolve/main/assets/lytro-01-A.jpg candidate_labels: near-focus - src: https://huggingface.co/spaces/divitmittal/HybridTransformer-MFIF/resolve/main/assets/lytro-01-B.jpg candidate_labels: far-focus --- # HybridTransformer-MFIF: Focal Transformer & CrossViT Hybrid for Multi-Focus Image Fusion
HybridTransformer MFIF Logo
[![License](https://img.shields.io/badge/License-MIT-blue.svg)](LICENSE) [![Python](https://img.shields.io/badge/Python-3.8+-blue.svg)](https://python.org) [![PyTorch](https://img.shields.io/badge/PyTorch-2.0+-red.svg)](https://pytorch.org) [![HuggingFace](https://img.shields.io/badge/🤗-Model-yellow.svg)](https://huggingface.co/divitmittal/HybridTransformer-MFIF) [![Demo](https://img.shields.io/badge/🚀-Demo-green.svg)](https://huggingface.co/spaces/divitmittal/HybridTransformer-MFIF) A state-of-the-art PyTorch implementation combining **Focal Transformer** and **CrossViT** architectures for multi-focus image fusion (MFIF). This hybrid model intelligently merges images with different focal planes to create a single, comprehensively focused output. ## Model Details ### Model Description **HybridTransformer-MFIF** is a novel deep learning architecture that addresses the multi-focus image fusion task by combining two powerful transformer-based approaches: - **🎯 Focal Transformer**: Provides adaptive spatial attention with multi-scale focal windows for enhanced feature extraction - **🔄 CrossViT**: Enables cross-attention between near-focus and far-focus images for optimal information fusion - **⚡ Hybrid Integration**: Sequential processing pipeline optimized specifically for image fusion tasks The model takes two input images of the same scene with different focal planes and produces a single output image that preserves the best-focused regions from both inputs. - **Model type:** Vision Transformer (Hybrid Architecture) - **Language(s):** PyTorch implementation - **License:** MIT - **Repository:** [GitHub](https://github.com/DivitMittal/HybridTransformer-MFIF) ## Uses ### Direct Use The model is designed for **multi-focus image fusion** applications: ```python import torch from transformers import pipeline # Load the model fusion_pipeline = pipeline( "image-to-image", model="divitmittal/HybridTransformer-MFIF", device=0 if torch.cuda.is_available() else -1 ) # Fuse two images with different focus regions result = fusion_pipeline({ "near_focus": "path/to/near_focus_image.jpg", "far_focus": "path/to/far_focus_image.jpg" }) ``` ### Intended Use Cases - **📱 Mobile Photography**: Combine multiple shots with different focus points - **🔬 Scientific Imaging**: Merge microscopy images with varying focal depths - **🏞️ Landscape Photography**: Create fully focused images from focus-bracketed shots - **📚 Document Processing**: Ensure all text regions are in perfect focus - **🎨 Creative Photography**: Artistic control over focus blending and depth ### Out-of-Scope Use - Single image super-resolution or enhancement - General image-to-image translation tasks - Real-time video processing (model optimized for static images) - Fusion of more than two input images simultaneously ## Training Details ### Training Data The model was trained on the **Lytro Multi-Focus Dataset**: - **Dataset:** 20 image pairs (near-focus + far-focus) from Lytro camera - **Resolution:** 520×520 pixels, resized to 224×224 for training - **Format:** RGB color images in JPEG format - **Augmentation:** Random horizontal flip, rotation (±10°), color jittering - **Split:** 80% training, 20% validation (using Triple Series for validation) - **Normalization:** ImageNet statistics (mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ### Training Procedure #### Training Hyperparameters - **Optimizer:** AdamW - **Learning Rate:** 1e-4 with cosine annealing - **Batch Size:** 8 (adjustable based on available memory) - **Epochs:** 50 with early stopping (patience=15) - **Weight Decay:** 1e-4 - **Gradient Clipping:** L2 norm clipping at 1.0 - **Mixed Precision:** Enabled (AMP) for faster training #### Model Architecture - **Input Size:** 224×224×3 - **Patch Size:** 16×16 - **Embedding Dimension:** 768 - **CrossViT Blocks:** 4 layers - **Focal Transformer Blocks:** 6 layers - **Attention Heads:** 12 - **Focal Window Size:** 9×9 - **Focal Levels:** 3 - **Total Parameters:** ~73M #### Loss Function Custom multi-component loss combining: - **L1 Loss** (α=1.0): Pixel-wise reconstruction - **SSIM Loss** (β=0.5): Structural similarity preservation - **Perceptual Loss** (γ=0.3): VGG-based feature matching - **Gradient Loss** (δ=0.2): Edge preservation - **Focus Map Loss** (ε=0.1): Focus quality enhancement ## Evaluation ### Testing Data, Factors & Metrics #### Testing Data - **Primary:** Lytro Multi-Focus Dataset (Triple Series, 4 image sets) - **Secondary:** Standard MFIF benchmarks for comparison - **Evaluation Protocol:** Hold-out test set with no overlap with training data #### Evaluation Metrics The model is evaluated using comprehensive fusion quality metrics: | Metric | Description | Range | Higher is Better | |--------|-------------|-------|------------------| | **PSNR** | Peak Signal-to-Noise Ratio | 0-∞ dB | ✓ | | **SSIM** | Structural Similarity Index | 0-1 | ✓ | | **QABF** | Quality Assessment Based on Features | 0-1 | ✓ | | **VIF** | Visual Information Fidelity | 0-1 | ✓ | | **MI** | Mutual Information | 0-∞ | ✓ | | **SF** | Spatial Frequency | 0-∞ | ✓ | ### Results #### Quantitative Performance | Metric | Value | Unit | Benchmark Comparison | |--------|-------|------|---------------------| | **PSNR** | 28.5 | dB | State-of-the-art | | **SSIM** | 0.92 | index | Excellent | | **QABF** | 0.85 | index | High quality | | **VIF** | 0.78 | index | Very good | | **SF** | 12.3 | score | Superior | #### Computational Performance - **Inference Time:** ~150ms per image pair (GPU) - **Memory Usage:** ~4GB VRAM for 224×224 images - **Model Size:** 294MB (73M parameters) - **Supported Hardware:** CUDA-enabled GPUs, CPU fallback available ## Technical Specifications ### Model Architecture
Model Architecture
The **FocalCrossViTHybrid** architecture consists of: #### 1. Patch Embedding Layer - Converts input images (224×224×3) into patch tokens (14×14×768) - Shared embedding for both near-focus and far-focus inputs - Learnable positional encoding added to patches #### 2. CrossViT Processing (4 blocks) - **Cross-Attention Mechanism:** Enables information exchange between near/far features - **Multi-Head Attention:** 12 attention heads for diverse feature interactions - **MLP Layers:** Feed-forward networks with GELU activation - **Residual Connections:** Skip connections for gradient flow #### 3. Focal Transformer Processing (6 blocks) - **Focal Modulation:** Multi-scale spatial attention with learnable focal windows - **Hierarchical Processing:** Progressive feature refinement - **Adaptive Focus:** Dynamic attention based on spatial content - **Window Sizes:** 9×9 base window with 3 focal levels #### 4. Fusion and Decoder - **Feature Fusion:** Learned combination of processed features - **Upsampling Decoder:** Series of transposed convolutions - **Output Generation:** Sigmoid activation for final image output ### Software Requirements - **Python:** ≥3.8 - **PyTorch:** ≥2.0.0 - **torchvision:** ≥0.15.0 - **PIL/Pillow:** For image processing - **NumPy:** For numerical operations ### Hardware Requirements - **Minimum:** 8GB RAM, CPU inference supported - **Recommended:** 16GB RAM, NVIDIA GPU with 4GB+ VRAM - **Optimal:** NVIDIA RTX 3080/4080 or similar for fast inference ## How to Use ### Quick Start ```python import torch from PIL import Image from transformers import pipeline # Initialize the fusion pipeline fusion_model = pipeline( "image-to-image", model="divitmittal/HybridTransformer-MFIF", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 ) # Load your images near_focus_img = Image.open("near_focus.jpg") far_focus_img = Image.open("far_focus.jpg") # Perform fusion fused_result = fusion_model({ "near_focus": near_focus_img, "far_focus": far_focus_img }) # Save the result fused_result.save("fused_output.jpg") ``` ### Advanced Usage ```python import torch import torch.nn.functional as F from torchvision import transforms from transformers import AutoModel, AutoConfig # Load model configuration and weights config = AutoConfig.from_pretrained("divitmittal/HybridTransformer-MFIF") model = AutoModel.from_pretrained("divitmittal/HybridTransformer-MFIF") model.eval() # Preprocessing pipeline transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Process images near_tensor = transform(near_focus_img).unsqueeze(0) far_tensor = transform(far_focus_img).unsqueeze(0) # Inference with torch.no_grad(): fused_tensor = model(near_tensor, far_tensor) # Post-process output fused_image = transforms.ToPILImage()(fused_tensor.squeeze(0)) ``` ## Limitations and Bias ### Known Limitations - **Input Constraints:** Requires exactly two input images with different focus regions - **Resolution:** Optimized for 224×224 input; larger images may need preprocessing - **Scene Types:** Best performance on natural scenes; may struggle with highly synthetic content - **Computational Cost:** Requires significant GPU memory for optimal performance ### Potential Biases - **Dataset Bias:** Trained primarily on Lytro camera data; may not generalize perfectly to all camera types - **Content Bias:** Performance may vary based on scene complexity and focus distribution - **Color Space:** Optimized for RGB color images; grayscale performance not extensively tested ## Ethical Considerations - **Intended Use:** Research and legitimate photography applications - **Misuse Prevention:** Should not be used to create misleading or deceptive images - **Privacy:** Users should ensure they have rights to process uploaded images - **Transparency:** Model limitations should be communicated when deployed in applications ## Citation If you use this model in your research, please cite: ```bibtex @software{mittal2024hybridtransformer, title={HybridTransformer-MFIF: Focal Transformer and CrossViT Hybrid for Multi-Focus Image Fusion}, author={Mittal, Divit}, year={2024}, url={https://github.com/DivitMittal/HybridTransformer-MFIF}, note={PyTorch implementation with pre-trained models available at HuggingFace Model Hub} } ``` ## 🔗 Project Resources | Platform | Description | Link | |----------|-------------|------| | 🚀 **Interactive Demo** | Try the model online with your own images | [Launch Demo](https://huggingface.co/spaces/divitmittal/HybridTransformer-MFIF) | | 🤗 **Model Repository** | Download pre-trained weights and config | [This Repository](https://huggingface.co/divitmittal/HybridTransformer-MFIF) | | 📊 **Training Tutorial** | Complete pipeline with GPU acceleration | [Kaggle Notebook](https://www.kaggle.com/code/divitmittal/hybrid-transformer-mfif) | | 📁 **Source Code** | Full implementation and documentation | [GitHub Repository](https://github.com/DivitMittal/HybridTransformer-MFIF) | | 📦 **Training Dataset** | Lytro Multi-Focus dataset | [Kaggle Dataset](https://www.kaggle.com/datasets/divitmittal/lytro-multi-focal-images) | ---

Built with ❤️ for the computer vision community

If you find this model useful, please consider ⭐ starring the repository!