FluentQ / models /local_llm.py
tommytracx's picture
Update models/local_llm.py
9f5d5d3 verified
raw
history blame
6.51 kB
"""
LLM implementation using Hugging Face Inference Endpoint with OpenAI compatibility.
"""
import requests
import os
import json
import logging
from typing import Dict, List, Optional, Any
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Endpoint configuration
HF_API_KEY = os.environ.get("HF_API_KEY", "")
ENDPOINT_URL = os.environ.get("ENDPOINT_URL", "https://cg01ow7izccjx1b2.us-east-1.aws.endpoints.huggingface.cloud/v1/chat/completions")
# Verify configuration
if not HF_API_KEY:
logger.warning("HF_API_KEY environment variable not set")
if not ENDPOINT_URL:
logger.warning("ENDPOINT_URL environment variable not set")
# Memory store for conversation history
conversation_memory: Dict[str, List[Dict[str, str]]] = {}
def run_llm(input_text: str, max_tokens: int = 512, temperature: float = 0.7) -> str:
"""
Process input text through HF Inference Endpoint.
Args:
input_text: User input to process
max_tokens: Maximum tokens to generate
temperature: Temperature for sampling (higher = more random)
Returns:
Generated response text
"""
headers = {
"Authorization": f"Bearer {HF_API_KEY}",
"Content-Type": "application/json"
}
# Format messages in OpenAI format
messages = [
{"role": "system", "content": "You are a helpful AI assistant for a telecom service. Answer questions clearly and concisely."},
{"role": "user", "content": input_text}
]
payload = {
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature
}
logger.debug(f"Sending request to endpoint with temperature={temperature}, max_tokens={max_tokens}")
try:
response = requests.post(ENDPOINT_URL, headers=headers, json=payload)
response.raise_for_status()
result = response.json()
response_text = result["choices"][0]["message"]["content"]
logger.debug(f"Generated response of {len(response_text)} characters")
return response_text
except requests.exceptions.RequestException as e:
error_msg = f"Error calling endpoint: {str(e)}"
if hasattr(e, 'response') and e.response is not None:
error_msg += f" - Status code: {e.response.status_code}, Response: {e.response.text}"
logger.error(error_msg)
return f"Error generating response: {str(e)}"
def run_llm_with_memory(input_text: str, session_id: str = "default", max_tokens: int = 512, temperature: float = 0.7) -> str:
"""
Process input with conversation memory.
Args:
input_text: User input to process
session_id: Unique identifier for conversation
max_tokens: Maximum tokens to generate
temperature: Temperature for sampling
Returns:
Generated response text
"""
# Initialize memory if needed
if session_id not in conversation_memory:
conversation_memory[session_id] = [
{"role": "system", "content": "You are a helpful AI assistant for a telecom service. Answer questions clearly and concisely."}
]
# Add current input to memory
conversation_memory[session_id].append({"role": "user", "content": input_text})
# Prepare the full conversation history
messages = conversation_memory[session_id].copy()
# Keep only the last 10 messages to avoid context length issues
if len(messages) > 10:
# Always keep the system message
messages = [messages[0]] + messages[-9:]
headers = {
"Authorization": f"Bearer {HF_API_KEY}",
"Content-Type": "application/json"
}
payload = {
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature
}
logger.debug(f"Sending memory-based request for session {session_id}")
try:
response = requests.post(ENDPOINT_URL, headers=headers, json=payload)
response.raise_for_status()
result = response.json()
response_text = result["choices"][0]["message"]["content"]
# Save response to memory
conversation_memory[session_id].append({"role": "assistant", "content": response_text})
return response_text
except requests.exceptions.RequestException as e:
error_msg = f"Error calling endpoint: {str(e)}"
if hasattr(e, 'response') and e.response is not None:
error_msg += f" - Status code: {e.response.status_code}, Response: {e.response.text}"
logger.error(error_msg)
return f"Error generating response: {str(e)}"
def clear_memory(session_id: str = "default") -> bool:
"""
Clear conversation memory for a specific session.
Args:
session_id: Unique identifier for conversation
"""
if session_id in conversation_memory:
conversation_memory[session_id] = [
{"role": "system", "content": "You are a helpful AI assistant for a telecom service. Answer questions clearly and concisely."}
]
return True
return False
def get_memory_sessions() -> List[str]:
"""
Get list of active memory sessions.
Returns:
List of session IDs
"""
return list(conversation_memory.keys())
def get_model_info() -> Dict[str, Any]:
"""
Get information about the connected model endpoint.
Returns:
Dictionary with endpoint information
"""
return {
"endpoint_url": ENDPOINT_URL,
"memory_sessions": len(conversation_memory),
"model_type": "Meta-Llama-3.1-8B-Instruct (Inference Endpoint)"
}
def test_endpoint() -> Dict[str, Any]:
"""
Test the endpoint connection.
Returns:
Status information
"""
try:
response = run_llm("Hello, this is a test message. Please respond with a short greeting.")
return {
"status": "connected",
"message": "Successfully connected to endpoint",
"sample_response": response[:50] + "..." if len(response) > 50 else response
}
except Exception as e:
return {
"status": "error",
"message": f"Failed to connect to endpoint: {str(e)}"
}