import torch import numpy as np from PIL import Image from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation import torch.nn.functional as F import logging import time from typing import Tuple, Optional logger = logging.getLogger('looks.studio.segformer') class SegformerParser: def __init__(self, model_path="mattmdjaga/segformer_b2_clothes"): self.start_time = time.time() logger.info(f"Initializing SegformerParser with model: {model_path}") try: self.processor = SegformerImageProcessor.from_pretrained(model_path) self.model = AutoModelForSemanticSegmentation.from_pretrained(model_path) self.device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Using device: {self.device}") self.model.to(self.device) # Define clothing-related labels self.clothing_labels = { 4: "upper-clothes", 5: "skirt", 6: "pants", 7: "dress", 8: "belt", 9: "left-shoe", 10: "right-shoe", 14: "left-arm", 15: "right-arm", 17: "scarf" } logger.info(f"SegformerParser initialized in {time.time() - self.start_time:.2f} seconds") except Exception as e: logger.error(f"Failed to initialize SegformerParser: {str(e)}") raise def _resize_image(self, image: Image.Image, max_size: int = 1024) -> Tuple[Image.Image, float]: """Resize image while maintaining aspect ratio if it exceeds max_size""" width, height = image.size scale = 1.0 if width > max_size or height > max_size: scale = max_size / max(width, height) new_width = int(width * scale) new_height = int(height * scale) logger.info(f"Resizing image from {width}x{height} to {new_width}x{new_height}") image = image.resize((new_width, new_height), Image.Resampling.LANCZOS) return image, scale def _validate_image(self, image: Image.Image) -> bool: """Validate input image""" if not isinstance(image, Image.Image): logger.error("Input is not a PIL Image") return False if image.mode not in ['RGB', 'RGBA']: logger.error(f"Unsupported image mode: {image.mode}") return False width, height = image.size if width < 64 or height < 64: logger.error(f"Image too small: {width}x{height}") return False if width > 4096 or height > 4096: logger.error(f"Image too large: {width}x{height}") return False return True def get_image_mask(self, image: Image.Image) -> Optional[Image.Image]: """Generate segmentation mask for clothing""" start_time = time.time() logger.info(f"Starting segmentation for image size: {image.size}") try: # Validate input image if not self._validate_image(image): return None # Convert RGBA to RGB if necessary if image.mode == 'RGBA': logger.info("Converting RGBA to RGB") image = image.convert('RGB') # Resize image if too large image, scale = self._resize_image(image) # Process the image logger.info("Processing image with Segformer") inputs = self.processor(images=image, return_tensors="pt").to(self.device) # Get predictions with torch.no_grad(): outputs = self.model(**inputs) logits = outputs.logits.cpu() # Upsample logits to original image size upsampled_logits = F.interpolate( logits, size=image.size[::-1], mode="bilinear", align_corners=False, ) # Get the segmentation mask pred_seg = upsampled_logits.argmax(dim=1)[0] # Create a binary mask for clothing mask = torch.zeros_like(pred_seg) for label_id in self.clothing_labels.keys(): mask[pred_seg == label_id] = 255 # Convert to PIL Image mask = Image.fromarray(mask.numpy().astype(np.uint8)) # Resize mask back to original size if needed if scale != 1.0: original_size = (int(image.size[0] / scale), int(image.size[1] / scale)) logger.info(f"Resizing mask back to original size: {original_size}") mask = mask.resize(original_size, Image.Resampling.NEAREST) logger.info(f"Segmentation completed in {time.time() - start_time:.2f} seconds") return mask except Exception as e: logger.error(f"Error during segmentation: {str(e)}") return None def get_all_masks(self, image: Image.Image) -> dict: """Return a dict of binary masks for each clothing part label.""" start_time = time.time() logger.info(f"Starting per-part segmentation for image size: {image.size}") masks = {} try: # Validate input image if not self._validate_image(image): return masks # Convert RGBA to RGB if necessary if image.mode == 'RGBA': logger.info("Converting RGBA to RGB") image = image.convert('RGB') # Resize image if too large image, scale = self._resize_image(image) # Process the image logger.info("Processing image with Segformer for all masks") inputs = self.processor(images=image, return_tensors="pt").to(self.device) # Get predictions with torch.no_grad(): outputs = self.model(**inputs) logits = outputs.logits.cpu() upsampled_logits = F.interpolate( logits, size=image.size[::-1], mode="bilinear", align_corners=False, ) pred_seg = upsampled_logits.argmax(dim=1)[0] # For each clothing label, create a binary mask for label_id, part_name in self.clothing_labels.items(): mask = (pred_seg == label_id).numpy().astype(np.uint8) * 255 mask_img = Image.fromarray(mask) # Resize mask back to original size if needed if scale != 1.0: original_size = (int(image.size[0] / scale), int(image.size[1] / scale)) mask_img = mask_img.resize(original_size, Image.Resampling.NEAREST) masks[part_name] = mask_img logger.info(f"Per-part segmentation completed in {time.time() - start_time:.2f} seconds") return masks except Exception as e: logger.error(f"Error during per-part segmentation: {str(e)}") return masks