Spaces:
Running
Running
from starvector.metrics.compute_l2 import L2DistanceCalculator | |
from starvector.metrics.compute_LPIPS import LPIPSDistanceCalculator | |
from starvector.metrics.compute_SSIM import SSIMDistanceCalculator | |
from starvector.metrics.compute_fid import FIDCalculator | |
from starvector.metrics.compute_clip_score import CLIPScoreCalculator | |
from starvector.data.util import rasterize_svg | |
from starvector.metrics.util import AverageMeter | |
from starvector.metrics.compute_dino_score import DINOScoreCalculator | |
from starvector.metrics.count_token_length import CountTokenLength | |
import os | |
from tqdm import tqdm | |
class SVGMetrics: | |
def __init__(self, config=None): | |
self.class_name = self.__class__.__name__ | |
default_config = { | |
'L2': True, | |
'Masked-L2': False, | |
'LPIPS': False, | |
'SSIM': False, | |
'FID': False, | |
'FID_clip': False, | |
'CLIPScore': False, | |
'CountTokenLength': False, | |
'ratio_post_processed': True, | |
'ratio_non_compiling': True, | |
'DinoScore': True, | |
} | |
self.config = config or default_config | |
self.metrics = { | |
'L2': L2DistanceCalculator, | |
'Masked-L2': lambda: L2DistanceCalculator(masked_l2=True), | |
'LPIPS': LPIPSDistanceCalculator, | |
'SSIM': SSIMDistanceCalculator, | |
'FID': lambda: FIDCalculator(model_name='InceptionV3'), | |
'FID_clip': lambda: FIDCalculator(model_name='ViT-B/32'), | |
'CLIPScore': CLIPScoreCalculator, | |
'CountTokenLength': CountTokenLength, | |
'ratio_post_processed': AverageMeter, | |
'ratio_non_compiling': AverageMeter, | |
'DinoScore': DINOScoreCalculator, | |
} | |
self.active_metrics = {k: v() for k, v in self.metrics.items() if self.config.get(k)} | |
def reset(self): | |
for metric in self.active_metrics.values(): | |
metric.reset() | |
def batch_contains_raster(self, batch): | |
return "gt_im" in batch and "gen_im" in batch | |
def batch_contains_svg(self, batch): | |
return "gt_svg" in batch and "gen_svg" in batch | |
def calculate_metrics(self, batch, update=True): | |
if not self.batch_contains_raster(batch): | |
batch["gt_im"] = [rasterize_svg(svg) for svg in batch["gt_svg"]] | |
batch["gen_im"] = [rasterize_svg(svg) for svg in batch["gen_svg"]] | |
avg_results_dict = {} | |
all_results_dict = {} | |
def get_sample_id(json_item): | |
return json_item.get('outpath_filename') or json_item.get('sample_id') | |
# initialize all_results_dict | |
for i, json_item in enumerate(batch['json']): | |
sample_id = get_sample_id(json_item) | |
if sample_id is None: | |
raise ValueError(f"Could not find 'outpath_filename' or 'sample_id' in batch['json'][{i}]") | |
all_results_dict[sample_id] = {} | |
for metric_name, metric in self.active_metrics.items(): | |
print(f"Calculating {metric_name}...") | |
# Handle metrics that return both average and per-sample results | |
if metric_name in ['L2', 'Masked-L2', 'SSIM', 'CLIPScore', 'LPIPS', 'CountTokenLength', 'DinoScore']: | |
avg_result, list_result = metric.calculate_score(batch, update=update) | |
avg_results_dict[metric_name] = avg_result | |
# Store individual results | |
for i, result in enumerate(list_result): | |
sample_id = get_sample_id(batch['json'][i]) | |
all_results_dict[sample_id][metric_name] = result | |
# Handle FID metrics that only return average | |
elif metric_name in ['FID', 'FID_clip']: | |
avg_results_dict[metric_name] = metric.calculate_score(batch) | |
# Handle other metrics (ratio metrics) | |
else: | |
self._handle_ratio_metric(metric_name, metric, batch, avg_results_dict, all_results_dict) | |
metric.reset() | |
print("Average results: \n", avg_results_dict) | |
return avg_results_dict, all_results_dict | |
def calculate_fid(self, batch): | |
if not self.batch_contains_raster(batch): | |
batch["gt_im"] = [rasterize_svg(svg) for svg in batch["gt_svg"]] | |
batch["gen_im"] = [rasterize_svg(svg) for svg in batch["gen_svg"]] | |
return self.active_metrics['FID'].calculate_score(batch).item() | |
def get_average_metrics(self): | |
metrics = {} | |
for metric_name, metric in self.active_metrics.items(): | |
if hasattr(metric, 'avg'): | |
metrics[metric_name] = metric.avg | |
elif hasattr(metric, 'get_average_score'): | |
metrics[metric_name] = metric.get_average_score() | |
return metrics | |
def _handle_ratio_metric(self, metric_name, metric, batch, avg_results_dict, all_results_dict): | |
"""Helper method to handle ratio-based metrics.""" | |
metric_key = metric_name.replace('avg_', '').replace('ratio_', '') | |
for item in batch['json']: | |
sample_id = get_sample_id(item) | |
value = item[metric_key] | |
all_results_dict[sample_id][metric_name] = value | |
metric.update(value, 1) | |
avg_results_dict[metric_name] = metric.avg |