Spaces:
Running
Running
# api.py | |
from fastapi import FastAPI, UploadFile, File | |
from fastapi.responses import StreamingResponse | |
import torch | |
from fonctions_api import charger_segformer, remap_classes | |
from fonctions_api import decode_cityscapes_mask | |
from PIL import Image | |
import io | |
import numpy as np | |
import albumentations as A | |
from albumentations.pytorch import ToTensorV2 | |
import torch.nn.functional as F | |
app = FastAPI() | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Chargement modèle SegFormer | |
model = charger_segformer(num_classes=8) | |
model.load_state_dict(torch.load("segformer_b5.pth", map_location=device)) | |
model.to(device) | |
model.eval() | |
# Prétraitement Albumentations | |
def preprocess(image: Image.Image) -> torch.Tensor: | |
transform = A.Compose([ | |
A.Resize(256, 256), | |
A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), | |
ToTensorV2() | |
]) | |
image_np = np.array(image.convert("RGB")) | |
transformed = transform(image=image_np) | |
return transformed['image'].unsqueeze(0).to(device) | |
# Palette couleur | |
PALETTE = { | |
0: (0, 0, 0), 1: (50, 50, 150), 2: (102, 0, 204), 3: (255, 85, 0), | |
4: (255, 255, 0), 5: (0, 255, 255), 6: (255, 0, 255), 7: (255, 255, 255) | |
} | |
def decode_mask(mask): | |
h, w = mask.shape | |
mask_rgb = np.zeros((h, w, 3), dtype=np.uint8) | |
for class_id, color in PALETTE.items(): | |
mask_rgb[mask == class_id] = color | |
return mask_rgb | |
def home(): | |
return {"status": "API avec modèle 'SegFormer' opérationnelle"} | |
async def predict(image: UploadFile = File(...)): | |
contents = await image.read() | |
img = Image.open(io.BytesIO(contents)) | |
tensor = preprocess(img) | |
with torch.no_grad(): | |
logits = model(tensor).logits | |
logits = F.interpolate(logits, size=(256, 256), mode="bilinear", align_corners=False) | |
pred_mask = logits.argmax(dim=1).squeeze().cpu().numpy() | |
#mask_rgb = decode_mask(pred_mask) | |
#mask_img = Image.fromarray(mask_rgb) | |
mask_img = Image.fromarray(pred_mask.astype(np.uint8)) # Masque brut | |
buf = io.BytesIO() | |
mask_img.save(buf, format="PNG") | |
buf.seek(0) | |
return StreamingResponse(buf, media_type="image/png") |