policy-docs-qa / utils /llm_utils.py
yjernite's picture
yjernite HF Staff
Upload 2 files
e672262 verified
import json
import os
import logging
from huggingface_hub import HfApi, InferenceClient
import utils.interface_utils as interface_utils
# Renamed constant to indicate it's a default/fallback
DEFAULT_LLM_ENDPOINT_URL = (
"https://r5lahjemc2zuajga.us-east-1.aws.endpoints.huggingface.cloud"
)
# Added Endpoint name constant
LLM_ENDPOINT_NAME = os.getenv(
"HF_LLM_ENDPOINT_NAME", "phi-4-max"
) # Get from env or default
RETRIEVAL_SYSTEM_PROMPT = """**Instructions:**
You are a helpful assistant presented with a document excerpts and a question.
Your job is to retrieve the most relevant passages from the provided document excerpt that contribute to help answer the question.
For each passage retrieved from the documents, provide:
- a brief summary of the context leading up to the passage (2 sentences max)
- the supported passage quoted exactly
- a brief summary of how the points in the passage are relevant to the question (2 sentences max)
The supporting passages should be a JSON-formatted list of dictionaries with the keys 'context' 'quote' and 'relevance'.
Provide up to 4 different supporting passages covering as many different aspects of the topic in question as possible.
Only include passages that are relevant to the question. If there are fewer or no relevant passages in the document, just return a shorter or empty list.
"""
QA_RETRIEVAL_PROMPT = """Find passages from the following documents that help answer the question.
**Document Content:**
```markdown
{document}
```
**Question:**
{question}
JSON Output:"""
ANSWER_SYSTEM_PROMPT = """**Instructions:**
You are a helpful assistant presented with a list of snippets extracted from documents and a question.
The snippets are presented in a JSON-formatted list that includes a unique id (`id`), context, relevance, and the exact quote.
Your job is to answer the question based *only* on the most relevant provided snippet quotes, citing the snippets used for each sentence.
**Output Format:**
Your response *must* be a JSON-formatted list of dictionaries. Each dictionary represents a sentence in your answer and must have the following keys:
- `sentence`: A string containing the sentence.
- `citations`: A list of integers, where each integer is the `id` of a snippet that supports the sentence.
**Example Output:**
```json
[
{
"sentence": "This is the first sentence of the answer.",
"citations": [1, 3]
},
{
"sentence": "This is the second sentence, supported by another snippet.",
"citations": [5]
}
]
```
**Constraints:**
- Base your answer *only* on the information within the provided snippets.
- Do *not* use external knowledge.
- The sentences should flow together coherently.
- A single sentence can cite multiple snippets.
- The final answer should be no more than 5-6 sentences long.
- Ensure the output is valid JSON.
"""
ANSWER_PROMPT = """
Given the following snippets, answer the question.
```json
{snippets}
```
**Question:**
{question}
JSON Output:"""
# Initialize client using token from environment variables
client = InferenceClient(token=os.getenv("HF_TOKEN"))
# --- Endpoint Status Check Function ---
def check_endpoint_status(token: str | None, endpoint_name: str = LLM_ENDPOINT_NAME):
"""Checks the Inference Endpoint status and returns status dict."""
# (Function body moved from app.py - Ensure logging is configured)
logging.info(f"Checking endpoint status for '{endpoint_name}'...")
if not token:
logging.warning("HF Token not available, cannot check endpoint status.")
return {
"status": "ready",
"warning": "HF Token not available for status check.",
}
try:
api = HfApi(token=token)
endpoint = api.get_inference_endpoint(name=endpoint_name, token=token)
status = endpoint.status
logging.info(f"Endpoint '{endpoint_name}' status: {status}")
if status == "running":
return {"status": "ready"}
else:
if status == "scaledToZero":
logging.info(
f"Endpoint '{endpoint_name}' is scaled to zero. Attempting to resume..."
)
try:
endpoint.resume()
user_message = f"The required LLM endpoint ('{endpoint_name}') was scaled to zero and is **now restarting**. Please wait a few minutes and try submitting your query again."
logging.info(f"Resume command sent for '{endpoint_name}'.")
return {"status": "error", "ui_message": user_message}
except Exception as resume_error:
logging.error(
f"Failed to resume endpoint '{endpoint_name}': {resume_error}"
)
user_message = f"The required LLM endpoint ('{endpoint_name}') is scaled to zero. An attempt to automatically resume it failed: {resume_error}. Please check the endpoint status on Hugging Face."
return {"status": "error", "ui_message": user_message}
else:
user_message = f"The required LLM endpoint ('{endpoint_name}') is currently **{status}**. Analysis cannot proceed until it is running. Please check the endpoint status on Hugging Face."
logging.warning(
f"Endpoint '{endpoint_name}' is not ready (Status: {status})."
)
return {"status": "error", "ui_message": user_message}
except Exception as e:
error_msg = f"Error checking endpoint status for {endpoint_name}: {e}"
logging.error(error_msg)
return {
"status": "error",
"ui_message": f"Failed to check endpoint status. Please verify the endpoint name ('{endpoint_name}') and your token. Error: {e}",
}
def retrieve_passages(
query, doc_embeds, passages, processed_docs, embed_model, max_docs=3
):
"""Retrieves relevant passages based on embedding similarity, limited by max_docs."""
queries = [query]
query_embeddings = embed_model.encode(queries, prompt_name="query")
scores = embed_model.similarity(query_embeddings, doc_embeds)
sorted_scores = scores.sort(descending=True)
sorted_vals = sorted_scores.values[0].tolist()
sorted_idx = sorted_scores.indices[0].tolist()
results = [
{
"passage_id": i,
"document_id": passages[i][0],
"chunk_id": passages[i][1],
"document_url": processed_docs[passages[i][0]]["url"],
"passage_text": passages[i][2],
"relevance": v,
}
for i, v in zip(sorted_idx, sorted_vals)
]
# Slice the results here
return results[:max_docs]
# --- Excerpt Processing Function ---
def process_single_excerpt(
excerpt_index: int, excerpt: dict, query: str, hf_client: InferenceClient
):
"""Processes a single retrieved excerpt using an LLM to find citations and spans."""
passage_text = excerpt.get("passage_text", "")
if not passage_text:
return {
"citations": [],
"all_spans": [],
"parse_successful": False,
"raw_error_response": "Empty passage text",
}
citations = []
all_spans = []
is_parse_successful = False
raw_error_response = None
try:
retrieval_prompt = QA_RETRIEVAL_PROMPT.format(
document=passage_text, question=query
)
response = hf_client.chat_completion(
messages=[
{"role": "system", "content": RETRIEVAL_SYSTEM_PROMPT},
{"role": "user", "content": retrieval_prompt},
],
model=os.getenv("HF_LLM_ENDPOINT_URL", DEFAULT_LLM_ENDPOINT_URL),
max_tokens=2048,
temperature=0.01,
)
# Attempt to parse JSON
response_content = response.choices[0].message.content.strip()
try:
# Find JSON block
json_match = response_content.split("```json", 1)
if len(json_match) > 1:
json_str = json_match[1].split("```", 1)[0]
parsed_json = json.loads(json_str)
citations = parsed_json
is_parse_successful = True
# Find spans for each citation
for cit in citations:
quote = cit.get("quote", "")
if quote:
# Call find_citation_spans from interface_utils
spans = interface_utils.find_citation_spans(
document=passage_text, citation=quote
)
cit["char_spans"] = spans # Store spans in the citation dict
all_spans.extend(spans)
else:
raise ValueError("No ```json block found in response")
except (json.JSONDecodeError, ValueError, IndexError) as json_e:
print(f"Error parsing JSON for excerpt {excerpt_index}: {json_e}")
is_parse_successful = False
raw_error_response = f"LLM Response (failed to parse): {response_content}" # Fixed potential newline issue
except Exception as llm_e:
print(f"Error during LLM call for excerpt {excerpt_index}: {llm_e}")
is_parse_successful = False
raw_error_response = f"LLM API Error: {llm_e}"
return {
"citations": citations,
"all_spans": all_spans,
"parse_successful": is_parse_successful,
"raw_error_response": raw_error_response,
}
def generate_summary_answer(snippets: list, query: str, hf_client: InferenceClient):
"""Generates a summarized answer based on provided snippets using an LLM."""
# NOTE: Removed llm_endpoint_url parameter, using env var directly
endpoint_url = os.getenv("HF_LLM_ENDPOINT_URL", DEFAULT_LLM_ENDPOINT_URL)
if not snippets:
return {
"answer_sentences": [],
"parse_successful": False,
"raw_error_response": "No snippets provided for summarization.",
}
try:
# Ensure snippets are formatted as a JSON string for the prompt
snippets_json_string = json.dumps(snippets, indent=2)
answer_prompt_formatted = ANSWER_PROMPT.format(
snippets=snippets_json_string, question=query
)
response = hf_client.chat_completion(
messages=[
{"role": "system", "content": ANSWER_SYSTEM_PROMPT},
{"role": "user", "content": answer_prompt_formatted},
],
model=endpoint_url,
max_tokens=512,
temperature=0.01,
)
# Attempt to parse JSON response
response_content = response.choices[0].message.content.strip()
try:
# Find JSON block (assuming it might be wrapped in ```json ... ```)
json_match = response_content.split("```json", 1)
if len(json_match) > 1:
json_str = json_match[1].split("```", 1)[0]
else: # Assume the response *is* the JSON if no backticks found
json_str = response_content
parsed_json = json.loads(json_str)
# Basic validation: check if it's a list of dictionaries with expected keys
if isinstance(parsed_json, list) and all(
isinstance(item, dict) and "sentence" in item and "citations" in item
for item in parsed_json
):
return {
"answer_sentences": parsed_json,
"parse_successful": True,
"raw_error_response": None,
}
else:
raise ValueError(
"Parsed JSON does not match expected format (list of {'sentence':..., 'citations':...})"
)
except (json.JSONDecodeError, ValueError, IndexError) as json_e:
print(f"Error parsing summary JSON: {json_e}")
return {
"answer_sentences": [],
"parse_successful": False,
"raw_error_response": f"LLM Response (failed to parse summary): {response_content}",
}
except Exception as llm_e:
print(f"Error during LLM summary call: {llm_e}")
return {
"answer_sentences": [],
"parse_successful": False,
"raw_error_response": f"LLM API Error during summary generation: {llm_e}",
}
# REMOVED Comment: This function will now live in app.py or interface_utils.py as it handles single excerpt processing
# def make_supporting_snippets(...): -> Now handled excerpt by excerpt in app.py