Spaces:
Running
Running
import requests | |
import os | |
import json | |
import logging | |
from typing import List, Dict, Any, Optional, Union | |
logger = logging.getLogger(__name__) | |
class GroqClient: | |
"""Direct implementation of Groq API client""" | |
def __init__(self, | |
model: str = "llama3-70b-8192", | |
temperature: float = 0.7, | |
max_tokens: int = 2048, | |
groq_api_key: Optional[str] = None): | |
"""Initialize the Groq client with model parameters""" | |
self.model = model | |
self.temperature = temperature | |
self.max_tokens = max_tokens | |
# Get API key from params or environment variables | |
self.api_key = groq_api_key or os.environ.get("GROQ_API_KEY_FALLBACK", os.environ.get("GROQ_API_KEY")) | |
if not self.api_key: | |
raise ValueError("Groq API key not found. Please provide it or set GROQ_API_KEY environment variable.") | |
self.base_url = "https://api.groq.com/openai/v1" | |
self.headers = { | |
"Authorization": f"Bearer {self.api_key}", | |
"Content-Type": "application/json" | |
} | |
def list_models(self) -> Dict[str, Any]: | |
"""List available models from Groq""" | |
url = f"{self.base_url}/models" | |
response = requests.get(url, headers=self.headers) | |
response.raise_for_status() | |
return response.json() | |
def generate(self, | |
messages: List[Dict[str, str]], | |
stream: bool = False) -> Union[Dict[str, Any], Any]: | |
"""Generate a response given a list of messages""" | |
url = f"{self.base_url}/chat/completions" | |
payload = { | |
"model": self.model, | |
"messages": messages, | |
"temperature": self.temperature, | |
"max_tokens": self.max_tokens, | |
"stream": stream | |
} | |
try: | |
response = requests.post(url, headers=self.headers, json=payload, stream=stream) | |
response.raise_for_status() | |
if stream: | |
return self._handle_streaming_response(response) | |
else: | |
return response.json() | |
except requests.exceptions.RequestException as e: | |
logger.error(f"Error calling Groq API: {e}") | |
if hasattr(e, 'response') and e.response is not None: | |
try: | |
error_details = e.response.json() | |
logger.error(f"API error details: {error_details}") | |
except: | |
logger.error(f"API error status code: {e.response.status_code}") | |
raise | |
def _handle_streaming_response(self, response): | |
"""Handle streaming response from Groq API""" | |
for line in response.iter_lines(): | |
if line: | |
line = line.decode('utf-8') | |
if line.startswith('data: '): | |
data = line[6:] # Remove 'data: ' prefix | |
if data.strip() == '[DONE]': | |
break | |
try: | |
json_data = json.loads(data) | |
yield json_data | |
except json.JSONDecodeError: | |
logger.error(f"Failed to decode JSON: {data}") | |
def __call__(self, prompt: str, **kwargs) -> str: | |
"""Make the client callable with a prompt for compatibility""" | |
messages = [ | |
{"role": "user", "content": prompt} | |
] | |
response = self.generate(messages) | |
return response['choices'][0]['message']['content'] |