Update app.py
Browse files
app.py
CHANGED
|
@@ -6,51 +6,45 @@ import google.generativeai as genai
|
|
| 6 |
import os
|
| 7 |
import markdown2
|
| 8 |
|
| 9 |
-
# Load TensorFlow model
|
| 10 |
model = tf.saved_model.load('model')
|
| 11 |
-
labels = ['cataract', 'diabetic_retinopathy', 'glaucoma', 'normal']
|
| 12 |
-
|
| 13 |
-
# Configure Gemini API
|
| 14 |
genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
|
| 15 |
|
| 16 |
-
#
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
|
|
|
| 24 |
try:
|
| 25 |
-
response = genai.GenerativeModel("gemini-1.5-flash").generate_content(
|
| 26 |
-
return markdown2.markdown(response.text.strip() if response and response.text else "No response."
|
| 27 |
except Exception as e:
|
| 28 |
return f"Error: {e}"
|
| 29 |
|
| 30 |
-
#
|
| 31 |
def predict_image(image):
|
| 32 |
-
img_array = np.expand_dims(np.array(image.resize((224, 224)))
|
| 33 |
-
predictions = model.signatures['serving_default'](tf.convert_to_tensor(img_array
|
| 34 |
|
| 35 |
top_label = labels[np.argmax(predictions.numpy())]
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
return {top_label: predictions.numpy().max()}, explanation
|
| 39 |
|
| 40 |
-
#
|
| 41 |
-
example_images = [[f"exp_eye_images/{img}"] for img in ["0_right_h.png", "03fd50da928d_dr.png", "108_right_h.png", "1062_right_c.png", "1084_right_c.png", "image_1002_g.jpg"]]
|
| 42 |
-
|
| 43 |
-
# Gradio Interface
|
| 44 |
interface = gr.Interface(
|
| 45 |
fn=predict_image,
|
| 46 |
inputs=gr.Image(type="pil"),
|
| 47 |
outputs=[gr.Label(num_top_classes=1, label="Prediction"), gr.HTML(label="Explanation", elem_classes=["scrollable-html"])],
|
| 48 |
-
examples=
|
| 49 |
title="DR Predictor",
|
| 50 |
-
description=
|
| 51 |
allow_flagging="never",
|
| 52 |
css=".scrollable-html {height: 206px; overflow-y: auto; border: 1px solid #ccc; padding: 10px; box-sizing: border-box;}"
|
| 53 |
)
|
| 54 |
|
| 55 |
interface.launch(share=True)
|
| 56 |
|
|
|
|
|
|
| 6 |
import os
|
| 7 |
import markdown2
|
| 8 |
|
| 9 |
+
# Load TensorFlow model & configure Gemini API
|
| 10 |
model = tf.saved_model.load('model')
|
|
|
|
|
|
|
|
|
|
| 11 |
genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
|
| 12 |
|
| 13 |
+
# Disease labels & inline prompt generation for AI response
|
| 14 |
+
labels = ['cataract', 'diabetic_retinopathy', 'glaucoma', 'normal']
|
| 15 |
+
prompt_map = lambda disease: (
|
| 16 |
+
"Provide a congratulatory message for healthy eyes with tips for maintenance." if disease == "normal"
|
| 17 |
+
else f"Diagnosis: {disease}\nDescription, causes, and prevention advice for {disease}."
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
# Function for AI-based disease explanation
|
| 21 |
+
def generate_explanation(disease):
|
| 22 |
try:
|
| 23 |
+
response = genai.GenerativeModel("gemini-1.5-flash").generate_content(prompt_map(disease))
|
| 24 |
+
return markdown2.markdown(response.text.strip()) if response and response.text else "No response available."
|
| 25 |
except Exception as e:
|
| 26 |
return f"Error: {e}"
|
| 27 |
|
| 28 |
+
# Image processing & prediction
|
| 29 |
def predict_image(image):
|
| 30 |
+
img_array = np.expand_dims(np.array(image.resize((224, 224))) / 255.0, axis=0).astype(np.float32)
|
| 31 |
+
predictions = model.signatures['serving_default'](tf.convert_to_tensor(img_array))['output_0']
|
| 32 |
|
| 33 |
top_label = labels[np.argmax(predictions.numpy())]
|
| 34 |
+
return {top_label: predictions.numpy().max()}, generate_explanation(top_label)
|
|
|
|
|
|
|
| 35 |
|
| 36 |
+
# Gradio Interface with minimalist style
|
|
|
|
|
|
|
|
|
|
| 37 |
interface = gr.Interface(
|
| 38 |
fn=predict_image,
|
| 39 |
inputs=gr.Image(type="pil"),
|
| 40 |
outputs=[gr.Label(num_top_classes=1, label="Prediction"), gr.HTML(label="Explanation", elem_classes=["scrollable-html"])],
|
| 41 |
+
examples=[[f"exp_eye_images/{img}"] for img in ["0_right_h.png", "03fd50da928d_dr.png", "108_right_h.png", "1062_right_c.png", "1084_right_c.png", "image_1002_g.jpg"]],
|
| 42 |
title="DR Predictor",
|
| 43 |
+
description="Upload an eye fundus image to receive a condition prediction. *For educational use only*.",
|
| 44 |
allow_flagging="never",
|
| 45 |
css=".scrollable-html {height: 206px; overflow-y: auto; border: 1px solid #ccc; padding: 10px; box-sizing: border-box;}"
|
| 46 |
)
|
| 47 |
|
| 48 |
interface.launch(share=True)
|
| 49 |
|
| 50 |
+
|