vit-xray-v1 / README.md
itsomk's picture
Update README.md
42a4e86 verified
---
language: en
tags:
- vision
- vit
- xray
- chest-xray
- classification
license: mit
pipeline_tag: image-classification
author: Om Kumar (@itsomk)
---
# ViT X-ray Multi-label (vit-xray-v1)
## Model Description
This model is a fine-tuned **Vision Transformer** ([google/vit-base-patch16-224-in21k](https://huggingface.co/google/vit-base-patch16-224-in21k)) for **multi-label classification of chest X-rays**.
It predicts the presence of multiple findings such as:
- **Nodule**
- **Infiltration**
- **Effusion**
- **Atelectasis**
**Author:** Om Kumar (Hugging Face: [@itsomk](https://huggingface.co/itsomk))
The model is designed for **research and educational purposes only** and should not be used as a substitute for clinical diagnosis.
---
## Intended Use
- **Research** in medical imaging and computer vision
- **Educational purposes** for understanding X-ray image classification
- **Baseline model** for further fine-tuning or domain adaptation
⚠️ **Not intended for clinical use**. Predictions should not guide medical decisions.
---
## Training Data
- Dataset: Chest X-ray images (publicly available datasets, e.g., NIH ChestX-ray14, etc.)
- Images were preprocessed (resized to 224x224, normalized).
- Labels are **multi-label**, meaning an X-ray can contain more than one finding.
---
## Model Performance
- Optimized for detecting **common thoracic abnormalities**.
- Evaluation metrics: AUC .
- Nodule AUC: 0.696
- Infiltration AUC: 0.684
- Effusion AUC: 0.843
- Atelectasis AUC: 0.762
---
## Quick Usage
```python
from transformers import AutoImageProcessor, AutoModelForImageClassification
import torch
from PIL import Image
MODEL = "itsomk/vit-xray-v1"
processor = AutoImageProcessor.from_pretrained(MODEL)
model = AutoModelForImageClassification.from_pretrained(MODEL)
img = Image.open("path/to/xray.jpg").convert("RGB")
inputs = processor(images=img, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.sigmoid(logits).squeeze().tolist()
results = {model.config.id2label[i]: float(probs[i]) for i in range(len(probs))}
print(results)