Phoenix21 commited on
Commit
c6f1571
·
verified ·
1 Parent(s): ffff0e6

Create groq_client.py

Browse files
Files changed (1) hide show
  1. groq_client.py +94 -0
groq_client.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import os
3
+ import json
4
+ import logging
5
+ from typing import List, Dict, Any, Optional, Union
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+ class GroqClient:
10
+ """Direct implementation of Groq API client"""
11
+
12
+ def __init__(self,
13
+ model: str = "llama3-70b-8192",
14
+ temperature: float = 0.7,
15
+ max_tokens: int = 2048,
16
+ groq_api_key: Optional[str] = None):
17
+ """Initialize the Groq client with model parameters"""
18
+ self.model = model
19
+ self.temperature = temperature
20
+ self.max_tokens = max_tokens
21
+
22
+ # Get API key from params or environment variables
23
+ self.api_key = groq_api_key or os.environ.get("GROQ_API_KEY_FALLBACK", os.environ.get("GROQ_API_KEY"))
24
+ if not self.api_key:
25
+ raise ValueError("Groq API key not found. Please provide it or set GROQ_API_KEY environment variable.")
26
+
27
+ self.base_url = "https://api.groq.com/openai/v1"
28
+ self.headers = {
29
+ "Authorization": f"Bearer {self.api_key}",
30
+ "Content-Type": "application/json"
31
+ }
32
+
33
+ def list_models(self) -> Dict[str, Any]:
34
+ """List available models from Groq"""
35
+ url = f"{self.base_url}/models"
36
+ response = requests.get(url, headers=self.headers)
37
+ response.raise_for_status()
38
+ return response.json()
39
+
40
+ def generate(self,
41
+ messages: List[Dict[str, str]],
42
+ stream: bool = False) -> Union[Dict[str, Any], Any]:
43
+ """Generate a response given a list of messages"""
44
+ url = f"{self.base_url}/chat/completions"
45
+
46
+ payload = {
47
+ "model": self.model,
48
+ "messages": messages,
49
+ "temperature": self.temperature,
50
+ "max_tokens": self.max_tokens,
51
+ "stream": stream
52
+ }
53
+
54
+ try:
55
+ response = requests.post(url, headers=self.headers, json=payload, stream=stream)
56
+ response.raise_for_status()
57
+
58
+ if stream:
59
+ return self._handle_streaming_response(response)
60
+ else:
61
+ return response.json()
62
+ except requests.exceptions.RequestException as e:
63
+ logger.error(f"Error calling Groq API: {e}")
64
+ if hasattr(e, 'response') and e.response is not None:
65
+ try:
66
+ error_details = e.response.json()
67
+ logger.error(f"API error details: {error_details}")
68
+ except:
69
+ logger.error(f"API error status code: {e.response.status_code}")
70
+ raise
71
+
72
+ def _handle_streaming_response(self, response):
73
+ """Handle streaming response from Groq API"""
74
+ for line in response.iter_lines():
75
+ if line:
76
+ line = line.decode('utf-8')
77
+ if line.startswith('data: '):
78
+ data = line[6:] # Remove 'data: ' prefix
79
+ if data.strip() == '[DONE]':
80
+ break
81
+ try:
82
+ json_data = json.loads(data)
83
+ yield json_data
84
+ except json.JSONDecodeError:
85
+ logger.error(f"Failed to decode JSON: {data}")
86
+
87
+ def __call__(self, prompt: str, **kwargs) -> str:
88
+ """Make the client callable with a prompt for compatibility"""
89
+ messages = [
90
+ {"role": "user", "content": prompt}
91
+ ]
92
+
93
+ response = self.generate(messages)
94
+ return response['choices'][0]['message']['content']