mebubo commited on
Commit
6f4449d
·
1 Parent(s): 51f63ae
Files changed (4) hide show
  1. expand.py +26 -16
  2. expand_llm.py +1 -1
  3. expand_test.py +29 -28
  4. run.py +3 -2
expand.py CHANGED
@@ -1,26 +1,33 @@
1
  from collections import defaultdict
2
- from dataclasses import dataclass
3
- from typing import Protocol
 
 
 
 
 
4
 
5
  @dataclass
6
  class Series:
7
  id: int
8
  tokens: list[int]
9
  budget: float
 
 
 
 
 
 
 
10
 
11
  @dataclass
12
  class Batch:
13
  items: list[Series]
14
 
15
- @dataclass
16
- class ExpansionOne:
17
- token: int
18
- cost: float
19
-
20
  @dataclass
21
  class ExpansionOneResult:
22
  series: Series
23
- expansions: list[ExpansionOne]
24
 
25
  @dataclass
26
  class ExpansionOneResultBatch:
@@ -33,7 +40,7 @@ class ExpanderOneBatch(Protocol):
33
  @dataclass
34
  class ExpansionResult:
35
  series: Series
36
- expansions: list[list[int]]
37
 
38
  @dataclass
39
  class ExpansionResultBatch:
@@ -42,7 +49,12 @@ class ExpansionResultBatch:
42
  def compute_new_series(result: ExpansionOneResult) -> list[Series]:
43
  results = []
44
  for expansion in result.expansions:
45
- results.append(Series(id=result.series.id, tokens=result.series.tokens + [expansion.token], budget=result.series.budget + expansion.cost))
 
 
 
 
 
46
  return results
47
 
48
  def compute_expansions(original_series: list[Series], expanded_series: list[Series]) -> ExpansionResultBatch:
