SkinCancerViT
Collection
A Multimodal Deep Learning Approach for Skin Cancer Classification by using ViTs (Visual Transformers)
β’
2 items
β’
Updated
First, clone the repository:
git clone https://github.com/ethicalabs-ai/SkinCancerViT.git
cd SkinCancerViT
Then, install the package in editable mode using uv (or pip):
uv sync # Recommended if you use uv
# Or, if using pip:
# pip install -e .
This package allows you to load and use a pre-trained SkinCancerViT model for prediction.
import torch
from skincancer_vit.model import SkinCancerViTModel
from PIL import Image
from datasets import load_dataset # To get a random sample
# Load the model from Hugging Face Hub
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SkinCancerViTModel.from_pretrained("ethicalabs/SkinCancerViT")
model.to(device) # Move model to the desired device
model.eval() # Set model to evaluation mode
# Example Prediction from a Specific Image File
image_file_path = "images/patient-001.jpg" # Specify your image file path here
specific_image = Image.open(image_file_path).convert("RGB")
# Example tabular data for this prediction
specific_age = 42
specific_localization = "face" # Ensure this matches one of your trained localization categories
predicted_dx, confidence = model.full_predict(
raw_image=specific_image,
raw_age=specific_age,
raw_localization=specific_localization,
device=device
)
print(f"Predicted Diagnosis: {predicted_dx}")
print(f"Confidence: {confidence:.4f}")
# Example Prediction from a Random Test Sample from the Dataset
dataset = load_dataset("marmal88/skin_cancer", split="test")
random_sample = dataset.shuffle(seed=42).select(range(1))[0] # Get the first shuffled sample
sample_image = random_sample["image"]
sample_age = random_sample["age"]
sample_localization = random_sample["localization"]
sample_true_dx = random_sample["dx"]
predicted_dx_sample, confidence_sample = model.full_predict(
raw_image=sample_image,
raw_age=sample_age,
raw_localization=sample_localization,
device=device
)
print(f"Predicted Diagnosis: {predicted_dx_sample}")
print(f"Confidence: {confidence_sample:.4f}")
print(f"Correct Prediction: {predicted_dx_sample == sample_true_dx}")
Base model
google/vit-base-patch16-224-in21k