import re from typing import Tuple from functools import partial from .mapreduce import mapreduce from .utils import ( get_prompt_response, logger, longcepo_init, loop_until_match, ) def run_longcepo( system_prompt: str, initial_query: str, client, model: str ) -> Tuple[str, int]: """ Executes the full LongCePO multi-stage pipeline to answer a complex query from long context. The pipeline includes: - Normalizing the format of the original query - Generating a plan of sub-questions - Iteratively answering each sub-question using a MapReduce-style question-answering engine - Aggregating QA history and producing a final answer with MapReduce Args: system_prompt (str): System prompt string. initial_query (str): Raw input string containing context and query separated by delimiter string. client: LLM API client instance. model (str): Base model name. Returns: Tuple[str, int]: Final answer and total number of tokens used across the pipeline. """ context, query, tokenizer, cb_log, longcepo_config = longcepo_init(initial_query) # Normalize query normalized_query, upd_log = get_prompt_response( client, model, longcepo_config.query_format_prompt.format(full_query=query), system_prompt, max_tokens=longcepo_config.max_output_tokens, ) cb_log.update(upd_log) logger.info(f"Normalized query: {normalized_query}") # Planning stage prompt = f"The question is: {normalized_query}" gen_fn = partial( get_prompt_response, client=client, model=model, prompt=prompt, system_prompt=longcepo_config.planning_system_prompt, max_tokens=longcepo_config.max_output_tokens, ) planning_response, upd_log = loop_until_match( gen_fn, pattern_list=("",) ) logger.info(f"Planning stage output:\n\n{planning_response}") questions = ( re.findall( r"\s*(.*?)\s*", planning_response, re.DOTALL )[0] .strip() .splitlines() ) # Loop to answer sub-queries from the plan qa_system_prompt = ( longcepo_config.system_prompt if longcepo_config.system_prompt is not None else system_prompt ) qa_history = "" for question in questions: if not question: continue question = re.sub(r"^\d+\.", "", question) answer, cb_log = mapreduce( qa_system_prompt, question, context, qa_history, client, model, tokenizer, longcepo_config=longcepo_config, cb_log=cb_log, ) qa_history += f"- Previous question: {question}\n\n" answer = re.sub(r"^:+", "", answer) qa_history += f"- Previous answer: {answer}\n\n" logger.info(f"QA history:\n\n{qa_history}") # Final answer generation answer, cb_log = mapreduce( qa_system_prompt, query, context, qa_history, client, model, tokenizer, longcepo_config=longcepo_config, cb_log=cb_log, ) return answer, cb_log["total_tokens"]