# api.py from fastapi import FastAPI, UploadFile, File from fastapi.responses import StreamingResponse import torch from fonctions_api import charger_segformer, remap_classes from fonctions_api import decode_cityscapes_mask from PIL import Image import io import numpy as np import albumentations as A from albumentations.pytorch import ToTensorV2 import torch.nn.functional as F app = FastAPI() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Chargement modèle SegFormer model = charger_segformer(num_classes=8) model.load_state_dict(torch.load("segformer_b5.pth", map_location=device)) model.to(device) model.eval() # Prétraitement Albumentations def preprocess(image: Image.Image) -> torch.Tensor: transform = A.Compose([ A.Resize(256, 256), A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), ToTensorV2() ]) image_np = np.array(image.convert("RGB")) transformed = transform(image=image_np) return transformed['image'].unsqueeze(0).to(device) # Palette couleur PALETTE = { 0: (0, 0, 0), 1: (50, 50, 150), 2: (102, 0, 204), 3: (255, 85, 0), 4: (255, 255, 0), 5: (0, 255, 255), 6: (255, 0, 255), 7: (255, 255, 255) } def decode_mask(mask): h, w = mask.shape mask_rgb = np.zeros((h, w, 3), dtype=np.uint8) for class_id, color in PALETTE.items(): mask_rgb[mask == class_id] = color return mask_rgb @app.get("/") def home(): return {"status": "API avec modèle 'SegFormer' opérationnelle"} @app.post("/predict") async def predict(image: UploadFile = File(...)): contents = await image.read() img = Image.open(io.BytesIO(contents)) tensor = preprocess(img) with torch.no_grad(): logits = model(tensor).logits logits = F.interpolate(logits, size=(256, 256), mode="bilinear", align_corners=False) pred_mask = logits.argmax(dim=1).squeeze().cpu().numpy() #mask_rgb = decode_mask(pred_mask) #mask_img = Image.fromarray(mask_rgb) mask_img = Image.fromarray(pred_mask.astype(np.uint8)) # Masque brut buf = io.BytesIO() mask_img.save(buf, format="PNG") buf.seek(0) return StreamingResponse(buf, media_type="image/png")