DermaBot / app.py
amasood's picture
Update app.py
84de2ae verified
raw
history blame
3.8 kB
import os
import torch
import streamlit as st
from PIL import Image
from transformers import AutoModelForImageClassification, AutoImageProcessor
from groq import Groq
# Set page config
st.set_page_config(page_title="DermaBot - AI Skin Disease Detector", page_icon="🩺", layout="wide")
# Load model and processor
MODEL_NAME = "Jayanth2002/dinov2-base-finetuned-SkinDisease"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForImageClassification.from_pretrained(MODEL_NAME).to(DEVICE)
processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
# Set up the Groq API key (replace with your actual key or use an environment variable)
client = Groq(api_key=os.environ.get("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"))
# Initialize session state for disease details
if "disease_name" not in st.session_state:
st.session_state.disease_name = None
if "disease_info" not in st.session_state:
st.session_state.disease_info = None
# Function to predict skin disease
def predict_skin_disease(image):
image = image.convert("RGB")
inputs = processor(images=image, return_tensors="pt").to(DEVICE)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
predicted_label = model.config.id2label[predicted_class_idx]
return predicted_label
# Function to get disease details from Groq API
def get_disease_info(disease_name):
prompt = f"Provide a detailed explanation about the skin disease '{disease_name}', including description of disease, causes, precausions, risk and treatment options."
chat_completion = client.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
model="llama-3.3-70b-versatile",
)
return chat_completion.choices[0].message.content
# Function to handle chatbot queries
def chatbot_response(disease_name, user_query):
if not disease_name:
return "Please upload an image and detect the disease first."
prompt = f"The detected skin disease is '{disease_name}'. {user_query}"
chat_completion = client.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
model="llama-3.3-70b-versatile",
)
return chat_completion.choices[0].message.content
# Streamlit UI
st.image("https://huggingface.co/spaces/your-huggingface-space/logo.png", width=200)
st.title("🩺 DermaBot - AI Skin Disease Detector")
st.write("Upload an image of a skin condition to get a diagnosis and ask questions about it.")
# Step 1: Upload image
uploaded_image = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
if uploaded_image:
image = Image.open(uploaded_image)
st.image(image, caption="Uploaded Image", use_container_width=True)
# Step 2: Detect disease
if st.button("Detect Disease"):
with st.spinner("Analyzing..."):
disease_name = predict_skin_disease(image)
disease_info = get_disease_info(disease_name)
# Store results in session state
st.session_state.disease_name = disease_name
st.session_state.disease_info = disease_info
# Display detected disease information if available
if st.session_state.disease_name:
st.success(f"**Detected Disease:** {st.session_state.disease_name}")
st.write(f"**Details:** {st.session_state.disease_info}")
# Step 3: Chatbot
st.subheader("💬 Ask DermaBot about this disease")
user_query = st.text_input("Ask about the detected disease:")
if st.button("Ask"):
with st.spinner("Thinking..."):
response = chatbot_response(st.session_state.disease_name, user_query)
st.write(response)
st.markdown("---")
st.write("🔍 Powered by **AI & Groq API** | © 2025 DermaBot")