File size: 2,281 Bytes
72f684c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from torchvision.transforms import ToTensor, Normalize
import torch
from torch.utils.data import DataLoader
from starvector.metrics.base_metric import BaseMetric
import lpips
from tqdm import tqdm


class LPIPSDistanceCalculator(BaseMetric): 
    def __init__(self, config=None, device='cuda'):
        super().__init__()
        self.class_name = self.__class__.__name__
        self.config = config
        self.model = lpips.LPIPS(net='vgg').to(device)
        self.metric = self.LPIPS
        self.to_tensor = ToTensor()
        self.normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        self.device = device

    def LPIPS(self, tensor_image1, tensor_image2):
        tensor_image1, tensor_image2 = tensor_image1.to(self.device), tensor_image2.to(self.device)
        return self.model(tensor_image1, tensor_image2)
    
    def to_tensor_transform(self, pil_img):
        return self.normalize(self.to_tensor(pil_img))

    def collate_fn(self, batch):
        gt_imgs, gen_imgs = zip(*batch)
        tensor_gt_imgs = torch.stack([self.to_tensor_transform(img) for img in gt_imgs])
        tensor_gen_imgs = torch.stack([self.to_tensor_transform(img) for img in gen_imgs])
        return tensor_gt_imgs, tensor_gen_imgs

    def calculate_score(self, batch, batch_size=8, update=True):
        gt_images = batch['gt_im']
        gen_images = batch['gen_im']

        # Create DataLoader with custom collate function
        data_loader = DataLoader(list(zip(gt_images, gen_images)), batch_size=batch_size, collate_fn=self.collate_fn, shuffle=False)
        
        values = []
        for tensor_gt_batch, tensor_gen_batch in tqdm(data_loader):
            # Compute LPIPS
            lpips_values = self.LPIPS(tensor_gt_batch, tensor_gen_batch)
            values.extend([lpips_values.squeeze().cpu().detach().tolist()] if lpips_values.numel() == 1 else lpips_values.squeeze().cpu().detach().tolist())

        if not values:
            print("No valid values found for metric calculation.")
            return float("nan")

        avg_score = sum(values) / len(values)
        if update:
            self.meter.update(avg_score, len(values))
            return self.meter.avg, values
        else:
            return avg_score, values