File size: 12,690 Bytes
e672262
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
import json
import os
import logging

from huggingface_hub import HfApi, InferenceClient

import utils.interface_utils as interface_utils

# Renamed constant to indicate it's a default/fallback
DEFAULT_LLM_ENDPOINT_URL = (
    "https://r5lahjemc2zuajga.us-east-1.aws.endpoints.huggingface.cloud"
)

# Added Endpoint name constant
LLM_ENDPOINT_NAME = os.getenv(
    "HF_LLM_ENDPOINT_NAME", "phi-4-max"
)  # Get from env or default

RETRIEVAL_SYSTEM_PROMPT = """**Instructions:**
You are a helpful assistant presented with a document excerpts and a question.
Your job is to retrieve the most relevant passages from the provided document excerpt that contribute to help answer the question.

For each passage retrieved from the documents, provide:
- a brief summary of the context leading up to the passage (2 sentences max)
- the supported passage quoted exactly
- a brief summary of how the points in the passage are relevant to the question (2 sentences max)

The supporting passages should be a JSON-formatted list of dictionaries with the keys 'context' 'quote' and 'relevance'.
Provide up to 4 different supporting passages covering as many different aspects of the topic in question as possible.
Only include passages that are relevant to the question. If there are fewer or no relevant passages in the document, just return a shorter or empty list.
"""

QA_RETRIEVAL_PROMPT = """Find passages from the following documents that help answer the question.

**Document Content:**
```markdown
{document}
```

**Question:**
{question}

JSON Output:"""

ANSWER_SYSTEM_PROMPT = """**Instructions:**
You are a helpful assistant presented with a list of snippets extracted from documents and a question.
The snippets are presented in a JSON-formatted list that includes a unique id (`id`), context, relevance, and the exact quote.
Your job is to answer the question based *only* on the most relevant provided snippet quotes, citing the snippets used for each sentence.

**Output Format:**
Your response *must* be a JSON-formatted list of dictionaries. Each dictionary represents a sentence in your answer and must have the following keys:
- `sentence`: A string containing the sentence.
- `citations`: A list of integers, where each integer is the `id` of a snippet that supports the sentence. 

**Example Output:**
```json
[
  {
    "sentence": "This is the first sentence of the answer.",
    "citations": [1, 3]
  },
  {
    "sentence": "This is the second sentence, supported by another snippet.",
    "citations": [5]
  }
]
```

**Constraints:**
- Base your answer *only* on the information within the provided snippets.
- Do *not* use external knowledge.
- The sentences should flow together coherently.
- A single sentence can cite multiple snippets.
- The final answer should be no more than 5-6 sentences long.
- Ensure the output is valid JSON.
"""

ANSWER_PROMPT = """
Given the following snippets, answer the question.
```json
{snippets}
```

**Question:**
{question}

JSON Output:"""

# Initialize client using token from environment variables
client = InferenceClient(token=os.getenv("HF_TOKEN"))


# --- Endpoint Status Check Function ---
def check_endpoint_status(token: str | None, endpoint_name: str = LLM_ENDPOINT_NAME):
    """Checks the Inference Endpoint status and returns status dict."""
    # (Function body moved from app.py - Ensure logging is configured)
    logging.info(f"Checking endpoint status for '{endpoint_name}'...")
    if not token:
        logging.warning("HF Token not available, cannot check endpoint status.")
        return {
            "status": "ready",
            "warning": "HF Token not available for status check.",
        }
    try:
        api = HfApi(token=token)
        endpoint = api.get_inference_endpoint(name=endpoint_name, token=token)
        status = endpoint.status
        logging.info(f"Endpoint '{endpoint_name}' status: {status}")
        if status == "running":
            return {"status": "ready"}
        else:
            if status == "scaledToZero":
                logging.info(
                    f"Endpoint '{endpoint_name}' is scaled to zero. Attempting to resume..."
                )
                try:
                    endpoint.resume()
                    user_message = f"The required LLM endpoint ('{endpoint_name}') was scaled to zero and is **now restarting**. Please wait a few minutes and try submitting your query again."
                    logging.info(f"Resume command sent for '{endpoint_name}'.")
                    return {"status": "error", "ui_message": user_message}
                except Exception as resume_error:
                    logging.error(
                        f"Failed to resume endpoint '{endpoint_name}': {resume_error}"
                    )
                    user_message = f"The required LLM endpoint ('{endpoint_name}') is scaled to zero. An attempt to automatically resume it failed: {resume_error}. Please check the endpoint status on Hugging Face."
                    return {"status": "error", "ui_message": user_message}
            else:
                user_message = f"The required LLM endpoint ('{endpoint_name}') is currently **{status}**. Analysis cannot proceed until it is running. Please check the endpoint status on Hugging Face."
                logging.warning(
                    f"Endpoint '{endpoint_name}' is not ready (Status: {status})."
                )
                return {"status": "error", "ui_message": user_message}
    except Exception as e:
        error_msg = f"Error checking endpoint status for {endpoint_name}: {e}"
        logging.error(error_msg)
        return {
            "status": "error",
            "ui_message": f"Failed to check endpoint status. Please verify the endpoint name ('{endpoint_name}') and your token. Error: {e}",
        }


