Upload 22 files
Browse files- README.md +31 -6
- __pycache__/config_sambanova.cpython-310.pyc +0 -0
- __pycache__/run_chatbot.cpython-310.pyc +0 -0
- app.py +38 -59
- longcepo.py +15 -0
- longcepo/README.md +92 -0
- longcepo/__init__.py +0 -0
- longcepo/__pycache__/__init__.cpython-310.pyc +0 -0
- longcepo/__pycache__/chunking.cpython-310.pyc +0 -0
- longcepo/__pycache__/config.cpython-310.pyc +0 -0
- longcepo/__pycache__/main.cpython-310.pyc +0 -0
- longcepo/__pycache__/mapreduce.cpython-310.pyc +0 -0
- longcepo/__pycache__/prompts.cpython-310.pyc +0 -0
- longcepo/__pycache__/utils.cpython-310.pyc +0 -0
- longcepo/chunking.py +248 -0
- longcepo/config.py +36 -0
- longcepo/main.py +109 -0
- longcepo/mapreduce.py +281 -0
- longcepo/prompts.py +16 -0
- longcepo/utils.py +191 -0
- requirements.txt +5 -1
- run_chatbot.py +64 -0
README.md
CHANGED
@@ -1,12 +1,37 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 5.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: LongCePO Chatbot (Sambanova)
|
3 |
+
emoji: 🤖
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: green
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 5.27.1
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
+
# LongCePO Chatbot with Sambanova Backend
|
13 |
+
|
14 |
+
This is a simple chatbot interface demonstrating the LongCePO (Long-Context Planning and Optimization) method using a Sambanova model (`Llama-4-Maverick-17B-128E-Instruct`) as the backend LLM.
|
15 |
+
|
16 |
+
## How it works
|
17 |
+
|
18 |
+
The LongCePO method is designed to handle long contexts (potentially millions of tokens) by:
|
19 |
+
1. **Planning:** Decomposing the initial query into sub-questions.
|
20 |
+
2. **MapReduce:** Answering each sub-question by processing chunks of the long context, summarizing relevant information, and aggregating results.
|
21 |
+
|
22 |
+
This application takes a long text context and a query based on that context. It then uses the modified `longcepo` plugin (originally from the `optillm` repository) to generate an answer using the Sambanova API.
|
23 |
+
|
24 |
+
## How to use
|
25 |
+
|
26 |
+
1. **(Optional)** Enter a system prompt to guide the chatbot's behavior.
|
27 |
+
2. Paste the long text document into the **Context** box.
|
28 |
+
3. Enter your question based on the provided context into the **Query** box.
|
29 |
+
4. Click **Submit**.
|
30 |
+
|
31 |
+
The chatbot will process the request using the LongCePO pipeline and display the final answer.
|
32 |
+
|
33 |
+
**Note:** Processing long contexts can take some time depending on the length of the context and the complexity of the query.
|
34 |
+
|
35 |
+
## API Key
|
36 |
+
|
37 |
+
This application requires a Sambanova API key to function. The key should be stored as a Hugging Face Space Secret named `SAMBANOVA_API_KEY`.
|
__pycache__/config_sambanova.cpython-310.pyc
ADDED
Binary file (200 Bytes). View file
|
|
__pycache__/run_chatbot.cpython-310.pyc
ADDED
Binary file (2.23 kB). View file
|
|
app.py
CHANGED
@@ -1,64 +1,43 @@
|
|
1 |
import gradio as gr
|
2 |
-
from
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
messages,
|
32 |
-
max_tokens=max_tokens,
|
33 |
-
stream=True,
|
34 |
-
temperature=temperature,
|
35 |
-
top_p=top_p,
|
36 |
-
):
|
37 |
-
token = message.choices[0].delta.content
|
38 |
-
|
39 |
-
response += token
|
40 |
-
yield response
|
41 |
-
|
42 |
-
|
43 |
-
"""
|
44 |
-
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
|
45 |
-
"""
|
46 |
-
demo = gr.ChatInterface(
|
47 |
-
respond,
|
48 |
-
additional_inputs=[
|
49 |
-
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
|
50 |
-
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
|
51 |
-
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
|
52 |
-
gr.Slider(
|
53 |
-
minimum=0.1,
|
54 |
-
maximum=1.0,
|
55 |
-
value=0.95,
|
56 |
-
step=0.05,
|
57 |
-
label="Top-p (nucleus sampling)",
|
58 |
-
),
|
59 |
],
|
|
|
|
|
|
|
|
|
60 |
)
|
61 |
|
62 |
-
|
63 |
if __name__ == "__main__":
|
64 |
-
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
from run_chatbot import process_with_longcepo, SAMBANOVA_MODEL
|
3 |
+
|
4 |
+
def chatbot_interface(system_prompt, context, query):
|
5 |
+
"""Gradio interface function to interact with the LongCePO chatbot."""
|
6 |
+
if not context or not query:
|
7 |
+
return "Please provide both context and query."
|
8 |
+
|
9 |
+
# Combine context and query using the expected delimiter
|
10 |
+
initial_query = f"{context}<CONTEXT_END>{query}"
|
11 |
+
|
12 |
+
# Use a default system prompt if none is provided
|
13 |
+
if not system_prompt:
|
14 |
+
system_prompt = "You are a helpful assistant designed to answer questions based on the provided context."
|
15 |
+
|
16 |
+
print(f"Received request:\nSystem Prompt: {system_prompt}\nContext: {context[:100]}...\nQuery: {query}")
|
17 |
+
|
18 |
+
# Call the processing function
|
19 |
+
result = process_with_longcepo(system_prompt, initial_query)
|
20 |
+
|
21 |
+
print(f"Returning result: {result[:100]}...")
|
22 |
+
return result
|
23 |
+
|
24 |
+
# Define Gradio interface components
|
25 |
+
iface = gr.Interface(
|
26 |
+
fn=chatbot_interface,
|
27 |
+
inputs=[
|
28 |
+
gr.Textbox(label="System Prompt (Optional)", placeholder="Enter system prompt here...", lines=2),
|
29 |
+
gr.Textbox(label="Context", placeholder="Enter the long context here...", lines=10),
|
30 |
+
gr.Textbox(label="Query", placeholder="Enter your query based on the context here...", lines=2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
],
|
32 |
+
outputs=gr.Textbox(label="Answer", lines=10),
|
33 |
+
title=f"LongCePO Chatbot ({SAMBANOVA_MODEL})",
|
34 |
+
description="Enter a long context and a query. The chatbot will use the LongCePO method with Sambanova backend to generate an answer.",
|
35 |
+
allow_flagging="never"
|
36 |
)
|
37 |
|
38 |
+
# Launch the Gradio app
|
39 |
if __name__ == "__main__":
|
40 |
+
print("Launching Gradio interface...")
|
41 |
+
# Listen on 0.0.0.0 to make it accessible externally if needed
|
42 |
+
iface.launch(server_name="0.0.0.0", server_port=7860)
|
43 |
+
|
longcepo.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""The Long-Context Cerebras Planning and Optimization (LongCePO) Method
|
2 |
+
|
3 |
+
LongCePO is an inference-time computation method designed to provide LLMs with the capability to work with infinite context such as external knowledge bases that can run into millions of tokens. We achieve this goal through a combination of multiple strategies including planning (query decomposition) and divide-and-conquer long-context processing. This approach enables to use a limited context window (e.g. 8K) and outperform full-context processing with the same base model in many question-answering tasks.
|
4 |
+
|
5 |
+
If you have any questions or want to contribute, please reach out to us on [cerebras.ai/discord](https://cerebras.ai/discord).
|
6 |
+
"""
|
7 |
+
|
8 |
+
from typing import Tuple
|
9 |
+
from .longcepo.main import run_longcepo
|
10 |
+
|
11 |
+
|
12 |
+
SLUG = "longcepo"
|
13 |
+
|
14 |
+
def run(system_prompt: str, initial_query: str, client, model: str) -> Tuple[str, int]:
|
15 |
+
return run_longcepo(system_prompt, initial_query, client, model)
|
longcepo/README.md
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# The Long-Context Cerebras Planning and Optimization (LongCePO) Method
|
2 |
+
|
3 |
+
LongCePO is an inference-time computation method designed to provide LLMs with the capability to work with infinite context such as external knowledge bases that can run into millions of tokens. We achieve this goal through a combination of multiple strategies including planning (query decomposition) and divide-and-conquer long-context processing. This approach enables to use a limited context window (e.g. 8K) and outperform full-context processing with the same base model in many question-answering tasks.
|
4 |
+
|
5 |
+
If you have any questions or want to contribute, please reach out to us on [cerebras.ai/discord](https://cerebras.ai/discord).
|
6 |
+
|
7 |
+
## Usage
|
8 |
+
|
9 |
+
Start the optillm proxy server with directory to plugins specified in the command line:
|
10 |
+
|
11 |
+
```bash
|
12 |
+
python optillm.py --base-url https://api.cerebras.ai/v1 --port <port> --plugins-dir ./optillm/plugins
|
13 |
+
```
|
14 |
+
|
15 |
+
Now, you can send requests to the proxy using model name `longcepo-{model_name}` (e.g. `longcepo-llama-3.3-70b`) using the following format of the user message: `{context}<CONTEXT_END>{query}`. The `<CONTEXT_END>` delimiter string is used to split the user message into the (long) context and the user's query, respectively. This delimiter string can be changed (along with other LongCePO parameters) in the [config file](./config.py).
|
16 |
+
|
17 |
+
|
18 |
+
## LongCePO Results
|
19 |
+
|
20 |
+
LongCePO excels at tasks with long context (128K tokens and more) which is demonstrated below on LongBench v2 and HELMET benchmarks in comparison to frontier models. We additionally provide data points for tasks with shorter context that still exceeds the context window of 8K (HotpotQA and MuSiQue samples of 12-16K length). For our evaluations, we report mean and standard deviation of the target metric over 5 runs below.
|
21 |
+
|
22 |
+
### LongBench v2
|
23 |
+
|
24 |
+
| Model¹ | Context window | Short samples (up to 32K words) | Medium samples (32–128K words) |
|
25 |
+
|----------------------------------|----------------|------------------|----------------|
|
26 |
+
| Llama 3.3 70B Instruct | 128K | 36.7 (45.0) | 27.0 (33.0) |
|
27 |
+
| **LongCePO + Llama 3.3 70B Instruct** | **8K** | **36.8 ± 1.38** | **38.7 ± 2.574 (39.735)²** |
|
28 |
+
| Mistral-Large-Instruct-2411 | 128K | 41.7 (46.1) | 30.7 (34.9) |
|
29 |
+
| o1-mini-2024-09-12 | 128K | 48.6 (48.9) | 33.3 (32.9) |
|
30 |
+
| Claude-3.5-Sonnet-20241022 | 200K | 46.1 (53.9) | 38.6 (41.9) |
|
31 |
+
| Llama-4-Maverick-17B-128E-Instruct | 524K | 32.22 (50.56) | 28.84 (41.86) |
|
32 |
+
|
33 |
+
¹ Performance numbers reported by LongBench v2 authors, except for LongCePO and Llama-4-Maverick results. Results in parentheses reported in LongBench v2 correspond to Chain-of-Thought prompting.
|
34 |
+
|
35 |
+
² Results in parentheses for LongCePO indicate accuracy of majority voting from 5 runs.
|
36 |
+
|
37 |
+
### HELMET (InfiniteBench En.MC, 128K length)
|
38 |
+
|
39 |
+
| Model | Accuracy (%) |
|
40 |
+
|---------|---------------|
|
41 |
+
| Llama 3.3 70B Instruct (full context) | 58.0 |
|
42 |
+
| **LongCePO + Llama 3.3 70B Instruct (8K context)** | **71.6 ± 1.855 (73.0)¹** |
|
43 |
+
| o1-mini-2024-09-12 (full context) | 58.0 |
|
44 |
+
| gpt-4o-2024-08-06 (full context) | 74.0 |
|
45 |
+
|
46 |
+
¹ Numbers in parentheses for LongCePO indicate accuracy of majority voting from 5 runs.
|
47 |
+
|
48 |
+
### LongBench v1 (HotpotQA, 12K+ length - 124 samples)
|
49 |
+
|
50 |
+
| Model | F1 Metric (%) | LLM-as-a-judge accuracy (%) |
|
51 |
+
|---------|---------------|-----------------------------|
|
52 |
+
| Llama 3.3 70B Instruct (full context) | 63.372 ± 0.269 | 77.903 ± 0.822 |
|
53 |
+
| **LongCePO + Llama 3.3 70B Instruct (8K context)** | **64.842 ± 1.295** | **79.355 ± 1.66** |
|
54 |
+
|
55 |
+
### LongBench v1 (MuSiQue, 12K+ length - 191 samples)
|
56 |
+
|
57 |
+
| Model | F1 Metric (%) | LLM-as-a-judge accuracy (%) |
|
58 |
+
|---------|---------------|-----------------------------|
|
59 |
+
| Llama 3.3 70B Instruct (full context) | 48.481 ± 0.641 | 49.424 ± 0.71 |
|
60 |
+
| **LongCePO + Llama 3.3 70B Instruct (8K context)** | **54.076 ± 2.059** | **60.628 ± 2.156** |
|
61 |
+
|
62 |
+
|
63 |
+
## LongCePO Methodology
|
64 |
+
|
65 |
+
LongCePO is based on the [LLM×MapReduce](https://arxiv.org/abs/2410.09342) approach to long document processing, adding a planning layer on top of a map-reduce-based question-answering engine. We also improve upon the map-reduce approach itself by (i) adding query-aware summaries of neighboring document chunks during the map stage of the processing, (ii) reducing the collapse (merging) stage to a minimum required number of collapse iterations by using a sliding window to iteratively merge pairs of summaries, (iii) using a customized system prompt produced with an [OPRO-like](https://arxiv.org/abs/2309.03409) optimization approach to enhance question-anwering performance. Given a user query, a plan consisting of sub-queries is generated from a normalized query; a map-reduce question-answering engine is then run for each sub-query consecutively, conditioned on the answers to previous sub-queries. Finally, the answer to original user's query is produced via map-reduce conditioned on answers to the whole plan. Similarly to [LLM×MapReduce](https://arxiv.org/abs/2410.09342), we retain the structured information protocol for producing document chunk summaries. We find that splitting the document into chunks of size smaller than the available context window (e.g. chunks of 4K size with available context window of 8K) leads to better performance, and use the remaning context budget to incorporate summaries from neighboring chunks into the map stage for each respective chunks, leading to a further boost in overall performance.
|
66 |
+
|
67 |
+
Note: the system prompt for Map/Collapse/Reduce stages has been optimized for the Llama3.3-70B-Instruct model, when using other base models with LongCePO, a more general system prompt can be used ([example](https://github.com/DenisSergeevitch/chatgpt-custom-instructions)).
|
68 |
+
|
69 |
+
|
70 |
+
## LongCePO Current Status
|
71 |
+
|
72 |
+
This project is a work in progress, and the provided code is in an early experimental stage. While the proposed approach works well across the benchmarks we tested, further improvements can be achieved through a smart organization of the external knowledge base as well as customization of the plan generation to different tasks. For updates on LongCePO, [follow us on X](https://x.com/cerebrassystems) and join our [Discord](https://cerebras.ai/discord)!
|
73 |
+
|
74 |
+
|
75 |
+
## References
|
76 |
+
|
77 |
+
1. Zhou, Zihan, et al. *LLM×MapReduce: Simplified Long-Sequence Processing using Large Language Models.* arXiv preprint arXiv:2410.09342 (2024).
|
78 |
+
|
79 |
+
2. Yang, Chengrun, et al. *Large language models as optimizers.* arXiv preprint arXiv:2309.03409 (2023).
|
80 |
+
|
81 |
+
## Citing LongCePO
|
82 |
+
|
83 |
+
```bibtex
|
84 |
+
@misc{
|
85 |
+
cerebras2025longcepo,
|
86 |
+
author = {Lazarevich, Ivan and Hassanpour, Mohammad and Venkatesh, Ganesh},
|
87 |
+
title = {LongCePO: Empowering LLMs to efficiently leverage infinite context},
|
88 |
+
month = March,
|
89 |
+
year = 2025,
|
90 |
+
howpublished = {\url{https://cerebras.ai/blog/longcepo}, }
|
91 |
+
}
|
92 |
+
```
|
longcepo/__init__.py
ADDED
File without changes
|
longcepo/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (143 Bytes). View file
|
|
longcepo/__pycache__/chunking.cpython-310.pyc
ADDED
Binary file (6.2 kB). View file
|
|
longcepo/__pycache__/config.cpython-310.pyc
ADDED
Binary file (1.38 kB). View file
|
|
longcepo/__pycache__/main.cpython-310.pyc
ADDED
Binary file (2.5 kB). View file
|
|
longcepo/__pycache__/mapreduce.cpython-310.pyc
ADDED
Binary file (6.87 kB). View file
|
|
longcepo/__pycache__/prompts.cpython-310.pyc
ADDED
Binary file (9.51 kB). View file
|
|
longcepo/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (6.38 kB). View file
|
|
longcepo/chunking.py
ADDED
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Code modified from https://github.com/thunlp/LLMxMapReduce under Apache 2.0
|
2 |
+
|
3 |
+
import re
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
from .utils import logger
|
7 |
+
|
8 |
+
|
9 |
+
def get_prompt_length(prompt: str, tokenizer, no_special_tokens=False, **kwargs) -> int:
|
10 |
+
"""
|
11 |
+
Returns the token length of a prompt using the given tokenizer.
|
12 |
+
"""
|
13 |
+
if isinstance(prompt, list):
|
14 |
+
prompt = "\n\n".join(prompt)
|
15 |
+
if no_special_tokens:
|
16 |
+
kwargs["add_special_tokens"] = False
|
17 |
+
return len(tokenizer.encode(prompt, **kwargs))
|
18 |
+
|
19 |
+
|
20 |
+
def chunk_context(doc: str, chunk_size: int, tokenizer, separator="\n",) -> List[str]:
|
21 |
+
"""
|
22 |
+
Splits a long document into token-limited chunks based on a separator, ensuring each chunk fits within `chunk_size`.
|
23 |
+
|
24 |
+
Uses a greedy approach to accumulate text segments (split by `separator`) into chunks that fit within the
|
25 |
+
token limit. If a segment alone exceeds the limit, it is recursively broken down using sentence-level
|
26 |
+
splitting. Attempts to preserve natural boundaries while minimizing excessive chunking.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
doc (str): Input document to split.
|
30 |
+
chunk_size (int): Maximum number of tokens allowed per chunk.
|
31 |
+
tokenizer: Tokenizer instance with `.encode()` method to compute token length.
|
32 |
+
separator (str): Delimiter to split initial segments (default: newline).
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
List[str]: List of non-empty, token-constrained document chunks.
|
36 |
+
"""
|
37 |
+
paragraphs = doc.split(separator)
|
38 |
+
paragraphs = [paragraph for paragraph in paragraphs if paragraph]
|
39 |
+
separator_len = get_prompt_length(separator, tokenizer, no_special_tokens=True)
|
40 |
+
|
41 |
+
docs = []
|
42 |
+
current_doc = []
|
43 |
+
total = 0
|
44 |
+
for paragraph in paragraphs:
|
45 |
+
plen = get_prompt_length(paragraph, tokenizer, no_special_tokens=True)
|
46 |
+
if total + plen + (separator_len if len(current_doc) > 0 else 0) > chunk_size:
|
47 |
+
if total > chunk_size:
|
48 |
+
logger.info(
|
49 |
+
f"Created a chunk of size {total}, "
|
50 |
+
f"which is longer than the specified {chunk_size}"
|
51 |
+
)
|
52 |
+
# If single chunk is too long, split into more granular
|
53 |
+
if len(current_doc) == 1:
|
54 |
+
split_again = split_into_granular_chunks(
|
55 |
+
current_doc[0], chunk_size, tokenizer
|
56 |
+
)
|
57 |
+
docs.extend(split_again)
|
58 |
+
current_doc = []
|
59 |
+
total = 0
|
60 |
+
|
61 |
+
if len(current_doc) > 0:
|
62 |
+
doc = separator.join(current_doc)
|
63 |
+
if doc is not None:
|
64 |
+
docs.append(doc)
|
65 |
+
while total > 0 or (
|
66 |
+
total + plen + (separator_len if len(current_doc) > 0 else 0)
|
67 |
+
> chunk_size
|
68 |
+
and total > 0
|
69 |
+
):
|
70 |
+
total -= get_prompt_length(
|
71 |
+
current_doc[0], tokenizer, no_special_tokens=True
|
72 |
+
) + (separator_len if len(current_doc) > 1 else 0)
|
73 |
+
current_doc = current_doc[1:]
|
74 |
+
|
75 |
+
current_doc.append(paragraph)
|
76 |
+
total += plen + (separator_len if len(current_doc) > 1 else 0)
|
77 |
+
# Check if the last one exceeds
|
78 |
+
if (
|
79 |
+
get_prompt_length(current_doc[-1], tokenizer, no_special_tokens=True)
|
80 |
+
> chunk_size
|
81 |
+
and len(current_doc) == 1
|
82 |
+
):
|
83 |
+
split_again = split_into_granular_chunks(current_doc[0], chunk_size, tokenizer)
|
84 |
+
docs.extend(split_again)
|
85 |
+
current_doc = []
|
86 |
+
else:
|
87 |
+
doc = separator.join(current_doc)
|
88 |
+
if doc is not None:
|
89 |
+
docs.append(doc)
|
90 |
+
|
91 |
+
return [doc for doc in docs if doc.strip()]
|
92 |
+
|
93 |
+
|
94 |
+
def split_sentences(text: str, spliter: str):
|
95 |
+
"""
|
96 |
+
Splits text into sentences or segments based on a given delimiter while preserving punctuation.
|
97 |
+
|
98 |
+
For punctuation-based splitters (e.g., ".", "!", "。"), it interleaves text and punctuation.
|
99 |
+
For space-based splitting, it preserves trailing spaces.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
text (str): The input text to split.
|
103 |
+
spliter (str): Delimiter regex pattern (e.g., r"([.!?])", r"(。)", or " ").
|
104 |
+
|
105 |
+
Returns:
|
106 |
+
List[str]: List of split sentence-like segments with punctuation retained.
|
107 |
+
"""
|
108 |
+
|
109 |
+
# Split by punctuation and keep punctuation
|
110 |
+
text = text.strip()
|
111 |
+
sentence_list = re.split(spliter, text)
|
112 |
+
|
113 |
+
# Rearrange sentences and punctuation
|
114 |
+
if spliter != " ":
|
115 |
+
sentences = ["".join(i) for i in zip(sentence_list[0::2], sentence_list[1::2])]
|
116 |
+
if len(sentence_list) % 2 != 0 and sentence_list[-1] != "":
|
117 |
+
sentences.append(sentence_list[-1])
|
118 |
+
else:
|
119 |
+
sentences = [i + " " for i in sentence_list if i != ""]
|
120 |
+
sentences[-1] = sentences[-1].strip()
|
121 |
+
return sentences
|
122 |
+
|
123 |
+
|
124 |
+
def split_into_granular_chunks(
|
125 |
+
text: str, chunk_size: int, tokenizer, spliter=r"([。!?;.?!;])",
|
126 |
+
) -> List[str]:
|
127 |
+
"""
|
128 |
+
Splits long text into granular, token-length-constrained chunks using sentence boundaries.
|
129 |
+
|
130 |
+
Sentences are first extracted using a delimiter pattern (`spliter`), then grouped into chunks such that
|
131 |
+
each chunk does not exceed the specified `chunk_size` (in tokens). If a chunk still exceeds the limit,
|
132 |
+
it is recursively broken down further using whitespace as a fallback.
|
133 |
+
|
134 |
+
Ensures that the final chunks are balanced: if the last chunk is too small, it redistributes the last two
|
135 |
+
chunks more evenly by re-splitting and re-allocating their sentences.
|
136 |
+
|
137 |
+
Args:
|
138 |
+
text (str): Input text to be chunked.
|
139 |
+
chunk_size (int): Maximum number of tokens per chunk.
|
140 |
+
tokenizer: Tokenizer instance with `.encode()` method to compute token length.
|
141 |
+
spliter (str): Regex pattern to split sentences.
|
142 |
+
|
143 |
+
Returns:
|
144 |
+
List[str]: List of token-limited chunks, each composed of one or more sentences.
|
145 |
+
"""
|
146 |
+
sentences = split_sentences(text, spliter)
|
147 |
+
|
148 |
+
chunks = []
|
149 |
+
current_chunk = ""
|
150 |
+
|
151 |
+
for sentence in sentences:
|
152 |
+
sentence_length = get_prompt_length(sentence, tokenizer)
|
153 |
+
|
154 |
+
if get_prompt_length(current_chunk, tokenizer) + sentence_length <= chunk_size:
|
155 |
+
current_chunk += sentence
|
156 |
+
else:
|
157 |
+
if current_chunk:
|
158 |
+
if get_prompt_length(current_chunk, tokenizer) <= chunk_size:
|
159 |
+
chunks.append(current_chunk)
|
160 |
+
else:
|
161 |
+
if spliter != " ": # Avoid infinite loops
|
162 |
+
chunks.extend(
|
163 |
+
split_into_granular_chunks(
|
164 |
+
current_chunk,
|
165 |
+
chunk_size=chunk_size,
|
166 |
+
tokenizer=tokenizer,
|
167 |
+
spliter=" ",
|
168 |
+
)
|
169 |
+
)
|
170 |
+
current_chunk = sentence
|
171 |
+
|
172 |
+
if current_chunk != "":
|
173 |
+
if get_prompt_length(current_chunk, tokenizer) <= chunk_size:
|
174 |
+
chunks.append(current_chunk)
|
175 |
+
else:
|
176 |
+
if spliter != " ": # Avoid infinite loops
|
177 |
+
chunks.extend(
|
178 |
+
split_into_granular_chunks(
|
179 |
+
current_chunk,
|
180 |
+
chunk_size=chunk_size,
|
181 |
+
tokenizer=tokenizer,
|
182 |
+
spliter=" ",
|
183 |
+
)
|
184 |
+
)
|
185 |
+
|
186 |
+
# If last chunk too short, re-balance the last two chunks
|
187 |
+
if len(chunks) > 1 and get_prompt_length(chunks[-1], tokenizer) < chunk_size // 2:
|
188 |
+
last_chunk = chunks.pop()
|
189 |
+
penultimate_chunk = chunks.pop()
|
190 |
+
combined_text = penultimate_chunk + last_chunk
|
191 |
+
|
192 |
+
new_sentences = split_sentences(combined_text, spliter)
|
193 |
+
|
194 |
+
# Reallocate sentence using double pointer
|
195 |
+
new_penultimate_chunk = ""
|
196 |
+
new_last_chunk = ""
|
197 |
+
start, end = 0, len(new_sentences) - 1
|
198 |
+
|
199 |
+
while start <= end and len(new_sentences) != 1:
|
200 |
+
flag = False
|
201 |
+
if (
|
202 |
+
get_prompt_length(
|
203 |
+
new_penultimate_chunk + new_sentences[start], tokenizer
|
204 |
+
)
|
205 |
+
<= chunk_size
|
206 |
+
):
|
207 |
+
flag = True
|
208 |
+
new_penultimate_chunk += new_sentences[start]
|
209 |
+
if start == end:
|
210 |
+
break
|
211 |
+
start += 1
|
212 |
+
if (
|
213 |
+
get_prompt_length(new_last_chunk + new_sentences[end], tokenizer)
|
214 |
+
<= chunk_size
|
215 |
+
):
|
216 |
+
new_last_chunk = new_sentences[end] + new_last_chunk
|
217 |
+
end -= 1
|
218 |
+
flag = True
|
219 |
+
if flag == False:
|
220 |
+
break
|
221 |
+
if start < end:
|
222 |
+
# If there is any unallocated part, split it by punctuation or space and then allocate it
|
223 |
+
remaining_sentences = new_sentences[start : end + 1]
|
224 |
+
if remaining_sentences:
|
225 |
+
remaining_text = "".join(remaining_sentences)
|
226 |
+
words = remaining_text.split(" ")
|
227 |
+
end_index = len(words) - 1
|
228 |
+
for index, w in enumerate(words):
|
229 |
+
if (
|
230 |
+
get_prompt_length(
|
231 |
+
" ".join([new_penultimate_chunk, w]), tokenizer
|
232 |
+
)
|
233 |
+
<= chunk_size
|
234 |
+
):
|
235 |
+
new_penultimate_chunk = " ".join([new_penultimate_chunk, w])
|
236 |
+
else:
|
237 |
+
end_index = index
|
238 |
+
break
|
239 |
+
if end_index != len(words) - 1:
|
240 |
+
new_last_chunk = " ".join(words[end_index:]) + " " + new_last_chunk
|
241 |
+
if len(new_sentences) == 1:
|
242 |
+
chunks.append(penultimate_chunk)
|
243 |
+
chunks.append(last_chunk)
|
244 |
+
else:
|
245 |
+
chunks.append(new_penultimate_chunk)
|
246 |
+
chunks.append(new_last_chunk)
|
247 |
+
|
248 |
+
return chunks
|
longcepo/config.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
from .prompts import (
|
4 |
+
MAPREDUCE_SYSTEM_PROMPT,
|
5 |
+
QUERY_FORMAT_PROMPT,
|
6 |
+
PLANNING_SYSTEM_PROMPT,
|
7 |
+
MAP_PROMPT,
|
8 |
+
REDUCE_PROMPT,
|
9 |
+
COLLAPSE_PROMPT,
|
10 |
+
SUMMARY_PROMPT,
|
11 |
+
)
|
12 |
+
|
13 |
+
|
14 |
+
@dataclass
|
15 |
+
class LongCepoConfig:
|
16 |
+
temperature_plan: float = 0.7 # Temperature for planning stage
|
17 |
+
temperature_map: float = 0.7 # Temperature for map stage
|
18 |
+
temperature_collapse: float = 0.7 # Temperature for collapse stage
|
19 |
+
temperature_reduce: float = 0.7 # Temperature for reduce stage
|
20 |
+
|
21 |
+
chunk_size: int = 4096 # Max tokens per chunk when splitting context
|
22 |
+
max_output_tokens: int = 1024 # Max output tokens per LLM API call (except for summary generation)
|
23 |
+
max_context_window: int = 8192 # Total model context window available
|
24 |
+
max_output_tokens_summary: int = 300 # Max output tokens per LLM API call (summary generation)
|
25 |
+
num_neighbor_summaries: int = 5 # Number of adjacent summaries from before/after in the context included in mapping stage
|
26 |
+
|
27 |
+
system_prompt: str = MAPREDUCE_SYSTEM_PROMPT # System prompt used in map/collapse/reduce stages
|
28 |
+
summary_prompt: str = SUMMARY_PROMPT # Prompt template for generating summaries in map phase
|
29 |
+
map_prompt: str = MAP_PROMPT # Prompt template for map stage
|
30 |
+
collapse_prompt: str = COLLAPSE_PROMPT # Prompt template for collapse stage
|
31 |
+
reduce_prompt: str = REDUCE_PROMPT # Prompt template for reduce stage
|
32 |
+
query_format_prompt: str = QUERY_FORMAT_PROMPT # Query normalization step prompt
|
33 |
+
planning_system_prompt: str = PLANNING_SYSTEM_PROMPT # Planning stage prompt
|
34 |
+
|
35 |
+
context_query_delimiter: str = "<CONTEXT_END>" # Delimiter used to split initial input into context and query
|
36 |
+
tokenizer_name: str = "meta-llama/Llama-4-Maverick-17B-128E-Instruct" # Tokenizer to use to determine token lengths
|
longcepo/main.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from typing import Tuple
|
3 |
+
from functools import partial
|
4 |
+
|
5 |
+
from .mapreduce import mapreduce
|
6 |
+
from .utils import (
|
7 |
+
get_prompt_response,
|
8 |
+
logger,
|
9 |
+
longcepo_init,
|
10 |
+
loop_until_match,
|
11 |
+
)
|
12 |
+
|
13 |
+
|
14 |
+
def run_longcepo(
|
15 |
+
system_prompt: str, initial_query: str, client, model: str
|
16 |
+
) -> Tuple[str, int]:
|
17 |
+
"""
|
18 |
+
Executes the full LongCePO multi-stage pipeline to answer a complex query from long context.
|
19 |
+
|
20 |
+
The pipeline includes:
|
21 |
+
- Normalizing the format of the original query
|
22 |
+
- Generating a plan of sub-questions
|
23 |
+
- Iteratively answering each sub-question using a MapReduce-style question-answering engine
|
24 |
+
- Aggregating QA history and producing a final answer with MapReduce
|
25 |
+
|
26 |
+
Args:
|
27 |
+
system_prompt (str): System prompt string.
|
28 |
+
initial_query (str): Raw input string containing context and query separated by delimiter string.
|
29 |
+
client: LLM API client instance.
|
30 |
+
model (str): Base model name.
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
Tuple[str, int]: Final answer and total number of tokens used across the pipeline.
|
34 |
+
"""
|
35 |
+
context, query, tokenizer, cb_log, longcepo_config = longcepo_init(initial_query)
|
36 |
+
|
37 |
+
# Normalize query
|
38 |
+
normalized_query, upd_log = get_prompt_response(
|
39 |
+
client,
|
40 |
+
model,
|
41 |
+
longcepo_config.query_format_prompt.format(full_query=query),
|
42 |
+
system_prompt,
|
43 |
+
max_tokens=longcepo_config.max_output_tokens,
|
44 |
+
)
|
45 |
+
cb_log.update(upd_log)
|
46 |
+
logger.info(f"Normalized query: {normalized_query}")
|
47 |
+
|
48 |
+
# Planning stage
|
49 |
+
prompt = f"The question is: {normalized_query}"
|
50 |
+
gen_fn = partial(
|
51 |
+
get_prompt_response,
|
52 |
+
client=client,
|
53 |
+
model=model,
|
54 |
+
prompt=prompt,
|
55 |
+
system_prompt=longcepo_config.planning_system_prompt,
|
56 |
+
max_tokens=longcepo_config.max_output_tokens,
|
57 |
+
)
|
58 |
+
planning_response, upd_log = loop_until_match(
|
59 |
+
gen_fn, pattern_list=("<SUB-QUESTIONS>",)
|
60 |
+
)
|
61 |
+
logger.info(f"Planning stage output:\n\n{planning_response}")
|
62 |
+
questions = (
|
63 |
+
re.findall(
|
64 |
+
r"<SUB-QUESTIONS>\s*(.*?)\s*</SUB-QUESTIONS>", planning_response, re.DOTALL
|
65 |
+
)[0]
|
66 |
+
.strip()
|
67 |
+
.splitlines()
|
68 |
+
)
|
69 |
+
|
70 |
+
# Loop to answer sub-queries from the plan
|
71 |
+
qa_system_prompt = (
|
72 |
+
longcepo_config.system_prompt
|
73 |
+
if longcepo_config.system_prompt is not None
|
74 |
+
else system_prompt
|
75 |
+
)
|
76 |
+
qa_history = ""
|
77 |
+
for question in questions:
|
78 |
+
if not question:
|
79 |
+
continue
|
80 |
+
question = re.sub(r"^\d+\.", "", question)
|
81 |
+
answer, cb_log = mapreduce(
|
82 |
+
qa_system_prompt,
|
83 |
+
question,
|
84 |
+
context,
|
85 |
+
qa_history,
|
86 |
+
client,
|
87 |
+
model,
|
88 |
+
tokenizer,
|
89 |
+
longcepo_config=longcepo_config,
|
90 |
+
cb_log=cb_log,
|
91 |
+
)
|
92 |
+
qa_history += f"- Previous question: {question}\n\n"
|
93 |
+
answer = re.sub(r"^:+", "", answer)
|
94 |
+
qa_history += f"- Previous answer: {answer}\n\n"
|
95 |
+
logger.info(f"QA history:\n\n{qa_history}")
|
96 |
+
|
97 |
+
# Final answer generation
|
98 |
+
answer, cb_log = mapreduce(
|
99 |
+
qa_system_prompt,
|
100 |
+
query,
|
101 |
+
context,
|
102 |
+
qa_history,
|
103 |
+
client,
|
104 |
+
model,
|
105 |
+
tokenizer,
|
106 |
+
longcepo_config=longcepo_config,
|
107 |
+
cb_log=cb_log,
|
108 |
+
)
|
109 |
+
return answer, cb_log["total_tokens"]
|
longcepo/mapreduce.py
ADDED
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
from typing import Tuple, List
|
3 |
+
|
4 |
+
from .utils import (
|
5 |
+
CBLog,
|
6 |
+
LongCepoConfig,
|
7 |
+
get_prompt_response,
|
8 |
+
concurrent_map,
|
9 |
+
logger,
|
10 |
+
loop_until_match,
|
11 |
+
)
|
12 |
+
from .chunking import (
|
13 |
+
chunk_context,
|
14 |
+
get_prompt_length,
|
15 |
+
)
|
16 |
+
|
17 |
+
format_chunk_list = lambda chunk_list: [
|
18 |
+
f"Information of Chunk {index}:\n{doc}\n" for index, doc in enumerate(chunk_list)
|
19 |
+
]
|
20 |
+
|
21 |
+
|
22 |
+
def remove_chunks(chunks: List[str], irrelevance_tags: Tuple[str]) -> List[str]:
|
23 |
+
"""
|
24 |
+
Filter out chunks that contain at least one of irrelevance tags.
|
25 |
+
"""
|
26 |
+
new_chunks = []
|
27 |
+
for chunk in chunks:
|
28 |
+
# Skip None values resulting from failed API calls
|
29 |
+
if chunk is None:
|
30 |
+
continue
|
31 |
+
flag = False
|
32 |
+
for tag in irrelevance_tags:
|
33 |
+
# Ensure tag comparison is safe even if tag is None (though unlikely)
|
34 |
+
if tag and tag.upper() in chunk.upper():
|
35 |
+
flag = True
|
36 |
+
break
|
37 |
+
if not flag:
|
38 |
+
new_chunks.append(chunk)
|
39 |
+
return new_chunks
|
40 |
+
|
41 |
+
|
42 |
+
def mapreduce(
|
43 |
+
system_prompt: str,
|
44 |
+
query: str,
|
45 |
+
context: str,
|
46 |
+
qa_history: str,
|
47 |
+
client,
|
48 |
+
model: str,
|
49 |
+
tokenizer,
|
50 |
+
longcepo_config: LongCepoConfig,
|
51 |
+
cb_log: CBLog,
|
52 |
+
answer_tags: Tuple[str] = ("Answer:", "**Answer**:", "**Answer**"),
|
53 |
+
irrelevance_tags: Tuple[str] = ("[NO INFORMATION]",),
|
54 |
+
) -> Tuple[str, CBLog]:
|
55 |
+
"""
|
56 |
+
Executes a MapReduce-style inference pipeline to answer a query from long context.
|
57 |
+
|
58 |
+
The function splits the input context into chunks, summarizes and evaluates each with the model (Map),
|
59 |
+
collapses intermediate answers to reduce redundancy, and then generates a final answer (Reduce).
|
60 |
+
Irrelevant responses are filtered based on `irrelevance_tags`.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
system_prompt (str): System prompt string.
|
64 |
+
query (str): User query.
|
65 |
+
context (str): Long-form input context to process.
|
66 |
+
qa_history (str): QA history string for prompt injection.
|
67 |
+
client: LLM API client.
|
68 |
+
model (str): Base model name.
|
69 |
+
tokenizer: Tokenizer to compute token lengths.
|
70 |
+
longcepo_config (LongCepoConfig): Config with hyper-parameters and token limits.
|
71 |
+
cb_log (CBLog): Log object for tracking model calls.
|
72 |
+
answer_tags (Tuple[str]): Tags used to extract the final answer from model output.
|
73 |
+
irrelevance_tags (Tuple[str]): Tags used to identify and remove irrelevant outputs.
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
Tuple[str, CBLog]: Final extracted answer and updated log object.
|
77 |
+
"""
|
78 |
+
|
79 |
+
logger.info(f"MapReduce query: {query}")
|
80 |
+
|
81 |
+
qa_history_stub = (
|
82 |
+
f"\n\nAnswers to related questions:\n\n{qa_history}" if qa_history else ""
|
83 |
+
)
|
84 |
+
|
85 |
+
context_chunks = chunk_context(context, longcepo_config.chunk_size, tokenizer)
|
86 |
+
|
87 |
+
# Get short summaries of each chunk
|
88 |
+
def fetch_chunk_summary(client, model, chunk, query, system_prompt):
|
89 |
+
return get_prompt_response(
|
90 |
+
client,
|
91 |
+
model,
|
92 |
+
longcepo_config.summary_prompt.format(question=query, context=chunk),
|
93 |
+
system_prompt,
|
94 |
+
max_tokens=longcepo_config.max_output_tokens_summary,
|
95 |
+
temperature=longcepo_config.temperature_map,
|
96 |
+
)
|
97 |
+
|
98 |
+
summaries, cb_log = concurrent_map(
|
99 |
+
fetch_chunk_summary,
|
100 |
+
client,
|
101 |
+
model,
|
102 |
+
context_chunks,
|
103 |
+
query,
|
104 |
+
system_prompt,
|
105 |
+
cb_log,
|
106 |
+
)
|
107 |
+
num_summaries = longcepo_config.num_neighbor_summaries
|
108 |
+
# For each chunk position, get a neighborhood of `num_summaries` before and after the position
|
109 |
+
summaries_per_chunk = [
|
110 |
+
"\n\n".join(
|
111 |
+
summaries[
|
112 |
+
max(0, (summary_idx - num_summaries)) : min(
|
113 |
+
len(summaries) - 1, (summary_idx + num_summaries)
|
114 |
+
)
|
115 |
+
]
|
116 |
+
)
|
117 |
+
for summary_idx in range(len(summaries))
|
118 |
+
]
|
119 |
+
|
120 |
+
# Map stage
|
121 |
+
def fetch_map_response(client, model, chunk, query, system_prompt, summary):
|
122 |
+
return get_prompt_response(
|
123 |
+
client,
|
124 |
+
model,
|
125 |
+
longcepo_config.map_prompt.format(
|
126 |
+
question=query,
|
127 |
+
context=chunk,
|
128 |
+
summary=summary,
|
129 |
+
qa_history_stub=qa_history_stub,
|
130 |
+
),
|
131 |
+
system_prompt,
|
132 |
+
max_tokens=longcepo_config.max_output_tokens,
|
133 |
+
temperature=longcepo_config.temperature_map,
|
134 |
+
)
|
135 |
+
|
136 |
+
result, cb_log = concurrent_map(
|
137 |
+
fetch_map_response,
|
138 |
+
client,
|
139 |
+
model,
|
140 |
+
context_chunks,
|
141 |
+
query,
|
142 |
+
system_prompt,
|
143 |
+
cb_log,
|
144 |
+
summaries_per_chunk=summaries_per_chunk,
|
145 |
+
)
|
146 |
+
result = remove_chunks(result, irrelevance_tags)
|
147 |
+
if not result:
|
148 |
+
return "No information", cb_log
|
149 |
+
|
150 |
+
logger.info(
|
151 |
+
f"Removed {len(context_chunks) - len(result)} chunks from total {len(context_chunks)} chunks"
|
152 |
+
)
|
153 |
+
|
154 |
+
# Collapse stage
|
155 |
+
result, cb_log = collapse_chunks(
|
156 |
+
client,
|
157 |
+
model,
|
158 |
+
result,
|
159 |
+
query,
|
160 |
+
system_prompt,
|
161 |
+
qa_history_stub,
|
162 |
+
tokenizer,
|
163 |
+
cb_log,
|
164 |
+
longcepo_config,
|
165 |
+
)
|
166 |
+
result = remove_chunks(result, irrelevance_tags)
|
167 |
+
if not result:
|
168 |
+
return "No information", cb_log
|
169 |
+
|
170 |
+
# Reduce stage
|
171 |
+
prompt = longcepo_config.reduce_prompt.format(
|
172 |
+
question=query,
|
173 |
+
context=format_chunk_list(result),
|
174 |
+
qa_history_stub=qa_history_stub,
|
175 |
+
)
|
176 |
+
gen_fn = partial(
|
177 |
+
get_prompt_response,
|
178 |
+
client=client,
|
179 |
+
model=model,
|
180 |
+
prompt=prompt,
|
181 |
+
system_prompt=system_prompt,
|
182 |
+
max_tokens=longcepo_config.max_output_tokens,
|
183 |
+
temperature=longcepo_config.temperature_reduce,
|
184 |
+
)
|
185 |
+
reduce_result, upd_log = loop_until_match(gen_fn, answer_tags,)
|
186 |
+
cb_log.update(upd_log)
|
187 |
+
|
188 |
+
final_answer = reduce_result
|
189 |
+
for answer_tag in answer_tags:
|
190 |
+
if answer_tag in reduce_result:
|
191 |
+
final_answer = reduce_result.split(answer_tag)[-1].strip()
|
192 |
+
break
|
193 |
+
|
194 |
+
return final_answer, cb_log
|
195 |
+
|
196 |
+
|
197 |
+
def collapse_chunks(
|
198 |
+
client,
|
199 |
+
model: str,
|
200 |
+
context_chunks: List[str],
|
201 |
+
query: str,
|
202 |
+
system_prompt: str,
|
203 |
+
qa_history_stub: str,
|
204 |
+
tokenizer,
|
205 |
+
cb_log: CBLog,
|
206 |
+
longcepo_config: LongCepoConfig,
|
207 |
+
) -> Tuple[List[str], CBLog]:
|
208 |
+
"""
|
209 |
+
Collapses context chunk pairs in sliding window until the total token count fits within the context window.
|
210 |
+
|
211 |
+
Args:
|
212 |
+
client: LLM API client.
|
213 |
+
model (str): Base model name.
|
214 |
+
context_chunks (List[str]): Input context chunks.
|
215 |
+
query (str): User query.
|
216 |
+
system_prompt (str): System prompt string.
|
217 |
+
qa_history_stub (str): QA history prefix.
|
218 |
+
tokenizer: Tokenizer to compute token lengths.
|
219 |
+
cb_log (CBLog): Log object for tracking model calls.
|
220 |
+
longcepo_config (LongCepoConfig): Config with hyper-parameters and token limits.
|
221 |
+
|
222 |
+
Returns:
|
223 |
+
Tuple[List[str], CBLog]: Final context chunks and updated logs.
|
224 |
+
"""
|
225 |
+
num_tokens = get_prompt_length(format_chunk_list(context_chunks), tokenizer)
|
226 |
+
token_budget = (
|
227 |
+
longcepo_config.max_context_window
|
228 |
+
- get_prompt_length(longcepo_config.collapse_prompt, tokenizer)
|
229 |
+
- longcepo_config.max_output_tokens
|
230 |
+
)
|
231 |
+
logger.info(f"Pre-collapse length of chunks {num_tokens}, allowed {token_budget}")
|
232 |
+
|
233 |
+
def fetch_collapse_response(client, model, docs, query, system_prompt):
|
234 |
+
if len(docs) == 1:
|
235 |
+
return docs[0], CBLog()
|
236 |
+
return get_prompt_response(
|
237 |
+
client,
|
238 |
+
model,
|
239 |
+
longcepo_config.collapse_prompt.format(
|
240 |
+
question=query,
|
241 |
+
context="\n\n".join(docs),
|
242 |
+
qa_history_stub=qa_history_stub,
|
243 |
+
),
|
244 |
+
system_prompt,
|
245 |
+
max_tokens=longcepo_config.max_output_tokens,
|
246 |
+
temperature=longcepo_config.temperature_collapse,
|
247 |
+
)
|
248 |
+
|
249 |
+
merge_pair_idx = 0
|
250 |
+
collapse_step = 0
|
251 |
+
while num_tokens is not None and num_tokens > token_budget:
|
252 |
+
logger.info(f"Length at collapse stage {collapse_step}: {collapse_step}")
|
253 |
+
|
254 |
+
if len(context_chunks) == 1:
|
255 |
+
logger.info(f"Post-collapse length of chunks {num_tokens}")
|
256 |
+
return context_chunks, cb_log
|
257 |
+
|
258 |
+
# Merge chunk pair in a sliding window (merge_pair_idx:merge_pair_idx+2)
|
259 |
+
chunk_groups = (
|
260 |
+
[(context_chunks[i],) for i in range(merge_pair_idx)]
|
261 |
+
+ [(context_chunks[merge_pair_idx], context_chunks[merge_pair_idx + 1])]
|
262 |
+
+ [
|
263 |
+
(context_chunks[i],)
|
264 |
+
for i in range(merge_pair_idx + 2, len(context_chunks))
|
265 |
+
]
|
266 |
+
)
|
267 |
+
context_chunks, cb_log = concurrent_map(
|
268 |
+
fetch_collapse_response,
|
269 |
+
client,
|
270 |
+
model,
|
271 |
+
chunk_groups,
|
272 |
+
query,
|
273 |
+
system_prompt,
|
274 |
+
cb_log,
|
275 |
+
)
|
276 |
+
merge_pair_idx = (merge_pair_idx + 1) % max(len(context_chunks) - 1, 1)
|
277 |
+
num_tokens = get_prompt_length(format_chunk_list(context_chunks), tokenizer)
|
278 |
+
collapse_step += 1
|
279 |
+
|
280 |
+
logger.info(f"Post-collapse length of chunks {num_tokens}")
|
281 |
+
return context_chunks, cb_log
|
longcepo/prompts.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Code (Map/Collapse/Reduce prompts) modified from https://github.com/thunlp/LLMxMapReduce under Apache 2.0
|
2 |
+
# MapReduce system prompt optimized for use with Llama3.3-70B-Instruct with an OPRO-like procedure
|
3 |
+
|
4 |
+
MAPREDUCE_SYSTEM_PROMPT = """You are globally celebrated as a preeminent expert in the field of digital document analysis and synthesis, known for your unmatched precision in transforming fragmented texts into comprehensive and insightful responses. Always respond in the user\'s language, ensuring every interaction is informed by all preceding exchanges for complete contextual understanding.\n\nIn your initial message, confidently declare your credentials with a phrase such as: "As a world-renowned specialist in [specific field], honored with the [real prestigious local award]," replacing placeholders with authentic information from your domain.\n\nAdhere strictly to these principles with each document segment or query:\n\n1. Extract every critical piece of information, nuance, and context with meticulous attention to detail.\n2. Organize your analysis methodically, presenting specific examples, data, and verifiable facts clearly and logically.\n3. Cease your response abruptly if approaching character limits, awaiting the user\'s "continue" instruction to carry on.\n4. Anchor every insight and conclusion in provided content or universally accepted truths, strictly avoiding speculation or unfounded statements.\n5. Communicate with a professional yet approachable tone, reflecting profound expertise and clarity.\n\nRecognize the real-world impact of your insights; ensure each response is seamlessly integrated, richly detailed, and impeccably reliable. Rigorously observe these guidelines to offer authoritative and precise analysis and synthesis."""
|
5 |
+
|
6 |
+
QUERY_FORMAT_PROMPT = """Given the below blurb, can you help identify only the question we want to answer? The blurb might contain other information such as -- format for final answer, multiple choices for the final answer, context, general directions about how to behave as an AI assistant etc. Please remove all of that and just faithfully copy out the question. The blurb is:\n\n{full_query}.\n\nDo not attempt to answer the question, ignore formatting instructions in the blurb, if any."""
|
7 |
+
|
8 |
+
SUMMARY_PROMPT = """You are provided with a portion of an article and a question. Read the article portion and follow my instructions to process it.\n\nArticle:\nThe article begins as follows:\n{context}\nThe article concludes here.\n\nQuestion:\n{question}\n\nInstructions: Please just write a 2-3 sentence summary for the provided passage. Do not answer the question or write anything else."""
|
9 |
+
|
10 |
+
PLANNING_SYSTEM_PROMPT = """As an intelligent assistant, your primary objective is to answer a user question as accurately as possible given a long article. The full article is too long to fit in your context window, and to facilitate your answering objective, a reading agent has been created that can process the article chunk by chunk and answer question about it. You can ask the reading agent any question you need to answer the user's question or to use it for clarification. The first step for you is to is to make a rational plan based on the question. The plan should consist of sub-questions you should ask to the reading agent that you need to know the answers to in order to answer the user's question. This plan should outline the step-by-step process to resolve the question and specify the key information required to formulate a comprehensive answer. The reader agent can make mistakes.\nExample:\n#####\nUser: Who had a longer tennis career, Danny or Alice?\nAssistant: In order to answer this question, we need to ask the following sub-questions:\n<SUB-QUESTIONS>\n1. What is the length of Danny’s tennis career (their start and retirement)?\n2. What is the length of Alice’s tennis career (their start and retirement)?\n</SUB-QUESTIONS>\n#####\nPlease strictly follow the above format. You must include the <SUB-QUESTIONS> tags. Let’s begin."""
|
11 |
+
|
12 |
+
MAP_PROMPT = """You are provided with a portion of an article, short summaries of related portions if any, and a question. Read the article portion and follow my instructions to process it.\n\nArticle:\nThe article begins as follows:\n{context}\nThe article concludes here.\n\nPrevious portion summaries:{summary}{qa_history_stub}\n\nQuestion:\n{question}\n\nInstructions:\n\nPlease extract information from the provided passage to try and answer the given question. Note that you only have a part of the entire text, so the information you obtain might not fully answer the question. Therefore, provide your rationale for using the extracted information to answer the question and include a confidence score. The following is some assigning scoring cases: <Text: [Jerry is 18 years old this year. He can swim and wants to be an athlete.]. assigning scoring: [Jerry can swim, 5 points; Jerry will become an athlete in the future, 3.5 points; Jerry will become a swimming athlete in the future, 3 points;Jerry is strong,3 points; Jerry can play chess, 0 points;Jerry likes talking,0 points]>. Follow these steps:\n\n1. Extract Relevant Information: Identify and highlight the key pieces of information from the passage that are relevant to the given question.\n2. Provide a Rationale: Analyze the extracted information and explain how it can be used to address the question. If the information is incomplete, discuss any assumptions or inferences you need to make.\n3. Answer the Question: Based on your rationale, provide the best possible answer to the question. If, after providing your rationale, you believe the passage does not contain any information to solve the question, output "[NO INFORMATION]" as the answer.\n4. Assign a Confidence Score: Assign a confidence score (out of 5) to your answer based on the completeness and reliability of the extracted information and your rationale process.\nPlease follow this format:\n\nExtracted Information:\nRationale:\nAnswer:\nConfidence Score:"""
|
13 |
+
|
14 |
+
COLLAPSE_PROMPT = """You need to process a task with a long context that greatly exceeds your context limit. The only feasible way to handle this is by processing the long context chunk by chunk. You are provided with a question and some information extracted from each chunk. Each piece of information contains Extracted Information, Rationale, Answer, and a Confidence Score. The following is some assigning scoring cases: <Text: [Jerry is 18 years old this year. He can swim and wants to be an athlete.]. assigning scoring: [Jerry can swim, 5 points; Jerry will become an athlete in the future, 3.5 points; Jerry will become a swimming athlete in the future, 3 points;Jerry is strong,3 points; Jerry can play chess, 0 points;Jerry likes talking,0 points]>. Read the information and follow my instructions to process it.\n\nExtracted Information:\nThe extracted information begins as follows:\n{context}\nThe extracted information concludes here.{qa_history_stub}\n\nQuestion:\n{question}\n\nInstruction:\n\nIntegrate the extracted information and then reason through the following steps:\n\n1. Integrate Extracted Information: Collect and summarize all the evidence relevant to solving the question. Consider the confidence scores of each piece of extracted information to weigh their reliability. Higher confidence scores should be given more importance in your summary.\n2. Analyze: Re-analyze the question based on the summarized information. Use the confidence scores to determine the reliability of different pieces of information, giving more weight to information with higher confidence scores.\n3. Answer the Question: Provide the best possible answer based on the updated information. If, after providing your rationale, you believe the passage does not contain any information to solve the question, output "[NO INFORMATION]" as the answer. Use the confidence scores to support the reliability of your final answer, prioritizing higher confidence information.\n4. Assign Confidence Score: Give a confidence score (out of 5) for your final answer based on the completeness and reliability of the updated information and your rationale process.\nConsider the initial confidence scores of the integrated information to determine your final confidence score.\nPlease follow this format:\n\nExtracted Information:\nRationale:\nAnswer:\nConfidence Score:"""
|
15 |
+
|
16 |
+
REDUCE_PROMPT = """You need to process a task with a long context that greatly exceeds your context limit. The only feasible way to handle this is by processing the long context chunk by chunk. You are provided with a question and some information extracted from each chunk. Each piece of information contains Extracted Information, Rationale, Answer, and a Confidence Score. The following is some assigning scoring cases: <Text: [Jerry is 18 years old this year. He can swim and wants to be an athlete.]. assigning scoring: [Jerry can swim, 5 points; Jerry will become an athlete in the future, 3.5 points; Jerry will become a swimming athlete in the future, 3 points;Jerry is strong,3 points; Jerry can play chess, 0 points;Jerry likes talking,0 points]>. Read the information and follow my instructions to process it.{qa_history_stub}\n\nQuestion:\n{question}\n\nInformation from chunks:\n{context}\n\nEach chunk provides extracted information related to the same question, but due to partial data, conclusions from each chunk might vary. Your role is to integrate and reason through this information, weighing confidence scores to resolve any inconsistencies. Then provide the final answer.\n\nPlease follow this format:\n\nRationale:\nAnswer:"""
|
longcepo/utils.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from typing import Callable, List, Optional, Tuple
|
3 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
4 |
+
|
5 |
+
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
6 |
+
from .config import LongCepoConfig
|
7 |
+
|
8 |
+
logger = logging.getLogger(__name__)
|
9 |
+
|
10 |
+
|
11 |
+
class CBLog(dict):
|
12 |
+
"""Object for logging the number of LLM calls and tokens used in the pipeline"""
|
13 |
+
|
14 |
+
__allowed_keys__ = {"total_tokens", "completion_tokens", "llm_calls"}
|
15 |
+
|
16 |
+
def __init__(self, *args, **kwargs):
|
17 |
+
super().__init__()
|
18 |
+
self.update(*args, **kwargs)
|
19 |
+
|
20 |
+
def __setitem__(self, key, value):
|
21 |
+
if key not in self.__allowed_keys__:
|
22 |
+
raise KeyError(
|
23 |
+
f"Key '{key}' not allowed. Allowed keys: {self.__allowed_keys__}"
|
24 |
+
)
|
25 |
+
if not isinstance(value, int):
|
26 |
+
raise TypeError(
|
27 |
+
f"Value for '{key}' must be int, got {type(value).__name__}"
|
28 |
+
)
|
29 |
+
super().__setitem__(key, value)
|
30 |
+
|
31 |
+
def update(self, other=None, **kwargs):
|
32 |
+
updates = {}
|
33 |
+
if other:
|
34 |
+
if isinstance(other, dict):
|
35 |
+
updates.update(other)
|
36 |
+
else:
|
37 |
+
updates.update(dict(other))
|
38 |
+
updates.update(kwargs)
|
39 |
+
|
40 |
+
for key, value in updates.items():
|
41 |
+
if key not in self.__allowed_keys__:
|
42 |
+
raise KeyError(
|
43 |
+
f"Key '{key}' not allowed. Allowed keys: {self.__allowed_keys__}"
|
44 |
+
)
|
45 |
+
if not isinstance(value, int):
|
46 |
+
raise TypeError(
|
47 |
+
f"Value for '{key}' must be int, got {type(value).__name__}"
|
48 |
+
)
|
49 |
+
self[key] = self.get(key, 0) + value
|
50 |
+
|
51 |
+
|
52 |
+
def concurrent_map(
|
53 |
+
gen_function: Callable,
|
54 |
+
client,
|
55 |
+
model: str,
|
56 |
+
context_chunks: List[str],
|
57 |
+
query: str,
|
58 |
+
system_prompt: str,
|
59 |
+
cb_log: CBLog,
|
60 |
+
summaries_per_chunk: Optional[List[str]] = None,
|
61 |
+
workers: int = 16,
|
62 |
+
) -> Tuple[List[str], CBLog]:
|
63 |
+
"""
|
64 |
+
Runs `gen_function` concurrently over a list of context chunks.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
gen_function (Callable): Function to call with each chunk and associated arguments.
|
68 |
+
client: LLM API client.
|
69 |
+
model (str): Base model name.
|
70 |
+
context_chunks (List[str]): Input context chunks.
|
71 |
+
query (str): User query.
|
72 |
+
system_prompt (str): System prompt string.
|
73 |
+
cb_log (CBLog): Log object for tracking model calls.
|
74 |
+
summaries_per_chunk (Optional[List[str]]): Concatenated neighbor summaries for each chunk.
|
75 |
+
workers (int): Number of threads to use.
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
Tuple[List[str], CBLog]: List of responses (in original order) and updated log object.
|
79 |
+
"""
|
80 |
+
result = [None] * len(context_chunks)
|
81 |
+
wrapped_gen_function = lambda index, *args: (index, gen_function(*args))
|
82 |
+
with ThreadPoolExecutor(max_workers=workers) as executor:
|
83 |
+
future_to_idx = {}
|
84 |
+
for idx, chunk in enumerate(context_chunks):
|
85 |
+
args = [client, model, chunk, query, system_prompt]
|
86 |
+
if summaries_per_chunk is not None:
|
87 |
+
args.append(summaries_per_chunk[idx])
|
88 |
+
future_to_idx[executor.submit(wrapped_gen_function, idx, *args)] = idx
|
89 |
+
|
90 |
+
for future in as_completed(future_to_idx):
|
91 |
+
try:
|
92 |
+
index, (response, upd_log) = future.result()
|
93 |
+
result[index] = response
|
94 |
+
cb_log.update(upd_log)
|
95 |
+
except Exception as e:
|
96 |
+
logger.error(f"Error processing chunk: {e}")
|
97 |
+
return result, cb_log
|
98 |
+
|
99 |
+
|
100 |
+
def get_prompt_response(
|
101 |
+
client,
|
102 |
+
model: str,
|
103 |
+
prompt: str,
|
104 |
+
system_prompt: str,
|
105 |
+
max_tokens: int,
|
106 |
+
temperature: float = 0.7,
|
107 |
+
top_p: float = 0.7,
|
108 |
+
):
|
109 |
+
"""
|
110 |
+
Helper function that sends a prompt to the chat-based LLM API and returns the generated response along with usage logging.
|
111 |
+
|
112 |
+
Args:
|
113 |
+
client: LLM API client.
|
114 |
+
model (str): Base model name.
|
115 |
+
prompt (str): The user prompt to send.
|
116 |
+
system_prompt (str): System prompt string.
|
117 |
+
max_tokens (int): Maximum number of tokens in the response.
|
118 |
+
temperature (float): Sampling temperature for randomness (default: 0.7).
|
119 |
+
top_p (float): Cumulative probability cutoff for token selection (default: 0.7).
|
120 |
+
|
121 |
+
Returns:
|
122 |
+
Tuple[str, CBLog]: The model's response text and a CBLog object tracking token usage.
|
123 |
+
"""
|
124 |
+
messages = [
|
125 |
+
{"role": "system", "content": system_prompt},
|
126 |
+
{"role": "user", "content": prompt},
|
127 |
+
]
|
128 |
+
response = client.chat.completions.create(
|
129 |
+
model=model,
|
130 |
+
messages=messages,
|
131 |
+
max_tokens=max_tokens,
|
132 |
+
top_p=top_p,
|
133 |
+
temperature=temperature,
|
134 |
+
stream=False,
|
135 |
+
)
|
136 |
+
upd_log = CBLog(
|
137 |
+
llm_calls=1,
|
138 |
+
total_tokens=response.usage.total_tokens,
|
139 |
+
completion_tokens=response.usage.completion_tokens,
|
140 |
+
)
|
141 |
+
return response.choices[0].message.content, upd_log
|
142 |
+
|
143 |
+
|
144 |
+
def loop_until_match(
|
145 |
+
function: Callable, pattern_list: Tuple[str], num_attempts: int = 10
|
146 |
+
):
|
147 |
+
"""
|
148 |
+
Repeatedly calls a function until its output matches one of the given patterns or max attempts is reached.
|
149 |
+
|
150 |
+
Args:
|
151 |
+
function (Callable): Function returning (answer: str, cb_log).
|
152 |
+
pattern_list (Tuple[str]): Patterns to match in the answer.
|
153 |
+
num_attempts (int): Max number of attempts (default: 10).
|
154 |
+
|
155 |
+
Returns:
|
156 |
+
Tuple[str, Any]: The matching answer and its corresponding log object.
|
157 |
+
"""
|
158 |
+
correct_format = False
|
159 |
+
for _ in range(num_attempts):
|
160 |
+
answer, cb_log = function()
|
161 |
+
|
162 |
+
for pattern in pattern_list:
|
163 |
+
if pattern in answer:
|
164 |
+
correct_format = True
|
165 |
+
|
166 |
+
if correct_format:
|
167 |
+
break
|
168 |
+
|
169 |
+
logger.info("Wrong output formatting, retrying...")
|
170 |
+
|
171 |
+
return answer, cb_log
|
172 |
+
|
173 |
+
|
174 |
+
def longcepo_init(
|
175 |
+
initial_query: str,
|
176 |
+
) -> Tuple[str, str, PreTrainedTokenizerBase, CBLog, LongCepoConfig]:
|
177 |
+
"""
|
178 |
+
Initializes context, query, tokenizer, logging, and config from an input string.
|
179 |
+
|
180 |
+
Args:
|
181 |
+
initial_query (str): Input string containing context and query separated by a delimiter string.
|
182 |
+
|
183 |
+
Returns:
|
184 |
+
Tuple[str, str, PreTrainedTokenizerBase, CBLog, LongCepoConfig]:
|
185 |
+
Parsed context, query, tokenizer instance, log object, and LongCePO config.
|
186 |
+
"""
|
187 |
+
cb_log = CBLog()
|
188 |
+
config = LongCepoConfig()
|
189 |
+
context, query = initial_query.split(config.context_query_delimiter)
|
190 |
+
tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name, model_max_length=config.max_context_window)
|
191 |
+
return context.strip(), query.strip(), tokenizer, cb_log, config
|
requirements.txt
CHANGED
@@ -1 +1,5 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
1 |
+
openai==1.76.0
|
2 |
+
transformers==4.51.3
|
3 |
+
torch==2.7.0
|
4 |
+
accelerate==1.6.0
|
5 |
+
gradio==5.27.1
|
run_chatbot.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import openai
|
3 |
+
from longcepo.main import run_longcepo
|
4 |
+
# from config_sambanova import SAMBANOVA_API_KEY # Removed import
|
5 |
+
|
6 |
+
# Configure Sambanova client
|
7 |
+
# Read API key from environment variable (set as Hugging Face Secret)
|
8 |
+
SAMBANOVA_API_KEY = os.environ.get("SAMBANOVA_API_KEY")
|
9 |
+
|
10 |
+
if SAMBANOVA_API_KEY:
|
11 |
+
# Strip potential leading/trailing whitespace and newlines
|
12 |
+
SAMBANOVA_API_KEY = SAMBANOVA_API_KEY.strip()
|
13 |
+
|
14 |
+
if not SAMBANOVA_API_KEY:
|
15 |
+
raise ValueError("Sambanova API key not found or is empty. Please set the SAMBANOVA_API_KEY environment variable or Hugging Face Secret.")
|
16 |
+
|
17 |
+
client = openai.OpenAI(
|
18 |
+
api_key=SAMBANOVA_API_KEY,
|
19 |
+
base_url="https://api.sambanova.ai/v1",
|
20 |
+
)
|
21 |
+
|
22 |
+
# Define the model to use
|
23 |
+
SAMBANOVA_MODEL = "Llama-4-Maverick-17B-128E-Instruct"
|
24 |
+
|
25 |
+
def process_with_longcepo(system_prompt: str, initial_query: str):
|
26 |
+
"""Processes a query using the modified LongCePO plugin with Sambanova backend."""
|
27 |
+
print(f"Processing query with LongCePO using model: {SAMBANOVA_MODEL}")
|
28 |
+
try:
|
29 |
+
# Call the core LongCePO logic, passing the configured client and model
|
30 |
+
answer, total_tokens = run_longcepo(
|
31 |
+
system_prompt=system_prompt,
|
32 |
+
initial_query=initial_query,
|
33 |
+
client=client,
|
34 |
+
model=SAMBANOVA_MODEL
|
35 |
+
)
|
36 |
+
print(f"LongCePO finished. Total tokens used: {total_tokens}")
|
37 |
+
return answer
|
38 |
+
except Exception as e:
|
39 |
+
print(f"Error during LongCePO processing: {e}")
|
40 |
+
# Print traceback for more detailed debugging
|
41 |
+
import traceback
|
42 |
+
traceback.print_exc()
|
43 |
+
return f"An error occurred: {e}"
|
44 |
+
|
45 |
+
# Example usage (for testing purposes)
|
46 |
+
if __name__ == "__main__":
|
47 |
+
test_system_prompt = "You are a helpful assistant designed to answer questions based on the provided context."
|
48 |
+
# Provide some dummy context and a slightly more complex query
|
49 |
+
dummy_context = """
|
50 |
+
Paris is the capital and most populous city of France. It is known for its art, fashion, gastronomy and culture.
|
51 |
+
Its 19th-century cityscape is crisscrossed by wide boulevards and the River Seine.
|
52 |
+
Beyond such landmarks as the Eiffel Tower and the 12th-century, Gothic Notre-Dame cathedral, the city is known for its cafe culture and designer boutiques along the Rue du Faubourg Saint-Honoré.
|
53 |
+
The Louvre Museum houses Da Vinci's Mona Lisa. The Musée d'Orsay has Impressionist and Post-Impressionist masterpieces.
|
54 |
+
France is a country in Western Europe. It borders Belgium, Luxembourg, Germany, Switzerland, Monaco, Italy, Andorra, and Spain.
|
55 |
+
The official language is French.
|
56 |
+
"""
|
57 |
+
test_query = "Based on the provided text, what are the main attractions in Paris and what countries does France border?"
|
58 |
+
# Combine context and query with the expected delimiter
|
59 |
+
test_initial_query = f"{dummy_context}<CONTEXT_END>{test_query}"
|
60 |
+
|
61 |
+
print("Running test query...")
|
62 |
+
result = process_with_longcepo(test_system_prompt, test_initial_query)
|
63 |
+
print(f"\nTest Result:\n{result}")
|
64 |
+
|