Spaces:
Running
Running
File size: 5,042 Bytes
4b722ec 8842640 4b722ec 8842640 4b722ec 8842640 4b722ec 8842640 4b722ec 8842640 4b722ec 8842640 adbebe0 8842640 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
from vertexai.generative_models import GenerativeModel
from dotenv import load_dotenv
from anthropic import AnthropicVertex
import os
from openai import OpenAI
from src.text_generation.vertexai_setup import initialize_vertexai_params, get_default_config
from huggingface_hub import InferenceClient
# Load environment variables
load_dotenv()
OAI_API_KEY = os.getenv("OPENAI_API_KEY")
def _validate_tokens(max_tokens: int) -> int:
"""
Validates the max_tokens parameter. Ensures it's within a valid range (1 to 8192).
If invalid, defaults to 8192.
"""
if 1 <= max_tokens <= 8192:
return max_tokens
return 8192
def _validate_temperature(temp: float) -> float:
"""
Validates the temperature parameter. Ensures it's within a valid range (0 to 1).
If invalid, defaults to 0.49.
"""
if 0 <= temp <= 1:
return temp
return 0.49
class LLMBaseClass:
"""
Base class for text generation. Users provide the HF model ID or other model identifiers
and can call the generate method to get responses.
"""
def __init__(self, model_id: str, max_tokens: int, temp: float) -> None:
self.model_id = model_id
self.api_key = None
self.temp = _validate_temperature(temp)
self.tokens = _validate_tokens(max_tokens)
self.model = self._initialize_model()
def _initialize_model(self):
"""
Initialize the model based on the provided model ID.
"""
if self.model_id == "gpt-4o-mini":
return self._initialize_openai_model()
elif self.model_id == "claude-3-5-sonnet@20240620":
return self._initialize_claude_model()
elif self.model_id in ["claude-3-5-sonnet@20240620",
"gemini-1.0-pro", "gemini-1.5-flash-001", "gemini-1.5-pro-001"]:
return self._initialize_vertexai_model()
else:
return self._initialize_hf_model()
def _initialize_openai_model(self):
"""
Initialize OpenAI model.
"""
self.api_key = OAI_API_KEY
return OpenAI(api_key=self.api_key)
def _initialize_claude_model(self):
"""
Initialize Claude model using Anthropic via Vertex AI.
"""
self.api_key = os.getenv("VERTEXAI_PROJECTID")
return AnthropicVertex(region="europe-west1", project_id=self.api_key)
def _initialize_vertexai_model(self):
"""
Initialize Google Gemini model using Vertex AI.
"""
default_gen_config, default_safe_settings = get_default_config()
gen_config = {
"temperature": self.temp,
"max_output_tokens": self.tokens,
}
return GenerativeModel(self.model_id,
generation_config=default_gen_config if gen_config is None else gen_config,
safety_settings=default_safe_settings)
def _initialize_hf_model(self):
self.api_key = os.getenv("HF_TOKEN")
return InferenceClient(token=self.api_key, model=self.model_id)
def generate(self, messages):
"""
Generate responses based on the model type and provided messages.
"""
if self.model_id == "gpt-4o-mini":
return self._generate_openai(messages)
elif self.model_id in ["claude-3-5-sonnet@20240620",
"gemini-1.0-pro", "gemini-1.5-flash-001", "gemini-1.5-pro-001"]:
return self._generate_vertexai(messages)
else:
return self._generate_hf(messages)
def _generate_openai(self, messages):
"""
Generate responses using OpenAI model.
"""
completion = self.model.chat.completions.create(
model=self.model_id,
messages=messages,
temperature=self.temp,
max_tokens=self.tokens,
)
return completion.choices[0].message.content
def _generate_vertexai(self, messages):
"""
Generate responses using Claude or Gemini models via Vertex AI.
"""
initialize_vertexai_params()
content = " ".join([message["content"] for message in messages])
if "claude" in self.model_id:
message = self.model.messages.create(
max_tokens=self.tokens,
model=self.model_id,
messages=[{"role": "user", "content": content}],
)
return message.content[0].text
else:
response = self.model.generate_content(content)
return response.text
def _generate_hf(self, messages):
"""
Generate responses using Hugging Face models.
"""
content = " ".join([message["content"] for message in messages])
response = self.model.chat_completion(
messages=[{"role": "user", "content": messages[0]["content"] + messages[1]["content"]}],
max_tokens=self.tokens, temperature=self.temp)
return response.choices[0].message.content
|