File size: 3,619 Bytes
c6f1571
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import requests
import os
import json
import logging
from typing import List, Dict, Any, Optional, Union

logger = logging.getLogger(__name__)

class GroqClient:
    """Direct implementation of Groq API client"""
    
    def __init__(self, 
                 model: str = "llama3-70b-8192", 
                 temperature: float = 0.7, 
                 max_tokens: int = 2048,
                 groq_api_key: Optional[str] = None):
        """Initialize the Groq client with model parameters"""
        self.model = model
        self.temperature = temperature
        self.max_tokens = max_tokens
        
        # Get API key from params or environment variables
        self.api_key = groq_api_key or os.environ.get("GROQ_API_KEY_FALLBACK", os.environ.get("GROQ_API_KEY"))
        if not self.api_key:
            raise ValueError("Groq API key not found. Please provide it or set GROQ_API_KEY environment variable.")
        
        self.base_url = "https://api.groq.com/openai/v1"
        self.headers = {
            "Authorization": f"Bearer {self.api_key}",
            "Content-Type": "application/json"
        }
    
    def list_models(self) -> Dict[str, Any]:
        """List available models from Groq"""
        url = f"{self.base_url}/models"
        response = requests.get(url, headers=self.headers)
        response.raise_for_status()
        return response.json()
    
    def generate(self, 
                 messages: List[Dict[str, str]], 
                 stream: bool = False) -> Union[Dict[str, Any], Any]:
        """Generate a response given a list of messages"""
        url = f"{self.base_url}/chat/completions"
        
        payload = {
            "model": self.model,
            "messages": messages,
            "temperature": self.temperature,
            "max_tokens": self.max_tokens,
            "stream": stream
        }
        
        try:
            response = requests.post(url, headers=self.headers, json=payload, stream=stream)
            response.raise_for_status()
            
            if stream:
                return self._handle_streaming_response(response)
            else:
                return response.json()
        except requests.exceptions.RequestException as e:
            logger.error(f"Error calling Groq API: {e}")
            if hasattr(e, 'response') and e.response is not None:
                try:
                    error_details = e.response.json()
                    logger.error(f"API error details: {error_details}")
                except:
                    logger.error(f"API error status code: {e.response.status_code}")
            raise
    
    def _handle_streaming_response(self, response):
        """Handle streaming response from Groq API"""
        for line in response.iter_lines():
            if line:
                line = line.decode('utf-8')
                if line.startswith('data: '):
                    data = line[6:]  # Remove 'data: ' prefix
                    if data.strip() == '[DONE]':
                        break
                    try:
                        json_data = json.loads(data)
                        yield json_data
                    except json.JSONDecodeError:
                        logger.error(f"Failed to decode JSON: {data}")
                        
    def __call__(self, prompt: str, **kwargs) -> str:
        """Make the client callable with a prompt for compatibility"""
        messages = [
            {"role": "user", "content": prompt}
        ]
        
        response = self.generate(messages)
        return response['choices'][0]['message']['content']