def retrieve_passages(
    query, doc_embeds, passages, processed_docs, embed_model, max_docs=3
):
    """Retrieves relevant passages based on embedding similarity, limited by max_docs."""
    queries = [query]
    query_embeddings = embed_model.encode(queries, prompt_name="query")
    scores = embed_model.similarity(query_embeddings, doc_embeds)
    sorted_scores = scores.sort(descending=True)
    sorted_vals = sorted_scores.values[0].tolist()
    sorted_idx = sorted_scores.indices[0].tolist()
    results = [
        {
            "passage_id": i,
            "document_id": passages[i][0],
            "chunk_id": passages[i][1],
            "document_url": processed_docs[passages[i][0]]["url"],
            "passage_text": passages[i][2],
            "relevance": v,
        }
        for i, v in zip(sorted_idx, sorted_vals)
    ]
    # Slice the results here
    return results[:max_docs]


# --- Excerpt Processing Function ---
def process_single_excerpt(
    excerpt_index: int, excerpt: dict, query: str, hf_client: InferenceClient
):
    """Processes a single retrieved excerpt using an LLM to find citations and spans."""

    passage_text = excerpt.get("passage_text", "")
    if not passage_text:
        return {
            "citations": [],
            "all_spans": [],
            "parse_successful": False,
            "raw_error_response": "Empty passage text",
        }

    citations = []
    all_spans = []
    is_parse_successful = False
    raw_error_response = None

    try:
        retrieval_prompt = QA_RETRIEVAL_PROMPT.format(
            document=passage_text, question=query
        )
        response = hf_client.chat_completion(
            messages=[
                {"role": "system", "content": RETRIEVAL_SYSTEM_PROMPT},
                {"role": "user", "content": retrieval_prompt},
            ],
            model=os.getenv("HF_LLM_ENDPOINT_URL", DEFAULT_LLM_ENDPOINT_URL),
            max_tokens=2048,
            temperature=0.01,
        )

        # Attempt to parse JSON
        response_content = response.choices[0].message.content.strip()
        try:
            # Find JSON block
            json_match = response_content.split("```json", 1)
            if len(json_match) > 1:
                json_str = json_match[1].split("```", 1)[0]
                parsed_json = json.loads(json_str)
                citations = parsed_json
                is_parse_successful = True
                # Find spans for each citation
                for cit in citations:
                    quote = cit.get("quote", "")
                    if quote:
                        # Call find_citation_spans from interface_utils
                        spans = interface_utils.find_citation_spans(
                            document=passage_text, citation=quote
                        )
                        cit["char_spans"] = spans  # Store spans in the citation dict
                        all_spans.extend(spans)
            else:
                raise ValueError("No ```json block found in response")
        except (json.JSONDecodeError, ValueError, IndexError) as json_e:
            print(f"Error parsing JSON for excerpt {excerpt_index}: {json_e}")
            is_parse_successful = False
            raw_error_response = f"LLM Response (failed to parse): {response_content}"  # Fixed potential newline issue

    except Exception as llm_e:
        print(f"Error during LLM call for excerpt {excerpt_index}: {llm_e}")
        is_parse_successful = False
        raw_error_response = f"LLM API Error: {llm_e}"

    return {
        "citations": citations,
        "all_spans": all_spans,
        "parse_successful": is_parse_successful,
        "raw_error_response": raw_error_response,
    }


def generate_summary_answer(snippets: list, query: str, hf_client: InferenceClient):
    """Generates a summarized answer based on provided snippets using an LLM."""
    # NOTE: Removed llm_endpoint_url parameter, using env var directly
    endpoint_url = os.getenv("HF_LLM_ENDPOINT_URL", DEFAULT_LLM_ENDPOINT_URL)
    if not snippets:
        return {
            "answer_sentences": [],
            "parse_successful": False,
            "raw_error_response": "No snippets provided for summarization.",
        }

    try:
        # Ensure snippets are formatted as a JSON string for the prompt
        snippets_json_string = json.dumps(snippets, indent=2)

        answer_prompt_formatted = ANSWER_PROMPT.format(
            snippets=snippets_json_string, question=query
        )

        response = hf_client.chat_completion(
            messages=[
                {"role": "system", "content": ANSWER_SYSTEM_PROMPT},
                {"role": "user", "content": answer_prompt_formatted},
            ],
            model=endpoint_url,
            max_tokens=512,
            temperature=0.01,
        )

        # Attempt to parse JSON response
        response_content = response.choices[0].message.content.strip()
        try:
            # Find JSON block (assuming it might be wrapped in ```json ... ```)
            json_match = response_content.split("```json", 1)
            if len(json_match) > 1:
                json_str = json_match[1].split("```", 1)[0]
            else:  # Assume the response *is* the JSON if no backticks found
                json_str = response_content

            parsed_json = json.loads(json_str)

            # Basic validation: check if it's a list of dictionaries with expected keys
            if isinstance(parsed_json, list) and all(
                isinstance(item, dict) and "sentence" in item and "citations" in item
                for item in parsed_json
            ):
                return {
                    "answer_sentences": parsed_json,
                    "parse_successful": True,
                    "raw_error_response": None,
                }
            else:
                raise ValueError(
                    "Parsed JSON does not match expected format (list of {'sentence':..., 'citations':...})"
                )

        except (json.JSONDecodeError, ValueError, IndexError) as json_e:
            print(f"Error parsing summary JSON: {json_e}")
            return {
                "answer_sentences": [],
                "parse_successful": False,
                "raw_error_response": f"LLM Response (failed to parse summary): {response_content}",
            }

    except Exception as llm_e:
        print(f"Error during LLM summary call: {llm_e}")
        return {
            "answer_sentences": [],
            "parse_successful": False,
            "raw_error_response": f"LLM API Error during summary generation: {llm_e}",
        }


# REMOVED Comment: This function will now live in app.py or interface_utils.py as it handles single excerpt processing
# def make_supporting_snippets(...): -> Now handled excerpt by excerpt in app.py