File size: 2,008 Bytes
640b1c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# src/llms/openai_llm.py
import openai
from typing import Optional, List

from .base_llm import BaseLLM

class OpenAILanguageModel(BaseLLM):
    def __init__(
        self, 
        api_key: str, 
        model: str = 'gpt-3.5-turbo'
    ):
        """
        Initialize OpenAI Language Model
        
        Args:
            api_key (str): OpenAI API key
            model (str): Name of the OpenAI model to use
        """
        openai.api_key = api_key
        self.model = model
    
    def generate(
        self, 
        prompt: str, 
        max_tokens: Optional[int] = 150,
        temperature: float = 0.7,
        **kwargs
    ) -> str:
        """
        Generate response using OpenAI API
        
        Args:
            prompt (str): Input prompt
            max_tokens (Optional[int]): Maximum tokens to generate
            temperature (float): Sampling temperature
        
        Returns:
            str: Generated response
        """
        response = openai.ChatCompletion.create(
            model=self.model,
            messages=[{"role": "user", "content": prompt}],
            max_tokens=max_tokens,
            temperature=temperature,
            **kwargs
        )
        
        return response.choices[0].message.content.strip()
    
    def tokenize(self, text: str) -> List[str]:
        """
        Tokenize text using OpenAI tokenizer
        
        Args:
            text (str): Input text to tokenize
        
        Returns:
            List[str]: List of tokens
        """
        # Note: This is a placeholder. OpenAI doesn't provide a direct 
        # tokenization method without making an API call.
        return text.split()
    
    def count_tokens(self, text: str) -> int:
        """
        Count tokens in the text
        
        Args:
            text (str): Input text to count tokens
        
        Returns:
            int: Number of tokens
        """
        # Approximate token counting
        return len(self.tokenize(text))