Spaces:
Sleeping
Sleeping
File size: 1,126 Bytes
5ac1897 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
import torch
from collections import defaultdict
from .metrics import *
class EvaluatorBase():
''' To use this class, you should inherit it and implement the `eval` method. '''
def __init__(self):
self.accumulator = defaultdict(list)
def eval(self, **kwargs):
''' Evaluate the metrics on the data. '''
raise NotImplementedError
def get_results(self, chosen_metric=None):
''' Get the current mean results. '''
# Only chosen metrics will be compacted and returned.
compacted = self._compact_accumulator(chosen_metric)
ret = {}
for k, v in compacted.items():
ret[k] = v.mean(dim=0).item()
return ret
def _compact_accumulator(self, chosen_metric=None):
''' Compact the accumulator list and return the compacted results. '''
ret = {}
for k, v in self.accumulator.items():
# Only chosen metrics will be compacted.
if chosen_metric is None or k in chosen_metric:
ret[k] = torch.cat(v, dim=0)
self.accumulator[k] = [ret[k]]
return ret
|