from functools import partial from typing import Tuple, List from .utils import ( CBLog, LongCepoConfig, get_prompt_response, concurrent_map, logger, loop_until_match, ) from .chunking import ( chunk_context, get_prompt_length, ) format_chunk_list = lambda chunk_list: [ f"Information of Chunk {index}:\n{doc}\n" for index, doc in enumerate(chunk_list) ] def remove_chunks(chunks: List[str], irrelevance_tags: Tuple[str]) -> List[str]: """ Filter out chunks that contain at least one of irrelevance tags. """ new_chunks = [] for chunk in chunks: # Skip None values resulting from failed API calls if chunk is None: continue flag = False for tag in irrelevance_tags: # Ensure tag comparison is safe even if tag is None (though unlikely) if tag and tag.upper() in chunk.upper(): flag = True break if not flag: new_chunks.append(chunk) return new_chunks def mapreduce( system_prompt: str, query: str, context: str, qa_history: str, client, model: str, tokenizer, longcepo_config: LongCepoConfig, cb_log: CBLog, answer_tags: Tuple[str] = ("Answer:", "**Answer**:", "**Answer**"), irrelevance_tags: Tuple[str] = ("[NO INFORMATION]",), ) -> Tuple[str, CBLog]: """ Executes a MapReduce-style inference pipeline to answer a query from long context. The function splits the input context into chunks, summarizes and evaluates each with the model (Map), collapses intermediate answers to reduce redundancy, and then generates a final answer (Reduce). Irrelevant responses are filtered based on `irrelevance_tags`. Args: system_prompt (str): System prompt string. query (str): User query. context (str): Long-form input context to process. qa_history (str): QA history string for prompt injection. client: LLM API client. model (str): Base model name. tokenizer: Tokenizer to compute token lengths. longcepo_config (LongCepoConfig): Config with hyper-parameters and token limits. cb_log (CBLog): Log object for tracking model calls. answer_tags (Tuple[str]): Tags used to extract the final answer from model output. irrelevance_tags (Tuple[str]): Tags used to identify and remove irrelevant outputs. Returns: Tuple[str, CBLog]: Final extracted answer and updated log object. """ logger.info(f"MapReduce query: {query}") qa_history_stub = ( f"\n\nAnswers to related questions:\n\n{qa_history}" if qa_history else "" ) context_chunks = chunk_context(context, longcepo_config.chunk_size, tokenizer) # Get short summaries of each chunk def fetch_chunk_summary(client, model, chunk, query, system_prompt): return get_prompt_response( client, model, longcepo_config.summary_prompt.format(question=query, context=chunk), system_prompt, max_tokens=longcepo_config.max_output_tokens_summary, temperature=longcepo_config.temperature_map, ) summaries, cb_log = concurrent_map( fetch_chunk_summary, client, model, context_chunks, query, system_prompt, cb_log, ) num_summaries = longcepo_config.num_neighbor_summaries # For each chunk position, get a neighborhood of `num_summaries` before and after the position summaries_per_chunk = [ "\n\n".join( summaries[ max(0, (summary_idx - num_summaries)) : min( len(summaries) - 1, (summary_idx + num_summaries) ) ] ) for summary_idx in range(len(summaries)) ] # Map stage def fetch_map_response(client, model, chunk, query, system_prompt, summary): return get_prompt_response( client, model, longcepo_config.map_prompt.format( question=query, context=chunk, summary=summary, qa_history_stub=qa_history_stub, ), system_prompt, max_tokens=longcepo_config.max_output_tokens, temperature=longcepo_config.temperature_map, ) result, cb_log = concurrent_map( fetch_map_response, client, model, context_chunks, query, system_prompt, cb_log, summaries_per_chunk=summaries_per_chunk, ) result = remove_chunks(result, irrelevance_tags) if not result: return "No information", cb_log logger.info( f"Removed {len(context_chunks) - len(result)} chunks from total {len(context_chunks)} chunks" ) # Collapse stage result, cb_log = collapse_chunks( client, model, result, query, system_prompt, qa_history_stub, tokenizer, cb_log, longcepo_config, ) result = remove_chunks(result, irrelevance_tags) if not result: return "No information", cb_log # Reduce stage prompt = longcepo_config.reduce_prompt.format( question=query, context=format_chunk_list(result), qa_history_stub=qa_history_stub, ) gen_fn = partial( get_prompt_response, client=client, model=model, prompt=prompt, system_prompt=system_prompt, max_tokens=longcepo_config.max_output_tokens, temperature=longcepo_config.temperature_reduce, ) reduce_result, upd_log = loop_until_match(gen_fn, answer_tags,) cb_log.update(upd_log) final_answer = reduce_result for answer_tag in answer_tags: if answer_tag in reduce_result: final_answer = reduce_result.split(answer_tag)[-1].strip() break return final_answer, cb_log def collapse_chunks( client, model: str, context_chunks: List[str], query: str, system_prompt: str, qa_history_stub: str, tokenizer, cb_log: CBLog, longcepo_config: LongCepoConfig, ) -> Tuple[List[str], CBLog]: """ Collapses context chunk pairs in sliding window until the total token count fits within the context window. Args: client: LLM API client. model (str): Base model name. context_chunks (List[str]): Input context chunks. query (str): User query. system_prompt (str): System prompt string. qa_history_stub (str): QA history prefix. tokenizer: Tokenizer to compute token lengths. cb_log (CBLog): Log object for tracking model calls. longcepo_config (LongCepoConfig): Config with hyper-parameters and token limits. Returns: Tuple[List[str], CBLog]: Final context chunks and updated logs. """ num_tokens = get_prompt_length(format_chunk_list(context_chunks), tokenizer) token_budget = ( longcepo_config.max_context_window - get_prompt_length(longcepo_config.collapse_prompt, tokenizer) - longcepo_config.max_output_tokens ) logger.info(f"Pre-collapse length of chunks {num_tokens}, allowed {token_budget}") def fetch_collapse_response(client, model, docs, query, system_prompt): if len(docs) == 1: return docs[0], CBLog() return get_prompt_response( client, model, longcepo_config.collapse_prompt.format( question=query, context="\n\n".join(docs), qa_history_stub=qa_history_stub, ), system_prompt, max_tokens=longcepo_config.max_output_tokens, temperature=longcepo_config.temperature_collapse, ) merge_pair_idx = 0 collapse_step = 0 while num_tokens is not None and num_tokens > token_budget: logger.info(f"Length at collapse stage {collapse_step}: {collapse_step}") if len(context_chunks) == 1: logger.info(f"Post-collapse length of chunks {num_tokens}") return context_chunks, cb_log # Merge chunk pair in a sliding window (merge_pair_idx:merge_pair_idx+2) chunk_groups = ( [(context_chunks[i],) for i in range(merge_pair_idx)] + [(context_chunks[merge_pair_idx], context_chunks[merge_pair_idx + 1])] + [ (context_chunks[i],) for i in range(merge_pair_idx + 2, len(context_chunks)) ] ) context_chunks, cb_log = concurrent_map( fetch_collapse_response, client, model, chunk_groups, query, system_prompt, cb_log, ) merge_pair_idx = (merge_pair_idx + 1) % max(len(context_chunks) - 1, 1) num_tokens = get_prompt_length(format_chunk_list(context_chunks), tokenizer) collapse_step += 1 logger.info(f"Post-collapse length of chunks {num_tokens}") return context_chunks, cb_log