try2 / app.py
vikram0B's picture
Update app.py
868baca verified
raw
history blame
2.16 kB
import gradio as gr
import tensorflow as tf
import numpy as np
from PIL import Image
import google.generativeai as genai
import os
import markdown2
# Load TensorFlow model
model = tf.saved_model.load('model')
labels = ['cataract', 'diabetic_retinopathy', 'glaucoma', 'normal']
# Configure Gemini API
genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
# Generate AI-based explanation for the predicted disease
def get_disease_detail(disease):
prompt = (
"Create a text congratulating on healthy eyes with tips to keep them healthy."
if disease == "normal" else
f"Diagnosis: {disease}\n\n"
f"What is {disease}?\nCauses and suggestions to prevent {disease}."
)
try:
response = genai.GenerativeModel("gemini-1.5-flash").generate_content(prompt)
return markdown2.markdown(response.text.strip() if response and response.text else "No response.")
except Exception as e:
return f"Error: {e}"
# Process and predict uploaded image
def predict_image(image):
img_array = np.expand_dims(np.array(image.resize((224, 224))).astype(np.float32) / 255.0, axis=0)
predictions = model.signatures['serving_default'](tf.convert_to_tensor(img_array, dtype=tf.float32))['output_0']
top_label = labels[np.argmax(predictions.numpy())]
explanation = get_disease_detail(top_label)
return {top_label: predictions.numpy().max()}, explanation
# Example images
example_images = [[f"exp_eye_images/{img}"] for img in ["0_right_h.png", "03fd50da928d_dr.png", "108_right_h.png", "1062_right_c.png", "1084_right_c.png", "image_1002_g.jpg"]]
# Gradio Interface
interface = gr.Interface(
fn=predict_image,
inputs=gr.Image(type="pil"),
outputs=[gr.Label(num_top_classes=1, label="Prediction"), gr.HTML(label="Explanation", elem_classes=["scrollable-html"])],
examples=example_images,
title="DR Predictor",
description=("Upload an eye fundus image, and the model predicts the condition."),
allow_flagging="never",
css=".scrollable-html {height: 206px; overflow-y: auto; border: 1px solid #ccc; padding: 10px; box-sizing: border-box;}"
)
interface.launch(share=True)