Snapshot
Browse files- expand.py +26 -16
- expand_llm.py +1 -1
- expand_test.py +29 -28
- run.py +3 -2
expand.py
CHANGED
@@ -1,26 +1,33 @@
|
|
1 |
from collections import defaultdict
|
2 |
-
from dataclasses import dataclass
|
3 |
-
from typing import Protocol
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
@dataclass
|
6 |
class Series:
|
7 |
id: int
|
8 |
tokens: list[int]
|
9 |
budget: float
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
@dataclass
|
12 |
class Batch:
|
13 |
items: list[Series]
|
14 |
|
15 |
-
@dataclass
|
16 |
-
class ExpansionOne:
|
17 |
-
token: int
|
18 |
-
cost: float
|
19 |
-
|
20 |
@dataclass
|
21 |
class ExpansionOneResult:
|
22 |
series: Series
|
23 |
-
expansions: list[
|
24 |
|
25 |
@dataclass
|
26 |
class ExpansionOneResultBatch:
|
@@ -33,7 +40,7 @@ class ExpanderOneBatch(Protocol):
|
|
33 |
@dataclass
|
34 |
class ExpansionResult:
|
35 |
series: Series
|
36 |
-
expansions: list[list[
|
37 |
|
38 |
@dataclass
|
39 |
class ExpansionResultBatch:
|
@@ -42,7 +49,12 @@ class ExpansionResultBatch:
|
|
42 |
def compute_new_series(result: ExpansionOneResult) -> list[Series]:
|
43 |
results = []
|
44 |
for expansion in result.expansions:
|
45 |
-
results.append(Series(
|
|
|
|
|
|
|
|
|
|
|
46 |
return results
|
47 |
|
48 |
def compute_expansions(original_series: list[Series], expanded_series: list[Series]) -> ExpansionResultBatch:
|
@@ -51,16 +63,14 @@ def compute_expansions(original_series: list[Series], expanded_series: list[Seri
|
|
51 |
# group original series by id
|
52 |
original_series_by_id = {s.id: s for s in original_series}
|
53 |
# group expanded series by id
|
54 |
-
expanded_series_by_id: dict[int, list[list[
|
55 |
for s in expanded_series:
|
56 |
-
|
|
|
57 |
results = []
|
58 |
for id, s in original_series_by_id.items():
|
59 |
expansions = expanded_series_by_id[id]
|
60 |
-
|
61 |
-
l = len(s.tokens)
|
62 |
-
trimmed_expansions = [e[l:] for e in expansions if len(e) > l]
|
63 |
-
expansion_result = ExpansionResult(series=s, expansions=trimmed_expansions)
|
64 |
results.append(expansion_result)
|
65 |
return ExpansionResultBatch(items=results)
|
66 |
|
|
|
1 |
from collections import defaultdict
|
2 |
+
from dataclasses import dataclass, field
|
3 |
+
from typing import Protocol, Self
|
4 |
+
|
5 |
+
@dataclass
|
6 |
+
class Expansion:
|
7 |
+
token: int
|
8 |
+
cost: float
|
9 |
|
10 |
@dataclass
|
11 |
class Series:
|
12 |
id: int
|
13 |
tokens: list[int]
|
14 |
budget: float
|
15 |
+
expansions: list[Expansion] = field(default_factory=list)
|
16 |
+
|
17 |
+
def get_all_tokens(self) -> list[int]:
|
18 |
+
return self.tokens + [e.token for e in self.expansions]
|
19 |
+
|
20 |
+
def get_remaining_budget(self) -> float:
|
21 |
+
return self.budget + sum(e.cost for e in self.expansions)
|
22 |
|
23 |
@dataclass
|
24 |
class Batch:
|
25 |
items: list[Series]
|
26 |
|
|
|
|
|
|
|
|
|
|
|
27 |
@dataclass
|
28 |
class ExpansionOneResult:
|
29 |
series: Series
|
30 |
+
expansions: list[Expansion]
|
31 |
|
32 |
@dataclass
|
33 |
class ExpansionOneResultBatch:
|
|
|
40 |
@dataclass
|
41 |
class ExpansionResult:
|
42 |
series: Series
|
43 |
+
expansions: list[list[Expansion]]
|
44 |
|
45 |
@dataclass
|
46 |
class ExpansionResultBatch:
|
|
|
49 |
def compute_new_series(result: ExpansionOneResult) -> list[Series]:
|
50 |
results = []
|
51 |
for expansion in result.expansions:
|
52 |
+
results.append(Series(
|
53 |
+
id=result.series.id,
|
54 |
+
tokens=result.series.tokens,
|
55 |
+
expansions=result.series.expansions + [expansion],
|
56 |
+
budget=result.series.budget
|
57 |
+
))
|
58 |
return results
|
59 |
|
60 |
def compute_expansions(original_series: list[Series], expanded_series: list[Series]) -> ExpansionResultBatch:
|
|
|
63 |
# group original series by id
|
64 |
original_series_by_id = {s.id: s for s in original_series}
|
65 |
# group expanded series by id
|
66 |
+
expanded_series_by_id: dict[int, list[list[Expansion]]] = defaultdict(list)
|
67 |
for s in expanded_series:
|
68 |
+
if len(s.expansions) != 0:
|
69 |
+
expanded_series_by_id[s.id].append(s.expansions)
|
70 |
results = []
|
71 |
for id, s in original_series_by_id.items():
|
72 |
expansions = expanded_series_by_id[id]
|
73 |
+
expansion_result = ExpansionResult(series=s, expansions=expansions)
|
|
|
|
|
|
|
74 |
results.append(expansion_result)
|
75 |
return ExpansionResultBatch(items=results)
|
76 |
|
expand_llm.py
CHANGED
@@ -15,6 +15,6 @@ class ExpanderOneBatchLLM:
|
|
15 |
next_tokens = find_next_tokens(self.model, inputs, self.tokenizer)
|
16 |
results = []
|
17 |
for s, next_tokens in zip(batch.items, next_tokens):
|
18 |
-
expansions = [
|
19 |
results.append(ExpansionOneResult(series=s, expansions=expansions))
|
20 |
return ExpansionOneResultBatch(items=results)
|
|
|
15 |
next_tokens = find_next_tokens(self.model, inputs, self.tokenizer)
|
16 |
results = []
|
17 |
for s, next_tokens in zip(batch.items, next_tokens):
|
18 |
+
expansions = [Expansion(token=token, cost=logprob) for token, logprob in next_tokens if logprob + s.get_remaining_budget() >= 0]
|
19 |
results.append(ExpansionOneResult(series=s, expansions=expansions))
|
20 |
return ExpansionOneResultBatch(items=results)
|
expand_test.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
from dataclasses import dataclass
|
2 |
-
from expand import Series, ExpanderOneBatch,
|
3 |
|
4 |
possible_sequences = [
|
5 |
[1, 21, 31, 41],
|
@@ -9,11 +9,12 @@ possible_sequences = [
|
|
9 |
[1, 22, 34, 41],
|
10 |
]
|
11 |
|
12 |
-
def expand_series(series: Series) -> list[
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
|
|
17 |
|
18 |
class HardcodedExpanderOneBatch(ExpanderOneBatch):
|
19 |
def expand(self, batch: Batch) -> ExpansionOneResultBatch:
|
@@ -38,8 +39,8 @@ def test_expander_budget_one():
|
|
38 |
expanded = expander.expand(Batch(items=[s]))
|
39 |
expected = ExpansionOneResultBatch(
|
40 |
items=[ExpansionOneResult(series=s, expansions=[
|
41 |
-
|
42 |
-
|
43 |
])]
|
44 |
)
|
45 |
assert expected == expanded
|
@@ -49,8 +50,8 @@ def test_expander_budget_two():
|
|
49 |
expanded = expander.expand(Batch(items=[s]))
|
50 |
expected = ExpansionOneResultBatch(
|
51 |
items=[ExpansionOneResult(series=s, expansions=[
|
52 |
-
|
53 |
-
|
54 |
])]
|
55 |
)
|
56 |
assert expected == expanded
|
@@ -68,8 +69,8 @@ def test_expander_budget_one_two_tokens():
|
|
68 |
expanded = expander.expand(Batch(items=[s]))
|
69 |
expected = ExpansionOneResultBatch(
|
70 |
items=[ExpansionOneResult(series=s, expansions=[
|
71 |
-
|
72 |
-
|
73 |
])]
|
74 |
)
|
75 |
assert expected == expanded
|
@@ -81,12 +82,12 @@ def test_expander_budget_one_two_tokens_two_series():
|
|
81 |
expected = ExpansionOneResultBatch(
|
82 |
items=[
|
83 |
ExpansionOneResult(series=s1, expansions=[
|
84 |
-
|
85 |
-
|
86 |
]),
|
87 |
ExpansionOneResult(series=s2, expansions=[
|
88 |
-
|
89 |
-
|
90 |
])
|
91 |
]
|
92 |
)
|
@@ -102,15 +103,15 @@ def test_expand_01():
|
|
102 |
ExpansionResult(
|
103 |
series=Series(id=0, tokens=[1, 21], budget=1.0),
|
104 |
expansions=[
|
105 |
-
[31],
|
106 |
-
[32],
|
107 |
]
|
108 |
),
|
109 |
ExpansionResult(
|
110 |
series=Series(id=1, tokens=[1, 22], budget=1.0),
|
111 |
expansions=[
|
112 |
-
[33],
|
113 |
-
[34],
|
114 |
]
|
115 |
),
|
116 |
])
|
@@ -125,16 +126,16 @@ def test_expand_02():
|
|
125 |
ExpansionResult(
|
126 |
series=Series(id=0, tokens=[1, 21], budget=2.0),
|
127 |
expansions=[
|
128 |
-
[31, 41],
|
129 |
-
[31, 42],
|
130 |
-
[32, 41],
|
131 |
]
|
132 |
),
|
133 |
ExpansionResult(
|
134 |
series=Series(id=1, tokens=[1, 22], budget=1.0),
|
135 |
expansions=[
|
136 |
-
[33],
|
137 |
-
[34],
|
138 |
]
|
139 |
),
|
140 |
])
|
@@ -149,9 +150,9 @@ def test_expand_03():
|
|
149 |
ExpansionResult(
|
150 |
series=Series(id=0, tokens=[1, 21], budget=3.0),
|
151 |
expansions=[
|
152 |
-
[31, 41],
|
153 |
-
[31, 42],
|
154 |
-
[32, 41, 51],
|
155 |
]
|
156 |
),
|
157 |
ExpansionResult(
|
|
|
1 |
from dataclasses import dataclass
|
2 |
+
from expand import Series, ExpanderOneBatch, Expansion, Batch, ExpansionOneResult, ExpansionOneResultBatch, ExpansionResult, ExpansionResultBatch, expand
|
3 |
|
4 |
possible_sequences = [
|
5 |
[1, 21, 31, 41],
|
|
|
9 |
[1, 22, 34, 41],
|
10 |
]
|
11 |
|
12 |
+
def expand_series(series: Series) -> list[Expansion]:
|
13 |
+
all_tokens = series.get_all_tokens()
|
14 |
+
l = len(all_tokens)
|
15 |
+
items = [s[l] for s in possible_sequences if s[:l] == all_tokens and len(s) > l]
|
16 |
+
candidates = [Expansion(token=l, cost=-1.0) for l in dict.fromkeys(items)]
|
17 |
+
return [c for c in candidates if c.cost + series.get_remaining_budget() >= 0]
|
18 |
|
19 |
class HardcodedExpanderOneBatch(ExpanderOneBatch):
|
20 |
def expand(self, batch: Batch) -> ExpansionOneResultBatch:
|
|
|
39 |
expanded = expander.expand(Batch(items=[s]))
|
40 |
expected = ExpansionOneResultBatch(
|
41 |
items=[ExpansionOneResult(series=s, expansions=[
|
42 |
+
Expansion(token=21, cost=-1.0),
|
43 |
+
Expansion(token=22, cost=-1.0),
|
44 |
])]
|
45 |
)
|
46 |
assert expected == expanded
|
|
|
50 |
expanded = expander.expand(Batch(items=[s]))
|
51 |
expected = ExpansionOneResultBatch(
|
52 |
items=[ExpansionOneResult(series=s, expansions=[
|
53 |
+
Expansion(token=21, cost=-1.0),
|
54 |
+
Expansion(token=22, cost=-1.0),
|
55 |
])]
|
56 |
)
|
57 |
assert expected == expanded
|
|
|
69 |
expanded = expander.expand(Batch(items=[s]))
|
70 |
expected = ExpansionOneResultBatch(
|
71 |
items=[ExpansionOneResult(series=s, expansions=[
|
72 |
+
Expansion(token=33, cost=-1.0),
|
73 |
+
Expansion(token=34, cost=-1.0),
|
74 |
])]
|
75 |
)
|
76 |
assert expected == expanded
|
|
|
82 |
expected = ExpansionOneResultBatch(
|
83 |
items=[
|
84 |
ExpansionOneResult(series=s1, expansions=[
|
85 |
+
Expansion(token=41, cost=-1.0),
|
86 |
+
Expansion(token=42, cost=-1.0),
|
87 |
]),
|
88 |
ExpansionOneResult(series=s2, expansions=[
|
89 |
+
Expansion(token=33, cost=-1.0),
|
90 |
+
Expansion(token=34, cost=-1.0),
|
91 |
])
|
92 |
]
|
93 |
)
|
|
|
103 |
ExpansionResult(
|
104 |
series=Series(id=0, tokens=[1, 21], budget=1.0),
|
105 |
expansions=[
|
106 |
+
[Expansion(token=31, cost=-1.0)],
|
107 |
+
[Expansion(token=32, cost=-1.0)],
|
108 |
]
|
109 |
),
|
110 |
ExpansionResult(
|
111 |
series=Series(id=1, tokens=[1, 22], budget=1.0),
|
112 |
expansions=[
|
113 |
+
[Expansion(token=33, cost=-1.0)],
|
114 |
+
[Expansion(token=34, cost=-1.0)],
|
115 |
]
|
116 |
),
|
117 |
])
|
|
|
126 |
ExpansionResult(
|
127 |
series=Series(id=0, tokens=[1, 21], budget=2.0),
|
128 |
expansions=[
|
129 |
+
[Expansion(token=31, cost=-1.0), Expansion(token=41, cost=-1.0)],
|
130 |
+
[Expansion(token=31, cost=-1.0), Expansion(token=42, cost=-1.0)],
|
131 |
+
[Expansion(token=32, cost=-1.0), Expansion(token=41, cost=-1.0)],
|
132 |
]
|
133 |
),
|
134 |
ExpansionResult(
|
135 |
series=Series(id=1, tokens=[1, 22], budget=1.0),
|
136 |
expansions=[
|
137 |
+
[Expansion(token=33, cost=-1.0)],
|
138 |
+
[Expansion(token=34, cost=-1.0)],
|
139 |
]
|
140 |
),
|
141 |
])
|
|
|
150 |
ExpansionResult(
|
151 |
series=Series(id=0, tokens=[1, 21], budget=3.0),
|
152 |
expansions=[
|
153 |
+
[Expansion(token=31, cost=-1.0), Expansion(token=41, cost=-1.0)],
|
154 |
+
[Expansion(token=31, cost=-1.0), Expansion(token=42, cost=-1.0)],
|
155 |
+
[Expansion(token=32, cost=-1.0), Expansion(token=41, cost=-1.0), Expansion(token=51, cost=-1.0)],
|
156 |
]
|
157 |
),
|
158 |
ExpansionResult(
|
run.py
CHANGED
@@ -29,7 +29,7 @@ expander = ExpanderOneBatchLLM(model, tokenizer)
|
|
29 |
#%%
|
30 |
series = []
|
31 |
for i, x in enumerate(contexts):
|
32 |
-
series.append(Series(id=i, tokens=x, budget=
|
33 |
|
34 |
#%%
|
35 |
batch = Batch(items=series)
|
@@ -42,7 +42,8 @@ def print_expansions(expansions: ExpansionResultBatch):
|
|
42 |
for result in expansions.items:
|
43 |
for expansion in result.expansions:
|
44 |
# convert tokens to string
|
45 |
-
|
|
|
46 |
print(f"{result.series.id}: {expansion} {s}")
|
47 |
|
48 |
print_expansions(expanded)
|
|
|
29 |
#%%
|
30 |
series = []
|
31 |
for i, x in enumerate(contexts):
|
32 |
+
series.append(Series(id=i, tokens=x, budget=7.0))
|
33 |
|
34 |
#%%
|
35 |
batch = Batch(items=series)
|
|
|
42 |
for result in expansions.items:
|
43 |
for expansion in result.expansions:
|
44 |
# convert tokens to string
|
45 |
+
tokens = [e.token for e in expansion]
|
46 |
+
s = tokenizer.decode(tokens)
|
47 |
print(f"{result.series.id}: {expansion} {s}")
|
48 |
|
49 |
print_expansions(expanded)
|