@@ -51,16 +63,14 @@ def compute_expansions(original_series: list[Series], expanded_series: list[Seri
51
  # group original series by id
52
  original_series_by_id = {s.id: s for s in original_series}
53
  # group expanded series by id
54
- expanded_series_by_id: dict[int, list[list[int]]] = defaultdict(list)
55
  for s in expanded_series:
56
- expanded_series_by_id[s.id].append(s.tokens)
 
57
  results = []
58
  for id, s in original_series_by_id.items():
59
  expansions = expanded_series_by_id[id]
60
- # subtract the original series from each expansion
61
- l = len(s.tokens)
62
- trimmed_expansions = [e[l:] for e in expansions if len(e) > l]
63
- expansion_result = ExpansionResult(series=s, expansions=trimmed_expansions)
64
  results.append(expansion_result)
65
  return ExpansionResultBatch(items=results)
66
 
 
1
  from collections import defaultdict
2
+ from dataclasses import dataclass, field
3
+ from typing import Protocol, Self
4
+
5
+ @dataclass
6
+ class Expansion:
7
+ token: int
8
+ cost: float
9
 
10
  @dataclass
11
  class Series:
12
  id: int
13
  tokens: list[int]
14
  budget: float
15
+ expansions: list[Expansion] = field(default_factory=list)
16
+
17
+ def get_all_tokens(self) -> list[int]:
18
+ return self.tokens + [e.token for e in self.expansions]
19
+
20
+ def get_remaining_budget(self) -> float:
21
+ return self.budget + sum(e.cost for e in self.expansions)
22
 
23
  @dataclass
24
  class Batch:
25
  items: list[Series]
26
 
 
 
 
 
 
27
  @dataclass
28
  class ExpansionOneResult:
29
  series: Series
30
+ expansions: list[Expansion]
31
 
32
  @dataclass
33
  class ExpansionOneResultBatch:
 
40
  @dataclass
41
  class ExpansionResult:
42
  series: Series
43
+ expansions: list[list[Expansion]]
44
 
45
  @dataclass
46
  class ExpansionResultBatch:
 
49
  def compute_new_series(result: ExpansionOneResult) -> list[Series]:
50
  results = []
51
  for expansion in result.expansions:
52
+ results.append(Series(
53
+ id=result.series.id,
54
+ tokens=result.series.tokens,
55
+ expansions=result.series.expansions + [expansion],
56
+ budget=result.series.budget
57
+ ))
58
  return results
59
 
60
  def compute_expansions(original_series: list[Series], expanded_series: list[Series]) -> ExpansionResultBatch:
 
63
  # group original series by id
64
  original_series_by_id = {s.id: s for s in original_series}
65
  # group expanded series by id
66
+ expanded_series_by_id: dict[int, list[list[Expansion]]] = defaultdict(list)
67
  for s in expanded_series:
68
+ if len(s.expansions) != 0:
69
+ expanded_series_by_id[s.id].append(s.expansions)
70
  results = []
71
  for id, s in original_series_by_id.items():
72
  expansions = expanded_series_by_id[id]
73
+ expansion_result = ExpansionResult(series=s, expansions=expansions)
 
 
 
74
  results.append(expansion_result)
75
  return ExpansionResultBatch(items=results)
76
 
expand_llm.py CHANGED
@@ -15,6 +15,6 @@ class ExpanderOneBatchLLM:
15
  next_tokens = find_next_tokens(self.model, inputs, self.tokenizer)
16
  results = []
17
  for s, next_tokens in zip(batch.items, next_tokens):
18
- expansions = [ExpansionOne(token=token, cost=logprob) for token, logprob in next_tokens if logprob + s.budget >= 0]
19
  results.append(ExpansionOneResult(series=s, expansions=expansions))
20
  return ExpansionOneResultBatch(items=results)
 
15
  next_tokens = find_next_tokens(self.model, inputs, self.tokenizer)
16
  results = []
17
  for s, next_tokens in zip(batch.items, next_tokens):
18
+ expansions = [Expansion(token=token, cost=logprob) for token, logprob in next_tokens if logprob + s.get_remaining_budget() >= 0]
19
  results.append(ExpansionOneResult(series=s, expansions=expansions))
20
  return ExpansionOneResultBatch(items=results)
expand_test.py CHANGED
@@ -1,5 +1,5 @@
1
  from dataclasses import dataclass
2
- from expand import Series, ExpanderOneBatch, ExpansionOne, Batch, ExpansionOneResult, ExpansionOneResultBatch, ExpansionResult, ExpansionResultBatch, expand
3
 
4
  possible_sequences = [
5
  [1, 21, 31, 41],
@@ -9,11 +9,12 @@ possible_sequences = [
9
  [1, 22, 34, 41],
10
  ]
11
 
12
- def expand_series(series: Series) -> list[ExpansionOne]:
13
- l = len(series.tokens)
14
- items = [s[l] for s in possible_sequences if s[:l] == series.tokens and len(s) > l]
15
- candidates = [ExpansionOne(token=l, cost=-1.0) for l in dict.fromkeys(items)]
16
- return [c for c in candidates if c.cost + series.budget >= 0]
 
17
 
18
  class HardcodedExpanderOneBatch(ExpanderOneBatch):
19
  def expand(self, batch: Batch) -> ExpansionOneResultBatch:
@@ -38,8 +39,8 @@ def test_expander_budget_one():
38
  expanded = expander.expand(Batch(items=[s]))
39
  expected = ExpansionOneResultBatch(
40
  items=[ExpansionOneResult(series=s, expansions=[
41
- ExpansionOne(token=21, cost=-1.0),
42
- ExpansionOne(token=22, cost=-1.0),
43
  ])]
44
  )
45
  assert expected == expanded
@@ -49,8 +50,8 @@ def test_expander_budget_two():
49
  expanded = expander.expand(Batch(items=[s]))
50
  expected = ExpansionOneResultBatch(
51
  items=[ExpansionOneResult(series=s, expansions=[
52
- ExpansionOne(token=21, cost=-1.0),
53
- ExpansionOne(token=22, cost=-1.0),
54
  ])]
55
  )
56
  assert expected == expanded
@@ -68,8 +69,8 @@ def test_expander_budget_one_two_tokens():
68
  expanded = expander.expand(Batch(items=[s]))
69
  expected = ExpansionOneResultBatch(
70
  items=[ExpansionOneResult(series=s, expansions=[
71
- ExpansionOne(token=33, cost=-1.0),
72
- ExpansionOne(token=34, cost=-1.0),
73
  ])]
74
  )
75
  assert expected == expanded
@@ -81,12 +82,12 @@ def test_expander_budget_one_two_tokens_two_series():
81
  expected = ExpansionOneResultBatch(
82
  items=[
83
  ExpansionOneResult(series=s1, expansions=[
84
- ExpansionOne(token=41, cost=-1.0),
85
- ExpansionOne(token=42, cost=-1.0),
86
  ]),
87
  ExpansionOneResult(series=s2, expansions=[
88
- ExpansionOne(token=33, cost=-1.0),
89
- ExpansionOne(token=34, cost=-1.0),
90
  ])
91
  ]
92
  )
@@ -102,15 +103,15 @@ def test_expand_01():
102
  ExpansionResult(
103
  series=Series(id=0, tokens=[1, 21], budget=1.0),
104
  expansions=[
105
- [31],
106
- [32],
107
  ]
108
  ),
109
  ExpansionResult(
110
  series=Series(id=1, tokens=[1, 22], budget=1.0),
111
  expansions=[
112
- [33],
113
- [34],
114
  ]
115
  ),
116
  ])
@@ -125,16 +126,16 @@ def test_expand_02():
125
  ExpansionResult(
126
  series=Series(id=0, tokens=[1, 21], budget=2.0),
127
  expansions=[
128
- [31, 41],
129
- [31, 42],
130
- [32, 41],
131
  ]
132
  ),
133
  ExpansionResult(
134
  series=Series(id=1, tokens=[1, 22], budget=1.0),
135
  expansions=[
136
- [33],
137
- [34],
138
  ]
139
  ),
140
  ])
