File size: 2,203 Bytes
f0a19a1
 
 
 
5b48288
 
f0a19a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0da33f9
 
 
 
f0a19a1
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# api.py
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import StreamingResponse
import torch
from fonctions_api import charger_segformer, remap_classes
from fonctions_api import decode_cityscapes_mask
from PIL import Image
import io
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torch.nn.functional as F

app = FastAPI()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Chargement modèle SegFormer
model = charger_segformer(num_classes=8)
model.load_state_dict(torch.load("segformer_b5.pth", map_location=device))
model.to(device)
model.eval()

# Prétraitement Albumentations
def preprocess(image: Image.Image) -> torch.Tensor:
    transform = A.Compose([
        A.Resize(256, 256),
        A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
        ToTensorV2()
    ])
    image_np = np.array(image.convert("RGB"))
    transformed = transform(image=image_np)
    return transformed['image'].unsqueeze(0).to(device)

# Palette couleur
PALETTE = {
    0: (0, 0, 0), 1: (50, 50, 150), 2: (102, 0, 204), 3: (255, 85, 0),
    4: (255, 255, 0), 5: (0, 255, 255), 6: (255, 0, 255), 7: (255, 255, 255)
}

def decode_mask(mask):
    h, w = mask.shape
    mask_rgb = np.zeros((h, w, 3), dtype=np.uint8)
    for class_id, color in PALETTE.items():
        mask_rgb[mask == class_id] = color
    return mask_rgb

@app.get("/")
def home():
    return {"status": "API avec modèle 'SegFormer' opérationnelle"}

@app.post("/predict")
async def predict(image: UploadFile = File(...)):
    contents = await image.read()
    img = Image.open(io.BytesIO(contents))

    tensor = preprocess(img)

    with torch.no_grad():
        logits = model(tensor).logits
        logits = F.interpolate(logits, size=(256, 256), mode="bilinear", align_corners=False)
        pred_mask = logits.argmax(dim=1).squeeze().cpu().numpy()

    #mask_rgb = decode_mask(pred_mask)
    #mask_img = Image.fromarray(mask_rgb)
    mask_img = Image.fromarray(pred_mask.astype(np.uint8))  # Masque brut


    buf = io.BytesIO()
    mask_img.save(buf, format="PNG")
    buf.seek(0)

    return StreamingResponse(buf, media_type="image/png")