Spaces:
Running
Running
import logging | |
import os | |
from dotenv import load_dotenv | |
from huggingface_hub import InferenceClient | |
from huggingface_hub.inference._generated.types import ChatCompletionOutput | |
from huggingface_hub.utils import HfHubHTTPError | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" | |
) | |
# Load environment variables from .env file | |
# load_dotenv() # Removed: This should be loaded only at the main entry point (app.py) | |
load_dotenv() # Restored: Ensure env vars are loaded when this module is imported/used | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
HF_INFERENCE_ENDPOINT_URL = os.getenv("HF_INFERENCE_ENDPOINT_URL") | |
# Default parameters for the LLM call | |
DEFAULT_MAX_TOKENS = 2048 | |
DEFAULT_TEMPERATURE = 0.1 # Lower temperature for more deterministic analysis | |
# Special dictionary to indicate a 503 error | |
ERROR_503_DICT = {"error_type": "503", "message": "Service Unavailable"} | |
def query_qwen_endpoint( | |
formatted_prompt: list[dict[str, str]], max_tokens: int = DEFAULT_MAX_TOKENS | |
) -> ChatCompletionOutput | dict | None: | |
""" | |
Queries the specified Qwen Inference Endpoint with the formatted prompt. | |
Args: | |
formatted_prompt: A list of message dictionaries for the chat completion API. | |
max_tokens: The maximum number of tokens to generate. | |
Returns: | |
The ChatCompletionOutput object from the inference client, | |
a specific dictionary (ERROR_503_DICT) if a 503 error occurs, | |
or None if another error occurs. | |
""" | |
if not HF_INFERENCE_ENDPOINT_URL: | |
logging.error("HF_INFERENCE_ENDPOINT_URL environment variable not set.") | |
return None | |
if not HF_TOKEN: | |
logging.warning( | |
"HF_TOKEN environment variable not set. Requests might fail if the endpoint requires authentication." | |
) | |
# Depending on endpoint config, it might still work without token | |
logging.info(f"Querying Inference Endpoint: {HF_INFERENCE_ENDPOINT_URL}") | |
client = InferenceClient(model=HF_INFERENCE_ENDPOINT_URL, token=HF_TOKEN) | |
try: | |
response = client.chat_completion( | |
messages=formatted_prompt, | |
max_tokens=max_tokens, | |
temperature=DEFAULT_TEMPERATURE, | |
# Qwen models often benefit from setting stop sequences if known, | |
# but we'll rely on max_tokens and model's natural stopping for now. | |
# stop=["<|im_end|>"] # Example stop token if needed for specific Qwen finetunes | |
) | |
logging.info("Successfully received response from Inference Endpoint.") | |
return response | |
except HfHubHTTPError as e: | |
# Check specifically for 503 Service Unavailable | |
if e.response is not None and e.response.status_code == 503: | |
logging.warning( | |
f"Encountered 503 Service Unavailable from endpoint: {HF_INFERENCE_ENDPOINT_URL}" | |
) | |
return ERROR_503_DICT # Return special dict for 503 | |
else: | |
# Handle other HTTP errors | |
logging.error(f"HTTP error querying Inference Endpoint: {e}") | |
if e.response is not None: | |
logging.error(f"Response details: {e.response.text}") | |
return None # Return None for other HTTP errors | |
except Exception as e: | |
logging.error(f"An unexpected error occurred querying Inference Endpoint: {e}") | |
print(f"An unexpected error occurred querying Inference Endpoint: {e}") | |
return None | |
def parse_qwen_response(response: ChatCompletionOutput | dict | None) -> str: | |
""" | |
Parses the response from the Qwen model to extract the generated text. | |
Handles potential None or error dict inputs. | |
Args: | |
response: The ChatCompletionOutput object, ERROR_503_DICT, or None. | |
Returns: | |
The extracted response text as a string, or an error message string. | |
""" | |
if response is None: | |
return "Error: Failed to get response from the language model." | |
# Check if it's our specific 503 error signal before trying to parse as ChatCompletionOutput | |
if isinstance(response, dict) and response.get("error_type") == "503": | |
return f"Error: {response['error_type']} {response['message']}" | |
# Check if it's likely the expected ChatCompletionOutput structure | |
if not hasattr(response, "choices"): | |
logging.error( | |
f"Unexpected response type received by parse_qwen_response: {type(response)}. Content: {response}" | |
) | |
return "Error: Received an unexpected response format from the language model endpoint." | |
try: | |
# Access the generated content according to the ChatCompletionOutput structure | |
if response.choices and len(response.choices) > 0: | |
content = response.choices[0].message.content | |
if content: | |
logging.info("Successfully parsed response content.") | |
return content.strip() | |
else: | |
logging.warning("Response received, but content is empty.") | |
return "Error: Received an empty response from the language model." | |
else: | |
logging.warning("Response received, but no choices found.") | |
return "Error: No response choices found in the language model output." | |
except AttributeError as e: | |
# This might catch cases where response looks like the object but lacks expected attributes | |
logging.error( | |
f"Attribute error parsing response: {e}. Response structure might be unexpected." | |
) | |
logging.error(f"Raw response object: {response}") | |
return "Error: Could not parse the structure of the language model response." | |
except Exception as e: | |
logging.error(f"An unexpected error occurred parsing the response: {e}") | |
return "Error: An unexpected error occurred while parsing the language model response." | |
# Example Usage (for testing - requires .env setup and potentially prompts.py) | |
# if __name__ == '__main__': | |
# # This example assumes you have a prompts.py that can generate a test prompt | |
# try: | |
# from prompts import format_code_for_analysis | |
# # Create a dummy prompt for testing | |
# test_files = {"app.py": "print('hello')"} | |
# test_prompt = format_code_for_analysis("test/minimal", test_files) | |
# print("--- Sending Test Prompt ---") | |
# print(test_prompt) | |
# api_response = query_qwen_endpoint(test_prompt) | |
# print("\n--- Raw API Response ---") | |
# print(api_response) | |
# print("\n--- Parsed Response ---") | |
# parsed_text = parse_qwen_response(api_response) | |
# print(parsed_text) | |
# except ImportError: | |
# print("Could not import prompts.py for testing. Run this test from the project root.") | |
# except Exception as e: | |
# print(f"An error occurred during testing: {e}") | |