@@ -149,9 +150,9 @@ def test_expand_03():
149
  ExpansionResult(
150
  series=Series(id=0, tokens=[1, 21], budget=3.0),
151
  expansions=[
152
- [31, 41],
153
- [31, 42],
154
- [32, 41, 51],
155
  ]
156
  ),
157
  ExpansionResult(
 
1
  from dataclasses import dataclass
2
+ from expand import Series, ExpanderOneBatch, Expansion, Batch, ExpansionOneResult, ExpansionOneResultBatch, ExpansionResult, ExpansionResultBatch, expand
3
 
4
  possible_sequences = [
5
  [1, 21, 31, 41],
 
9
  [1, 22, 34, 41],
10
  ]
11
 
12
+ def expand_series(series: Series) -> list[Expansion]:
13
+ all_tokens = series.get_all_tokens()
14
+ l = len(all_tokens)
15
+ items = [s[l] for s in possible_sequences if s[:l] == all_tokens and len(s) > l]
16
+ candidates = [Expansion(token=l, cost=-1.0) for l in dict.fromkeys(items)]
17
+ return [c for c in candidates if c.cost + series.get_remaining_budget() >= 0]
18
 
19
  class HardcodedExpanderOneBatch(ExpanderOneBatch):
20
  def expand(self, batch: Batch) -> ExpansionOneResultBatch:
 
39
  expanded = expander.expand(Batch(items=[s]))
40
  expected = ExpansionOneResultBatch(
41
  items=[ExpansionOneResult(series=s, expansions=[
42
+ Expansion(token=21, cost=-1.0),
43
+ Expansion(token=22, cost=-1.0),
44
  ])]
45
  )
46
  assert expected == expanded
 
50
  expanded = expander.expand(Batch(items=[s]))
51
  expected = ExpansionOneResultBatch(
52
  items=[ExpansionOneResult(series=s, expansions=[
53
+ Expansion(token=21, cost=-1.0),
54
+ Expansion(token=22, cost=-1.0),
55
  ])]
56
  )
57
  assert expected == expanded
 
69
  expanded = expander.expand(Batch(items=[s]))
70
  expected = ExpansionOneResultBatch(
71
  items=[ExpansionOneResult(series=s, expansions=[
72
+ Expansion(token=33, cost=-1.0),
73
+ Expansion(token=34, cost=-1.0),
74
  ])]
75
  )
76
  assert expected == expanded
 
82
  expected = ExpansionOneResultBatch(
83
  items=[
84
  ExpansionOneResult(series=s1, expansions=[
85
+ Expansion(token=41, cost=-1.0),
86
+ Expansion(token=42, cost=-1.0),
87
  ]),
88
  ExpansionOneResult(series=s2, expansions=[
89
+ Expansion(token=33, cost=-1.0),
90
+ Expansion(token=34, cost=-1.0),
91
  ])
92
  ]
93
  )
 
103
  ExpansionResult(
104
  series=Series(id=0, tokens=[1, 21], budget=1.0),
105
  expansions=[
106
+ [Expansion(token=31, cost=-1.0)],
107
+ [Expansion(token=32, cost=-1.0)],
108
  ]
109
  ),
110
  ExpansionResult(
111
  series=Series(id=1, tokens=[1, 22], budget=1.0),
112
  expansions=[
113
+ [Expansion(token=33, cost=-1.0)],
114
+ [Expansion(token=34, cost=-1.0)],
115
  ]
116
  ),
117
  ])
 
