lucas-ventura commited on
Commit
6303c5d
·
verified ·
1 Parent(s): c4f6fa0

Upload 6 files

Browse files
src/utils/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from src.utils.logging_utils import log_hyperparameters
2
+ from src.utils.pylogger import RankedLogger
3
+ from src.utils.rich_utils import enforce_tags, print_config_tree
4
+ from src.utils.utils import extras, task_wrapper
src/utils/logging_utils.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict
2
+
3
+ from lightning_utilities.core.rank_zero import rank_zero_only
4
+ from omegaconf import OmegaConf
5
+
6
+ from src.utils import pylogger
7
+
8
+ log = pylogger.RankedLogger(__name__, rank_zero_only=True)
9
+
10
+
11
+ @rank_zero_only
12
+ def log_hyperparameters(object_dict: Dict[str, Any]) -> None:
13
+ """Controls which config parts are saved by Lightning loggers.
14
+
15
+ Additionally saves:
16
+ - Number of model parameters
17
+
18
+ :param object_dict: A dictionary containing the following objects:
19
+ - `"cfg"`: A DictConfig object containing the main config.
20
+ - `"model"`: The Lightning model.
21
+ - `"trainer"`: The Lightning trainer.
22
+ """
23
+ hparams = {}
24
+
25
+ cfg = OmegaConf.to_container(object_dict["cfg"])
26
+ model = object_dict["model"]
27
+ trainer = object_dict["trainer"]
28
+
29
+ if not trainer.logger:
30
+ log.warning("Logger not found! Skipping hyperparameter logging...")
31
+ return
32
+
33
+ hparams["model"] = cfg["model"]
34
+
35
+ # save number of model parameters
36
+ hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
37
+ hparams["model/params/trainable"] = sum(
38
+ p.numel() for p in model.parameters() if p.requires_grad
39
+ )
40
+ hparams["model/params/non_trainable"] = sum(
41
+ p.numel() for p in model.parameters() if not p.requires_grad
42
+ )
43
+
44
+ hparams["data"] = cfg["data"]
45
+
46
+ hparams["extras"] = cfg.get("extras")
47
+
48
+ hparams["task_name"] = cfg.get("task_name")
49
+ hparams["tags"] = cfg.get("tags")
50
+ # hparams["ckpt_path"] = cfg.get("ckpt_path")
51
+ hparams["seed"] = cfg.get("seed")
52
+
53
+ # send hparams to all loggers
54
+ for logger in trainer.loggers:
55
+ logger.log_hyperparams(hparams)
src/utils/metrics.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchmetrics import Metric
3
+
4
+
5
+ class PRFMetric(Metric):
6
+ def __init__(self, **kwargs):
7
+ super().__init__(**kwargs)
8
+ self.add_state("t_precision", default=torch.tensor(0.0), dist_reduce_fx="sum")
9
+ self.add_state("t_recall", default=torch.tensor(0.0), dist_reduce_fx="sum")
10
+ self.add_state("t_f1", default=torch.tensor(0.0), dist_reduce_fx="sum")
11
+ self.add_state("n", default=torch.tensor(0.0), dist_reduce_fx="sum")
12
+
13
+ def update(self, vid_p, vid_r) -> None:
14
+ self.t_precision += vid_p
15
+ self.t_recall += vid_r
16
+ self.t_f1 += 2 * (vid_p * vid_r) / (vid_p + vid_r) if vid_p + vid_r else 0.0
17
+ self.n += 1
18
+
19
+ def compute(self):
20
+ avg_p = self.t_precision * 100 / self.n
21
+ avg_r = self.t_recall * 100 / self.n
22
+ avg_f1 = self.t_f1 * 100 / self.n
23
+ return {"precision": avg_p, "recall": avg_r, "f1": avg_f1}
src/utils/pylogger.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Mapping, Optional
3
+
4
+ from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only
5
+
6
+
7
+ class RankedLogger(logging.LoggerAdapter):
8
+ """A multi-GPU-friendly python command line logger."""
9
+
10
+ def __init__(
11
+ self,
12
+ name: str = __name__,
13
+ rank_zero_only: bool = False,
14
+ extra: Optional[Mapping[str, object]] = None,
15
+ ) -> None:
16
+ """Initializes a multi-GPU-friendly python command line logger that logs on all processes
17
+ with their rank prefixed in the log message.
18
+
19
+ :param name: The name of the logger. Default is ``__name__``.
20
+ :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`.
21
+ :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`.
22
+ """
23
+ logger = logging.getLogger(name)
24
+ super().__init__(logger=logger, extra=extra)
25
+ self.rank_zero_only = rank_zero_only
26
+
27
+ def log(
28
+ self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs
29
+ ) -> None:
30
+ """Delegate a log call to the underlying logger, after prefixing its message with the rank
31
+ of the process it's being logged from. If `'rank'` is provided, then the log will only
32
+ occur on that rank/process.
33
+
34
+ :param level: The level to log at. Look at `logging.__init__.py` for more information.
35
+ :param msg: The message to log.
36
+ :param rank: The rank to log at.
37
+ :param args: Additional args to pass to the underlying logging function.
38
+ :param kwargs: Any additional keyword args to pass to the underlying logging function.
39
+ """
40
+ if self.isEnabledFor(level):
41
+ msg, kwargs = self.process(msg, kwargs)
42
+ current_rank = getattr(rank_zero_only, "rank", None)
43
+ if current_rank is None:
44
+ raise RuntimeError(
45
+ "The `rank_zero_only.rank` needs to be set before use"
46
+ )
47
+ msg = rank_prefixed_message(msg, current_rank)
48
+ if self.rank_zero_only:
49
+ if current_rank == 0:
50
+ self.logger.log(level, msg, *args, **kwargs)
51
+ else:
52
+ if rank is None or current_rank == rank:
53
+ self.logger.log(level, msg, *args, **kwargs)
src/utils/rich_utils.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Sequence
3
+
4
+ import rich
5
+ import rich.syntax
6
+ import rich.tree
7
+ from hydra.core.hydra_config import HydraConfig
8
+ from lightning_utilities.core.rank_zero import rank_zero_only
9
+ from omegaconf import DictConfig, OmegaConf, open_dict
10
+ from rich.prompt import Prompt
11
+
12
+ from src.utils import pylogger
13
+
14
+ log = pylogger.RankedLogger(__name__, rank_zero_only=True)
15
+
16
+
17
+ @rank_zero_only
18
+ def print_config_tree(
19
+ cfg: DictConfig,
20
+ print_order: Sequence[str] = (
21
+ "data",
22
+ "model",
23
+ "logger",
24
+ "paths",
25
+ "extras",
26
+ ),
27
+ resolve: bool = False,
28
+ save_to_file: bool = False,
29
+ ) -> None:
30
+ """Prints the contents of a DictConfig as a tree structure using the Rich library.
31
+
32
+ :param cfg: A DictConfig composed by Hydra.
33
+ :param print_order: Determines in what order config components are printed. Default is ``("data", "model", "logger", "paths", "extras")``.
34
+ :param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``.
35
+ :param save_to_file: Whether to export config to the hydra output folder. Default is ``False``.
36
+ """
37
+ style = "dim"
38
+ tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
39
+
40
+ queue = []
41
+
42
+ # add fields from `print_order` to queue
43
+ for field in print_order:
44
+ queue.append(field) if field in cfg else log.warning(
45
+ f"Field '{field}' not found in config. Skipping '{field}' config printing..."
46
+ )
47
+
48
+ # add all the other fields to queue (not specified in `print_order`)
49
+ for field in cfg:
50
+ if field not in queue:
51
+ queue.append(field)
52
+
53
+ # generate config tree from queue
54
+ for field in queue:
55
+ branch = tree.add(field, style=style, guide_style=style)
56
+
57
+ config_group = cfg[field]
58
+ if isinstance(config_group, DictConfig):
59
+ branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
60
+ else:
61
+ branch_content = str(config_group)
62
+
63
+ branch.add(rich.syntax.Syntax(branch_content, "yaml"))
64
+
65
+ # print config tree
66
+ rich.print(tree)
67
+
68
+ # save config tree to file
69
+ if save_to_file:
70
+ with open(
71
+ Path(cfg.paths.output_dir, f"config_{cfg.run_type}.log"), "w"
72
+ ) as file:
73
+ rich.print(tree, file=file)
74
+
75
+
76
+ @rank_zero_only
77
+ def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
78
+ """Prompts user to input tags from command line if no tags are provided in config.
79
+
80
+ :param cfg: A DictConfig composed by Hydra.
81
+ :param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``.
82
+ """
83
+ if not cfg.get("tags"):
84
+ if "id" in HydraConfig().cfg.hydra.job:
85
+ raise ValueError("Specify tags before launching a multirun!")
86
+
87
+ log.warning("No tags provided in config. Prompting user to input tags...")
88
+ tags = Prompt.ask("Enter a list of comma separated tags", default="dev")
89
+ tags = [t.strip() for t in tags.split(",") if t != ""]
90
+
91
+ with open_dict(cfg):
92
+ cfg.tags = tags
93
+
94
+ log.info(f"Tags: {cfg.tags}")
95
+
96
+ if save_to_file:
97
+ with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file:
98
+ rich.print(cfg.tags, file=file)
src/utils/utils.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from importlib.util import find_spec
3
+ from typing import Any, Callable, Dict, Tuple
4
+
5
+ from omegaconf import DictConfig
6
+
7
+ from src.utils import pylogger, rich_utils
8
+
9
+ log = pylogger.RankedLogger(__name__, rank_zero_only=True)
10
+
11
+
12
+ def extras(cfg: DictConfig) -> None:
13
+ """Applies optional utilities before the task is started.
14
+
15
+ Utilities:
16
+ - Ignoring python warnings
17
+ - Setting tags from command line
18
+ - Rich config printing
19
+
20
+ :param cfg: A DictConfig object containing the config tree.
21
+ """
22
+ # return if no `extras` config
23
+ if not cfg.get("extras"):
24
+ log.warning("Extras config not found! <cfg.extras=null>")
25
+ return
26
+
27
+ # disable python warnings
28
+ if cfg.extras.get("ignore_warnings"):
29
+ log.info("Disabling python warnings! <cfg.extras.ignore_warnings=True>")
30
+ warnings.filterwarnings("ignore")
31
+
32
+ # prompt user to input tags from command line if none are provided in the config
33
+ if cfg.extras.get("enforce_tags"):
34
+ log.info("Enforcing tags! <cfg.extras.enforce_tags=True>")
35
+ rich_utils.enforce_tags(cfg, save_to_file=True)
36
+
37
+ # pretty print config tree using Rich library
38
+ if cfg.extras.get("print_config"):
39
+ log.info("Printing config tree with Rich! <cfg.extras.print_config=True>")
40
+ rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True)
41
+
42
+
43
+ def task_wrapper(task_func: Callable) -> Callable:
44
+ """Optional decorator that controls the failure behavior when executing the task function.
45
+
46
+ This wrapper can be used to:
47
+ - make sure loggers are closed even if the task function raises an exception (prevents multirun failure)
48
+ - save the exception to a `.log` file
49
+ - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later)
50
+ - etc. (adjust depending on your needs)
51
+
52
+ Example:
53
+ ```
54
+ @utils.task_wrapper
55
+ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
56
+ ...
57
+ return metric_dict, object_dict
58
+ ```
59
+
60
+ :param task_func: The task function to be wrapped.
61
+
62
+ :return: The wrapped task function.
63
+ """
64
+
65
+ def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
66
+ # execute the task
67
+ try:
68
+ task_func(cfg=cfg)
69
+
70
+ # things to do if exception occurs
71
+ except Exception as ex:
72
+ # save exception to `.log` file
73
+ log.exception("")
74
+
75
+ # some hyperparameter combinations might be invalid or cause out-of-memory errors
76
+ # so when using hparam search plugins like Optuna, you might want to disable
77
+ # raising the below exception to avoid multirun failure
78
+ raise ex
79
+
80
+ # things to always do after either success or exception
81
+ finally:
82
+ # display output dir path in terminal
83
+ log.info(f"Output dir: {cfg.paths.output_dir}")
84
+
85
+ # always close wandb run (even if exception occurs so multirun won't fail)
86
+ if find_spec("wandb"): # check if wandb is installed
87
+ import wandb
88
+
89
+ if wandb.run:
90
+ log.info("Closing wandb!")
91
+ wandb.finish()
92
+
93
+ return wrap