Muhammad Abdiel Al Hafiz commited on
Commit
f132889
·
1 Parent(s): e9ff5ea

just trying to fix

Browse files
Files changed (1) hide show
  1. app.py +14 -6
app.py CHANGED
@@ -5,14 +5,18 @@ from PIL import Image
5
  import google.generativeai as genai
6
  import os
7
 
 
8
  model_path = 'model'
9
  model = tf.saved_model.load(model_path)
10
 
 
11
  api_key = os.getenv("GEMINI_API_KEY")
12
  genai.configure(api_key=api_key)
13
 
 
14
  labels = ['cataract', 'diabetic_retinopathy', 'glaucoma', 'normal']
15
 
 
16
  def get_disease_detail(disease_name):
17
  prompt = (
18
  f"Diagnosis: {disease_name}\n\n"
@@ -23,25 +27,28 @@ def get_disease_detail(disease_name):
23
  )
24
  response = genai.GenerativeModel("gemini-1.5-flash").generate_content(prompt)
25
 
26
- # Safely extract text from response (without using candidates)
27
- return response.result if hasattr(response, 'result') else "No explanation available."
28
 
 
29
  def predict_image(image):
 
30
  image_resized = image.resize((224, 224))
31
  image_array = np.array(image_resized).astype(np.float32) / 255.0
32
  image_array = np.expand_dims(image_array, axis=0)
33
 
 
34
  predictions = model.signatures['serving_default'](tf.convert_to_tensor(image_array, dtype=tf.float32))['output_0']
35
 
36
- # Highest prediction
37
  top_index = np.argmax(predictions.numpy(), axis=1)[0]
38
  top_label = labels[top_index]
39
- top_probability = predictions.numpy()[0][top_index] * 100 # Convert to percentage
40
 
41
  # Fetch explanation from Gemini API
42
  explanation = get_disease_detail(top_label)
43
 
44
- # Construct output
45
  return {top_label: top_probability}, explanation
46
 
47
  # Example images
@@ -60,7 +67,7 @@ interface = gr.Interface(
60
  inputs=gr.Image(type="pil"),
61
  outputs=[
62
  gr.Label(num_top_classes=1, label="Prediction"),
63
- gr.Markdown(label="Explanation") # Using Markdown to render formatted text
64
  ],
65
  examples=example_images,
66
  title="Eye Diseases Classifier",
@@ -71,4 +78,5 @@ interface = gr.Interface(
71
  allow_flagging="never"
72
  )
73
 
 
74
  interface.launch(share=True)
 
5
  import google.generativeai as genai
6
  import os
7
 
8
+ # Load the model
9
  model_path = 'model'
10
  model = tf.saved_model.load(model_path)
11
 
12
+ # Configure Google Gemini API
13
  api_key = os.getenv("GEMINI_API_KEY")
14
  genai.configure(api_key=api_key)
15
 
16
+ # Labels for the classification model
17
  labels = ['cataract', 'diabetic_retinopathy', 'glaucoma', 'normal']
18
 
19
+ # Function to get disease details from Gemini API
20
  def get_disease_detail(disease_name):
21
  prompt = (
22
  f"Diagnosis: {disease_name}\n\n"
 
27
  )
28
  response = genai.GenerativeModel("gemini-1.5-flash").generate_content(prompt)
29
 
30
+ # Safely extract the content of the response (adjust if different field is used)
31
+ return response.result.strip() if hasattr(response, 'result') else "No explanation available."
32
 
33
+ # Prediction function for the image
34
  def predict_image(image):
35
+ # Preprocess image
36
  image_resized = image.resize((224, 224))
37
  image_array = np.array(image_resized).astype(np.float32) / 255.0
38
  image_array = np.expand_dims(image_array, axis=0)
39
 
40
+ # Get model predictions
41
  predictions = model.signatures['serving_default'](tf.convert_to_tensor(image_array, dtype=tf.float32))['output_0']
42
 
43
+ # Get highest probability prediction
44
  top_index = np.argmax(predictions.numpy(), axis=1)[0]
45
  top_label = labels[top_index]
46
+ top_probability = predictions.numpy()[0][top_index]
47
 
48
  # Fetch explanation from Gemini API
49
  explanation = get_disease_detail(top_label)
50
 
51
+ # Return the prediction and the explanation
52
  return {top_label: top_probability}, explanation
53
 
54
  # Example images
 
67
  inputs=gr.Image(type="pil"),
68
  outputs=[
69
  gr.Label(num_top_classes=1, label="Prediction"),
70
+ gr.Textbox(label="Explanation") # Regular Textbox for normal text
71
  ],
72
  examples=example_images,
73
  title="Eye Diseases Classifier",
 
78
  allow_flagging="never"
79
  )
80
 
81
+ # Launch the interface
82
  interface.launch(share=True)