try2 / app.py
vikram0B's picture
Update app.py
868baca verified
import gradio as gr
import tensorflow as tf
import numpy as np
from PIL import Image
import google.generativeai as genai
import os
import markdown2
# Load TensorFlow model
model = tf.saved_model.load('model')
labels = ['cataract', 'diabetic_retinopathy', 'glaucoma', 'normal']
# Configure Gemini API
genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
# Generate AI-based explanation for the predicted disease
def get_disease_detail(disease):
prompt = (
"Create a text congratulating on healthy eyes with tips to keep them healthy."
if disease == "normal" else
f"Diagnosis: {disease}\n\n"
f"What is {disease}?\nCauses and suggestions to prevent {disease}."
)
try:
response = genai.GenerativeModel("gemini-1.5-flash").generate_content(prompt)
return markdown2.markdown(response.text.strip() if response and response.text else "No response.")
except Exception as e:
return f"Error: {e}"
# Process and predict uploaded image
def predict_image(image):
img_array = np.expand_dims(np.array(image.resize((224, 224))).astype(np.float32) / 255.0, axis=0)
predictions = model.signatures['serving_default'](tf.convert_to_tensor(img_array, dtype=tf.float32))['output_0']
top_label = labels[np.argmax(predictions.numpy())]
explanation = get_disease_detail(top_label)
return {top_label: predictions.numpy().max()}, explanation
# Example images
example_images = [[f"exp_eye_images/{img}"] for img in ["0_right_h.png", "03fd50da928d_dr.png", "108_right_h.png", "1062_right_c.png", "1084_right_c.png", "image_1002_g.jpg"]]
# Gradio Interface
interface = gr.Interface(
fn=predict_image,
inputs=gr.Image(type="pil"),
outputs=[gr.Label(num_top_classes=1, label="Prediction"), gr.HTML(label="Explanation", elem_classes=["scrollable-html"])],
examples=example_images,
title="DR Predictor",
description=("Upload an eye fundus image, and the model predicts the condition."),
allow_flagging="never",
css=".scrollable-html {height: 206px; overflow-y: auto; border: 1px solid #ccc; padding: 10px; box-sizing: border-box;}"
)
interface.launch(share=True)