Muhammad Abdiel Al Hafiz commited on
Commit
a3cb4b9
·
1 Parent(s): f132889

just trying to fix again

Browse files
Files changed (1) hide show
  1. app.py +10 -17
app.py CHANGED
@@ -2,21 +2,19 @@ import gradio as gr
2
  import tensorflow as tf
3
  import numpy as np
4
  from PIL import Image
5
- import google.generativeai as genai
6
  import os
7
 
8
- # Load the model
9
  model_path = 'model'
10
  model = tf.saved_model.load(model_path)
11
 
12
- # Configure Google Gemini API
13
  api_key = os.getenv("GEMINI_API_KEY")
14
  genai.configure(api_key=api_key)
15
 
16
- # Labels for the classification model
17
  labels = ['cataract', 'diabetic_retinopathy', 'glaucoma', 'normal']
18
 
19
- # Function to get disease details from Gemini API
20
  def get_disease_detail(disease_name):
21
  prompt = (
22
  f"Diagnosis: {disease_name}\n\n"
@@ -25,30 +23,26 @@ def get_disease_detail(disease_name):
25
  "Suggestion\n(Suggestion to user)\n\n"
26
  "Reminder: Always seek professional help, such as a doctor."
27
  )
28
- response = genai.GenerativeModel("gemini-1.5-flash").generate_content(prompt)
29
-
30
- # Safely extract the content of the response (adjust if different field is used)
31
- return response.result.strip() if hasattr(response, 'result') else "No explanation available."
 
32
 
33
- # Prediction function for the image
34
  def predict_image(image):
35
- # Preprocess image
36
  image_resized = image.resize((224, 224))
37
  image_array = np.array(image_resized).astype(np.float32) / 255.0
38
  image_array = np.expand_dims(image_array, axis=0)
39
 
40
- # Get model predictions
41
  predictions = model.signatures['serving_default'](tf.convert_to_tensor(image_array, dtype=tf.float32))['output_0']
42
 
43
- # Get highest probability prediction
44
  top_index = np.argmax(predictions.numpy(), axis=1)[0]
45
  top_label = labels[top_index]
46
  top_probability = predictions.numpy()[0][top_index]
47
 
48
- # Fetch explanation from Gemini API
49
  explanation = get_disease_detail(top_label)
50
 
51
- # Return the prediction and the explanation
52
  return {top_label: top_probability}, explanation
53
 
54
  # Example images
@@ -67,7 +61,7 @@ interface = gr.Interface(
67
  inputs=gr.Image(type="pil"),
68
  outputs=[
69
  gr.Label(num_top_classes=1, label="Prediction"),
70
- gr.Textbox(label="Explanation") # Regular Textbox for normal text
71
  ],
72
  examples=example_images,
73
  title="Eye Diseases Classifier",
@@ -78,5 +72,4 @@ interface = gr.Interface(
78
  allow_flagging="never"
79
  )
80
 
81
- # Launch the interface
82
  interface.launch(share=True)
 
2
  import tensorflow as tf
3
  import numpy as np
4
  from PIL import Image
5
+ import google.generativeai as genai
6
  import os
7
 
8
+ # Load the TensorFlow model
9
  model_path = 'model'
10
  model = tf.saved_model.load(model_path)
11
 
12
+ # Configure Gemini API
13
  api_key = os.getenv("GEMINI_API_KEY")
14
  genai.configure(api_key=api_key)
15
 
 
16
  labels = ['cataract', 'diabetic_retinopathy', 'glaucoma', 'normal']
17
 
 
18
  def get_disease_detail(disease_name):
19
  prompt = (
20
  f"Diagnosis: {disease_name}\n\n"
 
23
  "Suggestion\n(Suggestion to user)\n\n"
24
  "Reminder: Always seek professional help, such as a doctor."
25
  )
26
+ try:
27
+ response = genai.GenerativeModel("gemini-1.5-flash").generate_content(prompt)
28
+ return response.text.strip()
29
+ except Exception as e:
30
+ return f"Error: {e}"
31
 
 
32
  def predict_image(image):
 
33
  image_resized = image.resize((224, 224))
34
  image_array = np.array(image_resized).astype(np.float32) / 255.0
35
  image_array = np.expand_dims(image_array, axis=0)
36
 
 
37
  predictions = model.signatures['serving_default'](tf.convert_to_tensor(image_array, dtype=tf.float32))['output_0']
38
 
39
+ # Highest prediction
40
  top_index = np.argmax(predictions.numpy(), axis=1)[0]
41
  top_label = labels[top_index]
42
  top_probability = predictions.numpy()[0][top_index]
43
 
 
44
  explanation = get_disease_detail(top_label)
45
 
 
46
  return {top_label: top_probability}, explanation
47
 
48
  # Example images
 
61
  inputs=gr.Image(type="pil"),
62
  outputs=[
63
  gr.Label(num_top_classes=1, label="Prediction"),
64
+ gr.Textbox(label="Explanation", lines=15)
65
  ],
66
  examples=example_images,
67
  title="Eye Diseases Classifier",
 
72
  allow_flagging="never"
73
  )
74
 
 
75
  interface.launch(share=True)