space-privacy / llm_interface.py
Yacine Jernite
added TLDR functionality
36de078
import logging
import os
from dotenv import load_dotenv
from huggingface_hub import InferenceClient
from huggingface_hub.inference._generated.types import ChatCompletionOutput
from huggingface_hub.utils import HfHubHTTPError
# Configure logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
# Load environment variables from .env file
# load_dotenv() # Removed: This should be loaded only at the main entry point (app.py)
load_dotenv() # Restored: Ensure env vars are loaded when this module is imported/used
HF_TOKEN = os.getenv("HF_TOKEN")
HF_INFERENCE_ENDPOINT_URL = os.getenv("HF_INFERENCE_ENDPOINT_URL")
# Default parameters for the LLM call
DEFAULT_MAX_TOKENS = 2048
DEFAULT_TEMPERATURE = 0.1 # Lower temperature for more deterministic analysis
# Special dictionary to indicate a 503 error
ERROR_503_DICT = {"error_type": "503", "message": "Service Unavailable"}
def query_qwen_endpoint(
formatted_prompt: list[dict[str, str]], max_tokens: int = DEFAULT_MAX_TOKENS
) -> ChatCompletionOutput | dict | None:
"""
Queries the specified Qwen Inference Endpoint with the formatted prompt.
Args:
formatted_prompt: A list of message dictionaries for the chat completion API.
max_tokens: The maximum number of tokens to generate.
Returns:
The ChatCompletionOutput object from the inference client,
a specific dictionary (ERROR_503_DICT) if a 503 error occurs,
or None if another error occurs.
"""
if not HF_INFERENCE_ENDPOINT_URL:
logging.error("HF_INFERENCE_ENDPOINT_URL environment variable not set.")
return None
if not HF_TOKEN:
logging.warning(
"HF_TOKEN environment variable not set. Requests might fail if the endpoint requires authentication."
)
# Depending on endpoint config, it might still work without token
logging.info(f"Querying Inference Endpoint: {HF_INFERENCE_ENDPOINT_URL}")
client = InferenceClient(model=HF_INFERENCE_ENDPOINT_URL, token=HF_TOKEN)
try:
response = client.chat_completion(
messages=formatted_prompt,
max_tokens=max_tokens,
temperature=DEFAULT_TEMPERATURE,
# Qwen models often benefit from setting stop sequences if known,
# but we'll rely on max_tokens and model's natural stopping for now.
# stop=["<|im_end|>"] # Example stop token if needed for specific Qwen finetunes
)
logging.info("Successfully received response from Inference Endpoint.")
return response
except HfHubHTTPError as e:
# Check specifically for 503 Service Unavailable
if e.response is not None and e.response.status_code == 503:
logging.warning(
f"Encountered 503 Service Unavailable from endpoint: {HF_INFERENCE_ENDPOINT_URL}"
)
return ERROR_503_DICT # Return special dict for 503
else:
# Handle other HTTP errors
logging.error(f"HTTP error querying Inference Endpoint: {e}")
if e.response is not None:
logging.error(f"Response details: {e.response.text}")
return None # Return None for other HTTP errors
except Exception as e:
logging.error(f"An unexpected error occurred querying Inference Endpoint: {e}")
print(f"An unexpected error occurred querying Inference Endpoint: {e}")
return None
def parse_qwen_response(response: ChatCompletionOutput | dict | None) -> str:
"""
Parses the response from the Qwen model to extract the generated text.
Handles potential None or error dict inputs.
Args:
response: The ChatCompletionOutput object, ERROR_503_DICT, or None.
Returns:
The extracted response text as a string, or an error message string.
"""
if response is None:
return "Error: Failed to get response from the language model."
# Check if it's our specific 503 error signal before trying to parse as ChatCompletionOutput
if isinstance(response, dict) and response.get("error_type") == "503":
return f"Error: {response['error_type']} {response['message']}"
# Check if it's likely the expected ChatCompletionOutput structure
if not hasattr(response, "choices"):
logging.error(
f"Unexpected response type received by parse_qwen_response: {type(response)}. Content: {response}"
)
return "Error: Received an unexpected response format from the language model endpoint."
try:
# Access the generated content according to the ChatCompletionOutput structure
if response.choices and len(response.choices) > 0:
content = response.choices[0].message.content
if content:
logging.info("Successfully parsed response content.")
return content.strip()
else:
logging.warning("Response received, but content is empty.")
return "Error: Received an empty response from the language model."
else:
logging.warning("Response received, but no choices found.")
return "Error: No response choices found in the language model output."
except AttributeError as e:
# This might catch cases where response looks like the object but lacks expected attributes
logging.error(
f"Attribute error parsing response: {e}. Response structure might be unexpected."
)
logging.error(f"Raw response object: {response}")
return "Error: Could not parse the structure of the language model response."
except Exception as e:
logging.error(f"An unexpected error occurred parsing the response: {e}")
return "Error: An unexpected error occurred while parsing the language model response."
# Example Usage (for testing - requires .env setup and potentially prompts.py)
# if __name__ == '__main__':
# # This example assumes you have a prompts.py that can generate a test prompt
# try:
# from prompts import format_code_for_analysis
# # Create a dummy prompt for testing
# test_files = {"app.py": "print('hello')"}
# test_prompt = format_code_for_analysis("test/minimal", test_files)
# print("--- Sending Test Prompt ---")
# print(test_prompt)
# api_response = query_qwen_endpoint(test_prompt)
# print("\n--- Raw API Response ---")
# print(api_response)
# print("\n--- Parsed Response ---")
# parsed_text = parse_qwen_response(api_response)
# print(parsed_text)
# except ImportError:
# print("Could not import prompts.py for testing. Run this test from the project root.")
# except Exception as e:
# print(f"An error occurred during testing: {e}")