File size: 6,686 Bytes
493728d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de4f670
 
5c390d4
493728d
 
 
 
 
 
 
 
a37b646
 
493728d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
import logging
from typing import Callable, List, Optional, Tuple
from concurrent.futures import ThreadPoolExecutor, as_completed

from transformers import AutoTokenizer, PreTrainedTokenizerBase
from .config import LongCepoConfig

logger = logging.getLogger(__name__)


class CBLog(dict):
    """Object for logging the number of LLM calls and tokens used in the pipeline"""

    __allowed_keys__ = {"total_tokens", "completion_tokens", "llm_calls"}

    def __init__(self, *args, **kwargs):
        super().__init__()
        self.update(*args, **kwargs)

    def __setitem__(self, key, value):
        if key not in self.__allowed_keys__:
            raise KeyError(
                f"Key '{key}' not allowed. Allowed keys: {self.__allowed_keys__}"
            )
        if not isinstance(value, int):
            raise TypeError(
                f"Value for '{key}' must be int, got {type(value).__name__}"
            )
        super().__setitem__(key, value)

    def update(self, other=None, **kwargs):
        updates = {}
        if other:
            if isinstance(other, dict):
                updates.update(other)
            else:
                updates.update(dict(other))
        updates.update(kwargs)

        for key, value in updates.items():
            if key not in self.__allowed_keys__:
                raise KeyError(
                    f"Key '{key}' not allowed. Allowed keys: {self.__allowed_keys__}"
                )
            if not isinstance(value, int):
                raise TypeError(
                    f"Value for '{key}' must be int, got {type(value).__name__}"
                )
            self[key] = self.get(key, 0) + value


def concurrent_map(
    gen_function: Callable,
    client,
    model: str,
    context_chunks: List[str],
    query: str,
    system_prompt: str,
    cb_log: CBLog,
    summaries_per_chunk: Optional[List[str]] = None,
    workers: int = 16,
) -> Tuple[List[str], CBLog]:
    """
    Runs `gen_function` concurrently over a list of context chunks.

    Args:
        gen_function (Callable): Function to call with each chunk and associated arguments.
        client: LLM API client.
        model (str): Base model name.
        context_chunks (List[str]): Input context chunks.
        query (str): User query.
        system_prompt (str): System prompt string.
        cb_log (CBLog): Log object for tracking model calls.
        summaries_per_chunk (Optional[List[str]]): Concatenated neighbor summaries for each chunk.
        workers (int): Number of threads to use.

    Returns:
        Tuple[List[str], CBLog]: List of responses (in original order) and updated log object.
    """
    result = [None] * len(context_chunks)
    wrapped_gen_function = lambda index, *args: (index, gen_function(*args))
    with ThreadPoolExecutor(max_workers=workers) as executor:
        future_to_idx = {}
        for idx, chunk in enumerate(context_chunks):
            args = [client, model, chunk, query, system_prompt]
            if summaries_per_chunk is not None:
                args.append(summaries_per_chunk[idx])
            future_to_idx[executor.submit(wrapped_gen_function, idx, *args)] = idx

        for future in as_completed(future_to_idx):
            try:
                index, (response, upd_log) = future.result()
                result[index] = response
                cb_log.update(upd_log)
            except Exception as e:
                logger.error(f"Error processing chunk: {e}")
    return result, cb_log


def get_prompt_response(
    client,
    model: str,
    prompt: str,
    system_prompt: str,
    max_tokens: int,
    temperature: float = 0.7,
    top_p: float = 0.7,
):
    """
    Helper function that sends a prompt to the chat-based LLM API and returns the generated response along with usage logging.

    Args:
        client: LLM API client.
        model (str): Base model name.
        prompt (str): The user prompt to send.
        system_prompt (str): System prompt string.
        max_tokens (int): Maximum number of tokens in the response.
        temperature (float): Sampling temperature for randomness (default: 0.7).
        top_p (float): Cumulative probability cutoff for token selection (default: 0.7).

    Returns:
        Tuple[str, CBLog]: The model's response text and a CBLog object tracking token usage.
    """
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": prompt},
    ]

    print("max_tokens", max_tokens)
    print("messages", messages)
    response = client.chat.completions.create(
        model=model,
        messages=messages,
        max_tokens=max_tokens,
        top_p=top_p,
        temperature=temperature,
        stream=False,
    )
    print("response")
    print(response)
    upd_log = CBLog(
        llm_calls=1,
        total_tokens=response.usage.total_tokens,
        completion_tokens=response.usage.completion_tokens,
    )
    return response.choices[0].message.content, upd_log


def loop_until_match(
    function: Callable, pattern_list: Tuple[str], num_attempts: int = 10
):
    """
    Repeatedly calls a function until its output matches one of the given patterns or max attempts is reached.

    Args:
        function (Callable): Function returning (answer: str, cb_log).
        pattern_list (Tuple[str]): Patterns to match in the answer.
        num_attempts (int): Max number of attempts (default: 10).

    Returns:
        Tuple[str, Any]: The matching answer and its corresponding log object.
    """
    correct_format = False
    for _ in range(num_attempts):
        answer, cb_log = function()

        for pattern in pattern_list:
            if pattern in answer:
                correct_format = True

        if correct_format:
            break

        logger.info("Wrong output formatting, retrying...")

    return answer, cb_log


def longcepo_init(
    initial_query: str,
) -> Tuple[str, str, PreTrainedTokenizerBase, CBLog, LongCepoConfig]:
    """
    Initializes context, query, tokenizer, logging, and config from an input string.

    Args:
        initial_query (str): Input string containing context and query separated by a delimiter string.

    Returns:
        Tuple[str, str, PreTrainedTokenizerBase, CBLog, LongCepoConfig]:
        Parsed context, query, tokenizer instance, log object, and LongCePO config.
    """
    cb_log = CBLog()
    config = LongCepoConfig()
    context, query = initial_query.split(config.context_query_delimiter)
    tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name, model_max_length=config.max_context_window)
    return context.strip(), query.strip(), tokenizer, cb_log, config