|
import gradio as gr |
|
import tensorflow as tf |
|
import numpy as np |
|
from PIL import Image |
|
import google.generativeai as genai |
|
import os |
|
import markdown2 |
|
|
|
|
|
model_path = 'model' |
|
model = tf.saved_model.load(model_path) |
|
|
|
|
|
api_key = os.getenv("GEMINI_API_KEY") |
|
genai.configure(api_key=api_key) |
|
|
|
labels = ['cataract', 'diabetic_retinopathy', 'glaucoma', 'normal'] |
|
|
|
def get_disease_detail(disease_name): |
|
|
|
|
|
|
|
|
|
|
|
prompt = ( |
|
"You are an Ophthalmologist with over 25 years of experience, you have treated thousands of patients with various eye diseases including cataracts, diabetic retinopathy and glaucoma. The entire medical process from disease identification to patient management is second nature to you and you are used to it. Your job is to critically and comprehensively make recommendations based on the diagnosis, the recommendations contain actions that can be taken on the patient, no need to re-explain the disease. In every recommendation you must remind the patient to always see the Ophthalmologist to validate the diagnosis and recommendation.\n" |
|
f"The diagnosis is {disease_name}, what are your recommendations?" |
|
) |
|
try: |
|
response = genai.GenerativeModel("gemini-1.5-flash").generate_content(prompt) |
|
return markdown2.markdown(response.text.strip()) |
|
except Exception as e: |
|
return f"Error: {e}" |
|
|
|
def predict_image(image): |
|
image_resized = image.resize((224, 224)) |
|
image_array = np.array(image_resized).astype(np.float32) / 255.0 |
|
image_array = np.expand_dims(image_array, axis=0) |
|
|
|
predictions = model.signatures['serving_default'](tf.convert_to_tensor(image_array, dtype=tf.float32))['output_0'] |
|
|
|
|
|
top_index = np.argmax(predictions.numpy(), axis=1)[0] |
|
top_label = labels[top_index] |
|
top_probability = predictions.numpy()[0][top_index] |
|
|
|
explanation = get_disease_detail(top_label) |
|
|
|
return {top_label: top_probability}, explanation |
|
|
|
|
|
example_images = [ |
|
["exp_eye_images/0_right_h.png"], |
|
["exp_eye_images/03fd50da928d_dr.png"], |
|
["exp_eye_images/108_right_h.png"], |
|
["exp_eye_images/1062_right_c.png"], |
|
["exp_eye_images/1084_right_c.png"], |
|
["exp_eye_images/image_1002_g.jpg"] |
|
] |
|
|
|
|
|
css = """ |
|
.scrollable-html { |
|
height: 206px; |
|
overflow-y: auto; |
|
border: 1px solid #ccc; |
|
padding: 10px; |
|
box-sizing: border-box; |
|
} |
|
""" |
|
|
|
|
|
interface = gr.Interface( |
|
fn=predict_image, |
|
inputs=gr.Image(type="pil"), |
|
outputs=[ |
|
gr.Label(num_top_classes=1, label="Prediction"), |
|
gr.HTML(label="Explanation", elem_classes=["scrollable-html"]) |
|
], |
|
examples=example_images, |
|
title="Eye Diseases Classifier", |
|
description=( |
|
"Upload an image of an eye fundus, and the model will predict it.\n\n" |
|
"**Disclaimer:** This model is intended as a form of learning process in the field of health-related machine learning and was trained with a limited amount and variety of data with a total of about 4000 data, so the prediction results may not always be correct. There is still a lot of room for improvisation on this model in the future." |
|
), |
|
allow_flagging="never", |
|
css=css |
|
) |
|
|
|
interface.launch(share=True) |
|
|