vikram0B commited on
Commit
abe0ac5
·
verified ·
1 Parent(s): 785921b

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -0
app.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import tensorflow as tf
3
+ import numpy as np
4
+ from PIL import Image
5
+ import google.generativeai as genai
6
+ import os
7
+ import markdown2
8
+
9
+ # Load the TensorFlow model
10
+ model_path = 'model'
11
+ model = tf.saved_model.load(model_path)
12
+
13
+ # Configure Gemini API
14
+ api_key = os.getenv("GEMINI_API_KEY")
15
+ genai.configure(api_key=api_key)
16
+
17
+ labels = ['cataract', 'diabetic_retinopathy', 'glaucoma', 'normal']
18
+
19
+ def get_disease_detail(disease_name):
20
+ if disease_name == "normal":
21
+ prompt = (
22
+ "Create a text that congratulates having healthy eyes and gives bullet point tips to keep eyes healthy."
23
+ )
24
+ else:
25
+ prompt = (
26
+ f"Diagnosis: {disease_name}\n\n"
27
+ "What is it?\n(Description about {disease_name})\n\n"
28
+ "What causes it?\n(Explain what causes {disease_name})\n\n"
29
+ "Suggestion\n(Suggestion to user)\n\n"
30
+ "Reminder: Always seek professional help, such as a doctor."
31
+ )
32
+ try:
33
+ response = genai.GenerativeModel("gemini-1.5-flash").generate_content(prompt)
34
+ return markdown2.markdown(response.text.strip())
35
+ except Exception as e:
36
+ return f"Error: {e}"
37
+
38
+ def predict_image(image):
39
+ image_resized = image.resize((224, 224))
40
+ image_array = np.array(image_resized).astype(np.float32) / 255.0
41
+ image_array = np.expand_dims(image_array, axis=0)
42
+
43
+ predictions = model.signatures['serving_default'](tf.convert_to_tensor(image_array, dtype=tf.float32))['output_0']
44
+
45
+ # Highest prediction
46
+ top_index = np.argmax(predictions.numpy(), axis=1)[0]
47
+ top_label = labels[top_index]
48
+ top_probability = predictions.numpy()[0][top_index]
49
+
50
+ explanation = get_disease_detail(top_label)
51
+
52
+ return {top_label: top_probability}, explanation
53
+
54
+ # Example images
55
+ example_images = [
56
+ ["exp_eye_images/0_right_h.png"],
57
+ ["exp_eye_images/03fd50da928d_dr.png"],
58
+ ["exp_eye_images/108_right_h.png"],
59
+ ["exp_eye_images/1062_right_c.png"],
60
+ ["exp_eye_images/1084_right_c.png"],
61
+ ["exp_eye_images/image_1002_g.jpg"]
62
+ ]
63
+
64
+ # Custom CSS for HTML height
65
+ css = """
66
+ .scrollable-html {
67
+ height: 206px;
68
+ overflow-y: auto;
69
+ border: 1px solid #ccc;
70
+ padding: 10px;
71
+ box-sizing: border-box;
72
+ }
73
+ """
74
+
75
+ # Gradio Interface
76
+ interface = gr.Interface(
77
+ fn=predict_image,
78
+ inputs=gr.Image(type="pil"),
79
+ outputs=[
80
+ gr.Label(num_top_classes=1, label="Prediction"),
81
+ gr.HTML(label="Explanation", elem_classes=["scrollable-html"])
82
+ ],
83
+ examples=example_images,
84
+ title="Eye Diseases Classifier",
85
+ description=(
86
+ "Upload an image of an eye fundus, and the model will predict it.\n\n"
87
+ "**Disclaimer:** This model is intended as a form of learning process in the field of health-related machine learning and was trained with a limited amount and variety of data with a total of about 4000 data, so the prediction results may not always be correct. There is still a lot of room for improvisation on this model in the future."
88
+ ),
89
+ allow_flagging="never",
90
+ css=css
91
+ )
92
+
93
+ interface.launch(share=True)