JimSmith007's picture
Update color mask app.py
0da33f9 verified
# api.py
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import StreamingResponse
import torch
from fonctions_api import charger_segformer, remap_classes
from fonctions_api import decode_cityscapes_mask
from PIL import Image
import io
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torch.nn.functional as F
app = FastAPI()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Chargement modèle SegFormer
model = charger_segformer(num_classes=8)
model.load_state_dict(torch.load("segformer_b5.pth", map_location=device))
model.to(device)
model.eval()
# Prétraitement Albumentations
def preprocess(image: Image.Image) -> torch.Tensor:
transform = A.Compose([
A.Resize(256, 256),
A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
ToTensorV2()
])
image_np = np.array(image.convert("RGB"))
transformed = transform(image=image_np)
return transformed['image'].unsqueeze(0).to(device)
# Palette couleur
PALETTE = {
0: (0, 0, 0), 1: (50, 50, 150), 2: (102, 0, 204), 3: (255, 85, 0),
4: (255, 255, 0), 5: (0, 255, 255), 6: (255, 0, 255), 7: (255, 255, 255)
}
def decode_mask(mask):
h, w = mask.shape
mask_rgb = np.zeros((h, w, 3), dtype=np.uint8)
for class_id, color in PALETTE.items():
mask_rgb[mask == class_id] = color
return mask_rgb
@app.get("/")
def home():
return {"status": "API avec modèle 'SegFormer' opérationnelle"}
@app.post("/predict")
async def predict(image: UploadFile = File(...)):
contents = await image.read()
img = Image.open(io.BytesIO(contents))
tensor = preprocess(img)
with torch.no_grad():
logits = model(tensor).logits
logits = F.interpolate(logits, size=(256, 256), mode="bilinear", align_corners=False)
pred_mask = logits.argmax(dim=1).squeeze().cpu().numpy()
#mask_rgb = decode_mask(pred_mask)
#mask_img = Image.fromarray(mask_rgb)
mask_img = Image.fromarray(pred_mask.astype(np.uint8)) # Masque brut
buf = io.BytesIO()
mask_img.save(buf, format="PNG")
buf.seek(0)
return StreamingResponse(buf, media_type="image/png")