File size: 5,042 Bytes
4b722ec
 
 
 
 
8842640
 
4b722ec
8842640
4b722ec
8842640
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b722ec
 
 
 
8842640
 
4b722ec
 
8842640
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b722ec
 
8842640
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adbebe0
8842640
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
from vertexai.generative_models import GenerativeModel
from dotenv import load_dotenv
from anthropic import AnthropicVertex
import os
from openai import OpenAI
from src.text_generation.vertexai_setup import initialize_vertexai_params, get_default_config
from huggingface_hub import InferenceClient

# Load environment variables
load_dotenv()
OAI_API_KEY = os.getenv("OPENAI_API_KEY")


def _validate_tokens(max_tokens: int) -> int:
    """
    Validates the max_tokens parameter. Ensures it's within a valid range (1 to 8192).
    If invalid, defaults to 8192.
    """
    if 1 <= max_tokens <= 8192:
        return max_tokens
    return 8192


def _validate_temperature(temp: float) -> float:
    """
    Validates the temperature parameter. Ensures it's within a valid range (0 to 1).
    If invalid, defaults to 0.49.
    """
    if 0 <= temp <= 1:
        return temp
    return 0.49


class LLMBaseClass:
    """
    Base class for text generation. Users provide the HF model ID or other model identifiers
    and can call the generate method to get responses.
    """

    def __init__(self, model_id: str, max_tokens: int, temp: float) -> None:
        self.model_id = model_id
        self.api_key = None
        self.temp = _validate_temperature(temp)
        self.tokens = _validate_tokens(max_tokens)
        self.model = self._initialize_model()

    def _initialize_model(self):
        """
        Initialize the model based on the provided model ID.
        """
        if self.model_id == "gpt-4o-mini":
            return self._initialize_openai_model()
        elif self.model_id == "claude-3-5-sonnet@20240620":
            return self._initialize_claude_model()
        elif self.model_id in ["claude-3-5-sonnet@20240620",
                               "gemini-1.0-pro", "gemini-1.5-flash-001", "gemini-1.5-pro-001"]:
            return self._initialize_vertexai_model()
        else:
            return self._initialize_hf_model()

    def _initialize_openai_model(self):
        """
        Initialize OpenAI model.
        """
        self.api_key = OAI_API_KEY
        return OpenAI(api_key=self.api_key)

    def _initialize_claude_model(self):
        """
        Initialize Claude model using Anthropic via Vertex AI.
        """
        self.api_key = os.getenv("VERTEXAI_PROJECTID")
        return AnthropicVertex(region="europe-west1", project_id=self.api_key)

    def _initialize_vertexai_model(self):
        """
        Initialize Google Gemini model using Vertex AI.
        """
        default_gen_config, default_safe_settings = get_default_config()
        gen_config = {
            "temperature": self.temp,
            "max_output_tokens": self.tokens,
        }
        return GenerativeModel(self.model_id,
                               generation_config=default_gen_config if gen_config is None else gen_config,
                               safety_settings=default_safe_settings)

    def _initialize_hf_model(self):
        self.api_key = os.getenv("HF_TOKEN")
        return InferenceClient(token=self.api_key, model=self.model_id)

    def generate(self, messages):
        """
        Generate responses based on the model type and provided messages.
        """
        if self.model_id == "gpt-4o-mini":
            return self._generate_openai(messages)
        elif self.model_id in ["claude-3-5-sonnet@20240620",
                               "gemini-1.0-pro", "gemini-1.5-flash-001", "gemini-1.5-pro-001"]:
            return self._generate_vertexai(messages)
        else:
            return self._generate_hf(messages)

    def _generate_openai(self, messages):
        """
        Generate responses using OpenAI model.
        """
        completion = self.model.chat.completions.create(
            model=self.model_id,
            messages=messages,
            temperature=self.temp,
            max_tokens=self.tokens,
        )
        return completion.choices[0].message.content

    def _generate_vertexai(self, messages):
        """
        Generate responses using Claude or Gemini models via Vertex AI.
        """
        initialize_vertexai_params()
        content = " ".join([message["content"] for message in messages])
        if "claude" in self.model_id:
            message = self.model.messages.create(
                max_tokens=self.tokens,
                model=self.model_id,
                messages=[{"role": "user", "content": content}],
            )
            return message.content[0].text
        else:
            response = self.model.generate_content(content)
            return response.text

    def _generate_hf(self, messages):
        """
        Generate responses using Hugging Face models.
        """
        content = " ".join([message["content"] for message in messages])
        response = self.model.chat_completion(
            messages=[{"role": "user", "content": messages[0]["content"] + messages[1]["content"]}],
            max_tokens=self.tokens, temperature=self.temp)
        return response.choices[0].message.content