starvector-1b-im2svg / starvector /metrics /count_token_length.py
hz2475's picture
init
72f684c
import torch
from torch.utils.data import DataLoader
from starvector.metrics.base_metric import BaseMetric
from tqdm import tqdm
from starvector.metrics.util import AverageMeter
from transformers import AutoTokenizer
class CountTokenLength(BaseMetric):
def __init__(self, config=None, device='cuda'):
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoder2-7b")
self.metric = self.calculate_token_length
self.meter_gt_tokens = AverageMeter()
self.meter_gen_tokens = AverageMeter()
self.meter_diff = AverageMeter()
def calculate_token_length(self, **kwargs):
svg = kwargs.get('gt_svg')
tokens = self.tokenizer.encode(svg)
gen_svg = kwargs.get('gen_svg')
gen_tokens = self.tokenizer.encode(gen_svg)
diff = len(gen_tokens) - len(tokens)
return len(tokens), len(gen_tokens), diff
def calculate_score(self, batch, update=None):
gt_svgs = batch['gt_svg']
gen_svgs = batch['gen_svg']
values = []
for gt_svg, gen_svg in tqdm(zip(gt_svgs, gen_svgs), total=len(gt_svgs), desc="Processing SVGs"):
gt_tokens, gen_tokens, diff = self.calculate_token_length(gt_svg=gt_svg, gen_svg=gen_svg)
self.meter_gt_tokens.update(gt_tokens, 1)
self.meter_gen_tokens.update(gen_tokens, 1)
self.meter_diff.update(diff, 1)
values.append({
'gt_tokens': gt_tokens,
'gen_tokens': gen_tokens,
'diff': diff
})
avg_score = {
'gt_tokens': self.meter_gt_tokens.avg,
'gen_tokens': self.meter_gen_tokens.avg,
'diff': self.meter_diff.avg
}
if not values:
print("No valid values found for metric calculation.")
return float("nan")
return avg_score, values
def reset(self):
self.meter_gt_tokens.reset()
self.meter_gen_tokens.reset()
self.meter_diff.reset()