Spaces:
Running
Running
File size: 9,779 Bytes
372531f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 |
import json
import os
import warnings
from typing import Dict, Any, List, Union, Type, get_origin, get_args
from .variables.default import DEFAULT_CONFIG
from .variables.base import BaseConfig
from ..retrievers.utils import get_all_retriever_names
class Config:
"""Config class for GPT Researcher."""
CONFIG_DIR = os.path.join(os.path.dirname(__file__), "variables")
def __init__(self, config_path: str | None = None):
"""Initialize the config class."""
self.config_path = config_path
self.llm_kwargs: Dict[str, Any] = {}
self.embedding_kwargs: Dict[str, Any] = {}
config_to_use = self.load_config(config_path)
self._set_attributes(config_to_use)
self._set_embedding_attributes()
self._set_llm_attributes()
self._handle_deprecated_attributes()
self._set_doc_path(config_to_use)
def _set_attributes(self, config: Dict[str, Any]) -> None:
for key, value in config.items():
env_value = os.getenv(key)
if env_value is not None:
value = self.convert_env_value(key, env_value, BaseConfig.__annotations__[key])
setattr(self, key.lower(), value)
# Handle RETRIEVER with default value
retriever_env = os.environ.get("RETRIEVER", config.get("RETRIEVER", "tavily"))
try:
self.retrievers = self.parse_retrievers(retriever_env)
except ValueError as e:
print(f"Warning: {str(e)}. Defaulting to 'tavily' retriever.")
self.retrievers = ["tavily"]
def _set_embedding_attributes(self) -> None:
self.embedding_provider, self.embedding_model = self.parse_embedding(
self.embedding
)
def _set_llm_attributes(self) -> None:
self.fast_llm_provider, self.fast_llm_model = self.parse_llm(self.fast_llm)
self.smart_llm_provider, self.smart_llm_model = self.parse_llm(self.smart_llm)
self.strategic_llm_provider, self.strategic_llm_model = self.parse_llm(self.strategic_llm)
def _handle_deprecated_attributes(self) -> None:
if os.getenv("EMBEDDING_PROVIDER") is not None:
warnings.warn(
"EMBEDDING_PROVIDER is deprecated and will be removed soon. Use EMBEDDING instead.",
FutureWarning,
stacklevel=2,
)
self.embedding_provider = (
os.environ["EMBEDDING_PROVIDER"] or self.embedding_provider
)
match os.environ["EMBEDDING_PROVIDER"]:
case "ollama":
self.embedding_model = os.environ["OLLAMA_EMBEDDING_MODEL"]
case "custom":
self.embedding_model = os.getenv("OPENAI_EMBEDDING_MODEL", "custom")
case "openai":
self.embedding_model = "text-embedding-3-large"
case "azure_openai":
self.embedding_model = "text-embedding-3-large"
case "huggingface":
self.embedding_model = "sentence-transformers/all-MiniLM-L6-v2"
case _:
raise Exception("Embedding provider not found.")
_deprecation_warning = (
"LLM_PROVIDER, FAST_LLM_MODEL and SMART_LLM_MODEL are deprecated and "
"will be removed soon. Use FAST_LLM and SMART_LLM instead."
)
if os.getenv("LLM_PROVIDER") is not None:
warnings.warn(_deprecation_warning, FutureWarning, stacklevel=2)
self.fast_llm_provider = (
os.environ["LLM_PROVIDER"] or self.fast_llm_provider
)
self.smart_llm_provider = (
os.environ["LLM_PROVIDER"] or self.smart_llm_provider
)
if os.getenv("FAST_LLM_MODEL") is not None:
warnings.warn(_deprecation_warning, FutureWarning, stacklevel=2)
self.fast_llm_model = os.environ["FAST_LLM_MODEL"] or self.fast_llm_model
if os.getenv("SMART_LLM_MODEL") is not None:
warnings.warn(_deprecation_warning, FutureWarning, stacklevel=2)
self.smart_llm_model = os.environ["SMART_LLM_MODEL"] or self.smart_llm_model
def _set_doc_path(self, config: Dict[str, Any]) -> None:
self.doc_path = config['DOC_PATH']
if self.doc_path:
try:
self.validate_doc_path()
except Exception as e:
print(f"Warning: Error validating doc_path: {str(e)}. Using default doc_path.")
self.doc_path = DEFAULT_CONFIG['DOC_PATH']
@classmethod
def load_config(cls, config_path: str | None) -> Dict[str, Any]:
"""Load a configuration by name."""
if config_path is None:
return DEFAULT_CONFIG
# config_path = os.path.join(cls.CONFIG_DIR, config_path)
if not os.path.exists(config_path):
if config_path and config_path != "default":
print(f"Warning: Configuration not found at '{config_path}'. Using default configuration.")
if not config_path.endswith(".json"):
print(f"Do you mean '{config_path}.json'?")
return DEFAULT_CONFIG
with open(config_path, "r") as f:
custom_config = json.load(f)
# Merge with default config to ensure all keys are present
merged_config = DEFAULT_CONFIG.copy()
merged_config.update(custom_config)
return merged_config
@classmethod
def list_available_configs(cls) -> List[str]:
"""List all available configuration names."""
configs = ["default"]
for file in os.listdir(cls.CONFIG_DIR):
if file.endswith(".json"):
configs.append(file[:-5]) # Remove .json extension
return configs
def parse_retrievers(self, retriever_str: str) -> List[str]:
"""Parse the retriever string into a list of retrievers and validate them."""
retrievers = [retriever.strip()
for retriever in retriever_str.split(",")]
valid_retrievers = get_all_retriever_names() or []
invalid_retrievers = [r for r in retrievers if r not in valid_retrievers]
if invalid_retrievers:
raise ValueError(
f"Invalid retriever(s) found: {', '.join(invalid_retrievers)}. "
f"Valid options are: {', '.join(valid_retrievers)}."
)
return retrievers
@staticmethod
def parse_llm(llm_str: str | None) -> tuple[str | None, str | None]:
"""Parse llm string into (llm_provider, llm_model)."""
from gpt_researcher.llm_provider.generic.base import _SUPPORTED_PROVIDERS
if llm_str is None:
return None, None
try:
llm_provider, llm_model = llm_str.split(":", 1)
assert llm_provider in _SUPPORTED_PROVIDERS, (
f"Unsupported {llm_provider}.\nSupported llm providers are: "
+ ", ".join(_SUPPORTED_PROVIDERS)
)
return llm_provider, llm_model
except ValueError:
raise ValueError(
"Set SMART_LLM or FAST_LLM = '<llm_provider>:<llm_model>' "
"Eg 'openai:gpt-4o-mini'"
)
@staticmethod
def parse_embedding(embedding_str: str | None) -> tuple[str | None, str | None]:
"""Parse embedding string into (embedding_provider, embedding_model)."""
from gpt_researcher.memory.embeddings import _SUPPORTED_PROVIDERS
if embedding_str is None:
return None, None
try:
embedding_provider, embedding_model = embedding_str.split(":", 1)
assert embedding_provider in _SUPPORTED_PROVIDERS, (
f"Unsupported {embedding_provider}.\nSupported embedding providers are: "
+ ", ".join(_SUPPORTED_PROVIDERS)
)
return embedding_provider, embedding_model
except ValueError:
raise ValueError(
"Set EMBEDDING = '<embedding_provider>:<embedding_model>' "
"Eg 'openai:text-embedding-3-large'"
)
def validate_doc_path(self):
"""Ensure that the folder exists at the doc path"""
os.makedirs(self.doc_path, exist_ok=True)
@staticmethod
def convert_env_value(key: str, env_value: str, type_hint: Type) -> Any:
"""Convert environment variable to the appropriate type based on the type hint."""
origin = get_origin(type_hint)
args = get_args(type_hint)
if origin is Union:
# Handle Union types (e.g., Union[str, None])
for arg in args:
if arg is type(None):
if env_value.lower() in ("none", "null", ""):
return None
else:
try:
return Config.convert_env_value(key, env_value, arg)
except ValueError:
continue
raise ValueError(f"Cannot convert {env_value} to any of {args}")
if type_hint is bool:
return env_value.lower() in ("true", "1", "yes", "on")
elif type_hint is int:
return int(env_value)
elif type_hint is float:
return float(env_value)
elif type_hint in (str, Any):
return env_value
elif origin is list or origin is List:
return json.loads(env_value)
else:
raise ValueError(f"Unsupported type {type_hint} for key {key}")
|