import requests import os import json import logging from typing import List, Dict, Any, Optional, Union logger = logging.getLogger(__name__) class GroqClient: """Direct implementation of Groq API client""" def __init__(self, model: str = "llama3-70b-8192", temperature: float = 0.7, max_tokens: int = 2048, groq_api_key: Optional[str] = None): """Initialize the Groq client with model parameters""" self.model = model self.temperature = temperature self.max_tokens = max_tokens # Get API key from params or environment variables self.api_key = groq_api_key or os.environ.get("GROQ_API_KEY_FALLBACK", os.environ.get("GROQ_API_KEY")) if not self.api_key: raise ValueError("Groq API key not found. Please provide it or set GROQ_API_KEY environment variable.") self.base_url = "https://api.groq.com/openai/v1" self.headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json" } def list_models(self) -> Dict[str, Any]: """List available models from Groq""" url = f"{self.base_url}/models" response = requests.get(url, headers=self.headers) response.raise_for_status() return response.json() def generate(self, messages: List[Dict[str, str]], stream: bool = False) -> Union[Dict[str, Any], Any]: """Generate a response given a list of messages""" url = f"{self.base_url}/chat/completions" payload = { "model": self.model, "messages": messages, "temperature": self.temperature, "max_tokens": self.max_tokens, "stream": stream } try: response = requests.post(url, headers=self.headers, json=payload, stream=stream) response.raise_for_status() if stream: return self._handle_streaming_response(response) else: return response.json() except requests.exceptions.RequestException as e: logger.error(f"Error calling Groq API: {e}") if hasattr(e, 'response') and e.response is not None: try: error_details = e.response.json() logger.error(f"API error details: {error_details}") except: logger.error(f"API error status code: {e.response.status_code}") raise def _handle_streaming_response(self, response): """Handle streaming response from Groq API""" for line in response.iter_lines(): if line: line = line.decode('utf-8') if line.startswith('data: '): data = line[6:] # Remove 'data: ' prefix if data.strip() == '[DONE]': break try: json_data = json.loads(data) yield json_data except json.JSONDecodeError: logger.error(f"Failed to decode JSON: {data}") def __call__(self, prompt: str, **kwargs) -> str: """Make the client callable with a prompt for compatibility""" messages = [ {"role": "user", "content": prompt} ] response = self.generate(messages) return response['choices'][0]['message']['content']