File size: 3,271 Bytes
308bca9 bb48904 308bca9 bb48904 8f33f3e bb48904 0235f77 308bca9 8f33f3e 308bca9 8f33f3e 308bca9 8f33f3e 308bca9 8f33f3e 308bca9 2bbbefa 8f33f3e 308bca9 bb48904 4537742 bb48904 0235f77 bb48904 4537742 c12f2c8 0235f77 8f33f3e bb48904 8f33f3e bb48904 3d1d657 4537742 8f33f3e 4537742 3d1d657 8f33f3e 3d1d657 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import torch
from expand import *
from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast, BatchEncoding
from dataclasses import dataclass
import time
type Tokenizer = PreTrainedTokenizer | PreTrainedTokenizerFast
def find_next_tokens(model: PreTrainedModel, inputs: BatchEncoding, threshold: float) -> list[list[tuple[int, float]]]:
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
print("Running inference")
start_time = time.time()
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
print(f"Inference done, took {time.time() - start_time} seconds")
start_time = time.time()
logits: torch.Tensor = outputs.logits[:, -1, :]
log_probs: torch.Tensor = torch.log_softmax(logits, dim=-1)
print(f"Log probs done, took {time.time() - start_time} seconds")
start_time = time.time()
result = []
print(f"Resulting tensor: {log_probs.shape}")
for probs in log_probs:
# Filter out low probability tokens for efficiency
above_threshold = torch.where(probs > threshold)
filtered_indices = above_threshold[0]
filtered_probs = probs[filtered_indices]
result.append([(idx.item(), prob.item()) for idx, prob in zip(filtered_indices, filtered_probs)])
print(f"Result done, took {time.time() - start_time} seconds")
return result
def prepare_inputs(contexts: list[list[int]], tokenizer: Tokenizer, device: torch.device) -> BatchEncoding:
texts = [tokenizer.decode(context, skip_special_tokens=True) for context in contexts]
return tokenizer(texts, return_tensors="pt", padding=True).to(device)
@dataclass
class LLMBatchExpander(BatchExpander):
model: PreTrainedModel
tokenizer: Tokenizer
threshold: float
def expand(self, batch: Batch) -> BatchCandidates:
inputs = prepare_inputs([s.get_all_tokens() for s in batch.items], self.tokenizer, self.model.device)
next_tokens = find_next_tokens(self.model, inputs, self.threshold)
start_time = time.time()
results = []
print(f"Batch size: {len(batch.items)}, next tokens size: {len(next_tokens)}")
for s, next_tokens in zip(batch.items, next_tokens):
expansions = [Expansion(token=token, cost=cost) for token, cost in next_tokens]
results.append(TokenCandidates(series=s, expansions=expansions))
print()
print(f"Token candidates done, took {time.time() - start_time} seconds")
return BatchCandidates(items=results)
def create_stopping_criterion_llm(tokenizer: Tokenizer) -> Callable[[Series, Expansion], bool]:
def stopping_criterion(series: Series, expansion: Expansion) -> bool:
d = default_completion_criterion(series, expansion)
if d:
return d
token_str = tokenizer.decode([expansion.token])
starts_with_space = token_str.startswith(" ")
# print(f"-----{token_str}-----, {starts_with_space=}")
is_first_token = len(series.expansions) == 0
if is_first_token and not starts_with_space:
return True
if not is_first_token and starts_with_space:
return True
return False
return stopping_criterion
|