Spaces:
Running
on
Zero
Running
on
Zero
Upload 6 files
Browse files- src/utils/__init__.py +4 -0
- src/utils/logging_utils.py +55 -0
- src/utils/metrics.py +23 -0
- src/utils/pylogger.py +53 -0
- src/utils/rich_utils.py +98 -0
- src/utils/utils.py +93 -0
src/utils/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.utils.logging_utils import log_hyperparameters
|
2 |
+
from src.utils.pylogger import RankedLogger
|
3 |
+
from src.utils.rich_utils import enforce_tags, print_config_tree
|
4 |
+
from src.utils.utils import extras, task_wrapper
|
src/utils/logging_utils.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict
|
2 |
+
|
3 |
+
from lightning_utilities.core.rank_zero import rank_zero_only
|
4 |
+
from omegaconf import OmegaConf
|
5 |
+
|
6 |
+
from src.utils import pylogger
|
7 |
+
|
8 |
+
log = pylogger.RankedLogger(__name__, rank_zero_only=True)
|
9 |
+
|
10 |
+
|
11 |
+
@rank_zero_only
|
12 |
+
def log_hyperparameters(object_dict: Dict[str, Any]) -> None:
|
13 |
+
"""Controls which config parts are saved by Lightning loggers.
|
14 |
+
|
15 |
+
Additionally saves:
|
16 |
+
- Number of model parameters
|
17 |
+
|
18 |
+
:param object_dict: A dictionary containing the following objects:
|
19 |
+
- `"cfg"`: A DictConfig object containing the main config.
|
20 |
+
- `"model"`: The Lightning model.
|
21 |
+
- `"trainer"`: The Lightning trainer.
|
22 |
+
"""
|
23 |
+
hparams = {}
|
24 |
+
|
25 |
+
cfg = OmegaConf.to_container(object_dict["cfg"])
|
26 |
+
model = object_dict["model"]
|
27 |
+
trainer = object_dict["trainer"]
|
28 |
+
|
29 |
+
if not trainer.logger:
|
30 |
+
log.warning("Logger not found! Skipping hyperparameter logging...")
|
31 |
+
return
|
32 |
+
|
33 |
+
hparams["model"] = cfg["model"]
|
34 |
+
|
35 |
+
# save number of model parameters
|
36 |
+
hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
|
37 |
+
hparams["model/params/trainable"] = sum(
|
38 |
+
p.numel() for p in model.parameters() if p.requires_grad
|
39 |
+
)
|
40 |
+
hparams["model/params/non_trainable"] = sum(
|
41 |
+
p.numel() for p in model.parameters() if not p.requires_grad
|
42 |
+
)
|
43 |
+
|
44 |
+
hparams["data"] = cfg["data"]
|
45 |
+
|
46 |
+
hparams["extras"] = cfg.get("extras")
|
47 |
+
|
48 |
+
hparams["task_name"] = cfg.get("task_name")
|
49 |
+
hparams["tags"] = cfg.get("tags")
|
50 |
+
# hparams["ckpt_path"] = cfg.get("ckpt_path")
|
51 |
+
hparams["seed"] = cfg.get("seed")
|
52 |
+
|
53 |
+
# send hparams to all loggers
|
54 |
+
for logger in trainer.loggers:
|
55 |
+
logger.log_hyperparams(hparams)
|
src/utils/metrics.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torchmetrics import Metric
|
3 |
+
|
4 |
+
|
5 |
+
class PRFMetric(Metric):
|
6 |
+
def __init__(self, **kwargs):
|
7 |
+
super().__init__(**kwargs)
|
8 |
+
self.add_state("t_precision", default=torch.tensor(0.0), dist_reduce_fx="sum")
|
9 |
+
self.add_state("t_recall", default=torch.tensor(0.0), dist_reduce_fx="sum")
|
10 |
+
self.add_state("t_f1", default=torch.tensor(0.0), dist_reduce_fx="sum")
|
11 |
+
self.add_state("n", default=torch.tensor(0.0), dist_reduce_fx="sum")
|
12 |
+
|
13 |
+
def update(self, vid_p, vid_r) -> None:
|
14 |
+
self.t_precision += vid_p
|
15 |
+
self.t_recall += vid_r
|
16 |
+
self.t_f1 += 2 * (vid_p * vid_r) / (vid_p + vid_r) if vid_p + vid_r else 0.0
|
17 |
+
self.n += 1
|
18 |
+
|
19 |
+
def compute(self):
|
20 |
+
avg_p = self.t_precision * 100 / self.n
|
21 |
+
avg_r = self.t_recall * 100 / self.n
|
22 |
+
avg_f1 = self.t_f1 * 100 / self.n
|
23 |
+
return {"precision": avg_p, "recall": avg_r, "f1": avg_f1}
|
src/utils/pylogger.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from typing import Mapping, Optional
|
3 |
+
|
4 |
+
from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only
|
5 |
+
|
6 |
+
|
7 |
+
class RankedLogger(logging.LoggerAdapter):
|
8 |
+
"""A multi-GPU-friendly python command line logger."""
|
9 |
+
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
name: str = __name__,
|
13 |
+
rank_zero_only: bool = False,
|
14 |
+
extra: Optional[Mapping[str, object]] = None,
|
15 |
+
) -> None:
|
16 |
+
"""Initializes a multi-GPU-friendly python command line logger that logs on all processes
|
17 |
+
with their rank prefixed in the log message.
|
18 |
+
|
19 |
+
:param name: The name of the logger. Default is ``__name__``.
|
20 |
+
:param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`.
|
21 |
+
:param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`.
|
22 |
+
"""
|
23 |
+
logger = logging.getLogger(name)
|
24 |
+
super().__init__(logger=logger, extra=extra)
|
25 |
+
self.rank_zero_only = rank_zero_only
|
26 |
+
|
27 |
+
def log(
|
28 |
+
self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs
|
29 |
+
) -> None:
|
30 |
+
"""Delegate a log call to the underlying logger, after prefixing its message with the rank
|
31 |
+
of the process it's being logged from. If `'rank'` is provided, then the log will only
|
32 |
+
occur on that rank/process.
|
33 |
+
|
34 |
+
:param level: The level to log at. Look at `logging.__init__.py` for more information.
|
35 |
+
:param msg: The message to log.
|
36 |
+
:param rank: The rank to log at.
|
37 |
+
:param args: Additional args to pass to the underlying logging function.
|
38 |
+
:param kwargs: Any additional keyword args to pass to the underlying logging function.
|
39 |
+
"""
|
40 |
+
if self.isEnabledFor(level):
|
41 |
+
msg, kwargs = self.process(msg, kwargs)
|
42 |
+
current_rank = getattr(rank_zero_only, "rank", None)
|
43 |
+
if current_rank is None:
|
44 |
+
raise RuntimeError(
|
45 |
+
"The `rank_zero_only.rank` needs to be set before use"
|
46 |
+
)
|
47 |
+
msg = rank_prefixed_message(msg, current_rank)
|
48 |
+
if self.rank_zero_only:
|
49 |
+
if current_rank == 0:
|
50 |
+
self.logger.log(level, msg, *args, **kwargs)
|
51 |
+
else:
|
52 |
+
if rank is None or current_rank == rank:
|
53 |
+
self.logger.log(level, msg, *args, **kwargs)
|
src/utils/rich_utils.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from typing import Sequence
|
3 |
+
|
4 |
+
import rich
|
5 |
+
import rich.syntax
|
6 |
+
import rich.tree
|
7 |
+
from hydra.core.hydra_config import HydraConfig
|
8 |
+
from lightning_utilities.core.rank_zero import rank_zero_only
|
9 |
+
from omegaconf import DictConfig, OmegaConf, open_dict
|
10 |
+
from rich.prompt import Prompt
|
11 |
+
|
12 |
+
from src.utils import pylogger
|
13 |
+
|
14 |
+
log = pylogger.RankedLogger(__name__, rank_zero_only=True)
|
15 |
+
|
16 |
+
|
17 |
+
@rank_zero_only
|
18 |
+
def print_config_tree(
|
19 |
+
cfg: DictConfig,
|
20 |
+
print_order: Sequence[str] = (
|
21 |
+
"data",
|
22 |
+
"model",
|
23 |
+
"logger",
|
24 |
+
"paths",
|
25 |
+
"extras",
|
26 |
+
),
|
27 |
+
resolve: bool = False,
|
28 |
+
save_to_file: bool = False,
|
29 |
+
) -> None:
|
30 |
+
"""Prints the contents of a DictConfig as a tree structure using the Rich library.
|
31 |
+
|
32 |
+
:param cfg: A DictConfig composed by Hydra.
|
33 |
+
:param print_order: Determines in what order config components are printed. Default is ``("data", "model", "logger", "paths", "extras")``.
|
34 |
+
:param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``.
|
35 |
+
:param save_to_file: Whether to export config to the hydra output folder. Default is ``False``.
|
36 |
+
"""
|
37 |
+
style = "dim"
|
38 |
+
tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
|
39 |
+
|
40 |
+
queue = []
|
41 |
+
|
42 |
+
# add fields from `print_order` to queue
|
43 |
+
for field in print_order:
|
44 |
+
queue.append(field) if field in cfg else log.warning(
|
45 |
+
f"Field '{field}' not found in config. Skipping '{field}' config printing..."
|
46 |
+
)
|
47 |
+
|
48 |
+
# add all the other fields to queue (not specified in `print_order`)
|
49 |
+
for field in cfg:
|
50 |
+
if field not in queue:
|
51 |
+
queue.append(field)
|
52 |
+
|
53 |
+
# generate config tree from queue
|
54 |
+
for field in queue:
|
55 |
+
branch = tree.add(field, style=style, guide_style=style)
|
56 |
+
|
57 |
+
config_group = cfg[field]
|
58 |
+
if isinstance(config_group, DictConfig):
|
59 |
+
branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
|
60 |
+
else:
|
61 |
+
branch_content = str(config_group)
|
62 |
+
|
63 |
+
branch.add(rich.syntax.Syntax(branch_content, "yaml"))
|
64 |
+
|
65 |
+
# print config tree
|
66 |
+
rich.print(tree)
|
67 |
+
|
68 |
+
# save config tree to file
|
69 |
+
if save_to_file:
|
70 |
+
with open(
|
71 |
+
Path(cfg.paths.output_dir, f"config_{cfg.run_type}.log"), "w"
|
72 |
+
) as file:
|
73 |
+
rich.print(tree, file=file)
|
74 |
+
|
75 |
+
|
76 |
+
@rank_zero_only
|
77 |
+
def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
|
78 |
+
"""Prompts user to input tags from command line if no tags are provided in config.
|
79 |
+
|
80 |
+
:param cfg: A DictConfig composed by Hydra.
|
81 |
+
:param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``.
|
82 |
+
"""
|
83 |
+
if not cfg.get("tags"):
|
84 |
+
if "id" in HydraConfig().cfg.hydra.job:
|
85 |
+
raise ValueError("Specify tags before launching a multirun!")
|
86 |
+
|
87 |
+
log.warning("No tags provided in config. Prompting user to input tags...")
|
88 |
+
tags = Prompt.ask("Enter a list of comma separated tags", default="dev")
|
89 |
+
tags = [t.strip() for t in tags.split(",") if t != ""]
|
90 |
+
|
91 |
+
with open_dict(cfg):
|
92 |
+
cfg.tags = tags
|
93 |
+
|
94 |
+
log.info(f"Tags: {cfg.tags}")
|
95 |
+
|
96 |
+
if save_to_file:
|
97 |
+
with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file:
|
98 |
+
rich.print(cfg.tags, file=file)
|
src/utils/utils.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
from importlib.util import find_spec
|
3 |
+
from typing import Any, Callable, Dict, Tuple
|
4 |
+
|
5 |
+
from omegaconf import DictConfig
|
6 |
+
|
7 |
+
from src.utils import pylogger, rich_utils
|
8 |
+
|
9 |
+
log = pylogger.RankedLogger(__name__, rank_zero_only=True)
|
10 |
+
|
11 |
+
|
12 |
+
def extras(cfg: DictConfig) -> None:
|
13 |
+
"""Applies optional utilities before the task is started.
|
14 |
+
|
15 |
+
Utilities:
|
16 |
+
- Ignoring python warnings
|
17 |
+
- Setting tags from command line
|
18 |
+
- Rich config printing
|
19 |
+
|
20 |
+
:param cfg: A DictConfig object containing the config tree.
|
21 |
+
"""
|
22 |
+
# return if no `extras` config
|
23 |
+
if not cfg.get("extras"):
|
24 |
+
log.warning("Extras config not found! <cfg.extras=null>")
|
25 |
+
return
|
26 |
+
|
27 |
+
# disable python warnings
|
28 |
+
if cfg.extras.get("ignore_warnings"):
|
29 |
+
log.info("Disabling python warnings! <cfg.extras.ignore_warnings=True>")
|
30 |
+
warnings.filterwarnings("ignore")
|
31 |
+
|
32 |
+
# prompt user to input tags from command line if none are provided in the config
|
33 |
+
if cfg.extras.get("enforce_tags"):
|
34 |
+
log.info("Enforcing tags! <cfg.extras.enforce_tags=True>")
|
35 |
+
rich_utils.enforce_tags(cfg, save_to_file=True)
|
36 |
+
|
37 |
+
# pretty print config tree using Rich library
|
38 |
+
if cfg.extras.get("print_config"):
|
39 |
+
log.info("Printing config tree with Rich! <cfg.extras.print_config=True>")
|
40 |
+
rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True)
|
41 |
+
|
42 |
+
|
43 |
+
def task_wrapper(task_func: Callable) -> Callable:
|
44 |
+
"""Optional decorator that controls the failure behavior when executing the task function.
|
45 |
+
|
46 |
+
This wrapper can be used to:
|
47 |
+
- make sure loggers are closed even if the task function raises an exception (prevents multirun failure)
|
48 |
+
- save the exception to a `.log` file
|
49 |
+
- mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later)
|
50 |
+
- etc. (adjust depending on your needs)
|
51 |
+
|
52 |
+
Example:
|
53 |
+
```
|
54 |
+
@utils.task_wrapper
|
55 |
+
def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
56 |
+
...
|
57 |
+
return metric_dict, object_dict
|
58 |
+
```
|
59 |
+
|
60 |
+
:param task_func: The task function to be wrapped.
|
61 |
+
|
62 |
+
:return: The wrapped task function.
|
63 |
+
"""
|
64 |
+
|
65 |
+
def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
66 |
+
# execute the task
|
67 |
+
try:
|
68 |
+
task_func(cfg=cfg)
|
69 |
+
|
70 |
+
# things to do if exception occurs
|
71 |
+
except Exception as ex:
|
72 |
+
# save exception to `.log` file
|
73 |
+
log.exception("")
|
74 |
+
|
75 |
+
# some hyperparameter combinations might be invalid or cause out-of-memory errors
|
76 |
+
# so when using hparam search plugins like Optuna, you might want to disable
|
77 |
+
# raising the below exception to avoid multirun failure
|
78 |
+
raise ex
|
79 |
+
|
80 |
+
# things to always do after either success or exception
|
81 |
+
finally:
|
82 |
+
# display output dir path in terminal
|
83 |
+
log.info(f"Output dir: {cfg.paths.output_dir}")
|
84 |
+
|
85 |
+
# always close wandb run (even if exception occurs so multirun won't fail)
|
86 |
+
if find_spec("wandb"): # check if wandb is installed
|
87 |
+
import wandb
|
88 |
+
|
89 |
+
if wandb.run:
|
90 |
+
log.info("Closing wandb!")
|
91 |
+
wandb.finish()
|
92 |
+
|
93 |
+
return wrap
|