File size: 2,161 Bytes
abe0ac5
 
 
 
868baca
abe0ac5
 
 
868baca
 
abe0ac5
 
868baca
 
 
 
 
 
 
 
 
 
 
abe0ac5
 
868baca
abe0ac5
 
 
868baca
abe0ac5
868baca
 
abe0ac5
868baca
abe0ac5
 
868baca
abe0ac5
 
868baca
abe0ac5
 
 
 
 
868baca
abe0ac5
868baca
 
abe0ac5
868baca
abe0ac5
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import gradio as gr
import tensorflow as tf
import numpy as np
from PIL import Image
import google.generativeai as genai
import os
import markdown2

# Load TensorFlow model
model = tf.saved_model.load('model')
labels = ['cataract', 'diabetic_retinopathy', 'glaucoma', 'normal']

# Configure Gemini API
genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))

# Generate AI-based explanation for the predicted disease
def get_disease_detail(disease):
    prompt = (
        "Create a text congratulating on healthy eyes with tips to keep them healthy."
        if disease == "normal" else
        f"Diagnosis: {disease}\n\n"
        f"What is {disease}?\nCauses and suggestions to prevent {disease}."
    )
    try:
        response = genai.GenerativeModel("gemini-1.5-flash").generate_content(prompt)
        return markdown2.markdown(response.text.strip() if response and response.text else "No response.")
    except Exception as e:
        return f"Error: {e}"

# Process and predict uploaded image
def predict_image(image):
    img_array = np.expand_dims(np.array(image.resize((224, 224))).astype(np.float32) / 255.0, axis=0)
    predictions = model.signatures['serving_default'](tf.convert_to_tensor(img_array, dtype=tf.float32))['output_0']

    top_label = labels[np.argmax(predictions.numpy())]
    explanation = get_disease_detail(top_label)

    return {top_label: predictions.numpy().max()}, explanation

# Example images
example_images = [[f"exp_eye_images/{img}"] for img in ["0_right_h.png", "03fd50da928d_dr.png", "108_right_h.png", "1062_right_c.png", "1084_right_c.png", "image_1002_g.jpg"]]

# Gradio Interface
interface = gr.Interface(
    fn=predict_image,
    inputs=gr.Image(type="pil"),
    outputs=[gr.Label(num_top_classes=1, label="Prediction"), gr.HTML(label="Explanation", elem_classes=["scrollable-html"])],
    examples=example_images,
    title="DR Predictor",
    description=("Upload an eye fundus image, and the model predicts the condition."),
    allow_flagging="never",
    css=".scrollable-html {height: 206px; overflow-y: auto; border: 1px solid #ccc; padding: 10px; box-sizing: border-box;}"
)

interface.launch(share=True)