tommytracx commited on
Commit
da518bc
·
verified ·
1 Parent(s): 98f2d6e

Update local_llm.py

Browse files
Files changed (1) hide show
  1. local_llm.py +127 -5
local_llm.py CHANGED
@@ -20,14 +20,17 @@ if not HF_API_KEY:
20
  if not ENDPOINT_URL:
21
  logger.warning("ENDPOINT_URL environment variable not set")
22
 
23
- def run_llm(prompt, max_tokens=512, temperature=0.7):
 
 
 
24
  """
25
  Process input text through HF Inference Endpoint.
26
 
27
  Args:
28
- prompt: Input prompt to process
29
  max_tokens: Maximum tokens to generate
30
- temperature: Temperature for sampling
31
 
32
  Returns:
33
  Generated response text
@@ -40,7 +43,7 @@ def run_llm(prompt, max_tokens=512, temperature=0.7):
40
  # Format messages in OpenAI format
41
  messages = [
42
  {"role": "system", "content": "You are a helpful AI assistant for a telecom service. Answer questions clearly and concisely."},
43
- {"role": "user", "content": prompt}
44
  ]
45
 
46
  payload = {
@@ -65,4 +68,123 @@ def run_llm(prompt, max_tokens=512, temperature=0.7):
65
  if hasattr(e, 'response') and e.response is not None:
66
  error_msg += f" - Status code: {e.response.status_code}, Response: {e.response.text}"
67
  logger.error(error_msg)
68
- return f"Error generating response: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  if not ENDPOINT_URL:
21
  logger.warning("ENDPOINT_URL environment variable not set")
22
 
23
+ # Memory store for conversation history
24
+ conversation_memory = {}
25
+
26
+ def run_llm(input_text, max_tokens=512, temperature=0.7):
27
  """
28
  Process input text through HF Inference Endpoint.
29
 
30
  Args:
31
+ input_text: User input to process
32
  max_tokens: Maximum tokens to generate
33
+ temperature: Temperature for sampling (higher = more random)
34
 
35
  Returns:
36
  Generated response text
 
43
  # Format messages in OpenAI format
44
  messages = [
45
  {"role": "system", "content": "You are a helpful AI assistant for a telecom service. Answer questions clearly and concisely."},
46
+ {"role": "user", "content": input_text}
47
  ]
48
 
49
  payload = {
 
68
  if hasattr(e, 'response') and e.response is not None:
69
  error_msg += f" - Status code: {e.response.status_code}, Response: {e.response.text}"
70
  logger.error(error_msg)
71
+ return f"Error generating response: {str(e)}"
72
+
73
+ def run_llm_with_memory(input_text, session_id="default", max_tokens=512, temperature=0.7):
74
+ """
75
+ Process input with conversation memory.
76
+
77
+ Args:
78
+ input_text: User input to process
79
+ session_id: Unique identifier for conversation
80
+ max_tokens: Maximum tokens to generate
81
+ temperature: Temperature for sampling
82
+
83
+ Returns:
84
+ Generated response text
85
+ """
86
+ # Initialize memory if needed
87
+ if session_id not in conversation_memory:
88
+ conversation_memory[session_id] = [
89
+ {"role": "system", "content": "You are a helpful AI assistant for a telecom service. Answer questions clearly and concisely."}
90
+ ]
91
+
92
+ # Add current input to memory
93
+ conversation_memory[session_id].append({"role": "user", "content": input_text})
94
+
95
+ # Prepare the full conversation history
96
+ messages = conversation_memory[session_id].copy()
97
+
98
+ # Keep only the last 10 messages to avoid context length issues
99
+ if len(messages) > 10:
100
+ # Always keep the system message
101
+ messages = [messages[0]] + messages[-9:]
102
+
103
+ headers = {
104
+ "Authorization": f"Bearer {HF_API_KEY}",
105
+ "Content-Type": "application/json"
106
+ }
107
+
108
+ payload = {
109
+ "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
110
+ "messages": messages,
111
+ "max_tokens": max_tokens,
112
+ "temperature": temperature
113
+ }
114
+
115
+ logger.info(f"Sending memory-based request for session {session_id}")
116
+
117
+ try:
118
+ response = requests.post(ENDPOINT_URL, headers=headers, json=payload)
119
+ response.raise_for_status()
120
+
121
+ result = response.json()
122
+ response_text = result["choices"][0]["message"]["content"]
123
+
124
+ # Save response to memory
125
+ conversation_memory[session_id].append({"role": "assistant", "content": response_text})
126
+
127
+ return response_text
128
+
129
+ except requests.exceptions.RequestException as e:
130
+ error_msg = f"Error calling endpoint: {str(e)}"
131
+ if hasattr(e, 'response') and e.response is not None:
132
+ error_msg += f" - Status code: {e.response.status_code}, Response: {e.response.text}"
133
+ logger.error(error_msg)
134
+ return f"Error generating response: {str(e)}"
135
+
136
+ def clear_memory(session_id="default"):
137
+ """
138
+ Clear conversation memory for a specific session.
139
+
140
+ Args:
141
+ session_id: Unique identifier for conversation
142
+ """
143
+ if session_id in conversation_memory:
144
+ conversation_memory[session_id] = [
145
+ {"role": "system", "content": "You are a helpful AI assistant for a telecom service. Answer questions clearly and concisely."}
146
+ ]
147
+ return True
148
+ return False
149
+
150
+ def get_memory_sessions():
151
+ """
152
+ Get list of active memory sessions.
153
+
154
+ Returns:
155
+ List of session IDs
156
+ """
157
+ return list(conversation_memory.keys())
158
+
159
+ def get_model_info():
160
+ """
161
+ Get information about the connected model endpoint.
162
+
163
+ Returns:
164
+ Dictionary with endpoint information
165
+ """
166
+ return {
167
+ "endpoint_url": ENDPOINT_URL,
168
+ "memory_sessions": len(conversation_memory),
169
+ "model_type": "Meta-Llama-3.1-8B-Instruct (Inference Endpoint)"
170
+ }
171
+
172
+ def test_endpoint():
173
+ """
174
+ Test the endpoint connection.
175
+
176
+ Returns:
177
+ Status information
178
+ """
179
+ try:
180
+ response = run_llm("Hello, this is a test message. Please respond with a short greeting.")
181
+ return {
182
+ "status": "connected",
183
+ "message": "Successfully connected to endpoint",
184
+ "sample_response": response[:50] + "..." if len(response) > 50 else response
185
+ }
186
+ except Exception as e:
187
+ return {
188
+ "status": "error",
189
+ "message": f"Failed to connect to endpoint: {str(e)}"
190
+ }