126
  ExpansionResult(
127
  series=Series(id=0, tokens=[1, 21], budget=2.0),
128
  expansions=[
129
+ [Expansion(token=31, cost=-1.0), Expansion(token=41, cost=-1.0)],
130
+ [Expansion(token=31, cost=-1.0), Expansion(token=42, cost=-1.0)],
131
+ [Expansion(token=32, cost=-1.0), Expansion(token=41, cost=-1.0)],
132
  ]
133
  ),
134
  ExpansionResult(
135
  series=Series(id=1, tokens=[1, 22], budget=1.0),
136
  expansions=[
137
+ [Expansion(token=33, cost=-1.0)],
138
+ [Expansion(token=34, cost=-1.0)],
139
  ]
140
  ),
141
  ])
 
150
  ExpansionResult(
151
  series=Series(id=0, tokens=[1, 21], budget=3.0),
152
  expansions=[
153
+ [Expansion(token=31, cost=-1.0), Expansion(token=41, cost=-1.0)],
154
+ [Expansion(token=31, cost=-1.0), Expansion(token=42, cost=-1.0)],
155
+ [Expansion(token=32, cost=-1.0), Expansion(token=41, cost=-1.0), Expansion(token=51, cost=-1.0)],
156
  ]
157
  ),
158
  ExpansionResult(
run.py CHANGED
@@ -29,7 +29,7 @@ expander = ExpanderOneBatchLLM(model, tokenizer)
29
  #%%
30
  series = []
31
  for i, x in enumerate(contexts):
32
- series.append(Series(id=i, tokens=x, budget=5.0))
33
 
34
  #%%
35
  batch = Batch(items=series)
@@ -42,7 +42,8 @@ def print_expansions(expansions: ExpansionResultBatch):
42
  for result in expansions.items:
43
  for expansion in result.expansions:
44
  # convert tokens to string
45
- s = tokenizer.decode(expansion)
 
46
  print(f"{result.series.id}: {expansion} {s}")
47
 
48
  print_expansions(expanded)
 
29
  #%%
30
  series = []
31
  for i, x in enumerate(contexts):
32
+ series.append(Series(id=i, tokens=x, budget=7.0))
33
 
34
  #%%
35
  batch = Batch(items=series)
 
42
  for result in expansions.items:
43
  for expansion in result.expansions:
44
  # convert tokens to string
45
+ tokens = [e.token for e in expansion]
46
+ s = tokenizer.decode(tokens)
47
  print(f"{result.series.id}: {expansion} {s}")
48
 
49
  print_expansions(expanded)