Image Classification
Transformers
English
CNN
Transformers
torch
Inference Endpoints

Hybrid Vision Transformer (HVT)

Model Name: Hybrid Vision Transformer (HVT)
Author: CodeWithDark
License: MIT (or specify another license)
Model Type: Hybrid Vision Transformer
Task: Image Classification / Object Detection / Medical Imaging
Dataset: Specify dataset used (e.g., ImageNet, CIFAR-100, custom dataset)
Framework: PyTorch


Model Description

The Hybrid Vision Transformer (HVT) integrates the strengths of Convolutional Neural Networks (CNNs) and Vision Transformers (ViTs) to effectively capture both local and global features in images. This hybrid approach addresses the limitations of traditional ViTs, especially when trained on smaller datasets, by incorporating inductive biases inherent to CNNs.

Key Features:

  • Hybrid Architecture: Combines CNNs' local feature extraction with ViTs' global context understanding.
  • Hierarchical Representation: Utilizes hierarchical pooling to reduce sequence length and computational complexity, similar to the approach in Scalable Vision Transformers with Hierarchical Pooling.
  • Dynamic Feature Aggregation: Enhances channel representation by re-calibrating and interacting different channel groups, inspired by Dynamic Hybrid Vision Transformer (DHVT).

Use Cases

  • Image Classification: Suitable for various classification tasks across different domains.
  • Object Detection: Effective in identifying and localizing objects within images.
  • Medical Imaging: Applicable in analyzing medical images for diagnostic purposes.

Training Details

  • Framework: PyTorch
  • Model Architecture: Hybrid Vision Transformer
  • Optimizer: AdamW
  • Loss Function: CrossEntropyLoss
  • Learning Rate: your value
  • Batch Size: your value
  • Epochs: your value
  • Hardware Used: Specify hardware (e.g., NVIDIA RTX 3090)


How to Use

Load the Model from Hugging Face Hub

from huggingface_hub import hf_hub_download
import torch

# Download model
model_path = hf_hub_download("codewithdark/hvt", "hvt.pth")

# Load Model

model = torch.load(model_path))
model.eval()

Inference Example

import torch
from torchvision import transforms
from PIL import Image

# Load Image
image_path = "your_image.jpg"
image = Image.open(image_path)

# Preprocess
transform = transforms.Compose([
    transforms.Resize((224, 224)),  
    transforms.ToTensor()
])
image = transform(image).unsqueeze(0)

# Make Prediction
with torch.no_grad():
    output = model(image)
    predicted_class = torch.argmax(output, dim=1).item()

print(f"Predicted Class: {predicted_class}")

Limitations & Future Work

  • Data Requirements: While the hybrid approach mitigates some data limitations, performance may still benefit from larger datasets.
  • Computational Resources: The model's complexity may require substantial computational power for training and inference.
  • Future Improvements: Explore advanced token mixing operations and structural reparameterization techniques, as discussed in FastViT, to enhance efficiency and performance.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported third-party Inference Providers, and the model is not deployed on the HF Inference API.

Datasets used to train codewithdark/Hvit