File size: 1,971 Bytes
5b48288
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# fonctions_api.py (version allégée pour l'API FastAPI avec SegFormer)

import torch
import torch.nn.functional as F
import numpy as np
from pathlib import Path
from transformers import SegformerForSemanticSegmentation

# -------- Palette (optionnelle pour colorisation) --------
PALETTE = {
    0: (0, 0, 0),           # void
    1: (50, 50, 150),       # flat
    2: (102, 0, 204),       # construction
    3: (255, 85, 0),        # object
    4: (255, 255, 0),       # nature
    5: (0, 255, 255),       # sky
    6: (255, 0, 255),       # human
    7: (255, 255, 255),     # vehicle
}

# -------- Fonction principale pour charger SegFormer --------
def charger_segformer(num_classes=8):
    model = SegformerForSemanticSegmentation.from_pretrained(
        "nvidia/segformer-b5-finetuned-ade-640-640",
        num_labels=num_classes,
        ignore_mismatched_sizes=True
    )
    model.config.num_labels = num_classes
    model.config.output_hidden_states = False
    return model

# -------- Remapping Cityscapes labelIds vers 8 classes --------
def remap_classes(mask: np.ndarray) -> np.ndarray:
    labelIds_to_main_classes = {
        0: 0, 1: 0, 2: 0, 3: 0, 4: 0, 5: 0, 6: 0,
        7: 1, 8: 1,
        9: 0, 10: 0,
        11: 2, 12: 2, 13: 2,
        14: 0, 15: 0, 16: 0,
        17: 3, 18: 3, 19: 3, 20: 3,
        21: 4, 22: 4,
        23: 5,
        24: 6, 25: 6,
        26: 7, 27: 7, 28: 7, 29: 7, 30: 7, 31: 7, 32: 7, 33: 7
    }
    remapped_mask = np.copy(mask)
    for original_class, new_class in labelIds_to_main_classes.items():
        remapped_mask[mask == original_class] = new_class
    remapped_mask[mask > 33] = 0
    return remapped_mask.astype(np.uint8)

# -------- Convertit un masque 2D en image RGB (optionnel) --------
def decode_cityscapes_mask(mask):
    h, w = mask.shape
    mask_rgb = np.zeros((h, w, 3), dtype=np.uint8)
    for class_id, color in PALETTE.items():
        mask_rgb[mask == class_id] = color
    return mask_rgb