Spaces:
Running
Running
import torch | |
from torch.utils.data import DataLoader | |
from starvector.metrics.base_metric import BaseMetric | |
from tqdm import tqdm | |
from starvector.metrics.util import AverageMeter | |
from transformers import AutoTokenizer | |
class CountTokenLength(BaseMetric): | |
def __init__(self, config=None, device='cuda'): | |
super().__init__() | |
self.tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoder2-7b") | |
self.metric = self.calculate_token_length | |
self.meter_gt_tokens = AverageMeter() | |
self.meter_gen_tokens = AverageMeter() | |
self.meter_diff = AverageMeter() | |
def calculate_token_length(self, **kwargs): | |
svg = kwargs.get('gt_svg') | |
tokens = self.tokenizer.encode(svg) | |
gen_svg = kwargs.get('gen_svg') | |
gen_tokens = self.tokenizer.encode(gen_svg) | |
diff = len(gen_tokens) - len(tokens) | |
return len(tokens), len(gen_tokens), diff | |
def calculate_score(self, batch, update=None): | |
gt_svgs = batch['gt_svg'] | |
gen_svgs = batch['gen_svg'] | |
values = [] | |
for gt_svg, gen_svg in tqdm(zip(gt_svgs, gen_svgs), total=len(gt_svgs), desc="Processing SVGs"): | |
gt_tokens, gen_tokens, diff = self.calculate_token_length(gt_svg=gt_svg, gen_svg=gen_svg) | |
self.meter_gt_tokens.update(gt_tokens, 1) | |
self.meter_gen_tokens.update(gen_tokens, 1) | |
self.meter_diff.update(diff, 1) | |
values.append({ | |
'gt_tokens': gt_tokens, | |
'gen_tokens': gen_tokens, | |
'diff': diff | |
}) | |
avg_score = { | |
'gt_tokens': self.meter_gt_tokens.avg, | |
'gen_tokens': self.meter_gen_tokens.avg, | |
'diff': self.meter_diff.avg | |
} | |
if not values: | |
print("No valid values found for metric calculation.") | |
return float("nan") | |
return avg_score, values | |
def reset(self): | |
self.meter_gt_tokens.reset() | |
self.meter_gen_tokens.reset() | |
self.meter_diff.reset() | |