File size: 3,307 Bytes
493728d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import re
from typing import Tuple
from functools import partial
from .mapreduce import mapreduce
from .utils import (
get_prompt_response,
logger,
longcepo_init,
loop_until_match,
)
def run_longcepo(
system_prompt: str, initial_query: str, client, model: str
) -> Tuple[str, int]:
"""
Executes the full LongCePO multi-stage pipeline to answer a complex query from long context.
The pipeline includes:
- Normalizing the format of the original query
- Generating a plan of sub-questions
- Iteratively answering each sub-question using a MapReduce-style question-answering engine
- Aggregating QA history and producing a final answer with MapReduce
Args:
system_prompt (str): System prompt string.
initial_query (str): Raw input string containing context and query separated by delimiter string.
client: LLM API client instance.
model (str): Base model name.
Returns:
Tuple[str, int]: Final answer and total number of tokens used across the pipeline.
"""
context, query, tokenizer, cb_log, longcepo_config = longcepo_init(initial_query)
# Normalize query
normalized_query, upd_log = get_prompt_response(
client,
model,
longcepo_config.query_format_prompt.format(full_query=query),
system_prompt,
max_tokens=longcepo_config.max_output_tokens,
)
cb_log.update(upd_log)
logger.info(f"Normalized query: {normalized_query}")
# Planning stage
prompt = f"The question is: {normalized_query}"
gen_fn = partial(
get_prompt_response,
client=client,
model=model,
prompt=prompt,
system_prompt=longcepo_config.planning_system_prompt,
max_tokens=longcepo_config.max_output_tokens,
)
planning_response, upd_log = loop_until_match(
gen_fn, pattern_list=("<SUB-QUESTIONS>",)
)
logger.info(f"Planning stage output:\n\n{planning_response}")
questions = (
re.findall(
r"<SUB-QUESTIONS>\s*(.*?)\s*</SUB-QUESTIONS>", planning_response, re.DOTALL
)[0]
.strip()
.splitlines()
)
# Loop to answer sub-queries from the plan
qa_system_prompt = (
longcepo_config.system_prompt
if longcepo_config.system_prompt is not None
else system_prompt
)
qa_history = ""
for question in questions:
if not question:
continue
question = re.sub(r"^\d+\.", "", question)
answer, cb_log = mapreduce(
qa_system_prompt,
question,
context,
qa_history,
client,
model,
tokenizer,
longcepo_config=longcepo_config,
cb_log=cb_log,
)
qa_history += f"- Previous question: {question}\n\n"
answer = re.sub(r"^:+", "", answer)
qa_history += f"- Previous answer: {answer}\n\n"
logger.info(f"QA history:\n\n{qa_history}")
# Final answer generation
answer, cb_log = mapreduce(
qa_system_prompt,
query,
context,
qa_history,
client,
model,
tokenizer,
longcepo_config=longcepo_config,
cb_log=cb_log,
)
return answer, cb_log["total_tokens"]
|