Spaces:
Running
Running
from vertexai.generative_models import GenerativeModel | |
from dotenv import load_dotenv | |
from anthropic import AnthropicVertex | |
import os | |
from openai import OpenAI | |
from src.text_generation.vertexai_setup import initialize_vertexai_params, get_default_config | |
from huggingface_hub import InferenceClient | |
# Load environment variables | |
load_dotenv() | |
OAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
def _validate_tokens(max_tokens: int) -> int: | |
""" | |
Validates the max_tokens parameter. Ensures it's within a valid range (1 to 8192). | |
If invalid, defaults to 8192. | |
""" | |
if 1 <= max_tokens <= 8192: | |
return max_tokens | |
return 8192 | |
def _validate_temperature(temp: float) -> float: | |
""" | |
Validates the temperature parameter. Ensures it's within a valid range (0 to 1). | |
If invalid, defaults to 0.49. | |
""" | |
if 0 <= temp <= 1: | |
return temp | |
return 0.49 | |
class LLMBaseClass: | |
""" | |
Base class for text generation. Users provide the HF model ID or other model identifiers | |
and can call the generate method to get responses. | |
""" | |
def __init__(self, model_id: str, max_tokens: int, temp: float) -> None: | |
self.model_id = model_id | |
self.api_key = None | |
self.temp = _validate_temperature(temp) | |
self.tokens = _validate_tokens(max_tokens) | |
self.model = self._initialize_model() | |
def _initialize_model(self): | |
""" | |
Initialize the model based on the provided model ID. | |
""" | |
if self.model_id == "gpt-4o-mini": | |
return self._initialize_openai_model() | |
elif self.model_id == "claude-3-5-sonnet@20240620": | |
return self._initialize_claude_model() | |
elif self.model_id in ["claude-3-5-sonnet@20240620", | |
"gemini-1.0-pro", "gemini-1.5-flash-001", "gemini-1.5-pro-001"]: | |
return self._initialize_vertexai_model() | |
else: | |
return self._initialize_hf_model() | |
def _initialize_openai_model(self): | |
""" | |
Initialize OpenAI model. | |
""" | |
self.api_key = OAI_API_KEY | |
return OpenAI(api_key=self.api_key) | |
def _initialize_claude_model(self): | |
""" | |
Initialize Claude model using Anthropic via Vertex AI. | |
""" | |
self.api_key = os.getenv("VERTEXAI_PROJECTID") | |
return AnthropicVertex(region="europe-west1", project_id=self.api_key) | |
def _initialize_vertexai_model(self): | |
""" | |
Initialize Google Gemini model using Vertex AI. | |
""" | |
default_gen_config, default_safe_settings = get_default_config() | |
gen_config = { | |
"temperature": self.temp, | |
"max_output_tokens": self.tokens, | |
} | |
return GenerativeModel(self.model_id, | |
generation_config=default_gen_config if gen_config is None else gen_config, | |
safety_settings=default_safe_settings) | |
def _initialize_hf_model(self): | |
self.api_key = os.getenv("HF_TOKEN") | |
return InferenceClient(token=self.api_key, model=self.model_id) | |
def generate(self, messages): | |
""" | |
Generate responses based on the model type and provided messages. | |
""" | |
if self.model_id == "gpt-4o-mini": | |
return self._generate_openai(messages) | |
elif self.model_id in ["claude-3-5-sonnet@20240620", | |
"gemini-1.0-pro", "gemini-1.5-flash-001", "gemini-1.5-pro-001"]: | |
return self._generate_vertexai(messages) | |
else: | |
return self._generate_hf(messages) | |
def _generate_openai(self, messages): | |
""" | |
Generate responses using OpenAI model. | |
""" | |
completion = self.model.chat.completions.create( | |
model=self.model_id, | |
messages=messages, | |
temperature=self.temp, | |
max_tokens=self.tokens, | |
) | |
return completion.choices[0].message.content | |
def _generate_vertexai(self, messages): | |
""" | |
Generate responses using Claude or Gemini models via Vertex AI. | |
""" | |
initialize_vertexai_params() | |
content = " ".join([message["content"] for message in messages]) | |
if "claude" in self.model_id: | |
message = self.model.messages.create( | |
max_tokens=self.tokens, | |
model=self.model_id, | |
messages=[{"role": "user", "content": content}], | |
) | |
return message.content[0].text | |
else: | |
response = self.model.generate_content(content) | |
return response.text | |
def _generate_hf(self, messages): | |
""" | |
Generate responses using Hugging Face models. | |
""" | |
content = " ".join([message["content"] for message in messages]) | |
response = self.model.chat_completion( | |
messages=[{"role": "user", "content": messages[0]["content"] + messages[1]["content"]}], | |
max_tokens=self.tokens, temperature=self.temp) | |
return response.choices[0].message.content | |