import os import io import uuid import numpy as np from PIL import Image, ImageFilter from fastapi import FastAPI, File, UploadFile, Form, HTTPException from fastapi.responses import JSONResponse import torch from transformers import CLIPModel, CLIPProcessor from diffusers import StableDiffusionInpaintPipeline from sam2.build_sam import build_sam2 from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator from huggingface_hub import HfApi, hf_hub_download import uvicorn # Configurar cachés antes de importar cualquier modelo os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers_cache" os.environ["HF_HOME"] = "/tmp/huggingface" os.makedirs("/tmp/transformers_cache", exist_ok=True) os.makedirs("/tmp/huggingface", exist_ok=True) app = FastAPI() # Etiquetas y umbral para filtrar regiones de ropa CLOTHING_LABELS = ["a piece of clothing", "shirt", "jacket", "pants", "dress", "skirt"] CLIP_THRESHOLD = 0.25 print("Starting app.py...") def process_image(pil_img: Image.Image, prompt: str, neg_prompt: str, hf_repo: str = None): # --- Configuración de dispositivo device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # --- Cargar y normalizar embeddings de texto con CLIP # Añadido from_tf=True para manejar pesos en formato TensorFlow clip_model = CLIPModel.from_pretrained( "openai/clip-vit-base-patch32", from_tf=True, cache_dir="/tmp/transformers_cache" ).to(device) clip_processor = CLIPProcessor.from_pretrained( "openai/clip-vit-base-patch32", cache_dir="/tmp/transformers_cache" ) text_inputs = clip_processor(text=CLOTHING_LABELS, return_tensors="pt", padding=True).to(device) with torch.no_grad(): text_embeddings = clip_model.get_text_features(**text_inputs) text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True) # --- Preparar imagen numpy para SAM np_img = np.array(pil_img) # --- Descargar y cargar SAM2 cache_dir = os.path.join("/tmp", "sam2_cache") os.makedirs(cache_dir, exist_ok=True) ckpt = os.path.join(cache_dir, "sam2_hiera_large.pt") cfg = os.path.join(cache_dir, "sam2_hiera_l.yaml") if not os.path.exists(ckpt): ckpt = hf_hub_download("facebook/sam2-hiera-large", "sam2_hiera_large.pt", repo_type="model", cache_dir=cache_dir) if not os.path.exists(cfg): cfg = hf_hub_download("facebook/sam2-hiera-large", "sam2_hiera_l.yaml", repo_type="model", cache_dir=cache_dir) sam2 = build_sam2("sam2_hiera_l", ckpt, device=device) mask_generator = SAM2AutomaticMaskGenerator(sam2) # --- Generar todas las máscaras masks = mask_generator.generate(np_img) # --- Filtrar máscaras por contenido de ropa usando CLIP combined = np.zeros(np_img.shape[:2], dtype=bool) for m in masks: seg = m.get("segmentation") if seg is None: continue ys, xs = np.where(seg) if ys.size == 0: continue y1, y2 = ys.min(), ys.max() x1, x2 = xs.min(), xs.max() patch = pil_img.crop((x1, y1, x2, y2)) inputs = clip_processor(images=patch, return_tensors="pt").to(device) with torch.no_grad(): img_emb = clip_model.get_image_features(**inputs) img_emb = img_emb / img_emb.norm(dim=-1, keepdim=True) sims = (img_emb @ text_embeddings.T).squeeze(0) if float(sims.max().cpu()) > CLIP_THRESHOLD: combined |= seg # --- Crear y procesar máscaras mask_bin = Image.fromarray((combined.astype(np.uint8)) * 255) mask_dilated = mask_bin.filter(ImageFilter.MaxFilter(15)) mask_for_inpaint = mask_dilated.filter(ImageFilter.GaussianBlur(7)) # --- Inpainting con Stable Diffusion pipe = StableDiffusionInpaintPipeline.from_pretrained( "sd-legacy/stable-diffusion-inpainting", revision="fp16", torch_dtype=torch.float16, cache_dir="/tmp/diffusers_cache" ).to(device) try: pipe.enable_xformers_memory_efficient_attention() except: pass if not combined.any(): result = pil_img.copy() else: result = pipe( prompt=prompt, negative_prompt=neg_prompt, image=pil_img, mask_image=mask_for_inpaint ).images[0] # --- Crear visualización de segmentaciones SAM viz = np.array(pil_img).astype(np.float32) rnd = np.random.RandomState(42) for m in masks: seg = m.get("segmentation") if seg is None: continue color = rnd.randint(0, 256, size=3, dtype=np.uint8) ys, xs = np.where(seg) viz[ys, xs] = viz[ys, xs] * 0.5 + color * 0.5 seg_viz = Image.fromarray(viz.astype(np.uint8)) # --- Subida a HF Hub (datasets) token = os.getenv("HF_TOKEN") if token is None: raise RuntimeError("HF_TOKEN no definido en variables de entorno") api = HfApi() if hf_repo is None: user = api.whoami(token=token)["name"] hf_repo = f"{user}/sam2-inpaint-outputs" api.create_repo(repo_id=hf_repo, repo_type="dataset", token=token, exist_ok=True) uid = uuid.uuid4().hex[:8] # Usa directorio temporal para archivos temporales temp_dir = "/tmp/sam2_outputs" os.makedirs(temp_dir, exist_ok=True) names = { "seg": os.path.join(temp_dir, f"sam_seg_{uid}.png"), "mask": os.path.join(temp_dir, f"mask_{uid}.png"), "out": os.path.join(temp_dir, f"inpaint_{uid}.png") } # Guardar temporales seg_viz.save(names["seg"]) mask_bin.save(names["mask"]) result.save(names["out"]) # Nombres para URLs url_names = { "seg": f"sam_seg_{uid}.png", "mask": f"mask_{uid}.png", "out": f"inpaint_{uid}.png" } # Subir for key, fname in names.items(): api.upload_file( path_or_fileobj=fname, path_in_repo=url_names[key], repo_id=hf_repo, repo_type="dataset", token=token ) os.remove(fname) base = f"https://huggingface.co/datasets/{hf_repo}/resolve/main" return ( f"{base}/{url_names['seg']}", f"{base}/{url_names['mask']}", f"{base}/{url_names['out']}" ) @app.post("/inpaint/") async def inpaint( file: UploadFile = File(...), prompt: str = Form(...), neg_prompt: str = Form("old clothes, residue, artifacts"), hf_repo: str = Form(None) ): # Leer imagen subida try: data = await file.read() img = Image.open(io.BytesIO(data)).convert("RGB") except Exception: raise HTTPException(status_code=400, detail="Imagen no válida") # Procesar try: seg_url, mask_url, out_url = process_image(img, prompt, neg_prompt, hf_repo) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # Responder JSON return JSONResponse({ "sam_segmentation": seg_url, "clothing_mask": mask_url, "inpainted": out_url }) # Agregar una función main para ejecutar directamente if __name__ == "__main__": # Precargar modelos para verificar que funcionen antes de iniciar el servidor print("Preloading CLIP model...") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") try: # Usar from_tf=True para arreglar la carga del modelo clip_model = CLIPModel.from_pretrained( "openai/clip-vit-base-patch32", from_tf=True, cache_dir="/tmp/transformers_cache" ).to(device) clip_processor = CLIPProcessor.from_pretrained( "openai/clip-vit-base-patch32", cache_dir="/tmp/transformers_cache" ) print("CLIP model loaded successfully!") except Exception as e: print(f"Error preloading CLIP model: {e}") # No salir - dejar que falle en tiempo de ejecución si es necesario # Ejecutar la aplicación FastAPI con uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)