Spaces:
Running
Running
Commit
·
5b48288
1
Parent(s):
ce6cd26
Tutilisation d'un fichier de fonction allégé pour l'API
Browse files- app.py +2 -1
- fonctions_api.py +58 -0
app.py
CHANGED
@@ -2,7 +2,8 @@
|
|
2 |
from fastapi import FastAPI, UploadFile, File
|
3 |
from fastapi.responses import StreamingResponse
|
4 |
import torch
|
5 |
-
from
|
|
|
6 |
from PIL import Image
|
7 |
import io
|
8 |
import numpy as np
|
|
|
2 |
from fastapi import FastAPI, UploadFile, File
|
3 |
from fastapi.responses import StreamingResponse
|
4 |
import torch
|
5 |
+
from fonctions_api import charger_segformer, remap_classes
|
6 |
+
from fonctions_api import decode_cityscapes_mask
|
7 |
from PIL import Image
|
8 |
import io
|
9 |
import numpy as np
|
fonctions_api.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# fonctions_api.py (version allégée pour l'API FastAPI avec SegFormer)
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import numpy as np
|
6 |
+
from pathlib import Path
|
7 |
+
from transformers import SegformerForSemanticSegmentation
|
8 |
+
|
9 |
+
# -------- Palette (optionnelle pour colorisation) --------
|
10 |
+
PALETTE = {
|
11 |
+
0: (0, 0, 0), # void
|
12 |
+
1: (50, 50, 150), # flat
|
13 |
+
2: (102, 0, 204), # construction
|
14 |
+
3: (255, 85, 0), # object
|
15 |
+
4: (255, 255, 0), # nature
|
16 |
+
5: (0, 255, 255), # sky
|
17 |
+
6: (255, 0, 255), # human
|
18 |
+
7: (255, 255, 255), # vehicle
|
19 |
+
}
|
20 |
+
|
21 |
+
# -------- Fonction principale pour charger SegFormer --------
|
22 |
+
def charger_segformer(num_classes=8):
|
23 |
+
model = SegformerForSemanticSegmentation.from_pretrained(
|
24 |
+
"nvidia/segformer-b5-finetuned-ade-640-640",
|
25 |
+
num_labels=num_classes,
|
26 |
+
ignore_mismatched_sizes=True
|
27 |
+
)
|
28 |
+
model.config.num_labels = num_classes
|
29 |
+
model.config.output_hidden_states = False
|
30 |
+
return model
|
31 |
+
|
32 |
+
# -------- Remapping Cityscapes labelIds vers 8 classes --------
|
33 |
+
def remap_classes(mask: np.ndarray) -> np.ndarray:
|
34 |
+
labelIds_to_main_classes = {
|
35 |
+
0: 0, 1: 0, 2: 0, 3: 0, 4: 0, 5: 0, 6: 0,
|
36 |
+
7: 1, 8: 1,
|
37 |
+
9: 0, 10: 0,
|
38 |
+
11: 2, 12: 2, 13: 2,
|
39 |
+
14: 0, 15: 0, 16: 0,
|
40 |
+
17: 3, 18: 3, 19: 3, 20: 3,
|
41 |
+
21: 4, 22: 4,
|
42 |
+
23: 5,
|
43 |
+
24: 6, 25: 6,
|
44 |
+
26: 7, 27: 7, 28: 7, 29: 7, 30: 7, 31: 7, 32: 7, 33: 7
|
45 |
+
}
|
46 |
+
remapped_mask = np.copy(mask)
|
47 |
+
for original_class, new_class in labelIds_to_main_classes.items():
|
48 |
+
remapped_mask[mask == original_class] = new_class
|
49 |
+
remapped_mask[mask > 33] = 0
|
50 |
+
return remapped_mask.astype(np.uint8)
|
51 |
+
|
52 |
+
# -------- Convertit un masque 2D en image RGB (optionnel) --------
|
53 |
+
def decode_cityscapes_mask(mask):
|
54 |
+
h, w = mask.shape
|
55 |
+
mask_rgb = np.zeros((h, w, 3), dtype=np.uint8)
|
56 |
+
for class_id, color in PALETTE.items():
|
57 |
+
mask_rgb[mask == class_id] = color
|
58 |
+
return mask_rgb
|