File size: 6,686 Bytes
493728d de4f670 5c390d4 493728d a37b646 493728d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 |
import logging
from typing import Callable, List, Optional, Tuple
from concurrent.futures import ThreadPoolExecutor, as_completed
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from .config import LongCepoConfig
logger = logging.getLogger(__name__)
class CBLog(dict):
"""Object for logging the number of LLM calls and tokens used in the pipeline"""
__allowed_keys__ = {"total_tokens", "completion_tokens", "llm_calls"}
def __init__(self, *args, **kwargs):
super().__init__()
self.update(*args, **kwargs)
def __setitem__(self, key, value):
if key not in self.__allowed_keys__:
raise KeyError(
f"Key '{key}' not allowed. Allowed keys: {self.__allowed_keys__}"
)
if not isinstance(value, int):
raise TypeError(
f"Value for '{key}' must be int, got {type(value).__name__}"
)
super().__setitem__(key, value)
def update(self, other=None, **kwargs):
updates = {}
if other:
if isinstance(other, dict):
updates.update(other)
else:
updates.update(dict(other))
updates.update(kwargs)
for key, value in updates.items():
if key not in self.__allowed_keys__:
raise KeyError(
f"Key '{key}' not allowed. Allowed keys: {self.__allowed_keys__}"
)
if not isinstance(value, int):
raise TypeError(
f"Value for '{key}' must be int, got {type(value).__name__}"
)
self[key] = self.get(key, 0) + value
def concurrent_map(
gen_function: Callable,
client,
model: str,
context_chunks: List[str],
query: str,
system_prompt: str,
cb_log: CBLog,
summaries_per_chunk: Optional[List[str]] = None,
workers: int = 16,
) -> Tuple[List[str], CBLog]:
"""
Runs `gen_function` concurrently over a list of context chunks.
Args:
gen_function (Callable): Function to call with each chunk and associated arguments.
client: LLM API client.
model (str): Base model name.
context_chunks (List[str]): Input context chunks.
query (str): User query.
system_prompt (str): System prompt string.
cb_log (CBLog): Log object for tracking model calls.
summaries_per_chunk (Optional[List[str]]): Concatenated neighbor summaries for each chunk.
workers (int): Number of threads to use.
Returns:
Tuple[List[str], CBLog]: List of responses (in original order) and updated log object.
"""
result = [None] * len(context_chunks)
wrapped_gen_function = lambda index, *args: (index, gen_function(*args))
with ThreadPoolExecutor(max_workers=workers) as executor:
future_to_idx = {}
for idx, chunk in enumerate(context_chunks):
args = [client, model, chunk, query, system_prompt]
if summaries_per_chunk is not None:
args.append(summaries_per_chunk[idx])
future_to_idx[executor.submit(wrapped_gen_function, idx, *args)] = idx
for future in as_completed(future_to_idx):
try:
index, (response, upd_log) = future.result()
result[index] = response
cb_log.update(upd_log)
except Exception as e:
logger.error(f"Error processing chunk: {e}")
return result, cb_log
def get_prompt_response(
client,
model: str,
prompt: str,
system_prompt: str,
max_tokens: int,
temperature: float = 0.7,
top_p: float = 0.7,
):
"""
Helper function that sends a prompt to the chat-based LLM API and returns the generated response along with usage logging.
Args:
client: LLM API client.
model (str): Base model name.
prompt (str): The user prompt to send.
system_prompt (str): System prompt string.
max_tokens (int): Maximum number of tokens in the response.
temperature (float): Sampling temperature for randomness (default: 0.7).
top_p (float): Cumulative probability cutoff for token selection (default: 0.7).
Returns:
Tuple[str, CBLog]: The model's response text and a CBLog object tracking token usage.
"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt},
]
print("max_tokens", max_tokens)
print("messages", messages)
response = client.chat.completions.create(
model=model,
messages=messages,
max_tokens=max_tokens,
top_p=top_p,
temperature=temperature,
stream=False,
)
print("response")
print(response)
upd_log = CBLog(
llm_calls=1,
total_tokens=response.usage.total_tokens,
completion_tokens=response.usage.completion_tokens,
)
return response.choices[0].message.content, upd_log
def loop_until_match(
function: Callable, pattern_list: Tuple[str], num_attempts: int = 10
):
"""
Repeatedly calls a function until its output matches one of the given patterns or max attempts is reached.
Args:
function (Callable): Function returning (answer: str, cb_log).
pattern_list (Tuple[str]): Patterns to match in the answer.
num_attempts (int): Max number of attempts (default: 10).
Returns:
Tuple[str, Any]: The matching answer and its corresponding log object.
"""
correct_format = False
for _ in range(num_attempts):
answer, cb_log = function()
for pattern in pattern_list:
if pattern in answer:
correct_format = True
if correct_format:
break
logger.info("Wrong output formatting, retrying...")
return answer, cb_log
def longcepo_init(
initial_query: str,
) -> Tuple[str, str, PreTrainedTokenizerBase, CBLog, LongCepoConfig]:
"""
Initializes context, query, tokenizer, logging, and config from an input string.
Args:
initial_query (str): Input string containing context and query separated by a delimiter string.
Returns:
Tuple[str, str, PreTrainedTokenizerBase, CBLog, LongCepoConfig]:
Parsed context, query, tokenizer instance, log object, and LongCePO config.
"""
cb_log = CBLog()
config = LongCepoConfig()
context, query = initial_query.split(config.context_query_delimiter)
tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name, model_max_length=config.max_context_window)
return context.strip(), query.strip(), tokenizer, cb_log, config
|