File size: 2,203 Bytes
d566fee
 
 
 
3fdf143
ee36b3c
ba699eb
d566fee
db0ef97
3fdf143
d5bc862
 
db0ef97
 
 
 
 
 
 
 
 
 
 
a3cb4b9
db0ef97
 
a3cb4b9
 
ee36b3c
db0ef97
d566fee
db0ef97
 
d566fee
3fdf143
db0ef97
 
 
 
 
 
d566fee
db0ef97
d566fee
 
 
3fdf143
db0ef97
3fdf143
db0ef97
ba699eb
3fdf143
d566fee
 
fe44a5a
3fdf143
d5bc862
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import gradio as gr
import tensorflow as tf
import numpy as np
from PIL import Image
import google.generativeai as genai
import os
import markdown2

# Load TensorFlow model
model = tf.saved_model.load('model')
labels = ['cataract', 'diabetic_retinopathy', 'glaucoma', 'normal']

# Configure Gemini API
genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))

# Generate AI-based explanation for the predicted disease
def get_disease_detail(disease):
    prompt = (
        "Create a text congratulating on healthy eyes with tips to keep them healthy."
        if disease == "normal" else
        f"Diagnosis: {disease}\n\n"
        f"What is {disease}?\nCauses and suggestions to prevent {disease}."
    )
    try:
        response = genai.GenerativeModel("gemini-1.5-flash").generate_content(prompt)
        return markdown2.markdown(response.text.strip() if response and response.text else "No response.")
    except Exception as e:
        return f"Error: {e}"

# Process and predict uploaded image
def predict_image(image):
    img_array = np.expand_dims(np.array(image.resize((224, 224))).astype(np.float32) / 255.0, axis=0)
    predictions = model.signatures['serving_default'](tf.convert_to_tensor(img_array, dtype=tf.float32))['output_0']

    top_label = labels[np.argmax(predictions.numpy())]
    explanation = get_disease_detail(top_label)

    return {top_label: predictions.numpy().max()}, explanation

# Example images
example_images = [[f"exp_eye_images/{img}"] for img in ["0_right_h.png", "03fd50da928d_dr.png", "108_right_h.png", "1062_right_c.png", "1084_right_c.png", "image_1002_g.jpg"]]

# Gradio Interface
interface = gr.Interface(
    fn=predict_image,
    inputs=gr.Image(type="pil"),
    outputs=[gr.Label(num_top_classes=1, label="Prediction"), gr.HTML(label="Explanation", elem_classes=["scrollable-html"])],
    examples=example_images,
    title="DR Predictor",
    description=("Upload an eye fundus image, and the model predicts the condition. This model is for educational use only."),
    allow_flagging="never",
    css=".scrollable-html {height: 206px; overflow-y: auto; border: 1px solid #ccc; padding: 10px; box-sizing: border-box;}"
)

interface.launch(share=True)