JimSmith007 commited on
Commit
5b48288
·
1 Parent(s): ce6cd26

Tutilisation d'un fichier de fonction allégé pour l'API

Browse files
Files changed (2) hide show
  1. app.py +2 -1
  2. fonctions_api.py +58 -0
app.py CHANGED
@@ -2,7 +2,8 @@
2
  from fastapi import FastAPI, UploadFile, File
3
  from fastapi.responses import StreamingResponse
4
  import torch
5
- from fonctions import charger_segformer
 
6
  from PIL import Image
7
  import io
8
  import numpy as np
 
2
  from fastapi import FastAPI, UploadFile, File
3
  from fastapi.responses import StreamingResponse
4
  import torch
5
+ from fonctions_api import charger_segformer, remap_classes
6
+ from fonctions_api import decode_cityscapes_mask
7
  from PIL import Image
8
  import io
9
  import numpy as np
fonctions_api.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # fonctions_api.py (version allégée pour l'API FastAPI avec SegFormer)
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ from pathlib import Path
7
+ from transformers import SegformerForSemanticSegmentation
8
+
9
+ # -------- Palette (optionnelle pour colorisation) --------
10
+ PALETTE = {
11
+ 0: (0, 0, 0), # void
12
+ 1: (50, 50, 150), # flat
13
+ 2: (102, 0, 204), # construction
14
+ 3: (255, 85, 0), # object
15
+ 4: (255, 255, 0), # nature
16
+ 5: (0, 255, 255), # sky
17
+ 6: (255, 0, 255), # human
18
+ 7: (255, 255, 255), # vehicle
19
+ }
20
+
21
+ # -------- Fonction principale pour charger SegFormer --------
22
+ def charger_segformer(num_classes=8):
23
+ model = SegformerForSemanticSegmentation.from_pretrained(
24
+ "nvidia/segformer-b5-finetuned-ade-640-640",
25
+ num_labels=num_classes,
26
+ ignore_mismatched_sizes=True
27
+ )
28
+ model.config.num_labels = num_classes
29
+ model.config.output_hidden_states = False
30
+ return model
31
+
32
+ # -------- Remapping Cityscapes labelIds vers 8 classes --------
33
+ def remap_classes(mask: np.ndarray) -> np.ndarray:
34
+ labelIds_to_main_classes = {
35
+ 0: 0, 1: 0, 2: 0, 3: 0, 4: 0, 5: 0, 6: 0,
36
+ 7: 1, 8: 1,
37
+ 9: 0, 10: 0,
38
+ 11: 2, 12: 2, 13: 2,
39
+ 14: 0, 15: 0, 16: 0,
40
+ 17: 3, 18: 3, 19: 3, 20: 3,
41
+ 21: 4, 22: 4,
42
+ 23: 5,
43
+ 24: 6, 25: 6,
44
+ 26: 7, 27: 7, 28: 7, 29: 7, 30: 7, 31: 7, 32: 7, 33: 7
45
+ }
46
+ remapped_mask = np.copy(mask)
47
+ for original_class, new_class in labelIds_to_main_classes.items():
48
+ remapped_mask[mask == original_class] = new_class
49
+ remapped_mask[mask > 33] = 0
50
+ return remapped_mask.astype(np.uint8)
51
+
52
+ # -------- Convertit un masque 2D en image RGB (optionnel) --------
53
+ def decode_cityscapes_mask(mask):
54
+ h, w = mask.shape
55
+ mask_rgb = np.zeros((h, w, 3), dtype=np.uint8)
56
+ for class_id, color in PALETTE.items():
57
+ mask_rgb[mask == class_id] = color
58
+ return mask_rgb