Ashmi Banerjee
update sustainability prompt and post processing
adbebe0
from vertexai.generative_models import GenerativeModel
from dotenv import load_dotenv
from anthropic import AnthropicVertex
import os
from openai import OpenAI
from src.text_generation.vertexai_setup import initialize_vertexai_params, get_default_config
from huggingface_hub import InferenceClient
# Load environment variables
load_dotenv()
OAI_API_KEY = os.getenv("OPENAI_API_KEY")
def _validate_tokens(max_tokens: int) -> int:
"""
Validates the max_tokens parameter. Ensures it's within a valid range (1 to 8192).
If invalid, defaults to 8192.
"""
if 1 <= max_tokens <= 8192:
return max_tokens
return 8192
def _validate_temperature(temp: float) -> float:
"""
Validates the temperature parameter. Ensures it's within a valid range (0 to 1).
If invalid, defaults to 0.49.
"""
if 0 <= temp <= 1:
return temp
return 0.49
class LLMBaseClass:
"""
Base class for text generation. Users provide the HF model ID or other model identifiers
and can call the generate method to get responses.
"""
def __init__(self, model_id: str, max_tokens: int, temp: float) -> None:
self.model_id = model_id
self.api_key = None
self.temp = _validate_temperature(temp)
self.tokens = _validate_tokens(max_tokens)
self.model = self._initialize_model()
def _initialize_model(self):
"""
Initialize the model based on the provided model ID.
"""
if self.model_id == "gpt-4o-mini":
return self._initialize_openai_model()
elif self.model_id == "claude-3-5-sonnet@20240620":
return self._initialize_claude_model()
elif self.model_id in ["claude-3-5-sonnet@20240620",
"gemini-1.0-pro", "gemini-1.5-flash-001", "gemini-1.5-pro-001"]:
return self._initialize_vertexai_model()
else:
return self._initialize_hf_model()
def _initialize_openai_model(self):
"""
Initialize OpenAI model.
"""
self.api_key = OAI_API_KEY
return OpenAI(api_key=self.api_key)
def _initialize_claude_model(self):
"""
Initialize Claude model using Anthropic via Vertex AI.
"""
self.api_key = os.getenv("VERTEXAI_PROJECTID")
return AnthropicVertex(region="europe-west1", project_id=self.api_key)
def _initialize_vertexai_model(self):
"""
Initialize Google Gemini model using Vertex AI.
"""
default_gen_config, default_safe_settings = get_default_config()
gen_config = {
"temperature": self.temp,
"max_output_tokens": self.tokens,
}
return GenerativeModel(self.model_id,
generation_config=default_gen_config if gen_config is None else gen_config,
safety_settings=default_safe_settings)
def _initialize_hf_model(self):
self.api_key = os.getenv("HF_TOKEN")
return InferenceClient(token=self.api_key, model=self.model_id)
def generate(self, messages):
"""
Generate responses based on the model type and provided messages.
"""
if self.model_id == "gpt-4o-mini":
return self._generate_openai(messages)
elif self.model_id in ["claude-3-5-sonnet@20240620",
"gemini-1.0-pro", "gemini-1.5-flash-001", "gemini-1.5-pro-001"]:
return self._generate_vertexai(messages)
else:
return self._generate_hf(messages)
def _generate_openai(self, messages):
"""
Generate responses using OpenAI model.
"""
completion = self.model.chat.completions.create(
model=self.model_id,
messages=messages,
temperature=self.temp,
max_tokens=self.tokens,
)
return completion.choices[0].message.content
def _generate_vertexai(self, messages):
"""
Generate responses using Claude or Gemini models via Vertex AI.
"""
initialize_vertexai_params()
content = " ".join([message["content"] for message in messages])
if "claude" in self.model_id:
message = self.model.messages.create(
max_tokens=self.tokens,
model=self.model_id,
messages=[{"role": "user", "content": content}],
)
return message.content[0].text
else:
response = self.model.generate_content(content)
return response.text
def _generate_hf(self, messages):
"""
Generate responses using Hugging Face models.
"""
content = " ".join([message["content"] for message in messages])
response = self.model.chat_completion(
messages=[{"role": "user", "content": messages[0]["content"] + messages[1]["content"]}],
max_tokens=self.tokens, temperature=self.temp)
return response.choices[0].message.content