from torchvision.transforms import ToTensor, Normalize import torch from torch.utils.data import DataLoader from starvector.metrics.base_metric import BaseMetric import lpips from tqdm import tqdm class LPIPSDistanceCalculator(BaseMetric): def __init__(self, config=None, device='cuda'): super().__init__() self.class_name = self.__class__.__name__ self.config = config self.model = lpips.LPIPS(net='vgg').to(device) self.metric = self.LPIPS self.to_tensor = ToTensor() self.normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) self.device = device def LPIPS(self, tensor_image1, tensor_image2): tensor_image1, tensor_image2 = tensor_image1.to(self.device), tensor_image2.to(self.device) return self.model(tensor_image1, tensor_image2) def to_tensor_transform(self, pil_img): return self.normalize(self.to_tensor(pil_img)) def collate_fn(self, batch): gt_imgs, gen_imgs = zip(*batch) tensor_gt_imgs = torch.stack([self.to_tensor_transform(img) for img in gt_imgs]) tensor_gen_imgs = torch.stack([self.to_tensor_transform(img) for img in gen_imgs]) return tensor_gt_imgs, tensor_gen_imgs def calculate_score(self, batch, batch_size=8, update=True): gt_images = batch['gt_im'] gen_images = batch['gen_im'] # Create DataLoader with custom collate function data_loader = DataLoader(list(zip(gt_images, gen_images)), batch_size=batch_size, collate_fn=self.collate_fn, shuffle=False) values = [] for tensor_gt_batch, tensor_gen_batch in tqdm(data_loader): # Compute LPIPS lpips_values = self.LPIPS(tensor_gt_batch, tensor_gen_batch) values.extend([lpips_values.squeeze().cpu().detach().tolist()] if lpips_values.numel() == 1 else lpips_values.squeeze().cpu().detach().tolist()) if not values: print("No valid values found for metric calculation.") return float("nan") avg_score = sum(values) / len(values) if update: self.meter.update(avg_score, len(values)) return self.meter.avg, values else: return avg_score, values