7.png

Alzheimer-Stage-Classifier

Alzheimer-Stage-Classifier is a multi-class image classification model based on google/siglip2-base-patch16-224, designed to identify stages of Alzheimer’s disease from medical imaging data. This tool can assist in clinical decision support, early diagnosis, and disease progression tracking.

Classification Report:
                  precision    recall  f1-score   support

    MildDemented     0.9634    0.9860    0.9746      8960
ModerateDemented     1.0000    1.0000    1.0000      6464
     NonDemented     0.8920    0.8910    0.8915      9600
VeryMildDemented     0.8904    0.8704    0.8803      8960

        accuracy                         0.9314     33984
       macro avg     0.9364    0.9369    0.9366     33984
    weighted avg     0.9309    0.9314    0.9311     33984

download (1).png


Label Classes

The model classifies input images into the following stages of Alzheimer’s disease:

0: MildDemented  
1: ModerateDemented  
2: NonDemented  
3: VeryMildDemented

Installation

pip install transformers torch pillow gradio

Example Inference Code

import gradio as gr
from transformers import AutoImageProcessor, SiglipForImageClassification
from PIL import Image
import torch

# Load model and processor
model_name = "prithivMLmods/Alzheimer-Stage-Classifier"
model = SiglipForImageClassification.from_pretrained(model_name)
processor = AutoImageProcessor.from_pretrained(model_name)

# ID to label mapping
id2label = {
    "0": "MildDemented",
    "1": "ModerateDemented",
    "2": "NonDemented",
    "3": "VeryMildDemented"
}

def classify_alzheimer_stage(image):
    image = Image.fromarray(image).convert("RGB")
    inputs = processor(images=image, return_tensors="pt")

    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        probs = torch.nn.functional.softmax(logits, dim=1).squeeze().tolist()

    prediction = {id2label[str(i)]: round(probs[i], 3) for i in range(len(probs))}
    return prediction

# Gradio Interface
iface = gr.Interface(
    fn=classify_alzheimer_stage,
    inputs=gr.Image(type="numpy"),
    outputs=gr.Label(num_top_classes=4, label="Alzheimer Stage"),
    title="Alzheimer-Stage-Classifier",
    description="Upload a brain scan image to classify the stage of Alzheimer's: NonDemented, VeryMildDemented, MildDemented, or ModerateDemented."
)

if __name__ == "__main__":
    iface.launch()

Applications

  • Early Alzheimer’s Screening
  • Clinical Diagnosis Support
  • Longitudinal Study & Disease Monitoring
  • Research on Cognitive Decline
Downloads last month
36
Safetensors
Model size
92.9M params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for prithivMLmods/Alzheimer-Stage-Classifier

Finetuned
(93)
this model

Dataset used to train prithivMLmods/Alzheimer-Stage-Classifier

Collection including prithivMLmods/Alzheimer-Stage-Classifier