File size: 4,676 Bytes
1c72248
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import argparse
import os
from PIL import Image
import torch
from torchvision.transforms import Resize, ToTensor
from diffusers import AutoencoderKL
from pytorch_fid import fid_score
from skimage.metrics import peak_signal_noise_ratio as psnr
import lpips
from tqdm import tqdm
from torchvision import transforms

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_images(folder_path):
    images = []
    for filename in os.listdir(folder_path):
        if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
            img_path = os.path.join(folder_path, filename)
            images.append(img_path)
    return images


def paramiter_count(model):
    state_dict = model.state_dict()
    paramiter_count = 0
    for key in state_dict:
        paramiter_count += torch.numel(state_dict[key])
    return int(paramiter_count)


def calculate_metrics(vae, images, max_imgs=-1, save_output=False):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    vae = vae.to(device)
    lpips_model = lpips.LPIPS(net='alex').to(device)

    rfid_scores = []
    psnr_scores = []
    lpips_scores = []

    # transform = transforms.Compose([
    #     transforms.Resize(256, antialias=True),
    #     transforms.CenterCrop(256)
    # ])
    # needs values between -1 and 1
    to_tensor = ToTensor()
    
    # remove _reconstructed.png files
    images = [img for img in images if not img.endswith("_reconstructed.png")]

    if max_imgs > 0 and len(images) > max_imgs:
        images = images[:max_imgs]

    for img_path in tqdm(images):
        try:
            img = Image.open(img_path).convert('RGB')
            # img_tensor = to_tensor(transform(img)).unsqueeze(0).to(device)
            img_tensor = to_tensor(img).unsqueeze(0).to(device)
            img_tensor = 2 * img_tensor - 1
            # if width or height is not divisible by 8, crop it
            if img_tensor.shape[2] % 8 != 0 or img_tensor.shape[3] % 8 != 0:
                img_tensor = img_tensor[:, :, :img_tensor.shape[2] // 8 * 8, :img_tensor.shape[3] // 8 * 8]

        except Exception as e:
            print(f"Error processing {img_path}: {e}")
            continue


        with torch.no_grad():
            reconstructed = vae.decode(vae.encode(img_tensor).latent_dist.sample()).sample

        # Calculate rFID
        # rfid = fid_score.calculate_frechet_distance(vae, img_tensor, reconstructed)
        # rfid_scores.append(rfid)

        # Calculate PSNR
        psnr_val = psnr(img_tensor.cpu().numpy(), reconstructed.cpu().numpy())
        psnr_scores.append(psnr_val)

        # Calculate LPIPS
        lpips_val = lpips_model(img_tensor, reconstructed).item()
        lpips_scores.append(lpips_val)

    # avg_rfid = sum(rfid_scores) / len(rfid_scores)
    avg_rfid = 0
    avg_psnr = sum(psnr_scores) / len(psnr_scores)
    avg_lpips = sum(lpips_scores) / len(lpips_scores)
    
    if save_output:
        filename_no_ext = os.path.splitext(os.path.basename(img_path))[0]
        folder = os.path.dirname(img_path)
        save_path = os.path.join(folder, filename_no_ext + "_reconstructed.png")
        reconstructed = (reconstructed + 1) / 2
        reconstructed = reconstructed.clamp(0, 1)
        reconstructed = transforms.ToPILImage()(reconstructed[0].cpu())
        reconstructed.save(save_path)

    return avg_rfid, avg_psnr, avg_lpips


def main():
    parser = argparse.ArgumentParser(description="Calculate average rFID, PSNR, and LPIPS for VAE reconstructions")
    parser.add_argument("--vae_path", type=str, required=True, help="Path to the VAE model")
    parser.add_argument("--image_folder", type=str, required=True, help="Path to the folder containing images")
    parser.add_argument("--max_imgs", type=int, default=-1, help="Max num of images. Default is -1 for all images.")
    # boolean store true
    parser.add_argument("--save_output", action="store_true", help="Save the output images")
    args = parser.parse_args()

    if  os.path.isfile(args.vae_path):
        vae = AutoencoderKL.from_single_file(args.vae_path)
    else:
        try:
            vae = AutoencoderKL.from_pretrained(args.vae_path)
        except:
            vae = AutoencoderKL.from_pretrained(args.vae_path, subfolder="vae")
    vae.eval()
    vae = vae.to(device)
    print(f"Model has {paramiter_count(vae)} parameters")
    images = load_images(args.image_folder)

    avg_rfid, avg_psnr, avg_lpips = calculate_metrics(vae, images, args.max_imgs, args.save_output)

    # print(f"Average rFID: {avg_rfid}")
    print(f"Average PSNR: {avg_psnr}")
    print(f"Average LPIPS: {avg_lpips}")


if __name__ == "__main__":
    main()