t4n15hq commited on
Commit
289d09a
·
1 Parent(s): 682f3ef

adjust metadata logic

Browse files
Files changed (1) hide show
  1. main.py +31 -89
main.py CHANGED
@@ -1,6 +1,5 @@
1
  from fastapi import FastAPI, File, UploadFile, Form
2
  from fastapi.middleware.cors import CORSMiddleware
3
- from typing import Optional
4
  import numpy as np
5
  import tensorflow as tf
6
  from PIL import Image
@@ -21,7 +20,7 @@ def custom_max_pool(x):
21
  # --- CONFIG ---
22
  IMG_SIZE = 224
23
  NUM_CLASSES = 23
24
- USE_METADATA_ADJUSTMENT = True # Toggle this to False if you want raw model output
25
 
26
  MODEL_PATHS = {
27
  "EfficientNetB3": "saved_models/EfficientNetB3_recovered.keras",
@@ -62,31 +61,22 @@ app.add_middleware(
62
  allow_headers=["*"],
63
  )
64
 
65
- # --- HEALTH CHECK ENDPOINT ---
66
  @app.get("/")
67
  def root():
68
  return {"status": "🩺 App is running."}
69
 
70
- # --- LOAD MODELS AT STARTUP ---
71
  models = {}
72
 
73
  @app.on_event("startup")
74
  def load_models():
75
  global models
76
- print("Loading models...")
77
  os.makedirs("saved_models", exist_ok=True)
78
-
79
  for name, path in MODEL_PATHS.items():
80
  if not os.path.exists(path):
81
- print(f"Downloading {name} from Google Drive...")
82
  gdown.download(MODEL_URLS[name], path, quiet=False)
83
-
84
- print(f"Loading {name}...")
85
  models[name] = tf.keras.models.load_model(path)
86
 
87
- print("✅ All models loaded.")
88
-
89
- # --- IMAGE UTILS ---
90
  def read_imagefile(file) -> Image.Image:
91
  image = Image.open(io.BytesIO(file))
92
  return image.convert("RGB")
@@ -112,80 +102,31 @@ def adjust_with_metadata(predictions, metadata):
112
  race = metadata.get("race", "").lower()
113
 
114
  rules = [
115
- # 0. Acne and Rosacea
116
- {"keyword": "acne", "condition": age < 25 or "oily" in skin_type, "factor": 1.3},
117
- {"keyword": "rosacea", "condition": "redness" in condition or "cheek" in condition, "factor": 1.3},
118
-
119
- # 1. AK/Basal Cell/Malignant Lesions
120
- {"keyword": "keratosis", "condition": age > 60 or "lesion" in condition, "factor": 1.3},
121
- {"keyword": "carcinoma", "condition": "bleeding" in condition or "non-healing" in condition, "factor": 1.3},
122
-
123
- # 2. Atopic Dermatitis
124
- {"keyword": "atopic", "condition": age < 10 or "itchy" in condition, "factor": 1.3},
125
-
126
- # 3. Bullous Disease
127
- {"keyword": "bullous", "condition": "blisters" in condition or age > 60, "factor": 1.3},
128
-
129
- # 4. Cellulitis/Impetigo/Bacterial
130
- {"keyword": "impetigo", "condition": "crust" in condition or "yellow" in condition, "factor": 1.3},
131
- {"keyword": "cellulitis", "condition": "swelling" in condition or "fever" in condition, "factor": 1.3},
132
-
133
- # 5. Eczema
134
- {"keyword": "eczema", "condition": "dry" in skin_type or "itchy" in condition, "factor": 1.3},
135
-
136
- # 6. Exanthems/Drug Eruption
137
- {"keyword": "drug eruption", "condition": "medication" in condition or "rash" in condition, "factor": 1.3},
138
-
139
- # 7. Hair Loss
140
- {"keyword": "hair loss", "condition": gender == "male" and age > 30, "factor": 1.3},
141
-
142
- # 8. Herpes/HPV/STDs
143
- {"keyword": "herpes", "condition": "painful" in condition or "genital" in condition, "factor": 1.3},
144
-
145
- # 9. Pigmentation Disorders
146
- {"keyword": "pigmentation", "condition": skin_type in ["dark", "brown"], "factor": 1.3},
147
-
148
- # 10. Lupus/Connective Tissue
149
- {"keyword": "lupus", "condition": "malar" in condition or gender == "female", "factor": 1.3},
150
- {"keyword": "connective", "condition": "joint" in condition or "fatigue" in condition, "factor": 1.3},
151
-
152
- # 11. Melanoma/Skin Cancer/Nevi
153
  {"keyword": "melanoma", "condition": skin_type == "light" and age > 50, "factor": 1.3},
154
-
155
- # 12. Nail Fungus
156
- {"keyword": "nail fungus", "condition": "thickened" in condition or age > 60, "factor": 1.3},
157
-
158
- # 13. Poison Ivy / Contact Dermatitis
159
- {"keyword": "poison ivy", "condition": "itchy" in condition or "blister" in condition, "factor": 1.3},
160
- {"keyword": "contact dermatitis", "condition": "burning" in condition or "irritant" in condition, "factor": 1.3},
161
-
162
- # 14. Psoriasis / Lichen Planus
163
- {"keyword": "psoriasis", "condition": "scaly" in condition or "elbow" in condition, "factor": 1.3},
164
-
165
- # 15. Scabies/Lyme/Bites
166
- {"keyword": "scabies", "condition": "night" in condition or "burrow" in condition, "factor": 1.3},
167
-
168
- # 16. Seborrheic Keratoses/Benign Tumors
169
- {"keyword": "seborrheic", "condition": age > 60 or "waxy" in condition, "factor": 1.3},
170
-
171
- # 17. Systemic Disease
172
- {"keyword": "systemic", "condition": "fatigue" in condition or "multiple" in condition, "factor": 1.3},
173
-
174
- # 18. Fungal (Tinea, Ringworm, etc.)
175
- {"keyword": "fungal", "condition": "itchy" in condition or "ring" in condition, "factor": 1.3},
176
-
177
- # 19. Urticaria (Hives)
178
- {"keyword": "urticaria", "condition": "hives" in condition or "allergy" in condition, "factor": 1.3},
179
-
180
- # 20. Vascular Tumors
181
- {"keyword": "vascular tumor", "condition": age < 5 or "birthmark" in condition, "factor": 1.3},
182
-
183
- # 21. Vasculitis
184
- {"keyword": "vasculitis", "condition": "purpura" in condition or "painful" in condition, "factor": 1.3},
185
-
186
- # 22. Warts/Molluscum/Viral
187
- {"keyword": "warts", "condition": "raised" in condition or "cauliflower" in condition, "factor": 1.3},
188
- {"keyword": "molluscum", "condition": "umbilicated" in condition or age < 12, "factor": 1.3},
189
  ]
190
 
191
  adjusted = []
@@ -199,7 +140,6 @@ def adjust_with_metadata(predictions, metadata):
199
 
200
  return sorted(adjusted, key=lambda x: x["confidence"], reverse=True)[:3]
201
 
202
-
203
  # --- PREDICT ENDPOINT ---
204
  @app.post("/predict/")
205
  async def predict(
@@ -241,7 +181,10 @@ async def predict(
241
  ]
242
 
243
  top3_indices = prediction.argsort()[-3:][::-1]
244
- top_preds = [{"label": class_labels[i], "confidence": float(prediction[i])} for i in top3_indices]
 
 
 
245
 
246
  metadata = {
247
  "age": age,
@@ -256,6 +199,5 @@ async def predict(
256
 
257
  return {
258
  "prediction": final_preds,
259
- "metadata": metadata,
260
- "metadata_adjustment_applied": USE_METADATA_ADJUSTMENT
261
  }
 
1
  from fastapi import FastAPI, File, UploadFile, Form
2
  from fastapi.middleware.cors import CORSMiddleware
 
3
  import numpy as np
4
  import tensorflow as tf
5
  from PIL import Image
 
20
  # --- CONFIG ---
21
  IMG_SIZE = 224
22
  NUM_CLASSES = 23
23
+ USE_METADATA_ADJUSTMENT = True # Toggle on/off
24
 
25
  MODEL_PATHS = {
26
  "EfficientNetB3": "saved_models/EfficientNetB3_recovered.keras",
 
61
  allow_headers=["*"],
62
  )
63
 
 
64
  @app.get("/")
65
  def root():
66
  return {"status": "🩺 App is running."}
67
 
 
68
  models = {}
69
 
70
  @app.on_event("startup")
71
  def load_models():
72
  global models
 
73
  os.makedirs("saved_models", exist_ok=True)
 
74
  for name, path in MODEL_PATHS.items():
75
  if not os.path.exists(path):
 
76
  gdown.download(MODEL_URLS[name], path, quiet=False)
 
 
77
  models[name] = tf.keras.models.load_model(path)
78
 
79
+ # --- UTILS ---
 
 
80
  def read_imagefile(file) -> Image.Image:
81
  image = Image.open(io.BytesIO(file))
82
  return image.convert("RGB")
 
102
  race = metadata.get("race", "").lower()
103
 
104
  rules = [
105
+ {"keyword": "acne", "condition": age > 40, "factor": 0.6},
106
+ {"keyword": "acne", "condition": age < 25, "factor": 1.2},
107
+ {"keyword": "eczema", "condition": skin_type == "dry", "factor": 1.2},
108
+ {"keyword": "eczema", "condition": skin_type == "oily", "factor": 0.8},
109
+ {"keyword": "warts", "condition": age < 12, "factor": 1.3},
110
+ {"keyword": "fungal", "condition": "itchy" in condition, "factor": 1.2},
111
+ {"keyword": "hair loss", "condition": gender == "male" and age > 40, "factor": 1.3},
112
+ {"keyword": "lupus", "condition": gender == "female", "factor": 1.2},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  {"keyword": "melanoma", "condition": skin_type == "light" and age > 50, "factor": 1.3},
114
+ {"keyword": "nail fungus", "condition": age > 60, "factor": 1.3},
115
+ {"keyword": "psoriasis", "condition": "flaky" in condition or "dry" in skin_type, "factor": 1.2},
116
+ {"keyword": "systemic", "condition": "fatigue" in condition or "pain" in condition, "factor": 1.2},
117
+ {"keyword": "urticaria", "condition": "allergy" in condition or "hives" in condition, "factor": 1.3},
118
+ {"keyword": "contact dermatitis", "condition": "red" in condition or "burning" in condition, "factor": 1.2},
119
+ {"keyword": "seborrheic", "condition": age > 60, "factor": 1.2},
120
+ {"keyword": "bullous", "condition": age > 60, "factor": 1.3},
121
+ {"keyword": "vasculitis", "condition": "joint" in condition or "swelling" in condition, "factor": 1.2},
122
+ {"keyword": "atopic dermatitis", "condition": age < 10, "factor": 1.2},
123
+ {"keyword": "pigmentation", "condition": skin_type == "dark", "factor": 1.3},
124
+ {"keyword": "hpv", "condition": age > 18, "factor": 1.2},
125
+ {"keyword": "vascular tumors", "condition": age < 5, "factor": 1.3},
126
+ {"keyword": "poison ivy", "condition": "rash" in condition or "camping" in condition, "factor": 1.3},
127
+ {"keyword": "bacterial", "condition": "pus" in condition or "fever" in condition, "factor": 1.3},
128
+ {"keyword": "drug eruption", "condition": "medication" in condition or "rash" in condition, "factor": 1.3},
129
+ {"keyword": "connective tissue", "condition": "joint" in condition or "fatigue" in condition, "factor": 1.3},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  ]
131
 
132
  adjusted = []
 
140
 
141
  return sorted(adjusted, key=lambda x: x["confidence"], reverse=True)[:3]
142
 
 
143
  # --- PREDICT ENDPOINT ---
144
  @app.post("/predict/")
145
  async def predict(
 
181
  ]
182
 
183
  top3_indices = prediction.argsort()[-3:][::-1]
184
+ top_preds = [
185
+ {"label": class_labels[i], "confidence": float(prediction[i])}
186
+ for i in top3_indices
187
+ ]
188
 
189
  metadata = {
190
  "age": age,
 
199
 
200
  return {
201
  "prediction": final_preds,
202
+ "metadata": metadata
 
203
  }