from fastapi import FastAPI, File, UploadFile, HTTPException from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor import torch from PIL import Image import numpy as np import io import base64 import logging # Inizializza l'app FastAPI app = FastAPI() # Configura il logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Carica il modello e il processore SegFormer try: logger.info("Caricamento del modello SegFormer...") model = SegformerForSemanticSegmentation.from_pretrained("sayeed99/segformer-b3-fashion") processor = SegformerImageProcessor.from_pretrained("sayeed99/segformer-b3-fashion") model.to("cpu") # Usa CPU per il free tier logger.info("Modello caricato con successo.") except Exception as e: logger.error(f"Errore nel caricamento del modello: {str(e)}") raise RuntimeError(f"Errore nel caricamento del modello: {str(e)}") # Funzione per segmentare l'immagine def segment_image(image: Image.Image): # Prepara l'input per SegFormer logger.info("Preparazione dell'immagine per l'inferenza...") inputs = processor(images=image, return_tensors="pt").to("cpu") # Inferenza logger.info("Esecuzione dell'inferenza...") with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits # Post-processa la maschera logger.info("Post-processing della maschera...") mask = torch.argmax(logits, dim=1)[0] mask = mask.cpu().numpy() # Converti la maschera in immagine mask_img = Image.fromarray((mask * 255 / mask.max()).astype(np.uint8)) # Converti la maschera in base64 per la risposta buffered = io.BytesIO() mask_img.save(buffered, format="PNG") mask_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") # Annotazioni annotations = {"mask": mask.tolist(), "label": "fashion"} return mask_base64, annotations # Endpoint API @app.post("/segment") async def segment_endpoint(file: UploadFile = File(...)): try: logger.info("Ricezione del file...") image_data = await file.read() image = Image.open(io.BytesIO(image_data)).convert("RGB") logger.info("Segmentazione dell'immagine...") mask_base64, annotations = segment_image(image) return { "mask": f"data:image/png;base64,{mask_base64}", "annotations": annotations } except Exception as e: logger.error(f"Errore nell'endpoint: {str(e)}") raise HTTPException(status_code=500, detail=f"Errore nell'elaborazione: {str(e)}") # Per compatibilità con Hugging Face Spaces if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)