AndreaLeylavergne commited on
Commit
704c92e
·
verified ·
1 Parent(s): 4830701

Supprimer backend_huggingfacemodel.py

Browse files
Files changed (1) hide show
  1. backend_huggingfacemodel.py +0 -186
backend_huggingfacemodel.py DELETED
@@ -1,186 +0,0 @@
1
- from fastapi import FastAPI, HTTPException
2
- from fastapi.responses import JSONResponse
3
- from pydantic import BaseModel
4
- from contextlib import asynccontextmanager
5
- import torch
6
- import numpy as np
7
- from PIL import Image
8
- import cv2
9
- import base64
10
- import io
11
- import logging
12
- from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
13
- import rasterio
14
- import matplotlib.pyplot as plt
15
-
16
- #app = FastAPI() #delete when charging the model
17
-
18
- # Configure logging
19
- logging.basicConfig(level=logging.INFO)
20
-
21
- # Chemin vers le dossier contenant les fichiers de votre modèle
22
- MODEL_DIR = "AndreaLeylavergne/segformer_B0_B5"
23
-
24
- @asynccontextmanager
25
- async def lifespan(app: FastAPI):
26
- global model, processor
27
- processor = SegformerImageProcessor.from_pretrained(MODEL_DIR)
28
- model = SegformerForSemanticSegmentation.from_pretrained(MODEL_DIR)
29
- logging.info("Model loaded successfully")
30
-
31
- yield
32
-
33
- app = FastAPI(lifespan=lifespan)
34
-
35
- class ImageID(BaseModel):
36
- image_id: str
37
-
38
-
39
- # Paths of images and annotated masks
40
- images_paths = {
41
- "image1": "./dataset/images_prepped/val/0000FT_000294.png",
42
- "image2": "./dataset/images_prepped/val/0000FT_000576.png",
43
- "image3": "./dataset/images_prepped/val/0000FT_001016.png"
44
- }
45
-
46
- annotated_masks_paths = {
47
- "image1": "./dataset/annotations_prepped_grouped/val/0000FT_000294.png",
48
- "image2": "./dataset/annotations_prepped_grouped/val/0000FT_000576.png",
49
- "image3": "./dataset/annotations_prepped_grouped/val/0000FT_001016.png"
50
- }
51
-
52
- # Utility functions
53
- def read_img(raster_file: str) -> np.ndarray:
54
- print(f"Reading image from {raster_file}")
55
- with rasterio.open(raster_file) as src_img:
56
- rgb = src_img.read([1, 2, 3]).transpose(1, 2, 0)
57
- rgb = rgb.astype(np.float32)
58
- return rgb
59
-
60
- def read_msk(raster_file: str) -> np.ndarray:
61
- print(f"Reading mask from {raster_file}")
62
- with rasterio.open(raster_file) as src_msk:
63
- array = src_msk.read(1)
64
- array = np.squeeze(array)
65
- return array
66
-
67
- def map_colors(mask):
68
- # Create a color map with high contrast colors
69
- color_map = np.array([
70
- [0, 0, 0], # Class 0 - background
71
- [128, 0, 0], # Class 1 - red
72
- [0, 128, 0], # Class 2 - green
73
- [128, 128, 0], # Class 3 - yellow
74
- [0, 0, 128], # Class 4 - blue
75
- [128, 0, 128], # Class 5 - magenta
76
- [0, 128, 128], # Class 6 - cyan
77
- [128, 128, 128], # Class 7 - white
78
- # Add more colors if you have more classes
79
- ])
80
-
81
- # Apply the color map to the mask
82
- color_mask = color_map[mask]
83
-
84
- return color_mask
85
-
86
- def predict(model, processor, image_path):
87
- image = read_img(image_path)
88
- inputs = processor(images=image, return_tensors="pt")
89
-
90
- with torch.no_grad():
91
- outputs = model(**inputs)
92
- logits = outputs.logits
93
- upsampled_logits = torch.nn.functional.interpolate(
94
- logits,
95
- size=image.shape[:2],
96
- mode='bilinear',
97
- align_corners=False
98
- )
99
- pred_seg = upsampled_logits.argmax(dim=1).squeeze().cpu().numpy()
100
-
101
- return pred_seg, image
102
-
103
- def calculate_iou(array1, array2):
104
- assert array1.shape == array2.shape, "Arrays must have the same shape"
105
-
106
- array1_binary = (array1 > 0).astype(int)
107
- array2_binary = (array2 > 0).astype(int)
108
-
109
- intersection = np.sum(array1_binary * array2_binary)
110
- union = np.sum(array1_binary) + np.sum(array2_binary) - intersection
111
-
112
- return intersection / union if union > 0 else 0
113
-
114
- @app.post("/predict/")
115
- async def predict_mask(data: ImageID):
116
- image_id = data.image_id
117
- if image_id not in images_paths:
118
- raise HTTPException(status_code=404, detail="Image ID not found")
119
-
120
- image_path = images_paths[image_id]
121
- #pr, overlay_image = predict(model, processor, image_path)
122
- pred_seg, image = predict(model, processor, image_path)
123
-
124
- # Convert the predicted mask to a color image for better visualization
125
- color_pred_mask = map_colors(pred_seg)
126
- color_pred_mask_image = Image.fromarray(color_pred_mask.astype(np.uint8))
127
-
128
- annotated_mask_path = annotated_masks_paths[image_id]
129
- annotated_mask_image = cv2.imread(annotated_mask_path, cv2.IMREAD_GRAYSCALE)
130
- annotated_mask_image = (annotated_mask_image / annotated_mask_image.max()) * 255
131
- annotated_mask_image = Image.fromarray(annotated_mask_image.astype(np.uint8)).resize((256, 128))
132
-
133
- color_pred_mask_stream = io.BytesIO()
134
- annotated_mask_stream = io.BytesIO()
135
-
136
- color_pred_mask_image.save(color_pred_mask_stream, format='PNG')
137
- annotated_mask_image.save(annotated_mask_stream, format='PNG')
138
-
139
- color_pred_mask_stream.seek(0)
140
- annotated_mask_stream.seek(0)
141
-
142
- color_pred_mask_data_url = base64.b64encode(color_pred_mask_stream.read()).decode('utf8')
143
- annotated_data_url = base64.b64encode(annotated_mask_stream.read()).decode('utf8')
144
-
145
- return JSONResponse(content={
146
- "annotated_mask": "data:image/png;base64," + annotated_data_url,
147
- "predicted_mask": "data:image/png;base64," + color_pred_mask_data_url
148
- })
149
-
150
- @app.post("/evaluate/")
151
- async def evaluate_masks(data: dict):
152
- annotated_mask_data = data['annotated_mask']
153
- predicted_mask_data = data['predicted_mask']
154
-
155
- try:
156
- annotated_mask = Image.open(io.BytesIO(base64.b64decode(annotated_mask_data.split(',')[1]))).resize((256, 128))
157
- predicted_mask = Image.open(io.BytesIO(base64.b64decode(predicted_mask_data.split(',')[1]))).resize((256, 128))
158
- predicted_mask = predicted_mask.convert("L") #Covert "L" converts en grayscale the mask before evaluating. Important for obtainig IoU
159
- except Exception as e:
160
- raise HTTPException(status_code=400, detail=f"Error decoding images: {str(e)}")
161
-
162
- annotated_mask_array = np.array(annotated_mask)
163
- predicted_mask_array = np.array(predicted_mask)
164
- iou_score = calculate_iou(annotated_mask_array, predicted_mask_array)
165
-
166
- annotated_mask_stream = io.BytesIO()
167
- mask_image_stream = io.BytesIO()
168
-
169
- annotated_mask.save(annotated_mask_stream, format='PNG')
170
- predicted_mask.save(mask_image_stream, format='PNG')
171
-
172
- annotated_mask_stream.seek(0)
173
- mask_image_stream.seek(0)
174
-
175
- annotated_data_url = base64.b64encode(annotated_mask_stream.read()).decode('utf8')
176
- predicted_data_url = base64.b64encode(mask_image_stream.read()).decode('utf8')
177
-
178
- return JSONResponse(content={
179
- "iou_score": iou_score,
180
- "annotated_mask": "data:image/png;base64," + annotated_data_url,
181
- "predicted_mask": "data:image/png;base64," + predicted_data_url
182
- })
183
-
184
- @app.get("/")
185
- def root():
186
- return {"Greeting": "Hello, World!"}