t4n15hq commited on
Commit
94fe463
·
1 Parent(s): 36b36c1

Initial working FastAPI deployment

Browse files
Files changed (3) hide show
  1. Dockerfile +14 -0
  2. main.py +198 -0
  3. requirements.txt +6 -0
Dockerfile ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9
2
+
3
+ RUN useradd -m -u 1000 user
4
+ USER user
5
+
6
+ ENV PATH="/home/user/.local/bin:$PATH"
7
+ WORKDIR /app
8
+
9
+ COPY --chown=user . .
10
+
11
+ RUN pip install --no-cache-dir --upgrade pip && \
12
+ pip install --no-cache-dir -r requirements.txt
13
+
14
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
main.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile, Form
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from typing import Optional
4
+ import numpy as np
5
+ import tensorflow as tf
6
+ from PIL import Image
7
+ import io
8
+ import keras
9
+ import os
10
+ import gdown
11
+ from keras.saving import register_keras_serializable
12
+
13
+ # Enable unsafe deserialization for custom Lambda layers
14
+ keras.config.enable_unsafe_deserialization()
15
+
16
+ # --- CUSTOM FUNCTION ---
17
+ @register_keras_serializable(package="Custom", name="custom_max_pool")
18
+ def custom_max_pool(x):
19
+ return tf.reduce_max(x, axis=[1, 2], keepdims=True)
20
+
21
+ # --- CONFIG ---
22
+ IMG_SIZE = 224
23
+ NUM_CLASSES = 23
24
+
25
+ MODEL_PATHS = {
26
+ "EfficientNetB3": "saved_models/EfficientNetB3_recovered.keras",
27
+ "ResNet50": "saved_models/ResNet50_recovered.keras",
28
+ "MobileNetV2": "saved_models/MobileNetV2_recovered.keras",
29
+ "DenseNet121": "saved_models/DenseNet121_recovered.keras"
30
+ }
31
+
32
+ MODEL_URLS = {
33
+ "EfficientNetB3": "https://drive.google.com/uc?id=1jP4-HoFFbGIFugqgRpVVt0V3LhOoKVkY",
34
+ "ResNet50": "https://drive.google.com/uc?id=1yv4duVkGHTyLEpw92CCJcUy9Y1VxC6Ec",
35
+ "MobileNetV2": "https://drive.google.com/uc?id=1fJtogp6fH7F2Wa2YvN_KTklgK2-ufqMN",
36
+ "DenseNet121": "https://drive.google.com/uc?id=1lJ0nlTP7cMTglEM6XIaTvAEZHVJ4dsN8"
37
+ }
38
+
39
+ MODEL_WEIGHTS = {
40
+ "EfficientNetB3": 0.260,
41
+ "ResNet50": 0.256,
42
+ "MobileNetV2": 0.222,
43
+ "DenseNet121": 0.261
44
+ }
45
+
46
+ PREPROCESS_FUNCS = {
47
+ "EfficientNetB3": tf.keras.applications.efficientnet.preprocess_input,
48
+ "ResNet50": tf.keras.applications.resnet.preprocess_input,
49
+ "MobileNetV2": tf.keras.applications.mobilenet_v2.preprocess_input,
50
+ "DenseNet121": tf.keras.applications.densenet.preprocess_input
51
+ }
52
+
53
+ # --- FASTAPI SETUP ---
54
+ app = FastAPI()
55
+
56
+ app.add_middleware(
57
+ CORSMiddleware,
58
+ allow_origins=["*"], # Restrict origins in production
59
+ allow_credentials=True,
60
+ allow_methods=["*"],
61
+ allow_headers=["*"],
62
+ )
63
+
64
+ # --- HEALTH CHECK ENDPOINT ---
65
+ @app.get("/")
66
+ def root():
67
+ return {"status": "🩺 App is running."}
68
+
69
+ # --- GLOBAL MODELS DICT ---
70
+ models = {}
71
+
72
+ # --- LOAD MODELS AT STARTUP ---
73
+ @app.on_event("startup")
74
+ def load_models():
75
+ global models
76
+ print("Loading models...")
77
+ os.makedirs("saved_models", exist_ok=True)
78
+
79
+ for name, path in MODEL_PATHS.items():
80
+ if not os.path.exists(path):
81
+ print(f"Downloading {name} from Google Drive...")
82
+ gdown.download(MODEL_URLS[name], path, quiet=False)
83
+
84
+ print(f"Loading {name}...")
85
+ models[name] = tf.keras.models.load_model(path)
86
+
87
+ print("✅ All models loaded.")
88
+
89
+ # --- UTILS ---
90
+ def read_imagefile(file) -> Image.Image:
91
+ image = Image.open(io.BytesIO(file))
92
+ return image.convert("RGB")
93
+
94
+ # --- ENSEMBLE PREDICTION ---
95
+ def predict_ensemble(image):
96
+ image = image.resize((IMG_SIZE, IMG_SIZE))
97
+ image_array = np.array(image)
98
+ ensemble_pred = np.zeros((NUM_CLASSES,))
99
+ for name, model in models.items():
100
+ preproc = PREPROCESS_FUNCS[name]
101
+ img_proc = preproc(image_array.copy())
102
+ img_proc = np.expand_dims(img_proc, axis=0)
103
+ pred = model.predict(img_proc, verbose=0)[0]
104
+ ensemble_pred += pred * MODEL_WEIGHTS[name]
105
+ return ensemble_pred
106
+
107
+ # --- METADATA-BASED ADJUSTMENT ---
108
+ def adjust_with_metadata(predictions, metadata):
109
+ adjusted = []
110
+ for pred in predictions:
111
+ label = pred["label"]
112
+ score = pred["confidence"]
113
+
114
+ try:
115
+ age = int(metadata["age"])
116
+ condition = metadata.get("condition", "").lower()
117
+ skin_type = metadata.get("skin_type", "").lower()
118
+
119
+ if "acne" in label.lower() and age > 40:
120
+ score *= 0.6
121
+ if "eczema" in label.lower() and skin_type == "dry":
122
+ score *= 1.2
123
+ if "warts" in label.lower() and age < 12:
124
+ score *= 1.3
125
+ if "fungal" in label.lower() and "itchy" in condition:
126
+ score *= 1.2
127
+ except Exception as e:
128
+ print("Metadata adjustment error:", e)
129
+
130
+ adjusted.append({"label": label, "confidence": score})
131
+
132
+ adjusted = sorted(adjusted, key=lambda x: x["confidence"], reverse=True)
133
+ return adjusted[:3]
134
+
135
+ # --- PREDICT ENDPOINT ---
136
+ @app.post("/predict/")
137
+ async def predict(
138
+ file: UploadFile = File(...),
139
+ age: int = Form(...),
140
+ race: str = Form(...),
141
+ gender: str = Form(...),
142
+ skin_color: str = Form(...),
143
+ skin_type: str = Form(...),
144
+ condition_description: str = Form(...)
145
+ ):
146
+ image = read_imagefile(await file.read())
147
+ prediction = predict_ensemble(image)
148
+
149
+ class_labels = [
150
+ "Acne and Rosacea Photos",
151
+ "Actinic Keratosis Basal Cell Carcinoma and other Malignant Lesions",
152
+ "Atopic Dermatitis Photos",
153
+ "Bullous Disease Photos",
154
+ "Cellulitis Impetigo and other Bacterial Infections",
155
+ "Eczema Photos",
156
+ "Exanthems and Drug Eruptions",
157
+ "Hair Loss Photos Alopecia and other Hair Diseases",
158
+ "Herpes HPV and other STDs Photos",
159
+ "Light Diseases and Disorders of Pigmentation",
160
+ "Lupus and other Connective Tissue diseases",
161
+ "Melanoma Skin Cancer Nevi and Moles",
162
+ "Nail Fungus and other Nail Disease",
163
+ "Poison Ivy Photos and other Contact Dermatitis",
164
+ "Psoriasis pictures Lichen Planus and related diseases",
165
+ "Scabies Lyme Disease and other Infestations and Bites",
166
+ "Seborrheic Keratoses and other Benign Tumors",
167
+ "Systemic Disease",
168
+ "Tinea Ringworm Candidiasis and other Fungal Infections",
169
+ "Urticaria Hives",
170
+ "Vascular Tumors",
171
+ "Vasculitis Photos",
172
+ "Warts Molluscum and other Viral Infections"
173
+ ]
174
+
175
+ top3_indices = prediction.argsort()[-3:][::-1]
176
+ top_preds = [
177
+ {
178
+ "label": class_labels[i],
179
+ "confidence": float(prediction[i])
180
+ }
181
+ for i in top3_indices
182
+ ]
183
+
184
+ metadata = {
185
+ "age": age,
186
+ "race": race,
187
+ "gender": gender,
188
+ "skin_color": skin_color,
189
+ "skin_type": skin_type,
190
+ "condition": condition_description
191
+ }
192
+
193
+ adjusted_preds = adjust_with_metadata(top_preds, metadata)
194
+
195
+ return {
196
+ "prediction": adjusted_preds,
197
+ "metadata": metadata
198
+ }
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn[standard]
3
+ tensorflow
4
+ pillow
5
+ numpy
6
+ gdown