File size: 2,032 Bytes
72f684c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
from torch.utils.data import DataLoader
from starvector.metrics.base_metric import BaseMetric
from tqdm import tqdm
from starvector.metrics.util import AverageMeter

from transformers import AutoTokenizer

class CountTokenLength(BaseMetric): 
    def __init__(self, config=None, device='cuda'):
        super().__init__()
        self.tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoder2-7b")
        self.metric = self.calculate_token_length
        self.meter_gt_tokens = AverageMeter()
        self.meter_gen_tokens = AverageMeter()
        self.meter_diff = AverageMeter()

    def calculate_token_length(self, **kwargs):
        svg = kwargs.get('gt_svg')
        tokens = self.tokenizer.encode(svg)
        gen_svg = kwargs.get('gen_svg')
        gen_tokens = self.tokenizer.encode(gen_svg)
        diff = len(gen_tokens) - len(tokens)
        return len(tokens), len(gen_tokens), diff

    def calculate_score(self, batch, update=None):
        gt_svgs = batch['gt_svg']
        gen_svgs = batch['gen_svg']
        values = []
        for gt_svg, gen_svg in tqdm(zip(gt_svgs, gen_svgs), total=len(gt_svgs), desc="Processing SVGs"):
            gt_tokens, gen_tokens, diff = self.calculate_token_length(gt_svg=gt_svg, gen_svg=gen_svg)
            self.meter_gt_tokens.update(gt_tokens, 1)
            self.meter_gen_tokens.update(gen_tokens, 1)
            self.meter_diff.update(diff, 1)
            values.append({
                'gt_tokens': gt_tokens,
                'gen_tokens': gen_tokens,
                'diff': diff
            })
        avg_score = {
            'gt_tokens': self.meter_gt_tokens.avg,
            'gen_tokens': self.meter_gen_tokens.avg,
            'diff': self.meter_diff.avg
        }
        if not values:
            print("No valid values found for metric calculation.")
            return float("nan")

        return avg_score, values

    def reset(self):
        self.meter_gt_tokens.reset()
        self.meter_gen_tokens.reset()
        self.meter_diff.reset()