Better names for Expander etc
Browse files- README.md +1 -1
- completions.py +1 -2
- expand.py +14 -14
- expand_llm.py +4 -4
- expand_test.py +28 -28
- run.py +2 -2
README.md
CHANGED
@@ -173,7 +173,7 @@ In my case, I stop when the budget is exhausted, and I also stop if the expansio
|
|
173 |
|
174 |
Given the batch and the stopping criterion, we can call the expander:
|
175 |
```python
|
176 |
-
expander =
|
177 |
expanded = expand(batch, expander, stopping_criterion)
|
178 |
```
|
179 |
|
|
|
173 |
|
174 |
Given the batch and the stopping criterion, we can call the expander:
|
175 |
```python
|
176 |
+
expander = LLMBatchExpander(model, tokenizer)
|
177 |
expanded = expand(batch, expander, stopping_criterion)
|
178 |
```
|
179 |
|
completions.py
CHANGED
@@ -92,8 +92,7 @@ def check_text(input_text: str, model: PreTrainedModel, tokenizer: Tokenizer, de
|
|
92 |
|
93 |
contexts = [word.context for _, word in low_prob_words]
|
94 |
|
95 |
-
|
96 |
-
expander = ExpanderOneBatchLLM(model, tokenizer)
|
97 |
|
98 |
#%%
|
99 |
series = []
|
|
|
92 |
|
93 |
contexts = [word.context for _, word in low_prob_words]
|
94 |
|
95 |
+
expander = LLMBatchExpander(model, tokenizer)
|
|
|
96 |
|
97 |
#%%
|
98 |
series = []
|
expand.py
CHANGED
@@ -25,28 +25,28 @@ class Batch:
|
|
25 |
items: list[Series]
|
26 |
|
27 |
@dataclass
|
28 |
-
class
|
29 |
series: Series
|
30 |
expansions: list[Expansion]
|
31 |
|
32 |
@dataclass
|
33 |
-
class
|
34 |
-
items: list[
|
35 |
|
36 |
# A fundamental operation that we can implement both using an LLM and using a list of hardcoded sequences, for testing
|
37 |
-
class
|
38 |
-
def expand(self, batch: Batch) ->
|
39 |
|
40 |
@dataclass
|
41 |
-
class
|
42 |
series: Series
|
43 |
expansions: list[list[Expansion]]
|
44 |
|
45 |
@dataclass
|
46 |
-
class
|
47 |
-
items: list[
|
48 |
|
49 |
-
def compute_new_series(result:
|
50 |
new_series_batch = []
|
51 |
for expansion in result.expansions:
|
52 |
if not stopping_criterion(result.series, expansion):
|
@@ -60,7 +60,7 @@ def compute_new_series(result: ExpansionOneResult, stopping_criterion: Callable[
|
|
60 |
completed_series = [result.series] if len(new_series_batch) == 0 else []
|
61 |
return new_series_batch, completed_series
|
62 |
|
63 |
-
def compute_expansions(original_series: list[Series], expanded_series: list[Series]) ->
|
64 |
# check that ids in original_series are unique
|
65 |
assert len(original_series) == len({s.id for s in original_series})
|
66 |
# group original series by id
|
@@ -73,15 +73,15 @@ def compute_expansions(original_series: list[Series], expanded_series: list[Seri
|
|
73 |
results = []
|
74 |
for id, s in original_series_by_id.items():
|
75 |
expansions = expanded_series_by_id[id]
|
76 |
-
expansion_result =
|
77 |
results.append(expansion_result)
|
78 |
-
return
|
79 |
|
80 |
def default_completion_criterion(series: Series, expansion: Expansion) -> bool:
|
81 |
return series.get_remaining_budget() + expansion.cost < 0
|
82 |
|
83 |
-
# A compound operation that we can implement generically, relying on
|
84 |
-
def expand(batch: Batch, expander:
|
85 |
completed_series: list[Series] = []
|
86 |
current_batch = batch
|
87 |
while len(current_batch.items) > 0:
|
|
|
25 |
items: list[Series]
|
26 |
|
27 |
@dataclass
|
28 |
+
class TokenCandidates:
|
29 |
series: Series
|
30 |
expansions: list[Expansion]
|
31 |
|
32 |
@dataclass
|
33 |
+
class BatchCandidates:
|
34 |
+
items: list[TokenCandidates]
|
35 |
|
36 |
# A fundamental operation that we can implement both using an LLM and using a list of hardcoded sequences, for testing
|
37 |
+
class BatchExpander(Protocol):
|
38 |
+
def expand(self, batch: Batch) -> BatchCandidates: ...
|
39 |
|
40 |
@dataclass
|
41 |
+
class CompletedSequence:
|
42 |
series: Series
|
43 |
expansions: list[list[Expansion]]
|
44 |
|
45 |
@dataclass
|
46 |
+
class CompletedBatch:
|
47 |
+
items: list[CompletedSequence]
|
48 |
|
49 |
+
def compute_new_series(result: TokenCandidates, stopping_criterion: Callable[[Series, Expansion], bool]) -> tuple[list[Series], list[Series]]:
|
50 |
new_series_batch = []
|
51 |
for expansion in result.expansions:
|
52 |
if not stopping_criterion(result.series, expansion):
|
|
|
60 |
completed_series = [result.series] if len(new_series_batch) == 0 else []
|
61 |
return new_series_batch, completed_series
|
62 |
|
63 |
+
def compute_expansions(original_series: list[Series], expanded_series: list[Series]) -> CompletedBatch:
|
64 |
# check that ids in original_series are unique
|
65 |
assert len(original_series) == len({s.id for s in original_series})
|
66 |
# group original series by id
|
|
|
73 |
results = []
|
74 |
for id, s in original_series_by_id.items():
|
75 |
expansions = expanded_series_by_id[id]
|
76 |
+
expansion_result = CompletedSequence(series=s, expansions=expansions)
|
77 |
results.append(expansion_result)
|
78 |
+
return CompletedBatch(items=results)
|
79 |
|
80 |
def default_completion_criterion(series: Series, expansion: Expansion) -> bool:
|
81 |
return series.get_remaining_budget() + expansion.cost < 0
|
82 |
|
83 |
+
# A compound operation that we can implement generically, relying on a BatchExpander
|
84 |
+
def expand(batch: Batch, expander: BatchExpander, completion_criterion: Callable[[Series, Expansion], bool] = default_completion_criterion) -> CompletedBatch:
|
85 |
completed_series: list[Series] = []
|
86 |
current_batch = batch
|
87 |
while len(current_batch.items) > 0:
|
expand_llm.py
CHANGED
@@ -22,18 +22,18 @@ def prepare_inputs(contexts: list[list[int]], tokenizer: Tokenizer, device: torc
|
|
22 |
return tokenizer(texts, return_tensors="pt", padding=True).to(device)
|
23 |
|
24 |
@dataclass
|
25 |
-
class
|
26 |
model: PreTrainedModel
|
27 |
tokenizer: Tokenizer
|
28 |
|
29 |
-
def expand(self, batch: Batch) ->
|
30 |
inputs = prepare_inputs([s.get_all_tokens() for s in batch.items], self.tokenizer, self.model.device)
|
31 |
next_tokens = find_next_tokens(self.model, inputs, self.tokenizer)
|
32 |
results = []
|
33 |
for s, next_tokens in zip(batch.items, next_tokens):
|
34 |
expansions = [Expansion(token=token, cost=cost) for token, cost in next_tokens]
|
35 |
-
results.append(
|
36 |
-
return
|
37 |
|
38 |
def create_stopping_criterion_llm(tokenizer: Tokenizer) -> Callable[[Series, Expansion], bool]:
|
39 |
def stopping_criterion(series: Series, expansion: Expansion) -> bool:
|
|
|
22 |
return tokenizer(texts, return_tensors="pt", padding=True).to(device)
|
23 |
|
24 |
@dataclass
|
25 |
+
class LLMBatchExpander(BatchExpander):
|
26 |
model: PreTrainedModel
|
27 |
tokenizer: Tokenizer
|
28 |
|
29 |
+
def expand(self, batch: Batch) -> BatchCandidates:
|
30 |
inputs = prepare_inputs([s.get_all_tokens() for s in batch.items], self.tokenizer, self.model.device)
|
31 |
next_tokens = find_next_tokens(self.model, inputs, self.tokenizer)
|
32 |
results = []
|
33 |
for s, next_tokens in zip(batch.items, next_tokens):
|
34 |
expansions = [Expansion(token=token, cost=cost) for token, cost in next_tokens]
|
35 |
+
results.append(TokenCandidates(series=s, expansions=expansions))
|
36 |
+
return BatchCandidates(items=results)
|
37 |
|
38 |
def create_stopping_criterion_llm(tokenizer: Tokenizer) -> Callable[[Series, Expansion], bool]:
|
39 |
def stopping_criterion(series: Series, expansion: Expansion) -> bool:
|
expand_test.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
from dataclasses import dataclass
|
2 |
-
from expand import Series,
|
3 |
|
4 |
possible_sequences = [
|
5 |
[1, 21, 31, 41],
|
@@ -16,21 +16,21 @@ def expand_series(series: Series) -> list[Expansion]:
|
|
16 |
candidates = [Expansion(token=l, cost=-1.0) for l in dict.fromkeys(items)]
|
17 |
return candidates
|
18 |
|
19 |
-
class
|
20 |
-
def expand(self, batch: Batch) ->
|
21 |
result = []
|
22 |
for s in batch.items:
|
23 |
expansions = expand_series(s)
|
24 |
-
result.append(
|
25 |
-
return
|
26 |
|
27 |
-
expander =
|
28 |
|
29 |
def test_expander_zero_budget():
|
30 |
s = Series(id=0, tokens=[1], budget=0.0)
|
31 |
expanded = expander.expand(Batch(items=[s]))
|
32 |
-
expected =
|
33 |
-
items=[
|
34 |
Expansion(token=21, cost=-1.0),
|
35 |
Expansion(token=22, cost=-1.0),
|
36 |
])]
|
@@ -40,8 +40,8 @@ def test_expander_zero_budget():
|
|
40 |
def test_expander_budget_one():
|
41 |
s = Series(id=0, tokens=[1], budget=1.0)
|
42 |
expanded = expander.expand(Batch(items=[s]))
|
43 |
-
expected =
|
44 |
-
items=[
|
45 |
Expansion(token=21, cost=-1.0),
|
46 |
Expansion(token=22, cost=-1.0),
|
47 |
])]
|
@@ -51,8 +51,8 @@ def test_expander_budget_one():
|
|
51 |
def test_expander_budget_two():
|
52 |
s = Series(id=0, tokens=[1], budget=2.0)
|
53 |
expanded = expander.expand(Batch(items=[s]))
|
54 |
-
expected =
|
55 |
-
items=[
|
56 |
Expansion(token=21, cost=-1.0),
|
57 |
Expansion(token=22, cost=-1.0),
|
58 |
])]
|
@@ -62,16 +62,16 @@ def test_expander_budget_two():
|
|
62 |
def test_expander_budget_one_no_expansion():
|
63 |
s = Series(id=0, tokens=[1, 20], budget=1.0)
|
64 |
expanded = expander.expand(Batch(items=[s]))
|
65 |
-
expected =
|
66 |
-
items=[
|
67 |
)
|
68 |
assert expected == expanded
|
69 |
|
70 |
def test_expander_budget_one_two_tokens():
|
71 |
s = Series(id=0, tokens=[1, 22], budget=1.0)
|
72 |
expanded = expander.expand(Batch(items=[s]))
|
73 |
-
expected =
|
74 |
-
items=[
|
75 |
Expansion(token=33, cost=-1.0),
|
76 |
Expansion(token=34, cost=-1.0),
|
77 |
])]
|
@@ -82,13 +82,13 @@ def test_expander_budget_one_two_tokens_two_series():
|
|
82 |
s1 = Series(id=0, tokens=[1, 21, 31], budget=1.0)
|
83 |
s2 = Series(id=1, tokens=[1, 22], budget=1.0)
|
84 |
expanded = expander.expand(Batch(items=[s1, s2]))
|
85 |
-
expected =
|
86 |
items=[
|
87 |
-
|
88 |
Expansion(token=41, cost=-1.0),
|
89 |
Expansion(token=42, cost=-1.0),
|
90 |
]),
|
91 |
-
|
92 |
Expansion(token=33, cost=-1.0),
|
93 |
Expansion(token=34, cost=-1.0),
|
94 |
])
|
@@ -102,15 +102,15 @@ def test_expand_01():
|
|
102 |
Series(id=1, tokens=[1, 22], budget=1.0),
|
103 |
])
|
104 |
expanded = expand(batch, expander)
|
105 |
-
assert expanded ==
|
106 |
-
|
107 |
series=Series(id=0, tokens=[1, 21], budget=1.0),
|
108 |
expansions=[
|
109 |
[Expansion(token=31, cost=-1.0)],
|
110 |
[Expansion(token=32, cost=-1.0)],
|
111 |
]
|
112 |
),
|
113 |
-
|
114 |
series=Series(id=1, tokens=[1, 22], budget=1.0),
|
115 |
expansions=[
|
116 |
[Expansion(token=33, cost=-1.0)],
|
@@ -125,8 +125,8 @@ def test_expand_02():
|
|
125 |
Series(id=1, tokens=[1, 22], budget=1.0),
|
126 |
])
|
127 |
expanded = expand(batch, expander)
|
128 |
-
assert expanded ==
|
129 |
-
|
130 |
series=Series(id=0, tokens=[1, 21], budget=2.0),
|
131 |
expansions=[
|
132 |
[Expansion(token=31, cost=-1.0), Expansion(token=41, cost=-1.0)],
|
@@ -134,7 +134,7 @@ def test_expand_02():
|
|
134 |
[Expansion(token=32, cost=-1.0), Expansion(token=41, cost=-1.0)],
|
135 |
]
|
136 |
),
|
137 |
-
|
138 |
series=Series(id=1, tokens=[1, 22], budget=1.0),
|
139 |
expansions=[
|
140 |
[Expansion(token=33, cost=-1.0)],
|
@@ -149,8 +149,8 @@ def test_expand_03():
|
|
149 |
Series(id=1, tokens=[1, 22], budget=0.0),
|
150 |
])
|
151 |
expanded = expand(batch, expander)
|
152 |
-
assert expanded ==
|
153 |
-
|
154 |
series=Series(id=0, tokens=[1, 21], budget=3.0),
|
155 |
expansions=[
|
156 |
[Expansion(token=31, cost=-1.0), Expansion(token=41, cost=-1.0)],
|
@@ -158,7 +158,7 @@ def test_expand_03():
|
|
158 |
[Expansion(token=32, cost=-1.0), Expansion(token=41, cost=-1.0), Expansion(token=51, cost=-1.0)],
|
159 |
]
|
160 |
),
|
161 |
-
|
162 |
series=Series(id=1, tokens=[1, 22], budget=0.0),
|
163 |
expansions=[],
|
164 |
),
|
|
|
1 |
from dataclasses import dataclass
|
2 |
+
from expand import Series, BatchExpander, Expansion, Batch, TokenCandidates, BatchCandidates, CompletedSequence, CompletedBatch, expand
|
3 |
|
4 |
possible_sequences = [
|
5 |
[1, 21, 31, 41],
|
|
|
16 |
candidates = [Expansion(token=l, cost=-1.0) for l in dict.fromkeys(items)]
|
17 |
return candidates
|
18 |
|
19 |
+
class PredefinedSequenceExpander(BatchExpander):
|
20 |
+
def expand(self, batch: Batch) -> BatchCandidates:
|
21 |
result = []
|
22 |
for s in batch.items:
|
23 |
expansions = expand_series(s)
|
24 |
+
result.append(TokenCandidates(series=s, expansions=expansions))
|
25 |
+
return BatchCandidates(items=result)
|
26 |
|
27 |
+
expander = PredefinedSequenceExpander()
|
28 |
|
29 |
def test_expander_zero_budget():
|
30 |
s = Series(id=0, tokens=[1], budget=0.0)
|
31 |
expanded = expander.expand(Batch(items=[s]))
|
32 |
+
expected = BatchCandidates(
|
33 |
+
items=[TokenCandidates(series=s, expansions=[
|
34 |
Expansion(token=21, cost=-1.0),
|
35 |
Expansion(token=22, cost=-1.0),
|
36 |
])]
|
|
|
40 |
def test_expander_budget_one():
|
41 |
s = Series(id=0, tokens=[1], budget=1.0)
|
42 |
expanded = expander.expand(Batch(items=[s]))
|
43 |
+
expected = BatchCandidates(
|
44 |
+
items=[TokenCandidates(series=s, expansions=[
|
45 |
Expansion(token=21, cost=-1.0),
|
46 |
Expansion(token=22, cost=-1.0),
|
47 |
])]
|
|
|
51 |
def test_expander_budget_two():
|
52 |
s = Series(id=0, tokens=[1], budget=2.0)
|
53 |
expanded = expander.expand(Batch(items=[s]))
|
54 |
+
expected = BatchCandidates(
|
55 |
+
items=[TokenCandidates(series=s, expansions=[
|
56 |
Expansion(token=21, cost=-1.0),
|
57 |
Expansion(token=22, cost=-1.0),
|
58 |
])]
|
|
|
62 |
def test_expander_budget_one_no_expansion():
|
63 |
s = Series(id=0, tokens=[1, 20], budget=1.0)
|
64 |
expanded = expander.expand(Batch(items=[s]))
|
65 |
+
expected = BatchCandidates(
|
66 |
+
items=[TokenCandidates(series=s, expansions=[])]
|
67 |
)
|
68 |
assert expected == expanded
|
69 |
|
70 |
def test_expander_budget_one_two_tokens():
|
71 |
s = Series(id=0, tokens=[1, 22], budget=1.0)
|
72 |
expanded = expander.expand(Batch(items=[s]))
|
73 |
+
expected = BatchCandidates(
|
74 |
+
items=[TokenCandidates(series=s, expansions=[
|
75 |
Expansion(token=33, cost=-1.0),
|
76 |
Expansion(token=34, cost=-1.0),
|
77 |
])]
|
|
|
82 |
s1 = Series(id=0, tokens=[1, 21, 31], budget=1.0)
|
83 |
s2 = Series(id=1, tokens=[1, 22], budget=1.0)
|
84 |
expanded = expander.expand(Batch(items=[s1, s2]))
|
85 |
+
expected = BatchCandidates(
|
86 |
items=[
|
87 |
+
TokenCandidates(series=s1, expansions=[
|
88 |
Expansion(token=41, cost=-1.0),
|
89 |
Expansion(token=42, cost=-1.0),
|
90 |
]),
|
91 |
+
TokenCandidates(series=s2, expansions=[
|
92 |
Expansion(token=33, cost=-1.0),
|
93 |
Expansion(token=34, cost=-1.0),
|
94 |
])
|
|
|
102 |
Series(id=1, tokens=[1, 22], budget=1.0),
|
103 |
])
|
104 |
expanded = expand(batch, expander)
|
105 |
+
assert expanded == CompletedBatch(items=[
|
106 |
+
CompletedSequence(
|
107 |
series=Series(id=0, tokens=[1, 21], budget=1.0),
|
108 |
expansions=[
|
109 |
[Expansion(token=31, cost=-1.0)],
|
110 |
[Expansion(token=32, cost=-1.0)],
|
111 |
]
|
112 |
),
|
113 |
+
CompletedSequence(
|
114 |
series=Series(id=1, tokens=[1, 22], budget=1.0),
|
115 |
expansions=[
|
116 |
[Expansion(token=33, cost=-1.0)],
|
|
|
125 |
Series(id=1, tokens=[1, 22], budget=1.0),
|
126 |
])
|
127 |
expanded = expand(batch, expander)
|
128 |
+
assert expanded == CompletedBatch(items=[
|
129 |
+
CompletedSequence(
|
130 |
series=Series(id=0, tokens=[1, 21], budget=2.0),
|
131 |
expansions=[
|
132 |
[Expansion(token=31, cost=-1.0), Expansion(token=41, cost=-1.0)],
|
|
|
134 |
[Expansion(token=32, cost=-1.0), Expansion(token=41, cost=-1.0)],
|
135 |
]
|
136 |
),
|
137 |
+
CompletedSequence(
|
138 |
series=Series(id=1, tokens=[1, 22], budget=1.0),
|
139 |
expansions=[
|
140 |
[Expansion(token=33, cost=-1.0)],
|
|
|
149 |
Series(id=1, tokens=[1, 22], budget=0.0),
|
150 |
])
|
151 |
expanded = expand(batch, expander)
|
152 |
+
assert expanded == CompletedBatch(items=[
|
153 |
+
CompletedSequence(
|
154 |
series=Series(id=0, tokens=[1, 21], budget=3.0),
|
155 |
expansions=[
|
156 |
[Expansion(token=31, cost=-1.0), Expansion(token=41, cost=-1.0)],
|
|
|
158 |
[Expansion(token=32, cost=-1.0), Expansion(token=41, cost=-1.0), Expansion(token=51, cost=-1.0)],
|
159 |
]
|
160 |
),
|
161 |
+
CompletedSequence(
|
162 |
series=Series(id=1, tokens=[1, 22], budget=0.0),
|
163 |
expansions=[],
|
164 |
),
|
run.py
CHANGED
@@ -24,7 +24,7 @@ low_prob_words = [(i, word) for i, word in enumerate(words) if word.logprob < lo
|
|
24 |
contexts = [word.context for _, word in low_prob_words]
|
25 |
|
26 |
#%%
|
27 |
-
expander =
|
28 |
|
29 |
#%%
|
30 |
series = []
|
@@ -41,7 +41,7 @@ stopping_criterion = create_stopping_criterion_llm(tokenizer)
|
|
41 |
expanded = expand(batch, expander, stopping_criterion)
|
42 |
|
43 |
# %%
|
44 |
-
def print_expansions(expansions:
|
45 |
for result in expansions.items:
|
46 |
for expansion in result.expansions:
|
47 |
# convert tokens to string
|
|
|
24 |
contexts = [word.context for _, word in low_prob_words]
|
25 |
|
26 |
#%%
|
27 |
+
expander = LLMBatchExpander(model, tokenizer)
|
28 |
|
29 |
#%%
|
30 |
series = []
|
|
|
41 |
expanded = expand(batch, expander, stopping_criterion)
|
42 |
|
43 |
# %%
|
44 |
+
def print_expansions(expansions: CompletedBatch):
|
45 |
for result in expansions.items:
|
46 |
for expansion in result.expansions:
|
47 |
# convert tokens to string
|