File size: 3,271 Bytes
308bca9
bb48904
308bca9
bb48904
8f33f3e
bb48904
 
 
0235f77
308bca9
 
8f33f3e
 
308bca9
 
8f33f3e
 
308bca9
 
8f33f3e
 
308bca9
8f33f3e
308bca9
2bbbefa
 
 
 
 
8f33f3e
308bca9
 
 
 
 
 
bb48904
4537742
bb48904
 
0235f77
bb48904
4537742
c12f2c8
0235f77
8f33f3e
bb48904
8f33f3e
bb48904
3d1d657
4537742
8f33f3e
 
4537742
3d1d657
 
 
 
 
 
 
 
8f33f3e
3d1d657
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
from expand import *
from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast, BatchEncoding
from dataclasses import dataclass
import time

type Tokenizer = PreTrainedTokenizer | PreTrainedTokenizerFast

def find_next_tokens(model: PreTrainedModel, inputs: BatchEncoding, threshold: float) -> list[list[tuple[int, float]]]:
    input_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]
    print("Running inference")
    start_time = time.time()
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
    print(f"Inference done, took {time.time() - start_time} seconds")
    start_time = time.time()
    logits: torch.Tensor = outputs.logits[:, -1, :]
    log_probs: torch.Tensor = torch.log_softmax(logits, dim=-1)
    print(f"Log probs done, took {time.time() - start_time} seconds")
    start_time = time.time()
    result = []
    print(f"Resulting tensor: {log_probs.shape}")
    for probs in log_probs:
        # Filter out low probability tokens for efficiency
        above_threshold = torch.where(probs > threshold)
        filtered_indices = above_threshold[0]
        filtered_probs = probs[filtered_indices]
        result.append([(idx.item(), prob.item()) for idx, prob in zip(filtered_indices, filtered_probs)])
    print(f"Result done, took {time.time() - start_time} seconds")
    return result

def prepare_inputs(contexts: list[list[int]], tokenizer: Tokenizer, device: torch.device) -> BatchEncoding:
    texts = [tokenizer.decode(context, skip_special_tokens=True) for context in contexts]
    return tokenizer(texts, return_tensors="pt", padding=True).to(device)

@dataclass
class LLMBatchExpander(BatchExpander):
    model: PreTrainedModel
    tokenizer: Tokenizer
    threshold: float

    def expand(self, batch: Batch) -> BatchCandidates:
        inputs = prepare_inputs([s.get_all_tokens() for s in batch.items], self.tokenizer, self.model.device)
        next_tokens = find_next_tokens(self.model, inputs, self.threshold)
        start_time = time.time()
        results = []
        print(f"Batch size: {len(batch.items)}, next tokens size: {len(next_tokens)}")
        for s, next_tokens in zip(batch.items, next_tokens):
            expansions = [Expansion(token=token, cost=cost) for token, cost in next_tokens]
            results.append(TokenCandidates(series=s, expansions=expansions))
        print()
        print(f"Token candidates done, took {time.time() - start_time} seconds")
        return BatchCandidates(items=results)

def create_stopping_criterion_llm(tokenizer: Tokenizer) -> Callable[[Series, Expansion], bool]:
    def stopping_criterion(series: Series, expansion: Expansion) -> bool:
        d = default_completion_criterion(series, expansion)
        if d:
            return d
        token_str = tokenizer.decode([expansion.token])
        starts_with_space = token_str.startswith(" ")
        # print(f"-----{token_str}-----, {starts_with_space=}")
        is_first_token = len(series.expansions) == 0
        if is_first_token and not starts_with_space:
            return True
        if not is_first_token and starts_with_space:
            return True
        return False
    return stopping_criterion