File size: 3,941 Bytes
308bca9
bb48904
308bca9
bb48904
8f33f3e
bb48904
 
 
0235f77
308bca9
 
8f33f3e
 
308bca9
 
8f33f3e
 
308bca9
 
8f33f3e
 
308bca9
8f33f3e
308bca9
2bbbefa
 
 
 
 
8f33f3e
308bca9
 
 
 
 
 
bb48904
4537742
bb48904
 
0235f77
2d12d3d
bb48904
4537742
8f33f3e
cefa0ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f33f3e
cefa0ab
3d1d657
 
 
 
 
 
 
 
8f33f3e
3d1d657
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
from expand import *
from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast, BatchEncoding
from dataclasses import dataclass
import time

type Tokenizer = PreTrainedTokenizer | PreTrainedTokenizerFast

def find_next_tokens(model: PreTrainedModel, inputs: BatchEncoding, threshold: float) -> list[list[tuple[int, float]]]:
    input_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]
    print("Running inference")
    start_time = time.time()
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
    print(f"Inference done, took {time.time() - start_time} seconds")
    start_time = time.time()
    logits: torch.Tensor = outputs.logits[:, -1, :]
    log_probs: torch.Tensor = torch.log_softmax(logits, dim=-1)
    print(f"Log probs done, took {time.time() - start_time} seconds")
    start_time = time.time()
    result = []
    print(f"Resulting tensor: {log_probs.shape}")
    for probs in log_probs:
        # Filter out low probability tokens for efficiency
        above_threshold = torch.where(probs > threshold)
        filtered_indices = above_threshold[0]
        filtered_probs = probs[filtered_indices]
        result.append([(idx.item(), prob.item()) for idx, prob in zip(filtered_indices, filtered_probs)])
    print(f"Result done, took {time.time() - start_time} seconds")
    return result

def prepare_inputs(contexts: list[list[int]], tokenizer: Tokenizer, device: torch.device) -> BatchEncoding:
    texts = [tokenizer.decode(context, skip_special_tokens=True) for context in contexts]
    return tokenizer(texts, return_tensors="pt", padding=True).to(device)

@dataclass
class LLMBatchExpander(BatchExpander):
    model: PreTrainedModel
    tokenizer: Tokenizer
    threshold: float
    chunk_size: int = 64

    def expand(self, batch: Batch) -> BatchCandidates:
        start_time = time.time()
        all_results = []

        # Split batch.items into chunks to avoid CUDA out of memory
        for i in range(0, len(batch.items), self.chunk_size):
            chunk_items = batch.items[i:i + self.chunk_size]
            print(f"Processing chunk {i//self.chunk_size + 1}/{(len(batch.items) + self.chunk_size - 1)//self.chunk_size} with {len(chunk_items)} items")

            # Process this chunk
            inputs = prepare_inputs([s.get_all_tokens() for s in chunk_items], self.tokenizer, self.model.device)
            chunk_next_tokens = find_next_tokens(self.model, inputs, self.threshold)

            # Create token candidates for this chunk
            for s, next_tokens in zip(chunk_items, chunk_next_tokens):
                expansions = [Expansion(token=token, cost=cost) for token, cost in next_tokens]
                all_results.append(TokenCandidates(series=s, expansions=expansions))

            # Clear CUDA cache to free up memory
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

        print(f"Total batch size: {len(batch.items)}, processed in {(len(batch.items) + self.chunk_size - 1)//self.chunk_size} chunks")
        print(f"Token candidates done, took {time.time() - start_time} seconds")
        return BatchCandidates(items=all_results)

def create_stopping_criterion_llm(tokenizer: Tokenizer) -> Callable[[Series, Expansion], bool]:
    def stopping_criterion(series: Series, expansion: Expansion) -> bool:
        d = default_completion_criterion(series, expansion)
        if d:
            return d
        token_str = tokenizer.decode([expansion.token])
        starts_with_space = token_str.startswith(" ")
        # print(f"-----{token_str}-----, {starts_with_space=}")
        is_first_token = len(series.expansions) == 0
        if is_first_token and not starts_with_space:
            return True
        if not is_first_token and starts_with_space:
            return True
        return False
    return stopping_criterion