Muhammad Abdiel Al Hafiz commited on
Commit
4a077d0
·
1 Parent(s): b3627a5

adjust output for disease explanation

Browse files
Files changed (1) hide show
  1. app.py +14 -44
app.py CHANGED
@@ -5,75 +5,45 @@ from PIL import Image
5
  import google.generativeai as genai
6
  import os
7
 
8
- # Load TensorFlow model
9
  model_path = 'model'
10
  model = tf.saved_model.load(model_path)
11
 
12
- # Set up Gemini API
13
  api_key = os.getenv("GEMINI_API_KEY")
14
  genai.configure(api_key=api_key)
15
 
16
- # Labels for classification
17
  labels = ['cataract', 'diabetic_retinopathy', 'glaucoma', 'normal']
18
 
19
  def get_disease_detail(disease_name):
20
  prompt = (
21
  f"Diagnosis: {disease_name}\n\n"
22
- "What is it?\n(Description about the disease)\n\n"
23
- "What causes it?\n(Explain what causes the disease)\n\n"
24
- "Suggestions\n(Suggestion to user)\n\n"
25
  "Reminder: Always seek professional help, such as a doctor."
26
  )
27
- response = genai.GenerativeModel("gemini-1.5-flash").generate_content(prompt)
28
-
29
- # Make sure we check for candidates and handle possible missing attributes correctly
30
- if response.candidates and response.candidates[0].text:
31
- return response.candidates[0].text.strip()
32
- else:
33
- return "No detailed explanation available."
34
-
35
- def safe_extract_section(text, start_keyword, end_keyword):
36
- """ Safely extract sections from the Gemini response based on start and end keywords."""
37
- if start_keyword in text and end_keyword in text:
38
- return text.split(start_keyword)[1].split(end_keyword)[0].strip()
39
- elif start_keyword in text:
40
- return text.split(start_keyword)[1].strip()
41
- else:
42
- return "Information not available."
43
 
44
  def predict_image(image):
45
- # Preprocess the image
46
  image_resized = image.resize((224, 224))
47
  image_array = np.array(image_resized).astype(np.float32) / 255.0
48
  image_array = np.expand_dims(image_array, axis=0)
49
 
50
- # Run prediction
51
  predictions = model.signatures['serving_default'](tf.convert_to_tensor(image_array, dtype=tf.float32))['output_0']
 
 
52
  top_index = np.argmax(predictions.numpy(), axis=1)[0]
53
  top_label = labels[top_index]
54
  top_probability = predictions.numpy()[0][top_index] * 100 # Convert to percentage
55
 
56
- # Get explanation from Gemini API
57
  explanation = get_disease_detail(top_label)
58
 
59
- # Extract relevant sections from the explanation
60
- diagnosis_section = f"**Diagnosis:** {top_label}"
61
- what_is_it = safe_extract_section(explanation, "What is it?", "What causes it?")
62
- causes = safe_extract_section(explanation, "What causes it?", "Suggestions")
63
- suggestions = safe_extract_section(explanation, "Suggestions", "Reminder")
64
- reminder = "Always seek professional help, such as a doctor."
65
-
66
- # Format explanation
67
- formatted_explanation = (
68
- f"{diagnosis_section}\n\n"
69
- f"**What is it?** {what_is_it}\n\n"
70
- f"**What causes it?** {causes}\n\n"
71
- f"**Suggestions:** {suggestions}\n\n"
72
- f"**Reminder:** {reminder}"
73
- )
74
-
75
- # Return both the prediction and the explanation
76
- return {top_label: top_probability}, formatted_explanation
77
 
78
  # Example images
79
  example_images = [
@@ -91,7 +61,7 @@ interface = gr.Interface(
91
  inputs=gr.Image(type="pil"),
92
  outputs=[
93
  gr.Label(num_top_classes=1, label="Prediction"),
94
- gr.Textbox(label="Explanation")
95
  ],
96
  examples=example_images,
97
  title="Eye Diseases Classifier",
 
5
  import google.generativeai as genai
6
  import os
7
 
8
+ # Load the TensorFlow model
9
  model_path = 'model'
10
  model = tf.saved_model.load(model_path)
11
 
12
+ # Configure Gemini API
13
  api_key = os.getenv("GEMINI_API_KEY")
14
  genai.configure(api_key=api_key)
15
 
 
16
  labels = ['cataract', 'diabetic_retinopathy', 'glaucoma', 'normal']
17
 
18
  def get_disease_detail(disease_name):
19
  prompt = (
20
  f"Diagnosis: {disease_name}\n\n"
21
+ "What is it?\n(Description about {disease_name})\n\n"
22
+ "What causes it?\n(Explain what causes {disease_name})\n\n"
23
+ "Suggestion\n(Suggestion to user)\n\n"
24
  "Reminder: Always seek professional help, such as a doctor."
25
  )
26
+ try:
27
+ response = genai.GenerativeModel("gemini-1.5-flash").generate_content(prompt)
28
+ return response.text.strip()
29
+ except Exception as e:
30
+ return f"Error: {e}"
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  def predict_image(image):
 
33
  image_resized = image.resize((224, 224))
34
  image_array = np.array(image_resized).astype(np.float32) / 255.0
35
  image_array = np.expand_dims(image_array, axis=0)
36
 
 
37
  predictions = model.signatures['serving_default'](tf.convert_to_tensor(image_array, dtype=tf.float32))['output_0']
38
+
39
+ # Highest prediction
40
  top_index = np.argmax(predictions.numpy(), axis=1)[0]
41
  top_label = labels[top_index]
42
  top_probability = predictions.numpy()[0][top_index] * 100 # Convert to percentage
43
 
 
44
  explanation = get_disease_detail(top_label)
45
 
46
+ return {top_label: top_probability}, explanation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  # Example images
49
  example_images = [
 
61
  inputs=gr.Image(type="pil"),
62
  outputs=[
63
  gr.Label(num_top_classes=1, label="Prediction"),
64
+ gr.Textbox(label="Explanation", lines=15)
65
  ],
66
  examples=example_images,
67
  title="Eye Diseases Classifier",