Working expand_llm
Browse files- completions.py +13 -29
- expand.py +2 -5
- expand_llm.py +20 -0
- expand_test.py +12 -12
- run.py +49 -0
completions.py
CHANGED
@@ -95,7 +95,7 @@ def generate_outputs(model: PreTrainedModel, inputs: BatchEncoding, num_samples:
|
|
95 |
)
|
96 |
return outputs
|
97 |
|
98 |
-
def
|
99 |
input_ids = inputs["input_ids"]
|
100 |
attention_mask = inputs["attention_mask"]
|
101 |
with torch.no_grad():
|
@@ -109,6 +109,18 @@ def find_next_tokens(model: PreTrainedModel, inputs: BatchEncoding, tokenizer: T
|
|
109 |
result.append([(i, tokenizer.convert_ids_to_tokens([i])[0], p) for i, p in enumerate(probs) if p > min_p])
|
110 |
return result
|
111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
def extract_replacements(outputs: GenerateOutput | torch.LongTensor, tokenizer: Tokenizer, num_inputs: int, input_len: int, num_samples: int = 5) -> list[list[str]]:
|
113 |
all_new_words = []
|
114 |
for i in range(num_inputs):
|
@@ -161,31 +173,3 @@ def check_text(input_text: str, model: PreTrainedModel, tokenizer: Tokenizer, de
|
|
161 |
else:
|
162 |
result.append(ApiWord(text=word.text, logprob=word.logprob, replacements=[]))
|
163 |
return result
|
164 |
-
|
165 |
-
# %%
|
166 |
-
model, tokenizer, device = load_model()
|
167 |
-
|
168 |
-
#%%
|
169 |
-
input_text = "The quick brown fox jumpz over"
|
170 |
-
inputs: BatchEncoding = tokenize(input_text, tokenizer, device)
|
171 |
-
|
172 |
-
#%%
|
173 |
-
token_probs: list[tuple[int, float]] = calculate_log_probabilities(model, tokenizer, inputs)
|
174 |
-
|
175 |
-
#%%
|
176 |
-
words = split_into_words(token_probs, tokenizer)
|
177 |
-
log_prob_threshold = -5.0
|
178 |
-
low_prob_words = [(i, word) for i, word in enumerate(words) if word.logprob < log_prob_threshold]
|
179 |
-
|
180 |
-
#%%
|
181 |
-
contexts = [word.context for _, word in low_prob_words]
|
182 |
-
inputs = prepare_inputs(contexts, tokenizer, device)
|
183 |
-
input_ids = inputs["input_ids"]
|
184 |
-
|
185 |
-
#%%
|
186 |
-
next_tokens = find_next_tokens(model, inputs, tokenizer, min_p=-5)
|
187 |
-
|
188 |
-
#%%
|
189 |
-
next_tokens
|
190 |
-
|
191 |
-
# %%
|
|
|
95 |
)
|
96 |
return outputs
|
97 |
|
98 |
+
def find_next_tokens_0(model: PreTrainedModel, inputs: BatchEncoding, tokenizer: Tokenizer, min_p: float) -> list[list[tuple[int, str, float]]]:
|
99 |
input_ids = inputs["input_ids"]
|
100 |
attention_mask = inputs["attention_mask"]
|
101 |
with torch.no_grad():
|
|
|
109 |
result.append([(i, tokenizer.convert_ids_to_tokens([i])[0], p) for i, p in enumerate(probs) if p > min_p])
|
110 |
return result
|
111 |
|
112 |
+
def find_next_tokens(model: PreTrainedModel, inputs: BatchEncoding, tokenizer: Tokenizer) -> list[list[tuple[int, float]]]:
|
113 |
+
input_ids = inputs["input_ids"]
|
114 |
+
attention_mask = inputs["attention_mask"]
|
115 |
+
with torch.no_grad():
|
116 |
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
117 |
+
logits: torch.Tensor = outputs.logits[:, -1, :]
|
118 |
+
log_probs: torch.Tensor = torch.log_softmax(logits, dim=-1)
|
119 |
+
result = []
|
120 |
+
for probs in log_probs:
|
121 |
+
result.append([(i, p) for i, p in enumerate(probs)])
|
122 |
+
return result
|
123 |
+
|
124 |
def extract_replacements(outputs: GenerateOutput | torch.LongTensor, tokenizer: Tokenizer, num_inputs: int, input_len: int, num_samples: int = 5) -> list[list[str]]:
|
125 |
all_new_words = []
|
126 |
for i in range(num_inputs):
|
|
|
173 |
else:
|
174 |
result.append(ApiWord(text=word.text, logprob=word.logprob, replacements=[]))
|
175 |
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
expand.py
CHANGED
@@ -2,10 +2,6 @@ from collections import defaultdict
|
|
2 |
from dataclasses import dataclass
|
3 |
from typing import Protocol
|
4 |
|
5 |
-
# import torch
|
6 |
-
# from transformers import PreTrainedModel
|
7 |
-
# from completions import find_next_tokens, Tokenizer
|
8 |
-
|
9 |
@dataclass
|
10 |
class Series:
|
11 |
id: int
|
@@ -46,7 +42,7 @@ class ExpansionResultBatch:
|
|
46 |
def compute_new_series(result: ExpansionOneResult) -> list[Series]:
|
47 |
results = []
|
48 |
for expansion in result.expansions:
|
49 |
-
results.append(Series(id=result.series.id, tokens=result.series.tokens + [expansion.token], budget=result.series.budget
|
50 |
return results
|
51 |
|
52 |
def compute_expansions(original_series: list[Series], expanded_series: list[Series]) -> ExpansionResultBatch:
|
@@ -73,6 +69,7 @@ def expand(batch: Batch, expander: ExpanderOneBatch) -> ExpansionResultBatch:
|
|
73 |
completed_series: list[Series] = []
|
74 |
current_batch = batch
|
75 |
while len(current_batch.items) > 0:
|
|
|
76 |
current_batch_items = []
|
77 |
expanded = expander.expand(current_batch)
|
78 |
for item in expanded.items:
|
|
|
2 |
from dataclasses import dataclass
|
3 |
from typing import Protocol
|
4 |
|
|
|
|
|
|
|
|
|
5 |
@dataclass
|
6 |
class Series:
|
7 |
id: int
|
|
|
42 |
def compute_new_series(result: ExpansionOneResult) -> list[Series]:
|
43 |
results = []
|
44 |
for expansion in result.expansions:
|
45 |
+
results.append(Series(id=result.series.id, tokens=result.series.tokens + [expansion.token], budget=result.series.budget + expansion.cost))
|
46 |
return results
|
47 |
|
48 |
def compute_expansions(original_series: list[Series], expanded_series: list[Series]) -> ExpansionResultBatch:
|
|
|
69 |
completed_series: list[Series] = []
|
70 |
current_batch = batch
|
71 |
while len(current_batch.items) > 0:
|
72 |
+
print(f"Expanding {len(current_batch.items)} series: {current_batch.items}")
|
73 |
current_batch_items = []
|
74 |
expanded = expander.expand(current_batch)
|
75 |
for item in expanded.items:
|
expand_llm.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from expand import *
|
2 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast, BatchEncoding
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from completions import prepare_inputs, find_next_tokens
|
5 |
+
|
6 |
+
type Tokenizer = PreTrainedTokenizer | PreTrainedTokenizerFast
|
7 |
+
|
8 |
+
@dataclass
|
9 |
+
class ExpanderOneBatchLLM:
|
10 |
+
model: PreTrainedModel
|
11 |
+
tokenizer: Tokenizer
|
12 |
+
|
13 |
+
def expand(self, batch: Batch) -> ExpansionOneResultBatch:
|
14 |
+
inputs = prepare_inputs([s.tokens for s in batch.items], self.tokenizer, self.model.device)
|
15 |
+
next_tokens = find_next_tokens(self.model, inputs, self.tokenizer)
|
16 |
+
results = []
|
17 |
+
for s, next_tokens in zip(batch.items, next_tokens):
|
18 |
+
expansions = [ExpansionOne(token=token, cost=logprob) for token, logprob in next_tokens if logprob + s.budget >= 0]
|
19 |
+
results.append(ExpansionOneResult(series=s, expansions=expansions))
|
20 |
+
return ExpansionOneResultBatch(items=results)
|
expand_test.py
CHANGED
@@ -12,8 +12,8 @@ possible_sequences = [
|
|
12 |
def expand_series(series: Series) -> list[ExpansionOne]:
|
13 |
l = len(series.tokens)
|
14 |
items = [s[l] for s in possible_sequences if s[:l] == series.tokens and len(s) > l]
|
15 |
-
candidates = [ExpansionOne(token=l, cost
|
16 |
-
return [c for c in candidates if c.cost
|
17 |
|
18 |
class HardcodedExpanderOneBatch(ExpanderOneBatch):
|
19 |
def expand(self, batch: Batch) -> ExpansionOneResultBatch:
|
@@ -38,8 +38,8 @@ def test_expander_budget_one():
|
|
38 |
expanded = expander.expand(Batch(items=[s]))
|
39 |
expected = ExpansionOneResultBatch(
|
40 |
items=[ExpansionOneResult(series=s, expansions=[
|
41 |
-
ExpansionOne(token=21, cost
|
42 |
-
ExpansionOne(token=22, cost
|
43 |
])]
|
44 |
)
|
45 |
assert expected == expanded
|
@@ -49,8 +49,8 @@ def test_expander_budget_two():
|
|
49 |
expanded = expander.expand(Batch(items=[s]))
|
50 |
expected = ExpansionOneResultBatch(
|
51 |
items=[ExpansionOneResult(series=s, expansions=[
|
52 |
-
ExpansionOne(token=21, cost
|
53 |
-
ExpansionOne(token=22, cost
|
54 |
])]
|
55 |
)
|
56 |
assert expected == expanded
|
@@ -68,8 +68,8 @@ def test_expander_budget_one_two_tokens():
|
|
68 |
expanded = expander.expand(Batch(items=[s]))
|
69 |
expected = ExpansionOneResultBatch(
|
70 |
items=[ExpansionOneResult(series=s, expansions=[
|
71 |
-
ExpansionOne(token=33, cost
|
72 |
-
ExpansionOne(token=34, cost
|
73 |
])]
|
74 |
)
|
75 |
assert expected == expanded
|
@@ -81,12 +81,12 @@ def test_expander_budget_one_two_tokens_two_series():
|
|
81 |
expected = ExpansionOneResultBatch(
|
82 |
items=[
|
83 |
ExpansionOneResult(series=s1, expansions=[
|
84 |
-
ExpansionOne(token=41, cost
|
85 |
-
ExpansionOne(token=42, cost
|
86 |
]),
|
87 |
ExpansionOneResult(series=s2, expansions=[
|
88 |
-
ExpansionOne(token=33, cost
|
89 |
-
ExpansionOne(token=34, cost
|
90 |
])
|
91 |
]
|
92 |
)
|
|
|
12 |
def expand_series(series: Series) -> list[ExpansionOne]:
|
13 |
l = len(series.tokens)
|
14 |
items = [s[l] for s in possible_sequences if s[:l] == series.tokens and len(s) > l]
|
15 |
+
candidates = [ExpansionOne(token=l, cost=-1.0) for l in dict.fromkeys(items)]
|
16 |
+
return [c for c in candidates if c.cost + series.budget >= 0]
|
17 |
|
18 |
class HardcodedExpanderOneBatch(ExpanderOneBatch):
|
19 |
def expand(self, batch: Batch) -> ExpansionOneResultBatch:
|
|
|
38 |
expanded = expander.expand(Batch(items=[s]))
|
39 |
expected = ExpansionOneResultBatch(
|
40 |
items=[ExpansionOneResult(series=s, expansions=[
|
41 |
+
ExpansionOne(token=21, cost=-1.0),
|
42 |
+
ExpansionOne(token=22, cost=-1.0),
|
43 |
])]
|
44 |
)
|
45 |
assert expected == expanded
|
|
|
49 |
expanded = expander.expand(Batch(items=[s]))
|
50 |
expected = ExpansionOneResultBatch(
|
51 |
items=[ExpansionOneResult(series=s, expansions=[
|
52 |
+
ExpansionOne(token=21, cost=-1.0),
|
53 |
+
ExpansionOne(token=22, cost=-1.0),
|
54 |
])]
|
55 |
)
|
56 |
assert expected == expanded
|
|
|
68 |
expanded = expander.expand(Batch(items=[s]))
|
69 |
expected = ExpansionOneResultBatch(
|
70 |
items=[ExpansionOneResult(series=s, expansions=[
|
71 |
+
ExpansionOne(token=33, cost=-1.0),
|
72 |
+
ExpansionOne(token=34, cost=-1.0),
|
73 |
])]
|
74 |
)
|
75 |
assert expected == expanded
|
|
|
81 |
expected = ExpansionOneResultBatch(
|
82 |
items=[
|
83 |
ExpansionOneResult(series=s1, expansions=[
|
84 |
+
ExpansionOne(token=41, cost=-1.0),
|
85 |
+
ExpansionOne(token=42, cost=-1.0),
|
86 |
]),
|
87 |
ExpansionOneResult(series=s2, expansions=[
|
88 |
+
ExpansionOne(token=33, cost=-1.0),
|
89 |
+
ExpansionOne(token=34, cost=-1.0),
|
90 |
])
|
91 |
]
|
92 |
)
|
run.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#%%
|
2 |
+
from completions import *
|
3 |
+
from expand_llm import *
|
4 |
+
from expand import *
|
5 |
+
|
6 |
+
# %%
|
7 |
+
model, tokenizer, device = load_model()
|
8 |
+
|
9 |
+
#%%
|
10 |
+
# input_text = "The quick brown fox jumpz over"
|
11 |
+
# input_text = "He asked me to prostate myself before the king"
|
12 |
+
input_text = "Здравствуйте, я хочу предвыполнить заказ"
|
13 |
+
inputs: BatchEncoding = tokenize(input_text, tokenizer, device)
|
14 |
+
|
15 |
+
#%%
|
16 |
+
token_probs: list[tuple[int, float]] = calculate_log_probabilities(model, tokenizer, inputs)
|
17 |
+
|
18 |
+
#%%
|
19 |
+
words = split_into_words(token_probs, tokenizer)
|
20 |
+
log_prob_threshold = -5.0
|
21 |
+
low_prob_words = [(i, word) for i, word in enumerate(words) if word.logprob < log_prob_threshold]
|
22 |
+
|
23 |
+
#%%
|
24 |
+
contexts = [word.context for _, word in low_prob_words]
|
25 |
+
|
26 |
+
#%%
|
27 |
+
expander = ExpanderOneBatchLLM(model, tokenizer)
|
28 |
+
|
29 |
+
#%%
|
30 |
+
series = []
|
31 |
+
for i, x in enumerate(contexts):
|
32 |
+
series.append(Series(id=i, tokens=x, budget=5.0))
|
33 |
+
|
34 |
+
#%%
|
35 |
+
batch = Batch(items=series)
|
36 |
+
|
37 |
+
#%%
|
38 |
+
expanded = expand(batch, expander)
|
39 |
+
|
40 |
+
# %%
|
41 |
+
def print_expansions(expansions: ExpansionResultBatch):
|
42 |
+
for result in expansions.items:
|
43 |
+
for expansion in result.expansions:
|
44 |
+
# convert tokens to string
|
45 |
+
s = tokenizer.decode(expansion)
|
46 |
+
print(f"{result.series.id}: {expansion} {s}")
|
47 |
+
|
48 |
+
print_expansions(expanded)
|
49 |
+
# %%
|