ChangranHuuu's picture
Update longcepo/utils.py
a37b646 verified
import logging
from typing import Callable, List, Optional, Tuple
from concurrent.futures import ThreadPoolExecutor, as_completed
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from .config import LongCepoConfig
logger = logging.getLogger(__name__)
class CBLog(dict):
"""Object for logging the number of LLM calls and tokens used in the pipeline"""
__allowed_keys__ = {"total_tokens", "completion_tokens", "llm_calls"}
def __init__(self, *args, **kwargs):
super().__init__()
self.update(*args, **kwargs)
def __setitem__(self, key, value):
if key not in self.__allowed_keys__:
raise KeyError(
f"Key '{key}' not allowed. Allowed keys: {self.__allowed_keys__}"
)
if not isinstance(value, int):
raise TypeError(
f"Value for '{key}' must be int, got {type(value).__name__}"
)
super().__setitem__(key, value)
def update(self, other=None, **kwargs):
updates = {}
if other:
if isinstance(other, dict):
updates.update(other)
else:
updates.update(dict(other))
updates.update(kwargs)
for key, value in updates.items():
if key not in self.__allowed_keys__:
raise KeyError(
f"Key '{key}' not allowed. Allowed keys: {self.__allowed_keys__}"
)
if not isinstance(value, int):
raise TypeError(
f"Value for '{key}' must be int, got {type(value).__name__}"
)
self[key] = self.get(key, 0) + value
def concurrent_map(
gen_function: Callable,
client,
model: str,
context_chunks: List[str],
query: str,
system_prompt: str,
cb_log: CBLog,
summaries_per_chunk: Optional[List[str]] = None,
workers: int = 16,
) -> Tuple[List[str], CBLog]:
"""
Runs `gen_function` concurrently over a list of context chunks.
Args:
gen_function (Callable): Function to call with each chunk and associated arguments.
client: LLM API client.
model (str): Base model name.
context_chunks (List[str]): Input context chunks.
query (str): User query.
system_prompt (str): System prompt string.
cb_log (CBLog): Log object for tracking model calls.
summaries_per_chunk (Optional[List[str]]): Concatenated neighbor summaries for each chunk.
workers (int): Number of threads to use.
Returns:
Tuple[List[str], CBLog]: List of responses (in original order) and updated log object.
"""
result = [None] * len(context_chunks)
wrapped_gen_function = lambda index, *args: (index, gen_function(*args))
with ThreadPoolExecutor(max_workers=workers) as executor:
future_to_idx = {}
for idx, chunk in enumerate(context_chunks):
args = [client, model, chunk, query, system_prompt]
if summaries_per_chunk is not None:
args.append(summaries_per_chunk[idx])
future_to_idx[executor.submit(wrapped_gen_function, idx, *args)] = idx
for future in as_completed(future_to_idx):
try:
index, (response, upd_log) = future.result()
result[index] = response
cb_log.update(upd_log)
except Exception as e:
logger.error(f"Error processing chunk: {e}")
return result, cb_log
def get_prompt_response(
client,
model: str,
prompt: str,
system_prompt: str,
max_tokens: int,
temperature: float = 0.7,
top_p: float = 0.7,
):
"""
Helper function that sends a prompt to the chat-based LLM API and returns the generated response along with usage logging.
Args:
client: LLM API client.
model (str): Base model name.
prompt (str): The user prompt to send.
system_prompt (str): System prompt string.
max_tokens (int): Maximum number of tokens in the response.
temperature (float): Sampling temperature for randomness (default: 0.7).
top_p (float): Cumulative probability cutoff for token selection (default: 0.7).
Returns:
Tuple[str, CBLog]: The model's response text and a CBLog object tracking token usage.
"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt},
]
print("max_tokens", max_tokens)
print("messages", messages)
response = client.chat.completions.create(
model=model,
messages=messages,
max_tokens=max_tokens,
top_p=top_p,
temperature=temperature,
stream=False,
)
print("response")
print(response)
upd_log = CBLog(
llm_calls=1,
total_tokens=response.usage.total_tokens,
completion_tokens=response.usage.completion_tokens,
)
return response.choices[0].message.content, upd_log
def loop_until_match(
function: Callable, pattern_list: Tuple[str], num_attempts: int = 10
):
"""
Repeatedly calls a function until its output matches one of the given patterns or max attempts is reached.
Args:
function (Callable): Function returning (answer: str, cb_log).
pattern_list (Tuple[str]): Patterns to match in the answer.
num_attempts (int): Max number of attempts (default: 10).
Returns:
Tuple[str, Any]: The matching answer and its corresponding log object.
"""
correct_format = False
for _ in range(num_attempts):
answer, cb_log = function()
for pattern in pattern_list:
if pattern in answer:
correct_format = True
if correct_format:
break
logger.info("Wrong output formatting, retrying...")
return answer, cb_log
def longcepo_init(
initial_query: str,
) -> Tuple[str, str, PreTrainedTokenizerBase, CBLog, LongCepoConfig]:
"""
Initializes context, query, tokenizer, logging, and config from an input string.
Args:
initial_query (str): Input string containing context and query separated by a delimiter string.
Returns:
Tuple[str, str, PreTrainedTokenizerBase, CBLog, LongCepoConfig]:
Parsed context, query, tokenizer instance, log object, and LongCePO config.
"""
cb_log = CBLog()
config = LongCepoConfig()
context, query = initial_query.split(config.context_query_delimiter)
tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name, model_max_length=config.max_context_window)
return context.strip(), query.strip(), tokenizer, cb_log, config