Spaces:
Paused
Paused
import os | |
import json | |
from typing import Union | |
import oyaml as yaml | |
import re | |
from collections import OrderedDict | |
from toolkit.paths import TOOLKIT_ROOT | |
possible_extensions = ['.json', '.jsonc', '.yaml', '.yml'] | |
def get_cwd_abs_path(path): | |
if not os.path.isabs(path): | |
path = os.path.join(os.getcwd(), path) | |
return path | |
def replace_env_vars_in_string(s: str) -> str: | |
""" | |
Replace placeholders like ${VAR_NAME} with the value of the corresponding environment variable. | |
If the environment variable is not set, raise an error. | |
""" | |
def replacer(match): | |
var_name = match.group(1) | |
value = os.environ.get(var_name) | |
if value is None: | |
raise ValueError(f"Environment variable {var_name} not set. Please ensure it's defined before proceeding.") | |
return value | |
return re.sub(r'\$\{([^}]+)\}', replacer, s) | |
def preprocess_config(config: OrderedDict, name: str = None): | |
if "job" not in config: | |
raise ValueError("config file must have a job key") | |
if "config" not in config: | |
raise ValueError("config file must have a config section") | |
if "name" not in config["config"] and name is None: | |
raise ValueError("config file must have a config.name key") | |
# we need to replace tags. For now just [name] | |
if name is None: | |
name = config["config"]["name"] | |
config_string = json.dumps(config) | |
config_string = config_string.replace("[name]", name) | |
config = json.loads(config_string, object_pairs_hook=OrderedDict) | |
return config | |
# Fixes issue where yaml doesnt load exponents correctly | |
fixed_loader = yaml.SafeLoader | |
fixed_loader.add_implicit_resolver( | |
u'tag:yaml.org,2002:float', | |
re.compile(u'''^(?: | |
[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)? | |
|[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+) | |
|\\.[0-9_]+(?:[eE][-+][0-9]+)? | |
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]* | |
|[-+]?\\.(?:inf|Inf|INF) | |
|\\.(?:nan|NaN|NAN))$''', re.X), | |
list(u'-+0123456789.')) | |
def get_config( | |
config_file_path_or_dict: Union[str, dict, OrderedDict], | |
name=None | |
): | |
# if we got a dict, process it and return it | |
if isinstance(config_file_path_or_dict, dict) or isinstance(config_file_path_or_dict, OrderedDict): | |
config = config_file_path_or_dict | |
return preprocess_config(config, name) | |
config_file_path = config_file_path_or_dict | |
# first check if it is in the config folder | |
config_path = os.path.join(TOOLKIT_ROOT, 'config', config_file_path) | |
# see if it is in the config folder with any of the possible extensions if it doesnt have one | |
real_config_path = None | |
if not os.path.exists(config_path): | |
for ext in possible_extensions: | |
if os.path.exists(config_path + ext): | |
real_config_path = config_path + ext | |
break | |
# if we didn't find it there, check if it is a full path | |
if not real_config_path: | |
if os.path.exists(config_file_path): | |
real_config_path = config_file_path | |
elif os.path.exists(get_cwd_abs_path(config_file_path)): | |
real_config_path = get_cwd_abs_path(config_file_path) | |
if not real_config_path: | |
raise ValueError(f"Could not find config file {config_file_path}") | |
# if we found it, check if it is a json or yaml file | |
with open(real_config_path, 'r', encoding='utf-8') as f: | |
content = f.read() | |
content_with_env_replaced = replace_env_vars_in_string(content) | |
if real_config_path.endswith('.json') or real_config_path.endswith('.jsonc'): | |
config = json.loads(content_with_env_replaced, object_pairs_hook=OrderedDict) | |
elif real_config_path.endswith('.yaml') or real_config_path.endswith('.yml'): | |
config = yaml.load(content_with_env_replaced, Loader=fixed_loader) | |
else: | |
raise ValueError(f"Config file {config_file_path} must be a json or yaml file") | |
return preprocess_config(config, name) | |