File size: 2,203 Bytes
d566fee 3fdf143 ee36b3c ba699eb d566fee db0ef97 3fdf143 d5bc862 db0ef97 a3cb4b9 db0ef97 a3cb4b9 ee36b3c db0ef97 d566fee db0ef97 d566fee 3fdf143 db0ef97 d566fee db0ef97 d566fee 3fdf143 db0ef97 3fdf143 db0ef97 ba699eb 3fdf143 d566fee fe44a5a 3fdf143 d5bc862 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
import gradio as gr
import tensorflow as tf
import numpy as np
from PIL import Image
import google.generativeai as genai
import os
import markdown2
# Load TensorFlow model
model = tf.saved_model.load('model')
labels = ['cataract', 'diabetic_retinopathy', 'glaucoma', 'normal']
# Configure Gemini API
genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
# Generate AI-based explanation for the predicted disease
def get_disease_detail(disease):
prompt = (
"Create a text congratulating on healthy eyes with tips to keep them healthy."
if disease == "normal" else
f"Diagnosis: {disease}\n\n"
f"What is {disease}?\nCauses and suggestions to prevent {disease}."
)
try:
response = genai.GenerativeModel("gemini-1.5-flash").generate_content(prompt)
return markdown2.markdown(response.text.strip() if response and response.text else "No response.")
except Exception as e:
return f"Error: {e}"
# Process and predict uploaded image
def predict_image(image):
img_array = np.expand_dims(np.array(image.resize((224, 224))).astype(np.float32) / 255.0, axis=0)
predictions = model.signatures['serving_default'](tf.convert_to_tensor(img_array, dtype=tf.float32))['output_0']
top_label = labels[np.argmax(predictions.numpy())]
explanation = get_disease_detail(top_label)
return {top_label: predictions.numpy().max()}, explanation
# Example images
example_images = [[f"exp_eye_images/{img}"] for img in ["0_right_h.png", "03fd50da928d_dr.png", "108_right_h.png", "1062_right_c.png", "1084_right_c.png", "image_1002_g.jpg"]]
# Gradio Interface
interface = gr.Interface(
fn=predict_image,
inputs=gr.Image(type="pil"),
outputs=[gr.Label(num_top_classes=1, label="Prediction"), gr.HTML(label="Explanation", elem_classes=["scrollable-html"])],
examples=example_images,
title="DR Predictor",
description=("Upload an eye fundus image, and the model predicts the condition. This model is for educational use only."),
allow_flagging="never",
css=".scrollable-html {height: 206px; overflow-y: auto; border: 1px solid #ccc; padding: 10px; box-sizing: border-box;}"
)
interface.launch(share=True)
|