|
from copy import copy |
|
from typing import Any, Dict, List, Optional |
|
|
|
from hydra.utils import instantiate |
|
from omegaconf import DictConfig, OmegaConf |
|
|
|
from src.utils.logging_utils import get_pylogger |
|
|
|
logger = get_pylogger(__name__) |
|
|
|
|
|
def execute_pipeline( |
|
input: Any, |
|
setup: Optional[Any] = None, |
|
**processors, |
|
) -> Any: |
|
if setup is not None and callable(setup): |
|
setup() |
|
result = input |
|
for processor_name, processor_config in processors.items(): |
|
if not isinstance(processor_config, dict) or "_processor_" not in processor_config: |
|
continue |
|
logger.info(f"call processor: {processor_name}") |
|
config = copy(processor_config) |
|
if not config.pop("_enabled_", True): |
|
logger.warning(f"skip processor because it is disabled: {processor_name}") |
|
continue |
|
|
|
if "_target_" in config: |
|
raise ValueError( |
|
f"processor {processor_name} has a key '_target_', which is not allowed" |
|
) |
|
config["_target_"] = config.pop("_processor_") |
|
|
|
|
|
|
|
tmp_result = instantiate(config, result, _convert_="partial") |
|
if tmp_result is not None: |
|
result = tmp_result |
|
else: |
|
logger.warning(f'processor "{processor_name}" did not return a result') |
|
return result |
|
|
|
|
|
def instantiate_dict_entries( |
|
config: DictConfig, key: str, entry_description: Optional[str] = None |
|
) -> List: |
|
entries: List = [] |
|
key_config = config.get(key) |
|
|
|
if not key_config: |
|
logger.warning(f"{key} config is empty.") |
|
return entries |
|
|
|
if not isinstance(key_config, DictConfig): |
|
raise TypeError("Logger config must be a DictConfig!") |
|
|
|
for _, entry_conf in key_config.items(): |
|
if isinstance(entry_conf, DictConfig) and "_target_" in entry_conf: |
|
logger.info(f"Instantiating {entry_description or key} <{entry_conf._target_}>") |
|
entries.append(instantiate(entry_conf, _convert_="partial")) |
|
|
|
return entries |
|
|
|
|
|
def prepare_omegaconf(): |
|
|
|
if not OmegaConf.has_resolver("replace"): |
|
OmegaConf.register_new_resolver("replace", lambda s, x, y: s.replace(x, y)) |
|
else: |
|
logger.warning("OmegaConf resolver 'replace' is already registered") |
|
|
|
|
|
def parse_config(config_string: str, format: str) -> Dict[str, Any]: |
|
"""Parse a configuration string.""" |
|
if format == "json": |
|
import json |
|
|
|
return json.loads(config_string) |
|
elif format == "yaml": |
|
import yaml |
|
|
|
return yaml.safe_load(config_string) |
|
else: |
|
raise ValueError(f"Unsupported format: {format}. Use 'json' or 'yaml'.") |
|
|