Muhammad Abdiel Al Hafiz commited on
Commit
6704e8f
·
1 Parent(s): 910566d

adjust output for disease explanation

Browse files
Files changed (1) hide show
  1. app.py +54 -31
app.py CHANGED
@@ -5,48 +5,70 @@ from PIL import Image
5
  import google.generativeai as genai
6
  import os
7
 
 
8
  model_path = 'model'
9
  model = tf.saved_model.load(model_path)
10
 
 
11
  api_key = os.getenv("GEMINI_API_KEY")
12
  genai.configure(api_key=api_key)
13
 
 
14
  labels = ['cataract', 'diabetic_retinopathy', 'glaucoma', 'normal']
15
 
16
  def get_disease_detail(disease_name):
17
- prompt = (
18
- f"Diagnosis: {disease_name}\n\n"
19
- "What is it?\n(Description about {disease_name})\n\n"
20
- "What cause it?\n(Explain what causes {disease_name})\n\n"
21
- "Suggestion\n(Suggestion to user)\n\n"
22
- "Reminder: Always seek professional help, such as a doctor."
23
- )
24
- response = genai.GenerativeModel("gemini-1.5-flash").generate_content(prompt)
25
- return response.text.strip()
 
 
 
 
 
 
 
 
 
26
 
27
  def predict_image(image):
28
- image_resized = image.resize((224, 224))
29
- image_array = np.array(image_resized).astype(np.float32) / 255.0
30
- image_array = np.expand_dims(image_array, axis=0)
 
 
 
 
 
 
 
31
 
32
- predictions = model.signatures['serving_default'](tf.convert_to_tensor(image_array, dtype=tf.float32))['output_0']
33
-
34
- # Highest prediction
35
- top_index = np.argmax(predictions.numpy(), axis=1)[0]
36
- top_label = labels[top_index]
37
- top_probability = predictions.numpy()[0][top_index]
38
 
39
- explanation = get_disease_detail(top_label)
 
 
 
 
 
40
 
41
- formatted_explanation = (
42
- f"**Diagnosis:** {top_label}\n\n"
43
- f"**What is it?** {explanation.split('What causes it?')[0].split('Suggestions')[0].strip()}\n\n"
44
- f"**What causes it?** {explanation.split('What causes it?')[1].split('Suggestions')[0].strip()}\n\n"
45
- f"**Suggestions:** {explanation.split('Suggestions')[1].split('Reminder')[0].strip()}\n\n"
46
- f"**Reminder:** Always seek professional help, such as a doctor."
47
- )
 
48
 
49
- return {top_label: top_probability}, formatted_explanation
 
50
 
51
  # Example images
52
  example_images = [
@@ -63,13 +85,14 @@ interface = gr.Interface(
63
  fn=predict_image,
64
  inputs=gr.Image(type="pil"),
65
  outputs=[
66
- gr.Label(num_top_classes=1, label="Prediction"),
67
- gr.Textbox(label="Explanation")],
 
68
  examples=example_images,
69
  title="Eye Diseases Classifier",
70
  description=(
71
- "Upload an image of an eye fundus, and the model will predict it.\n\n"
72
- "**Disclaimer:** This model is intended as a form of learning process in the field of health-related machine learning and was trained with a limited amount and variety of data with a total of about 4000 data, so the prediction results may not always be correct. There is still a lot of room for improvisation on this model in the future."
73
  ),
74
  allow_flagging="never"
75
  )
 
5
  import google.generativeai as genai
6
  import os
7
 
8
+ # Load TensorFlow model
9
  model_path = 'model'
10
  model = tf.saved_model.load(model_path)
11
 
12
+ # Set up Gemini API
13
  api_key = os.getenv("GEMINI_API_KEY")
14
  genai.configure(api_key=api_key)
15
 
16
+ # Labels for classification
17
  labels = ['cataract', 'diabetic_retinopathy', 'glaucoma', 'normal']
18
 
19
  def get_disease_detail(disease_name):
20
+ prompt = (
21
+ f"Diagnosis: {disease_name}\n\n"
22
+ "What is it?\n(Description about the disease)\n\n"
23
+ "What causes it?\n(Explain what causes the disease)\n\n"
24
+ "Suggestions\n(Suggestion to user)\n\n"
25
+ "Reminder: Always seek professional help, such as a doctor."
26
+ )
27
+ response = genai.GenerativeModel("gemini-1.5-flash").generate_content(prompt)
28
+ return response.text.strip()
29
+
30
+ def safe_extract_section(text, start_keyword, end_keyword):
31
+ """ Safely extract sections from the Gemini response based on start and end keywords."""
32
+ if start_keyword in text and end_keyword in text:
33
+ return text.split(start_keyword)[1].split(end_keyword)[0].strip()
34
+ elif start_keyword in text:
35
+ return text.split(start_keyword)[1].strip()
36
+ else:
37
+ return "Information not available."
38
 
39
  def predict_image(image):
40
+ # Preprocess the image
41
+ image_resized = image.resize((224, 224))
42
+ image_array = np.array(image_resized).astype(np.float32) / 255.0
43
+ image_array = np.expand_dims(image_array, axis=0)
44
+
45
+ # Run prediction
46
+ predictions = model.signatures['serving_default'](tf.convert_to_tensor(image_array, dtype=tf.float32))['output_0']
47
+ top_index = np.argmax(predictions.numpy(), axis=1)[0]
48
+ top_label = labels[top_index]
49
+ top_probability = predictions.numpy()[0][top_index] * 100 # Convert to percentage
50
 
51
+ # Get explanation from Gemini API
52
+ explanation = get_disease_detail(top_label)
 
 
 
 
53
 
54
+ # Extract relevant sections from the explanation
55
+ diagnosis_section = f"**Diagnosis:** {top_label}"
56
+ what_is_it = safe_extract_section(explanation, "What is it?", "What causes it?")
57
+ causes = safe_extract_section(explanation, "What causes it?", "Suggestions")
58
+ suggestions = safe_extract_section(explanation, "Suggestions", "Reminder")
59
+ reminder = "Always seek professional help, such as a doctor."
60
 
61
+ # Format explanation
62
+ formatted_explanation = (
63
+ f"{diagnosis_section}\n\n"
64
+ f"**What is it?** {what_is_it}\n\n"
65
+ f"**What causes it?** {causes}\n\n"
66
+ f"**Suggestions:** {suggestions}\n\n"
67
+ f"**Reminder:** {reminder}"
68
+ )
69
 
70
+ # Return both the prediction and the explanation
71
+ return {top_label: top_probability}, formatted_explanation
72
 
73
  # Example images
74
  example_images = [
 
85
  fn=predict_image,
86
  inputs=gr.Image(type="pil"),
87
  outputs=[
88
+ gr.Label(num_top_classes=1, label="Prediction"),
89
+ gr.Textbox(label="Explanation")
90
+ ],
91
  examples=example_images,
92
  title="Eye Diseases Classifier",
93
  description=(
94
+ "Upload an image of an eye fundus, and the model will predict it.\n\n"
95
+ "**Disclaimer:** This model is intended as a form of learning process in the field of health-related machine learning and was trained with a limited amount and variety of data with a total of about 4000 data, so the prediction results may not always be correct. There is still a lot of room for improvisation on this model in the future."
96
  ),
97
  allow_flagging="never"
98
  )