File size: 3,811 Bytes
2fb63bf 6f4449d 3d1d657 6f4449d 2fb63bf 6f4449d 2fb63bf 6f4449d 2fb63bf 6f4449d 2fb63bf 3d1d657 2fb63bf 3d1d657 2fb63bf 6f4449d 2fb63bf 6f4449d 2fb63bf 6f4449d 2fb63bf 3d1d657 2fb63bf 3d1d657 2fb63bf bb48904 2fb63bf 3d1d657 2fb63bf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Callable, Protocol, Self
@dataclass
class Expansion:
token: int
cost: float
@dataclass
class Series:
id: int
tokens: list[int]
budget: float
expansions: list[Expansion] = field(default_factory=list)
def get_all_tokens(self) -> list[int]:
return self.tokens + [e.token for e in self.expansions]
def get_remaining_budget(self) -> float:
return self.budget + sum(e.cost for e in self.expansions)
@dataclass
class Batch:
items: list[Series]
@dataclass
class ExpansionOneResult:
series: Series
expansions: list[Expansion]
@dataclass
class ExpansionOneResultBatch:
items: list[ExpansionOneResult]
# A fundamental operation that we can implement both using an LLM and using a list of hardcoded sequences, for testing
class ExpanderOneBatch(Protocol):
def expand(self, batch: Batch) -> ExpansionOneResultBatch: ...
@dataclass
class ExpansionResult:
series: Series
expansions: list[list[Expansion]]
@dataclass
class ExpansionResultBatch:
items: list[ExpansionResult]
def compute_new_series(result: ExpansionOneResult, stopping_criterion: Callable[[Series, Expansion], bool]) -> tuple[list[Series], list[Series]]:
new_series_batch = []
for expansion in result.expansions:
if not stopping_criterion(result.series, expansion):
new_series = Series(
id=result.series.id,
tokens=result.series.tokens,
expansions=result.series.expansions + [expansion],
budget=result.series.budget
)
new_series_batch.append(new_series)
completed_series = [result.series] if len(new_series_batch) == 0 else []
return new_series_batch, completed_series
def compute_expansions(original_series: list[Series], expanded_series: list[Series]) -> ExpansionResultBatch:
# check that ids in original_series are unique
assert len(original_series) == len({s.id for s in original_series})
# group original series by id
original_series_by_id = {s.id: s for s in original_series}
# group expanded series by id
expanded_series_by_id: dict[int, list[list[Expansion]]] = defaultdict(list)
for s in expanded_series:
if len(s.expansions) != 0:
expanded_series_by_id[s.id].append(s.expansions)
results = []
for id, s in original_series_by_id.items():
expansions = expanded_series_by_id[id]
expansion_result = ExpansionResult(series=s, expansions=expansions)
results.append(expansion_result)
return ExpansionResultBatch(items=results)
def default_completion_criterion(series: Series, expansion: Expansion) -> bool:
return series.get_remaining_budget() + expansion.cost < 0
# A compound operation that we can implement generically, relying on an ExpanderOneBatch
def expand(batch: Batch, expander: ExpanderOneBatch, completion_criterion: Callable[[Series, Expansion], bool] = default_completion_criterion) -> ExpansionResultBatch:
completed_series: list[Series] = []
current_batch = batch
while len(current_batch.items) > 0:
print(f"Expanding {len(current_batch.items)} series: {current_batch.items}")
current_batch_items = []
expanded = expander.expand(current_batch)
for item in expanded.items:
if len(item.expansions) == 0:
completed_series.append(item.series)
else:
new_series, completed = compute_new_series(item, completion_criterion)
completed_series.extend(completed)
current_batch_items.extend(new_series)
current_batch = Batch(items=current_batch_items)
return compute_expansions(batch.items, completed_series)
|