ChatbotLangchain / groq_client.py
Phoenix21's picture
Create groq_client.py
c6f1571 verified
raw
history blame contribute delete
3.62 kB
import requests
import os
import json
import logging
from typing import List, Dict, Any, Optional, Union
logger = logging.getLogger(__name__)
class GroqClient:
"""Direct implementation of Groq API client"""
def __init__(self,
model: str = "llama3-70b-8192",
temperature: float = 0.7,
max_tokens: int = 2048,
groq_api_key: Optional[str] = None):
"""Initialize the Groq client with model parameters"""
self.model = model
self.temperature = temperature
self.max_tokens = max_tokens
# Get API key from params or environment variables
self.api_key = groq_api_key or os.environ.get("GROQ_API_KEY_FALLBACK", os.environ.get("GROQ_API_KEY"))
if not self.api_key:
raise ValueError("Groq API key not found. Please provide it or set GROQ_API_KEY environment variable.")
self.base_url = "https://api.groq.com/openai/v1"
self.headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
def list_models(self) -> Dict[str, Any]:
"""List available models from Groq"""
url = f"{self.base_url}/models"
response = requests.get(url, headers=self.headers)
response.raise_for_status()
return response.json()
def generate(self,
messages: List[Dict[str, str]],
stream: bool = False) -> Union[Dict[str, Any], Any]:
"""Generate a response given a list of messages"""
url = f"{self.base_url}/chat/completions"
payload = {
"model": self.model,
"messages": messages,
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"stream": stream
}
try:
response = requests.post(url, headers=self.headers, json=payload, stream=stream)
response.raise_for_status()
if stream:
return self._handle_streaming_response(response)
else:
return response.json()
except requests.exceptions.RequestException as e:
logger.error(f"Error calling Groq API: {e}")
if hasattr(e, 'response') and e.response is not None:
try:
error_details = e.response.json()
logger.error(f"API error details: {error_details}")
except:
logger.error(f"API error status code: {e.response.status_code}")
raise
def _handle_streaming_response(self, response):
"""Handle streaming response from Groq API"""
for line in response.iter_lines():
if line:
line = line.decode('utf-8')
if line.startswith('data: '):
data = line[6:] # Remove 'data: ' prefix
if data.strip() == '[DONE]':
break
try:
json_data = json.loads(data)
yield json_data
except json.JSONDecodeError:
logger.error(f"Failed to decode JSON: {data}")
def __call__(self, prompt: str, **kwargs) -> str:
"""Make the client callable with a prompt for compatibility"""
messages = [
{"role": "user", "content": prompt}
]
response = self.generate(messages)
return response['choices'][0]['message']['content']