mebubo commited on
Commit
4537742
·
1 Parent(s): a4b9140

Better names for Expander etc

Browse files
Files changed (6) hide show
  1. README.md +1 -1
  2. completions.py +1 -2
  3. expand.py +14 -14
  4. expand_llm.py +4 -4
  5. expand_test.py +28 -28
  6. run.py +2 -2
README.md CHANGED
@@ -173,7 +173,7 @@ In my case, I stop when the budget is exhausted, and I also stop if the expansio
173
 
174
  Given the batch and the stopping criterion, we can call the expander:
175
  ```python
176
- expander = ExpanderOneBatchLLM(model, tokenizer)
177
  expanded = expand(batch, expander, stopping_criterion)
178
  ```
179
 
 
173
 
174
  Given the batch and the stopping criterion, we can call the expander:
175
  ```python
176
+ expander = LLMBatchExpander(model, tokenizer)
177
  expanded = expand(batch, expander, stopping_criterion)
178
  ```
179
 
completions.py CHANGED
@@ -92,8 +92,7 @@ def check_text(input_text: str, model: PreTrainedModel, tokenizer: Tokenizer, de
92
 
93
  contexts = [word.context for _, word in low_prob_words]
94
 
95
-
96
- expander = ExpanderOneBatchLLM(model, tokenizer)
97
 
98
  #%%
99
  series = []
 
92
 
93
  contexts = [word.context for _, word in low_prob_words]
94
 
95
+ expander = LLMBatchExpander(model, tokenizer)
 
96
 
97
  #%%
98
  series = []
expand.py CHANGED
@@ -25,28 +25,28 @@ class Batch:
25
  items: list[Series]
26
 
27
  @dataclass
28
- class ExpansionOneResult:
29
  series: Series
30
  expansions: list[Expansion]
31
 
32
  @dataclass
33
- class ExpansionOneResultBatch:
34
- items: list[ExpansionOneResult]
35
 
36
  # A fundamental operation that we can implement both using an LLM and using a list of hardcoded sequences, for testing
37
- class ExpanderOneBatch(Protocol):
38
- def expand(self, batch: Batch) -> ExpansionOneResultBatch: ...
39
 
40
  @dataclass
41
- class ExpansionResult:
42
  series: Series
43
  expansions: list[list[Expansion]]
44
 
45
  @dataclass
46
- class ExpansionResultBatch:
47
- items: list[ExpansionResult]
48
 
49
- def compute_new_series(result: ExpansionOneResult, stopping_criterion: Callable[[Series, Expansion], bool]) -> tuple[list[Series], list[Series]]:
50
  new_series_batch = []
51
  for expansion in result.expansions:
52
  if not stopping_criterion(result.series, expansion):
@@ -60,7 +60,7 @@ def compute_new_series(result: ExpansionOneResult, stopping_criterion: Callable[
60
  completed_series = [result.series] if len(new_series_batch) == 0 else []
61
  return new_series_batch, completed_series
62
 
63
- def compute_expansions(original_series: list[Series], expanded_series: list[Series]) -> ExpansionResultBatch:
64
  # check that ids in original_series are unique
65
  assert len(original_series) == len({s.id for s in original_series})
66
  # group original series by id
@@ -73,15 +73,15 @@ def compute_expansions(original_series: list[Series], expanded_series: list[Seri
73
  results = []
74
  for id, s in original_series_by_id.items():
75
  expansions = expanded_series_by_id[id]
76
- expansion_result = ExpansionResult(series=s, expansions=expansions)
77
  results.append(expansion_result)
78
- return ExpansionResultBatch(items=results)
79
 
80
  def default_completion_criterion(series: Series, expansion: Expansion) -> bool:
81
  return series.get_remaining_budget() + expansion.cost < 0
82
 
83
- # A compound operation that we can implement generically, relying on an ExpanderOneBatch
84
- def expand(batch: Batch, expander: ExpanderOneBatch, completion_criterion: Callable[[Series, Expansion], bool] = default_completion_criterion) -> ExpansionResultBatch:
85
  completed_series: list[Series] = []
86
  current_batch = batch
87
  while len(current_batch.items) > 0:
 
25
  items: list[Series]
26
 
27
  @dataclass
28
+ class TokenCandidates:
29
  series: Series
30
  expansions: list[Expansion]
31
 
32
  @dataclass
33
+ class BatchCandidates:
34
+ items: list[TokenCandidates]
35
 
36
  # A fundamental operation that we can implement both using an LLM and using a list of hardcoded sequences, for testing
37
+ class BatchExpander(Protocol):
38
+ def expand(self, batch: Batch) -> BatchCandidates: ...
39
 
40
  @dataclass
41
+ class CompletedSequence:
42
  series: Series
43
  expansions: list[list[Expansion]]
44
 
45
  @dataclass
46
+ class CompletedBatch:
47
+ items: list[CompletedSequence]
48
 
49
+ def compute_new_series(result: TokenCandidates, stopping_criterion: Callable[[Series, Expansion], bool]) -> tuple[list[Series], list[Series]]:
50
  new_series_batch = []
51
  for expansion in result.expansions:
52
  if not stopping_criterion(result.series, expansion):
 
60
  completed_series = [result.series] if len(new_series_batch) == 0 else []
61
  return new_series_batch, completed_series
62
 
63
+ def compute_expansions(original_series: list[Series], expanded_series: list[Series]) -> CompletedBatch:
64
  # check that ids in original_series are unique
65
  assert len(original_series) == len({s.id for s in original_series})
66
  # group original series by id
 
73
  results = []
74
  for id, s in original_series_by_id.items():
75
  expansions = expanded_series_by_id[id]
76
+ expansion_result = CompletedSequence(series=s, expansions=expansions)
77
  results.append(expansion_result)
78
+ return CompletedBatch(items=results)
79
 
80
  def default_completion_criterion(series: Series, expansion: Expansion) -> bool:
81
  return series.get_remaining_budget() + expansion.cost < 0
82
 
83
+ # A compound operation that we can implement generically, relying on a BatchExpander
84
+ def expand(batch: Batch, expander: BatchExpander, completion_criterion: Callable[[Series, Expansion], bool] = default_completion_criterion) -> CompletedBatch:
85
  completed_series: list[Series] = []
86
  current_batch = batch
87
  while len(current_batch.items) > 0:
expand_llm.py CHANGED
@@ -22,18 +22,18 @@ def prepare_inputs(contexts: list[list[int]], tokenizer: Tokenizer, device: torc
22
  return tokenizer(texts, return_tensors="pt", padding=True).to(device)
23
 
24
  @dataclass
25
- class ExpanderOneBatchLLM:
26
  model: PreTrainedModel
27
  tokenizer: Tokenizer
28
 
29
- def expand(self, batch: Batch) -> ExpansionOneResultBatch:
30
  inputs = prepare_inputs([s.get_all_tokens() for s in batch.items], self.tokenizer, self.model.device)
31
  next_tokens = find_next_tokens(self.model, inputs, self.tokenizer)
32
  results = []
33
  for s, next_tokens in zip(batch.items, next_tokens):
34
  expansions = [Expansion(token=token, cost=cost) for token, cost in next_tokens]
35
- results.append(ExpansionOneResult(series=s, expansions=expansions))
36
- return ExpansionOneResultBatch(items=results)
37
 
38
  def create_stopping_criterion_llm(tokenizer: Tokenizer) -> Callable[[Series, Expansion], bool]:
39
  def stopping_criterion(series: Series, expansion: Expansion) -> bool:
 
22
  return tokenizer(texts, return_tensors="pt", padding=True).to(device)
23
 
24
  @dataclass
25
+ class LLMBatchExpander(BatchExpander):
26
  model: PreTrainedModel
27
  tokenizer: Tokenizer
28
 
29
+ def expand(self, batch: Batch) -> BatchCandidates:
30
  inputs = prepare_inputs([s.get_all_tokens() for s in batch.items], self.tokenizer, self.model.device)
31
  next_tokens = find_next_tokens(self.model, inputs, self.tokenizer)
32
  results = []
33
  for s, next_tokens in zip(batch.items, next_tokens):
34
  expansions = [Expansion(token=token, cost=cost) for token, cost in next_tokens]
35
+ results.append(TokenCandidates(series=s, expansions=expansions))
36
+ return BatchCandidates(items=results)
37
 
38
  def create_stopping_criterion_llm(tokenizer: Tokenizer) -> Callable[[Series, Expansion], bool]:
39
  def stopping_criterion(series: Series, expansion: Expansion) -> bool:
expand_test.py CHANGED
@@ -1,5 +1,5 @@
1
  from dataclasses import dataclass
2
- from expand import Series, ExpanderOneBatch, Expansion, Batch, ExpansionOneResult, ExpansionOneResultBatch, ExpansionResult, ExpansionResultBatch, expand
3
 
4
  possible_sequences = [
5
  [1, 21, 31, 41],
@@ -16,21 +16,21 @@ def expand_series(series: Series) -> list[Expansion]:
16
  candidates = [Expansion(token=l, cost=-1.0) for l in dict.fromkeys(items)]
17
  return candidates
18
 
19
- class HardcodedExpanderOneBatch(ExpanderOneBatch):
20
- def expand(self, batch: Batch) -> ExpansionOneResultBatch:
21
  result = []
22
  for s in batch.items:
23
  expansions = expand_series(s)
24
- result.append(ExpansionOneResult(series=s, expansions=expansions))
25
- return ExpansionOneResultBatch(items=result)
26
 
27
- expander = HardcodedExpanderOneBatch()
28
 
29
  def test_expander_zero_budget():
30
  s = Series(id=0, tokens=[1], budget=0.0)
31
  expanded = expander.expand(Batch(items=[s]))
32
- expected = ExpansionOneResultBatch(
33
- items=[ExpansionOneResult(series=s, expansions=[
34
  Expansion(token=21, cost=-1.0),
35
  Expansion(token=22, cost=-1.0),
36
  ])]
@@ -40,8 +40,8 @@ def test_expander_zero_budget():
40
  def test_expander_budget_one():
41
  s = Series(id=0, tokens=[1], budget=1.0)
42
  expanded = expander.expand(Batch(items=[s]))
43
- expected = ExpansionOneResultBatch(
44
- items=[ExpansionOneResult(series=s, expansions=[
45
  Expansion(token=21, cost=-1.0),
46
  Expansion(token=22, cost=-1.0),
47
  ])]
@@ -51,8 +51,8 @@ def test_expander_budget_one():
51
  def test_expander_budget_two():
52
  s = Series(id=0, tokens=[1], budget=2.0)
53
  expanded = expander.expand(Batch(items=[s]))
54
- expected = ExpansionOneResultBatch(
55
- items=[ExpansionOneResult(series=s, expansions=[
56
  Expansion(token=21, cost=-1.0),
57
  Expansion(token=22, cost=-1.0),
58
  ])]
@@ -62,16 +62,16 @@ def test_expander_budget_two():
62
  def test_expander_budget_one_no_expansion():
63
  s = Series(id=0, tokens=[1, 20], budget=1.0)
64
  expanded = expander.expand(Batch(items=[s]))
65
- expected = ExpansionOneResultBatch(
66
- items=[ExpansionOneResult(series=s, expansions=[])]
67
  )
68
  assert expected == expanded
69
 
70
  def test_expander_budget_one_two_tokens():
71
  s = Series(id=0, tokens=[1, 22], budget=1.0)
72
  expanded = expander.expand(Batch(items=[s]))
73
- expected = ExpansionOneResultBatch(
74
- items=[ExpansionOneResult(series=s, expansions=[
75
  Expansion(token=33, cost=-1.0),
76
  Expansion(token=34, cost=-1.0),
77
  ])]
@@ -82,13 +82,13 @@ def test_expander_budget_one_two_tokens_two_series():
82
  s1 = Series(id=0, tokens=[1, 21, 31], budget=1.0)
83
  s2 = Series(id=1, tokens=[1, 22], budget=1.0)
84
  expanded = expander.expand(Batch(items=[s1, s2]))
85
- expected = ExpansionOneResultBatch(
86
  items=[
87
- ExpansionOneResult(series=s1, expansions=[
88
  Expansion(token=41, cost=-1.0),
89
  Expansion(token=42, cost=-1.0),
90
  ]),
91
- ExpansionOneResult(series=s2, expansions=[
92
  Expansion(token=33, cost=-1.0),
93
  Expansion(token=34, cost=-1.0),
94
  ])
@@ -102,15 +102,15 @@ def test_expand_01():
102
  Series(id=1, tokens=[1, 22], budget=1.0),
103
  ])
104
  expanded = expand(batch, expander)
105
- assert expanded == ExpansionResultBatch(items=[
106
- ExpansionResult(
107
  series=Series(id=0, tokens=[1, 21], budget=1.0),
108
  expansions=[
109
  [Expansion(token=31, cost=-1.0)],
110
  [Expansion(token=32, cost=-1.0)],
111
  ]
112
  ),
113
- ExpansionResult(
114
  series=Series(id=1, tokens=[1, 22], budget=1.0),
115
  expansions=[
116
  [Expansion(token=33, cost=-1.0)],
@@ -125,8 +125,8 @@ def test_expand_02():
125
  Series(id=1, tokens=[1, 22], budget=1.0),
126
  ])
127
  expanded = expand(batch, expander)
128
- assert expanded == ExpansionResultBatch(items=[
129
- ExpansionResult(
130
  series=Series(id=0, tokens=[1, 21], budget=2.0),
131
  expansions=[
132
  [Expansion(token=31, cost=-1.0), Expansion(token=41, cost=-1.0)],
@@ -134,7 +134,7 @@ def test_expand_02():
134
  [Expansion(token=32, cost=-1.0), Expansion(token=41, cost=-1.0)],
135
  ]
136
  ),
137
- ExpansionResult(
138
  series=Series(id=1, tokens=[1, 22], budget=1.0),
139
  expansions=[
140
  [Expansion(token=33, cost=-1.0)],
@@ -149,8 +149,8 @@ def test_expand_03():
149
  Series(id=1, tokens=[1, 22], budget=0.0),
150
  ])
151
  expanded = expand(batch, expander)
152
- assert expanded == ExpansionResultBatch(items=[
153
- ExpansionResult(
154
  series=Series(id=0, tokens=[1, 21], budget=3.0),
155
  expansions=[
156
  [Expansion(token=31, cost=-1.0), Expansion(token=41, cost=-1.0)],
@@ -158,7 +158,7 @@ def test_expand_03():
158
  [Expansion(token=32, cost=-1.0), Expansion(token=41, cost=-1.0), Expansion(token=51, cost=-1.0)],
159
  ]
160
  ),
161
- ExpansionResult(
162
  series=Series(id=1, tokens=[1, 22], budget=0.0),
163
  expansions=[],
164
  ),
 
1
  from dataclasses import dataclass
2
+ from expand import Series, BatchExpander, Expansion, Batch, TokenCandidates, BatchCandidates, CompletedSequence, CompletedBatch, expand
3
 
4
  possible_sequences = [
5
  [1, 21, 31, 41],
 
16
  candidates = [Expansion(token=l, cost=-1.0) for l in dict.fromkeys(items)]
17
  return candidates
18
 
19
+ class PredefinedSequenceExpander(BatchExpander):
20
+ def expand(self, batch: Batch) -> BatchCandidates:
21
  result = []
22
  for s in batch.items:
23
  expansions = expand_series(s)
24
+ result.append(TokenCandidates(series=s, expansions=expansions))
25
+ return BatchCandidates(items=result)
26
 
27
+ expander = PredefinedSequenceExpander()
28
 
29
  def test_expander_zero_budget():
30
  s = Series(id=0, tokens=[1], budget=0.0)
31
  expanded = expander.expand(Batch(items=[s]))
32
+ expected = BatchCandidates(
33
+ items=[TokenCandidates(series=s, expansions=[
34
  Expansion(token=21, cost=-1.0),
35
  Expansion(token=22, cost=-1.0),
36
  ])]
 
40
  def test_expander_budget_one():
41
  s = Series(id=0, tokens=[1], budget=1.0)
42
  expanded = expander.expand(Batch(items=[s]))
43
+ expected = BatchCandidates(
44
+ items=[TokenCandidates(series=s, expansions=[
45
  Expansion(token=21, cost=-1.0),
46
  Expansion(token=22, cost=-1.0),
47
  ])]
 
51
  def test_expander_budget_two():
52
  s = Series(id=0, tokens=[1], budget=2.0)
53
  expanded = expander.expand(Batch(items=[s]))
54
+ expected = BatchCandidates(
55
+ items=[TokenCandidates(series=s, expansions=[
56
  Expansion(token=21, cost=-1.0),
57
  Expansion(token=22, cost=-1.0),
58
  ])]
 
62
  def test_expander_budget_one_no_expansion():
63
  s = Series(id=0, tokens=[1, 20], budget=1.0)
64
  expanded = expander.expand(Batch(items=[s]))
65
+ expected = BatchCandidates(
66
+ items=[TokenCandidates(series=s, expansions=[])]
67
  )
68
  assert expected == expanded
69
 
70
  def test_expander_budget_one_two_tokens():
71
  s = Series(id=0, tokens=[1, 22], budget=1.0)
72
  expanded = expander.expand(Batch(items=[s]))
73
+ expected = BatchCandidates(
74
+ items=[TokenCandidates(series=s, expansions=[
75
  Expansion(token=33, cost=-1.0),
76
  Expansion(token=34, cost=-1.0),
77
  ])]
 
82
  s1 = Series(id=0, tokens=[1, 21, 31], budget=1.0)
83
  s2 = Series(id=1, tokens=[1, 22], budget=1.0)
84
  expanded = expander.expand(Batch(items=[s1, s2]))
85
+ expected = BatchCandidates(
86
  items=[
87
+ TokenCandidates(series=s1, expansions=[
88
  Expansion(token=41, cost=-1.0),
89
  Expansion(token=42, cost=-1.0),
90
  ]),
91
+ TokenCandidates(series=s2, expansions=[
92
  Expansion(token=33, cost=-1.0),
93
  Expansion(token=34, cost=-1.0),
94
  ])
 
102
  Series(id=1, tokens=[1, 22], budget=1.0),
103
  ])
104
  expanded = expand(batch, expander)
105
+ assert expanded == CompletedBatch(items=[
106
+ CompletedSequence(
107
  series=Series(id=0, tokens=[1, 21], budget=1.0),
108
  expansions=[
109
  [Expansion(token=31, cost=-1.0)],
110
  [Expansion(token=32, cost=-1.0)],
111
  ]
112
  ),
113
+ CompletedSequence(
114
  series=Series(id=1, tokens=[1, 22], budget=1.0),
115
  expansions=[
116
  [Expansion(token=33, cost=-1.0)],
 
125
  Series(id=1, tokens=[1, 22], budget=1.0),
126
  ])
127
  expanded = expand(batch, expander)
128
+ assert expanded == CompletedBatch(items=[
129
+ CompletedSequence(
130
  series=Series(id=0, tokens=[1, 21], budget=2.0),
131
  expansions=[
132
  [Expansion(token=31, cost=-1.0), Expansion(token=41, cost=-1.0)],
 
134
  [Expansion(token=32, cost=-1.0), Expansion(token=41, cost=-1.0)],
135
  ]
136
  ),
137
+ CompletedSequence(
138
  series=Series(id=1, tokens=[1, 22], budget=1.0),
139
  expansions=[
140
  [Expansion(token=33, cost=-1.0)],
 
149
  Series(id=1, tokens=[1, 22], budget=0.0),
150
  ])
151
  expanded = expand(batch, expander)
152
+ assert expanded == CompletedBatch(items=[
153
+ CompletedSequence(
154
  series=Series(id=0, tokens=[1, 21], budget=3.0),
155
  expansions=[
156
  [Expansion(token=31, cost=-1.0), Expansion(token=41, cost=-1.0)],
 
158
  [Expansion(token=32, cost=-1.0), Expansion(token=41, cost=-1.0), Expansion(token=51, cost=-1.0)],
159
  ]
160
  ),
161
+ CompletedSequence(
162
  series=Series(id=1, tokens=[1, 22], budget=0.0),
163
  expansions=[],
164
  ),
run.py CHANGED
@@ -24,7 +24,7 @@ low_prob_words = [(i, word) for i, word in enumerate(words) if word.logprob < lo
24
  contexts = [word.context for _, word in low_prob_words]
25
 
26
  #%%
27
- expander = ExpanderOneBatchLLM(model, tokenizer)
28
 
29
  #%%
30
  series = []
@@ -41,7 +41,7 @@ stopping_criterion = create_stopping_criterion_llm(tokenizer)
41
  expanded = expand(batch, expander, stopping_criterion)
42
 
43
  # %%
44
- def print_expansions(expansions: ExpansionResultBatch):
45
  for result in expansions.items:
46
  for expansion in result.expansions:
47
  # convert tokens to string
 
24
  contexts = [word.context for _, word in low_prob_words]
25
 
26
  #%%
27
+ expander = LLMBatchExpander(model, tokenizer)
28
 
29
  #%%
30
  series = []
 
41
  expanded = expand(batch, expander, stopping_criterion)
42
 
43
  # %%
44
+ def print_expansions(expansions: CompletedBatch):
45
  for result in expansions.items:
46
  for expansion in result.expansions:
47
  # convert tokens to string