vikram0B commited on
Commit
db0ef97
·
verified ·
1 Parent(s): d5bc862

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -19
app.py CHANGED
@@ -6,41 +6,48 @@ import google.generativeai as genai
6
  import os
7
  import markdown2
8
 
9
- # Load TensorFlow model & configure Gemini API
10
  model = tf.saved_model.load('model')
11
- genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
12
-
13
- # Disease labels & inline prompt generation for AI response
14
  labels = ['cataract', 'diabetic_retinopathy', 'glaucoma', 'normal']
15
- prompt_map = lambda disease: (
16
- "Provide a congratulatory message for healthy eyes with tips for maintenance." if disease == "normal"
17
- else f"Diagnosis: {disease}\nDescription, causes, and prevention advice for {disease}."
18
- )
19
 
20
- # Function for AI-based disease explanation
21
- def generate_explanation(disease):
 
 
 
 
 
 
 
 
 
22
  try:
23
- response = genai.GenerativeModel("gemini-1.5-flash").generate_content(prompt_map(disease))
24
- return markdown2.markdown(response.text.strip()) if response and response.text else "No response available."
25
  except Exception as e:
26
  return f"Error: {e}"
27
 
28
- # Image processing & prediction
29
  def predict_image(image):
30
- img_array = np.expand_dims(np.array(image.resize((224, 224))) / 255.0, axis=0).astype(np.float32)
31
- predictions = model.signatures['serving_default'](tf.convert_to_tensor(img_array))['output_0']
32
 
33
  top_label = labels[np.argmax(predictions.numpy())]
34
- return {top_label: predictions.numpy().max()}, generate_explanation(top_label)
 
 
 
 
 
35
 
36
- # Gradio Interface with minimalist style
37
  interface = gr.Interface(
38
  fn=predict_image,
39
  inputs=gr.Image(type="pil"),
40
  outputs=[gr.Label(num_top_classes=1, label="Prediction"), gr.HTML(label="Explanation", elem_classes=["scrollable-html"])],
41
- examples=[[f"exp_eye_images/{img}"] for img in ["0_right_h.png", "03fd50da928d_dr.png", "108_right_h.png", "1062_right_c.png", "1084_right_c.png", "image_1002_g.jpg"]],
42
  title="DR Predictor",
43
- description="Upload an eye fundus image to receive a condition prediction. *For educational use only*.",
44
  allow_flagging="never",
45
  css=".scrollable-html {height: 206px; overflow-y: auto; border: 1px solid #ccc; padding: 10px; box-sizing: border-box;}"
46
  )
 
6
  import os
7
  import markdown2
8
 
9
+ # Load TensorFlow model
10
  model = tf.saved_model.load('model')
 
 
 
11
  labels = ['cataract', 'diabetic_retinopathy', 'glaucoma', 'normal']
 
 
 
 
12
 
13
+ # Configure Gemini API
14
+ genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
15
+
16
+ # Generate AI-based explanation for the predicted disease
17
+ def get_disease_detail(disease):
18
+ prompt = (
19
+ "Create a text congratulating on healthy eyes with tips to keep them healthy."
20
+ if disease == "normal" else
21
+ f"Diagnosis: {disease}\n\n"
22
+ f"What is {disease}?\nCauses and suggestions to prevent {disease}."
23
+ )
24
  try:
25
+ response = genai.GenerativeModel("gemini-1.5-flash").generate_content(prompt)
26
+ return markdown2.markdown(response.text.strip() if response and response.text else "No response.")
27
  except Exception as e:
28
  return f"Error: {e}"
29
 
30
+ # Process and predict uploaded image
31
  def predict_image(image):
32
+ img_array = np.expand_dims(np.array(image.resize((224, 224))).astype(np.float32) / 255.0, axis=0)
33
+ predictions = model.signatures['serving_default'](tf.convert_to_tensor(img_array, dtype=tf.float32))['output_0']
34
 
35
  top_label = labels[np.argmax(predictions.numpy())]
36
+ explanation = get_disease_detail(top_label)
37
+
38
+ return {top_label: predictions.numpy().max()}, explanation
39
+
40
+ # Example images
41
+ example_images = [[f"exp_eye_images/{img}"] for img in ["0_right_h.png", "03fd50da928d_dr.png", "108_right_h.png", "1062_right_c.png", "1084_right_c.png", "image_1002_g.jpg"]]
42
 
43
+ # Gradio Interface
44
  interface = gr.Interface(
45
  fn=predict_image,
46
  inputs=gr.Image(type="pil"),
47
  outputs=[gr.Label(num_top_classes=1, label="Prediction"), gr.HTML(label="Explanation", elem_classes=["scrollable-html"])],
48
+ examples=example_images,
49
  title="DR Predictor",
50
+ description=("Upload an eye fundus image, and the model predicts the condition. This model is for educational use only."),
51
  allow_flagging="never",
52
  css=".scrollable-html {height: 206px; overflow-y: auto; border: 1px solid #ccc; padding: 10px; box-sizing: border-box;}"
53
  )