Spaces:
Sleeping
Sleeping
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
""" | |
A standalone module for aggregating metrics. | |
Metrics can be logged from anywhere using the `log_*` functions defined | |
in this module. The logged values will be aggregated dynamically based | |
on the aggregation context in which the logging occurs. See the | |
:func:`aggregate` context manager for more details. | |
""" | |
import contextlib | |
import time | |
import uuid | |
from collections import OrderedDict, defaultdict | |
from typing import Callable, Dict, List, Optional | |
from .meters import * | |
# Aggregation contexts are considered "active" when inside the scope | |
# created by the :func:`aggregate` context manager. | |
_aggregators = OrderedDict() | |
_active_aggregators = OrderedDict() | |
_active_aggregators_cnt = defaultdict(lambda: 0) | |
def reset() -> None: | |
"""Reset all metrics aggregators.""" | |
_aggregators.clear() | |
_active_aggregators.clear() | |
_active_aggregators_cnt.clear() | |
# The "default" aggregator observes all logged values. | |
_aggregators["default"] = MetersDict() | |
_active_aggregators["default"] = _aggregators["default"] | |
_active_aggregators_cnt["default"] = 1 | |
reset() | |
def aggregate(name: Optional[str] = None, new_root: bool = False): | |
"""Context manager to aggregate metrics under a given name. | |
Aggregations can be nested. If *new_root* is ``False``, then logged | |
metrics will be recorded along the entire stack of nested | |
aggregators, including a global "default" aggregator. If *new_root* | |
is ``True``, then this aggregator will be the root of a new | |
aggregation stack, thus bypassing any parent aggregators. | |
Note that aggregation contexts are uniquely identified by their | |
*name* (e.g., train, valid). Creating a context with an existing | |
name will reuse the corresponding :class:`MetersDict` instance. | |
If no name is given, then a temporary aggregator will be created. | |
Usage:: | |
with metrics.aggregate("train"): | |
for step, batch in enumerate(epoch): | |
with metrics.aggregate("train_inner") as agg: | |
metrics.log_scalar("loss", get_loss(batch)) | |
if step % log_interval == 0: | |
print(agg.get_smoothed_value("loss")) | |
agg.reset() | |
print(metrics.get_smoothed_values("train")["loss"]) | |
Args: | |
name (str): name of the aggregation. Defaults to a | |
random/temporary name if not given explicitly. | |
new_root (bool): make this aggregation the root of a new | |
aggregation stack. | |
""" | |
if name is None: | |
# generate a temporary name | |
name = str(uuid.uuid4()) | |
assert name not in _aggregators | |
agg = MetersDict() | |
else: | |
assert name != "default" | |
agg = _aggregators.setdefault(name, MetersDict()) | |
if new_root: | |
backup_aggregators = _active_aggregators.copy() | |
_active_aggregators.clear() | |
backup_aggregators_cnt = _active_aggregators_cnt.copy() | |
_active_aggregators_cnt.clear() | |
_active_aggregators[name] = agg | |
_active_aggregators_cnt[name] += 1 | |
yield agg | |
_active_aggregators_cnt[name] -= 1 | |
if _active_aggregators_cnt[name] == 0 and name in _active_aggregators: | |
del _active_aggregators[name] | |
if new_root: | |
_active_aggregators.clear() | |
_active_aggregators.update(backup_aggregators) | |
_active_aggregators_cnt.clear() | |
_active_aggregators_cnt.update(backup_aggregators_cnt) | |
def get_active_aggregators() -> List[MetersDict]: | |
return list(_active_aggregators.values()) | |
def log_scalar( | |
key: str, | |
value: float, | |
weight: float = 1, | |
priority: int = 10, | |
round: Optional[int] = None, | |
): | |
"""Log a scalar value. | |
Args: | |
key (str): name of the field to log | |
value (float): value to log | |
weight (float): weight that this value contributes to the average. | |
A weight of 0 will always log the latest value. | |
priority (int): smaller values are logged earlier in the output | |
round (Optional[int]): number of digits to round to when displaying | |
""" | |
for agg in get_active_aggregators(): | |
if key not in agg: | |
agg.add_meter(key, AverageMeter(round=round), priority) | |
agg[key].update(value, weight) | |
def log_derived(key: str, fn: Callable[[MetersDict], float], priority: int = 20): | |
"""Log a scalar value derived from other meters. | |
Args: | |
key (str): name of the field to log | |
fn (Callable[[MetersDict], float]): function that takes a single | |
argument *meters* and returns the derived value | |
priority (int): smaller values are logged earlier in the output | |
""" | |
for agg in get_active_aggregators(): | |
if key not in agg: | |
agg.add_meter(key, MetersDict._DerivedMeter(fn), priority) | |
def log_speed( | |
key: str, | |
value: float, | |
priority: int = 30, | |
round: Optional[int] = None, | |
): | |
"""Log the rate of some quantity per second. | |
Args: | |
key (str): name of the field to log | |
value (float): value to log | |
priority (int): smaller values are logged earlier in the output | |
round (Optional[int]): number of digits to round to when displaying | |
""" | |
for agg in get_active_aggregators(): | |
if key not in agg: | |
agg.add_meter(key, TimeMeter(round=round), priority) | |
agg[key].reset() # reset meter on the first call | |
else: | |
agg[key].update(value) | |
def log_start_time(key: str, priority: int = 40, round: Optional[int] = None): | |
"""Log the duration of some event in seconds. | |
The duration will be computed once :func:`log_stop_time` is called. | |
Args: | |
key (str): name of the field to log | |
priority (int): smaller values are logged earlier in the output | |
round (Optional[int]): number of digits to round to when displaying | |
""" | |
for agg in get_active_aggregators(): | |
if key not in agg: | |
agg.add_meter(key, StopwatchMeter(round=round), priority) | |
agg[key].start() | |
def log_stop_time(key: str, weight: float = 0.0, prehook=None): | |
"""Log the duration of some event in seconds. | |
The duration will be computed since :func:`log_start_time` was called. | |
Set weight > 0 to report the average time instead of the sum. | |
Args: | |
key (str): name of the field to log | |
weight (float): weight that this time contributes to the average | |
prehook (function, no arguments): will be called before the timer | |
is stopped. For example, use prehook=torch.cuda.synchronize to | |
make sure all gpu operations are done before timer is stopped. | |
""" | |
for agg in get_active_aggregators(): | |
if key in agg: | |
agg[key].stop(weight, prehook) | |
def log_custom( | |
new_meter_fn: Callable[[], Meter], | |
key: str, | |
*args, | |
priority: int = 50, | |
**kwargs, | |
): | |
"""Log using a custom Meter. | |
Any extra *args* or *kwargs* will be passed through to the Meter's | |
*update* method. | |
Args: | |
new_meter_fn (Callable[[], Meter]): function that returns a new | |
Meter instance | |
key (str): name of the field to log | |
priority (int): smaller values are logged earlier in the output | |
""" | |
for agg in get_active_aggregators(): | |
if key not in agg: | |
agg.add_meter(key, new_meter_fn(), priority) | |
agg[key].update(*args, **kwargs) | |
def reset_meter(name: str, key: str) -> None: | |
"""Reset Meter instance aggregated under a given *name* and *key*.""" | |
meter = get_meter(name, key) | |
if meter is not None: | |
meter.reset() | |
def reset_meters(name: str) -> None: | |
"""Reset Meter instances aggregated under a given *name*.""" | |
meters = get_meters(name) | |
if meters is not None: | |
meters.reset() | |
def get_meter(name: str, key: str) -> Meter: | |
"""Get a single Meter instance aggregated under *name* and *key*. | |
Returns: | |
Meter or None if no metrics have been logged under *name* and *key*. | |
""" | |
if name not in _aggregators: | |
return None | |
return _aggregators[name].get(key, None) | |
def get_meters(name: str) -> MetersDict: | |
"""Get Meter instances aggregated under a given *name*. | |
Returns: | |
MetersDict or None if no metrics have been logged under *name*. | |
""" | |
return _aggregators.get(name, None) | |
def get_smoothed_value(name: str, key: str) -> float: | |
"""Get a single smoothed value. | |
Raises: | |
KeyError: if no metrics have been logged under *name* and *key*. | |
""" | |
return _aggregators[name].get_smoothed_value(key) | |
def get_smoothed_values(name: str) -> Dict[str, float]: | |
"""Get smoothed values aggregated under a given *name*. | |
Raises: | |
KeyError: if no metrics have been logged under *name*. | |
""" | |
return _aggregators[name].get_smoothed_values() | |
def state_dict(): | |
return OrderedDict([(name, agg.state_dict()) for name, agg in _aggregators.items()]) | |
def load_state_dict(state_dict): | |
for name, agg_state in state_dict.items(): | |
_aggregators[name] = MetersDict() | |
_aggregators[name].load_state_dict(agg_state) | |