ArneBinder's picture
update from https://github.com/ArneBinder/pie-document-level/pull/397
ced4316 verified
from copy import copy
from typing import Any, Dict, List, Optional
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf
from src.utils.logging_utils import get_pylogger
logger = get_pylogger(__name__)
def execute_pipeline(
input: Any,
setup: Optional[Any] = None,
**processors,
) -> Any:
if setup is not None and callable(setup):
setup()
result = input
for processor_name, processor_config in processors.items():
if not isinstance(processor_config, dict) or "_processor_" not in processor_config:
continue
logger.info(f"call processor: {processor_name}")
config = copy(processor_config)
if not config.pop("_enabled_", True):
logger.warning(f"skip processor because it is disabled: {processor_name}")
continue
# rename key "_processor_" to "_target_"
if "_target_" in config:
raise ValueError(
f"processor {processor_name} has a key '_target_', which is not allowed"
)
config["_target_"] = config.pop("_processor_")
# IMPORTANT: We pass result as the first argument after the config in contrast to adding it to the config.
# By doing so, we prevent that it gets converted into a OmegaConf object which would be converted back to
# a simple dict breaking all the DatasetDict methods
tmp_result = instantiate(config, result, _convert_="partial")
if tmp_result is not None:
result = tmp_result
else:
logger.warning(f'processor "{processor_name}" did not return a result')
return result
def instantiate_dict_entries(
config: DictConfig, key: str, entry_description: Optional[str] = None
) -> List:
entries: List = []
key_config = config.get(key)
if not key_config:
logger.warning(f"{key} config is empty.")
return entries
if not isinstance(key_config, DictConfig):
raise TypeError("Logger config must be a DictConfig!")
for _, entry_conf in key_config.items():
if isinstance(entry_conf, DictConfig) and "_target_" in entry_conf:
logger.info(f"Instantiating {entry_description or key} <{entry_conf._target_}>")
entries.append(instantiate(entry_conf, _convert_="partial"))
return entries
def prepare_omegaconf():
# register replace resolver (used to replace "/" with "-" in names to use them as e.g. wandb project names)
if not OmegaConf.has_resolver("replace"):
OmegaConf.register_new_resolver("replace", lambda s, x, y: s.replace(x, y))
else:
logger.warning("OmegaConf resolver 'replace' is already registered")
def parse_config(config_string: str, format: str) -> Dict[str, Any]:
"""Parse a configuration string."""
if format == "json":
import json
return json.loads(config_string)
elif format == "yaml":
import yaml
return yaml.safe_load(config_string)
else:
raise ValueError(f"Unsupported format: {format}. Use 'json' or 'yaml'.")