hz2475's picture
init
72f684c
from starvector.metrics.util import AverageMeter
from tqdm import tqdm
import math
class BaseMetric:
def __init__(self):
self.meter = AverageMeter()
def reset(self):
self.meter.reset()
def calculate_score(self, batch, update=True):
"""
Batch: {"gt_im": [PIL Image], "gen_im": [Image]}
"""
values = []
batch_size = len(next(iter(batch.values())))
for index in tqdm(range(batch_size)):
kwargs = {}
for key in ["gt_im", "gen_im", "gt_svg", "gen_svg", "caption"]:
if key in batch:
kwargs[key] = batch[key][index]
try:
measure = self.metric(**kwargs)
except Exception as e:
print("Error calculating metric: {}".format(e))
continue
if math.isnan(measure):
continue
values.append(measure)
if not values:
print("No valid values found for metric calculation.")
return float("nan")
score = sum(values) / len(values)
if update:
self.meter.update(score, len(values))
return self.meter.avg, values
else:
return score, values
def metric(self, **kwargs):
"""
This method should be overridden by subclasses to provide the specific metric computation.
"""
raise NotImplementedError("The metric method must be implemented by subclasses.")
def get_average_score(self):
return self.meter.avg