Spaces:
Running
Running
import json | |
import os | |
import logging | |
from huggingface_hub import HfApi, InferenceClient | |
import utils.interface_utils as interface_utils | |
# Renamed constant to indicate it's a default/fallback | |
DEFAULT_LLM_ENDPOINT_URL = ( | |
"https://r5lahjemc2zuajga.us-east-1.aws.endpoints.huggingface.cloud" | |
) | |
# Added Endpoint name constant | |
LLM_ENDPOINT_NAME = os.getenv( | |
"HF_LLM_ENDPOINT_NAME", "phi-4-max" | |
) # Get from env or default | |
RETRIEVAL_SYSTEM_PROMPT = """**Instructions:** | |
You are a helpful assistant presented with a document excerpts and a question. | |
Your job is to retrieve the most relevant passages from the provided document excerpt that contribute to help answer the question. | |
For each passage retrieved from the documents, provide: | |
- a brief summary of the context leading up to the passage (2 sentences max) | |
- the supported passage quoted exactly | |
- a brief summary of how the points in the passage are relevant to the question (2 sentences max) | |
The supporting passages should be a JSON-formatted list of dictionaries with the keys 'context' 'quote' and 'relevance'. | |
Provide up to 4 different supporting passages covering as many different aspects of the topic in question as possible. | |
Only include passages that are relevant to the question. If there are fewer or no relevant passages in the document, just return a shorter or empty list. | |
""" | |
QA_RETRIEVAL_PROMPT = """Find passages from the following documents that help answer the question. | |
**Document Content:** | |
```markdown | |
{document} | |
``` | |
**Question:** | |
{question} | |
JSON Output:""" | |
ANSWER_SYSTEM_PROMPT = """**Instructions:** | |
You are a helpful assistant presented with a list of snippets extracted from documents and a question. | |
The snippets are presented in a JSON-formatted list that includes a unique id (`id`), context, relevance, and the exact quote. | |
Your job is to answer the question based *only* on the most relevant provided snippet quotes, citing the snippets used for each sentence. | |
**Output Format:** | |
Your response *must* be a JSON-formatted list of dictionaries. Each dictionary represents a sentence in your answer and must have the following keys: | |
- `sentence`: A string containing the sentence. | |
- `citations`: A list of integers, where each integer is the `id` of a snippet that supports the sentence. | |
**Example Output:** | |
```json | |
[ | |
{ | |
"sentence": "This is the first sentence of the answer.", | |
"citations": [1, 3] | |
}, | |
{ | |
"sentence": "This is the second sentence, supported by another snippet.", | |
"citations": [5] | |
} | |
] | |
``` | |
**Constraints:** | |
- Base your answer *only* on the information within the provided snippets. | |
- Do *not* use external knowledge. | |
- The sentences should flow together coherently. | |
- A single sentence can cite multiple snippets. | |
- The final answer should be no more than 5-6 sentences long. | |
- Ensure the output is valid JSON. | |
""" | |
ANSWER_PROMPT = """ | |
Given the following snippets, answer the question. | |
```json | |
{snippets} | |
``` | |
**Question:** | |
{question} | |
JSON Output:""" | |
# Initialize client using token from environment variables | |
client = InferenceClient(token=os.getenv("HF_TOKEN")) | |
# --- Endpoint Status Check Function --- | |
def check_endpoint_status(token: str | None, endpoint_name: str = LLM_ENDPOINT_NAME): | |
"""Checks the Inference Endpoint status and returns status dict.""" | |
# (Function body moved from app.py - Ensure logging is configured) | |
logging.info(f"Checking endpoint status for '{endpoint_name}'...") | |
if not token: | |
logging.warning("HF Token not available, cannot check endpoint status.") | |
return { | |
"status": "ready", | |
"warning": "HF Token not available for status check.", | |
} | |
try: | |
api = HfApi(token=token) | |
endpoint = api.get_inference_endpoint(name=endpoint_name, token=token) | |
status = endpoint.status | |
logging.info(f"Endpoint '{endpoint_name}' status: {status}") | |
if status == "running": | |
return {"status": "ready"} | |
else: | |
if status == "scaledToZero": | |
logging.info( | |
f"Endpoint '{endpoint_name}' is scaled to zero. Attempting to resume..." | |
) | |
try: | |
endpoint.resume() | |
user_message = f"The required LLM endpoint ('{endpoint_name}') was scaled to zero and is **now restarting**. Please wait a few minutes and try submitting your query again." | |
logging.info(f"Resume command sent for '{endpoint_name}'.") | |
return {"status": "error", "ui_message": user_message} | |
except Exception as resume_error: | |
logging.error( | |
f"Failed to resume endpoint '{endpoint_name}': {resume_error}" | |
) | |
user_message = f"The required LLM endpoint ('{endpoint_name}') is scaled to zero. An attempt to automatically resume it failed: {resume_error}. Please check the endpoint status on Hugging Face." | |
return {"status": "error", "ui_message": user_message} | |
else: | |
user_message = f"The required LLM endpoint ('{endpoint_name}') is currently **{status}**. Analysis cannot proceed until it is running. Please check the endpoint status on Hugging Face." | |
logging.warning( | |
f"Endpoint '{endpoint_name}' is not ready (Status: {status})." | |
) | |
return {"status": "error", "ui_message": user_message} | |
except Exception as e: | |
error_msg = f"Error checking endpoint status for {endpoint_name}: {e}" | |
logging.error(error_msg) | |
return { | |
"status": "error", | |
"ui_message": f"Failed to check endpoint status. Please verify the endpoint name ('{endpoint_name}') and your token. Error: {e}", | |
} | |
def retrieve_passages( | |
query, doc_embeds, passages, processed_docs, embed_model, max_docs=3 | |
): | |
"""Retrieves relevant passages based on embedding similarity, limited by max_docs.""" | |
queries = [query] | |
query_embeddings = embed_model.encode(queries, prompt_name="query") | |
scores = embed_model.similarity(query_embeddings, doc_embeds) | |
sorted_scores = scores.sort(descending=True) | |
sorted_vals = sorted_scores.values[0].tolist() | |
sorted_idx = sorted_scores.indices[0].tolist() | |
results = [ | |
{ | |
"passage_id": i, | |
"document_id": passages[i][0], | |
"chunk_id": passages[i][1], | |
"document_url": processed_docs[passages[i][0]]["url"], | |
"passage_text": passages[i][2], | |
"relevance": v, | |
} | |
for i, v in zip(sorted_idx, sorted_vals) | |
] | |
# Slice the results here | |
return results[:max_docs] | |
# --- Excerpt Processing Function --- | |
def process_single_excerpt( | |
excerpt_index: int, excerpt: dict, query: str, hf_client: InferenceClient | |
): | |
"""Processes a single retrieved excerpt using an LLM to find citations and spans.""" | |
passage_text = excerpt.get("passage_text", "") | |
if not passage_text: | |
return { | |
"citations": [], | |
"all_spans": [], | |
"parse_successful": False, | |
"raw_error_response": "Empty passage text", | |
} | |
citations = [] | |
all_spans = [] | |
is_parse_successful = False | |
raw_error_response = None | |
try: | |
retrieval_prompt = QA_RETRIEVAL_PROMPT.format( | |
document=passage_text, question=query | |
) | |
response = hf_client.chat_completion( | |
messages=[ | |
{"role": "system", "content": RETRIEVAL_SYSTEM_PROMPT}, | |
{"role": "user", "content": retrieval_prompt}, | |
], | |
model=os.getenv("HF_LLM_ENDPOINT_URL", DEFAULT_LLM_ENDPOINT_URL), | |
max_tokens=2048, | |
temperature=0.01, | |
) | |
# Attempt to parse JSON | |
response_content = response.choices[0].message.content.strip() | |
try: | |
# Find JSON block | |
json_match = response_content.split("```json", 1) | |
if len(json_match) > 1: | |
json_str = json_match[1].split("```", 1)[0] | |
parsed_json = json.loads(json_str) | |
citations = parsed_json | |
is_parse_successful = True | |
# Find spans for each citation | |
for cit in citations: | |
quote = cit.get("quote", "") | |
if quote: | |
# Call find_citation_spans from interface_utils | |
spans = interface_utils.find_citation_spans( | |
document=passage_text, citation=quote | |
) | |
cit["char_spans"] = spans # Store spans in the citation dict | |
all_spans.extend(spans) | |
else: | |
raise ValueError("No ```json block found in response") | |
except (json.JSONDecodeError, ValueError, IndexError) as json_e: | |
print(f"Error parsing JSON for excerpt {excerpt_index}: {json_e}") | |
is_parse_successful = False | |
raw_error_response = f"LLM Response (failed to parse): {response_content}" # Fixed potential newline issue | |
except Exception as llm_e: | |
print(f"Error during LLM call for excerpt {excerpt_index}: {llm_e}") | |
is_parse_successful = False | |
raw_error_response = f"LLM API Error: {llm_e}" | |
return { | |
"citations": citations, | |
"all_spans": all_spans, | |
"parse_successful": is_parse_successful, | |
"raw_error_response": raw_error_response, | |
} | |
def generate_summary_answer(snippets: list, query: str, hf_client: InferenceClient): | |
"""Generates a summarized answer based on provided snippets using an LLM.""" | |
# NOTE: Removed llm_endpoint_url parameter, using env var directly | |
endpoint_url = os.getenv("HF_LLM_ENDPOINT_URL", DEFAULT_LLM_ENDPOINT_URL) | |
if not snippets: | |
return { | |
"answer_sentences": [], | |
"parse_successful": False, | |
"raw_error_response": "No snippets provided for summarization.", | |
} | |
try: | |
# Ensure snippets are formatted as a JSON string for the prompt | |
snippets_json_string = json.dumps(snippets, indent=2) | |
answer_prompt_formatted = ANSWER_PROMPT.format( | |
snippets=snippets_json_string, question=query | |
) | |
response = hf_client.chat_completion( | |
messages=[ | |
{"role": "system", "content": ANSWER_SYSTEM_PROMPT}, | |
{"role": "user", "content": answer_prompt_formatted}, | |
], | |
model=endpoint_url, | |
max_tokens=512, | |
temperature=0.01, | |
) | |
# Attempt to parse JSON response | |
response_content = response.choices[0].message.content.strip() | |
try: | |
# Find JSON block (assuming it might be wrapped in ```json ... ```) | |
json_match = response_content.split("```json", 1) | |
if len(json_match) > 1: | |
json_str = json_match[1].split("```", 1)[0] | |
else: # Assume the response *is* the JSON if no backticks found | |
json_str = response_content | |
parsed_json = json.loads(json_str) | |
# Basic validation: check if it's a list of dictionaries with expected keys | |
if isinstance(parsed_json, list) and all( | |
isinstance(item, dict) and "sentence" in item and "citations" in item | |
for item in parsed_json | |
): | |
return { | |
"answer_sentences": parsed_json, | |
"parse_successful": True, | |
"raw_error_response": None, | |
} | |
else: | |
raise ValueError( | |
"Parsed JSON does not match expected format (list of {'sentence':..., 'citations':...})" | |
) | |
except (json.JSONDecodeError, ValueError, IndexError) as json_e: | |
print(f"Error parsing summary JSON: {json_e}") | |
return { | |
"answer_sentences": [], | |
"parse_successful": False, | |
"raw_error_response": f"LLM Response (failed to parse summary): {response_content}", | |
} | |
except Exception as llm_e: | |
print(f"Error during LLM summary call: {llm_e}") | |
return { | |
"answer_sentences": [], | |
"parse_successful": False, | |
"raw_error_response": f"LLM API Error during summary generation: {llm_e}", | |
} | |
# REMOVED Comment: This function will now live in app.py or interface_utils.py as it handles single excerpt processing | |
# def make_supporting_snippets(...): -> Now handled excerpt by excerpt in app.py | |