Supprimer backend_huggingfacemodel.py
Browse files- backend_huggingfacemodel.py +0 -186
backend_huggingfacemodel.py
DELETED
@@ -1,186 +0,0 @@
|
|
1 |
-
from fastapi import FastAPI, HTTPException
|
2 |
-
from fastapi.responses import JSONResponse
|
3 |
-
from pydantic import BaseModel
|
4 |
-
from contextlib import asynccontextmanager
|
5 |
-
import torch
|
6 |
-
import numpy as np
|
7 |
-
from PIL import Image
|
8 |
-
import cv2
|
9 |
-
import base64
|
10 |
-
import io
|
11 |
-
import logging
|
12 |
-
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
|
13 |
-
import rasterio
|
14 |
-
import matplotlib.pyplot as plt
|
15 |
-
|
16 |
-
#app = FastAPI() #delete when charging the model
|
17 |
-
|
18 |
-
# Configure logging
|
19 |
-
logging.basicConfig(level=logging.INFO)
|
20 |
-
|
21 |
-
# Chemin vers le dossier contenant les fichiers de votre modèle
|
22 |
-
MODEL_DIR = "AndreaLeylavergne/segformer_B0_B5"
|
23 |
-
|
24 |
-
@asynccontextmanager
|
25 |
-
async def lifespan(app: FastAPI):
|
26 |
-
global model, processor
|
27 |
-
processor = SegformerImageProcessor.from_pretrained(MODEL_DIR)
|
28 |
-
model = SegformerForSemanticSegmentation.from_pretrained(MODEL_DIR)
|
29 |
-
logging.info("Model loaded successfully")
|
30 |
-
|
31 |
-
yield
|
32 |
-
|
33 |
-
app = FastAPI(lifespan=lifespan)
|
34 |
-
|
35 |
-
class ImageID(BaseModel):
|
36 |
-
image_id: str
|
37 |
-
|
38 |
-
|
39 |
-
# Paths of images and annotated masks
|
40 |
-
images_paths = {
|
41 |
-
"image1": "./dataset/images_prepped/val/0000FT_000294.png",
|
42 |
-
"image2": "./dataset/images_prepped/val/0000FT_000576.png",
|
43 |
-
"image3": "./dataset/images_prepped/val/0000FT_001016.png"
|
44 |
-
}
|
45 |
-
|
46 |
-
annotated_masks_paths = {
|
47 |
-
"image1": "./dataset/annotations_prepped_grouped/val/0000FT_000294.png",
|
48 |
-
"image2": "./dataset/annotations_prepped_grouped/val/0000FT_000576.png",
|
49 |
-
"image3": "./dataset/annotations_prepped_grouped/val/0000FT_001016.png"
|
50 |
-
}
|
51 |
-
|
52 |
-
# Utility functions
|
53 |
-
def read_img(raster_file: str) -> np.ndarray:
|
54 |
-
print(f"Reading image from {raster_file}")
|
55 |
-
with rasterio.open(raster_file) as src_img:
|
56 |
-
rgb = src_img.read([1, 2, 3]).transpose(1, 2, 0)
|
57 |
-
rgb = rgb.astype(np.float32)
|
58 |
-
return rgb
|
59 |
-
|
60 |
-
def read_msk(raster_file: str) -> np.ndarray:
|
61 |
-
print(f"Reading mask from {raster_file}")
|
62 |
-
with rasterio.open(raster_file) as src_msk:
|
63 |
-
array = src_msk.read(1)
|
64 |
-
array = np.squeeze(array)
|
65 |
-
return array
|
66 |
-
|
67 |
-
def map_colors(mask):
|
68 |
-
# Create a color map with high contrast colors
|
69 |
-
color_map = np.array([
|
70 |
-
[0, 0, 0], # Class 0 - background
|
71 |
-
[128, 0, 0], # Class 1 - red
|
72 |
-
[0, 128, 0], # Class 2 - green
|
73 |
-
[128, 128, 0], # Class 3 - yellow
|
74 |
-
[0, 0, 128], # Class 4 - blue
|
75 |
-
[128, 0, 128], # Class 5 - magenta
|
76 |
-
[0, 128, 128], # Class 6 - cyan
|
77 |
-
[128, 128, 128], # Class 7 - white
|
78 |
-
# Add more colors if you have more classes
|
79 |
-
])
|
80 |
-
|
81 |
-
# Apply the color map to the mask
|
82 |
-
color_mask = color_map[mask]
|
83 |
-
|
84 |
-
return color_mask
|
85 |
-
|
86 |
-
def predict(model, processor, image_path):
|
87 |
-
image = read_img(image_path)
|
88 |
-
inputs = processor(images=image, return_tensors="pt")
|
89 |
-
|
90 |
-
with torch.no_grad():
|
91 |
-
outputs = model(**inputs)
|
92 |
-
logits = outputs.logits
|
93 |
-
upsampled_logits = torch.nn.functional.interpolate(
|
94 |
-
logits,
|
95 |
-
size=image.shape[:2],
|
96 |
-
mode='bilinear',
|
97 |
-
align_corners=False
|
98 |
-
)
|
99 |
-
pred_seg = upsampled_logits.argmax(dim=1).squeeze().cpu().numpy()
|
100 |
-
|
101 |
-
return pred_seg, image
|
102 |
-
|
103 |
-
def calculate_iou(array1, array2):
|
104 |
-
assert array1.shape == array2.shape, "Arrays must have the same shape"
|
105 |
-
|
106 |
-
array1_binary = (array1 > 0).astype(int)
|
107 |
-
array2_binary = (array2 > 0).astype(int)
|
108 |
-
|
109 |
-
intersection = np.sum(array1_binary * array2_binary)
|
110 |
-
union = np.sum(array1_binary) + np.sum(array2_binary) - intersection
|
111 |
-
|
112 |
-
return intersection / union if union > 0 else 0
|
113 |
-
|
114 |
-
@app.post("/predict/")
|
115 |
-
async def predict_mask(data: ImageID):
|
116 |
-
image_id = data.image_id
|
117 |
-
if image_id not in images_paths:
|
118 |
-
raise HTTPException(status_code=404, detail="Image ID not found")
|
119 |
-
|
120 |
-
image_path = images_paths[image_id]
|
121 |
-
#pr, overlay_image = predict(model, processor, image_path)
|
122 |
-
pred_seg, image = predict(model, processor, image_path)
|
123 |
-
|
124 |
-
# Convert the predicted mask to a color image for better visualization
|
125 |
-
color_pred_mask = map_colors(pred_seg)
|
126 |
-
color_pred_mask_image = Image.fromarray(color_pred_mask.astype(np.uint8))
|
127 |
-
|
128 |
-
annotated_mask_path = annotated_masks_paths[image_id]
|
129 |
-
annotated_mask_image = cv2.imread(annotated_mask_path, cv2.IMREAD_GRAYSCALE)
|
130 |
-
annotated_mask_image = (annotated_mask_image / annotated_mask_image.max()) * 255
|
131 |
-
annotated_mask_image = Image.fromarray(annotated_mask_image.astype(np.uint8)).resize((256, 128))
|
132 |
-
|
133 |
-
color_pred_mask_stream = io.BytesIO()
|
134 |
-
annotated_mask_stream = io.BytesIO()
|
135 |
-
|
136 |
-
color_pred_mask_image.save(color_pred_mask_stream, format='PNG')
|
137 |
-
annotated_mask_image.save(annotated_mask_stream, format='PNG')
|
138 |
-
|
139 |
-
color_pred_mask_stream.seek(0)
|
140 |
-
annotated_mask_stream.seek(0)
|
141 |
-
|
142 |
-
color_pred_mask_data_url = base64.b64encode(color_pred_mask_stream.read()).decode('utf8')
|
143 |
-
annotated_data_url = base64.b64encode(annotated_mask_stream.read()).decode('utf8')
|
144 |
-
|
145 |
-
return JSONResponse(content={
|
146 |
-
"annotated_mask": "data:image/png;base64," + annotated_data_url,
|
147 |
-
"predicted_mask": "data:image/png;base64," + color_pred_mask_data_url
|
148 |
-
})
|
149 |
-
|
150 |
-
@app.post("/evaluate/")
|
151 |
-
async def evaluate_masks(data: dict):
|
152 |
-
annotated_mask_data = data['annotated_mask']
|
153 |
-
predicted_mask_data = data['predicted_mask']
|
154 |
-
|
155 |
-
try:
|
156 |
-
annotated_mask = Image.open(io.BytesIO(base64.b64decode(annotated_mask_data.split(',')[1]))).resize((256, 128))
|
157 |
-
predicted_mask = Image.open(io.BytesIO(base64.b64decode(predicted_mask_data.split(',')[1]))).resize((256, 128))
|
158 |
-
predicted_mask = predicted_mask.convert("L") #Covert "L" converts en grayscale the mask before evaluating. Important for obtainig IoU
|
159 |
-
except Exception as e:
|
160 |
-
raise HTTPException(status_code=400, detail=f"Error decoding images: {str(e)}")
|
161 |
-
|
162 |
-
annotated_mask_array = np.array(annotated_mask)
|
163 |
-
predicted_mask_array = np.array(predicted_mask)
|
164 |
-
iou_score = calculate_iou(annotated_mask_array, predicted_mask_array)
|
165 |
-
|
166 |
-
annotated_mask_stream = io.BytesIO()
|
167 |
-
mask_image_stream = io.BytesIO()
|
168 |
-
|
169 |
-
annotated_mask.save(annotated_mask_stream, format='PNG')
|
170 |
-
predicted_mask.save(mask_image_stream, format='PNG')
|
171 |
-
|
172 |
-
annotated_mask_stream.seek(0)
|
173 |
-
mask_image_stream.seek(0)
|
174 |
-
|
175 |
-
annotated_data_url = base64.b64encode(annotated_mask_stream.read()).decode('utf8')
|
176 |
-
predicted_data_url = base64.b64encode(mask_image_stream.read()).decode('utf8')
|
177 |
-
|
178 |
-
return JSONResponse(content={
|
179 |
-
"iou_score": iou_score,
|
180 |
-
"annotated_mask": "data:image/png;base64," + annotated_data_url,
|
181 |
-
"predicted_mask": "data:image/png;base64," + predicted_data_url
|
182 |
-
})
|
183 |
-
|
184 |
-
@app.get("/")
|
185 |
-
def root():
|
186 |
-
return {"Greeting": "Hello, World!"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|