maya-persistence / src /llm /llmFactory.py
anubhav77's picture
v0.1.1
37419af
raw
history blame
1.68 kB
import logging
from baseInfra.dbInterface import DbInterface
from llm.hostedLLM import HostedLLM
from llm.togetherLLM import TogetherLLM
from llm.palmLLM import PalmLLM
from llm.geminiLLM import GeminiLLM
class LLMFactory:
"""
Factory class for creating LLM objects.
"""
def __init__(self):
"""
Constructor for the LLMFactory class.
Args:
db_interface: The DBInterface object to use for getting LLM configs.
"""
self._db_interface = DbInterface()
def get_llm(self, llm_path: str) -> object:
"""
Gets an LLM object of the specified type.
Args:
llm_path: The path to the LLM config.
Returns:
The LLM object.
"""
logger = logging.getLogger(__name__)
try:
config = self._db_interface.get_config(llm_path)
logger.debug(llm_path)
logger.debug(config)
llm_type = config["llm_type"]
llm_config=config["llm_config"]
except Exception as ex:
logger.exception("Exception in getLLM")
logger.exception(ex)
config={}
llm_type=""
llm_config={}
if llm_type == "hostedLLM":
return HostedLLM(**llm_config)
elif llm_type == "togetherLLM":
return TogetherLLM(**llm_config)
elif llm_type == "palmLLM":
return PalmLLM(**llm_config)
elif llm_type == "geminiLLM":
return GeminiLLM(**llm_config)
else:
logger.error(f"Invalid LLM type: {llm_type}")
raise ValueError(f"Invalid LLM type: {llm_type}")