import logging from typing import Callable, List, Optional, Tuple from concurrent.futures import ThreadPoolExecutor, as_completed from transformers import AutoTokenizer, PreTrainedTokenizerBase from .config import LongCepoConfig logger = logging.getLogger(__name__) class CBLog(dict): """Object for logging the number of LLM calls and tokens used in the pipeline""" __allowed_keys__ = {"total_tokens", "completion_tokens", "llm_calls"} def __init__(self, *args, **kwargs): super().__init__() self.update(*args, **kwargs) def __setitem__(self, key, value): if key not in self.__allowed_keys__: raise KeyError( f"Key '{key}' not allowed. Allowed keys: {self.__allowed_keys__}" ) if not isinstance(value, int): raise TypeError( f"Value for '{key}' must be int, got {type(value).__name__}" ) super().__setitem__(key, value) def update(self, other=None, **kwargs): updates = {} if other: if isinstance(other, dict): updates.update(other) else: updates.update(dict(other)) updates.update(kwargs) for key, value in updates.items(): if key not in self.__allowed_keys__: raise KeyError( f"Key '{key}' not allowed. Allowed keys: {self.__allowed_keys__}" ) if not isinstance(value, int): raise TypeError( f"Value for '{key}' must be int, got {type(value).__name__}" ) self[key] = self.get(key, 0) + value def concurrent_map( gen_function: Callable, client, model: str, context_chunks: List[str], query: str, system_prompt: str, cb_log: CBLog, summaries_per_chunk: Optional[List[str]] = None, workers: int = 16, ) -> Tuple[List[str], CBLog]: """ Runs `gen_function` concurrently over a list of context chunks. Args: gen_function (Callable): Function to call with each chunk and associated arguments. client: LLM API client. model (str): Base model name. context_chunks (List[str]): Input context chunks. query (str): User query. system_prompt (str): System prompt string. cb_log (CBLog): Log object for tracking model calls. summaries_per_chunk (Optional[List[str]]): Concatenated neighbor summaries for each chunk. workers (int): Number of threads to use. Returns: Tuple[List[str], CBLog]: List of responses (in original order) and updated log object. """ result = [None] * len(context_chunks) wrapped_gen_function = lambda index, *args: (index, gen_function(*args)) with ThreadPoolExecutor(max_workers=workers) as executor: future_to_idx = {} for idx, chunk in enumerate(context_chunks): args = [client, model, chunk, query, system_prompt] if summaries_per_chunk is not None: args.append(summaries_per_chunk[idx]) future_to_idx[executor.submit(wrapped_gen_function, idx, *args)] = idx for future in as_completed(future_to_idx): try: index, (response, upd_log) = future.result() result[index] = response cb_log.update(upd_log) except Exception as e: logger.error(f"Error processing chunk: {e}") return result, cb_log def get_prompt_response( client, model: str, prompt: str, system_prompt: str, max_tokens: int, temperature: float = 0.7, top_p: float = 0.7, ): """ Helper function that sends a prompt to the chat-based LLM API and returns the generated response along with usage logging. Args: client: LLM API client. model (str): Base model name. prompt (str): The user prompt to send. system_prompt (str): System prompt string. max_tokens (int): Maximum number of tokens in the response. temperature (float): Sampling temperature for randomness (default: 0.7). top_p (float): Cumulative probability cutoff for token selection (default: 0.7). Returns: Tuple[str, CBLog]: The model's response text and a CBLog object tracking token usage. """ messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}, ] print("max_tokens", max_tokens) print("messages", messages) response = client.chat.completions.create( model=model, messages=messages, max_tokens=max_tokens, top_p=top_p, temperature=temperature, stream=False, ) print("response") print(response) upd_log = CBLog( llm_calls=1, total_tokens=response.usage.total_tokens, completion_tokens=response.usage.completion_tokens, ) return response.choices[0].message.content, upd_log def loop_until_match( function: Callable, pattern_list: Tuple[str], num_attempts: int = 10 ): """ Repeatedly calls a function until its output matches one of the given patterns or max attempts is reached. Args: function (Callable): Function returning (answer: str, cb_log). pattern_list (Tuple[str]): Patterns to match in the answer. num_attempts (int): Max number of attempts (default: 10). Returns: Tuple[str, Any]: The matching answer and its corresponding log object. """ correct_format = False for _ in range(num_attempts): answer, cb_log = function() for pattern in pattern_list: if pattern in answer: correct_format = True if correct_format: break logger.info("Wrong output formatting, retrying...") return answer, cb_log def longcepo_init( initial_query: str, ) -> Tuple[str, str, PreTrainedTokenizerBase, CBLog, LongCepoConfig]: """ Initializes context, query, tokenizer, logging, and config from an input string. Args: initial_query (str): Input string containing context and query separated by a delimiter string. Returns: Tuple[str, str, PreTrainedTokenizerBase, CBLog, LongCepoConfig]: Parsed context, query, tokenizer instance, log object, and LongCePO config. """ cb_log = CBLog() config = LongCepoConfig() context, query = initial_query.split(config.context_query_delimiter) tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name, model_max_length=config.max_context_window) return context.strip(), query.strip(), tokenizer, cb_log, config