File size: 3,811 Bytes
2fb63bf
6f4449d
3d1d657
6f4449d
 
 
 
 
2fb63bf
 
 
 
 
 
6f4449d
 
 
 
 
 
 
2fb63bf
 
 
 
 
 
 
 
6f4449d
2fb63bf
 
 
 
 
 
 
 
 
 
 
 
6f4449d
2fb63bf
 
 
 
 
3d1d657
 
2fb63bf
3d1d657
 
 
 
 
 
 
 
 
 
2fb63bf
 
 
 
 
 
 
6f4449d
2fb63bf
6f4449d
 
2fb63bf
 
 
6f4449d
2fb63bf
 
 
3d1d657
 
 
2fb63bf
3d1d657
2fb63bf
 
 
bb48904
2fb63bf
 
 
 
 
 
3d1d657
 
 
2fb63bf
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Callable, Protocol, Self

@dataclass
class Expansion:
    token: int
    cost: float

@dataclass
class Series:
    id: int
    tokens: list[int]
    budget: float
    expansions: list[Expansion] = field(default_factory=list)

    def get_all_tokens(self) -> list[int]:
        return self.tokens + [e.token for e in self.expansions]

    def get_remaining_budget(self) -> float:
        return self.budget + sum(e.cost for e in self.expansions)

@dataclass
class Batch:
    items: list[Series]

@dataclass
class ExpansionOneResult:
    series: Series
    expansions: list[Expansion]

@dataclass
class ExpansionOneResultBatch:
    items: list[ExpansionOneResult]

# A fundamental operation that we can implement both using an LLM and using a list of hardcoded sequences, for testing
class ExpanderOneBatch(Protocol):
    def expand(self, batch: Batch) -> ExpansionOneResultBatch: ...

@dataclass
class ExpansionResult:
    series: Series
    expansions: list[list[Expansion]]

@dataclass
class ExpansionResultBatch:
    items: list[ExpansionResult]

def compute_new_series(result: ExpansionOneResult, stopping_criterion: Callable[[Series, Expansion], bool]) -> tuple[list[Series], list[Series]]:
    new_series_batch = []
    for expansion in result.expansions:
        if not stopping_criterion(result.series, expansion):
            new_series = Series(
                id=result.series.id,
                tokens=result.series.tokens,
                expansions=result.series.expansions + [expansion],
                budget=result.series.budget
            )
            new_series_batch.append(new_series)
    completed_series = [result.series] if len(new_series_batch) == 0 else []
    return new_series_batch, completed_series

def compute_expansions(original_series: list[Series], expanded_series: list[Series]) -> ExpansionResultBatch:
    # check that ids in original_series are unique
    assert len(original_series) == len({s.id for s in original_series})
    # group original series by id
    original_series_by_id = {s.id: s for s in original_series}
    # group expanded series by id
    expanded_series_by_id: dict[int, list[list[Expansion]]] = defaultdict(list)
    for s in expanded_series:
        if len(s.expansions) != 0:
            expanded_series_by_id[s.id].append(s.expansions)
    results = []
    for id, s in original_series_by_id.items():
        expansions = expanded_series_by_id[id]
        expansion_result = ExpansionResult(series=s, expansions=expansions)
        results.append(expansion_result)
    return ExpansionResultBatch(items=results)

def default_completion_criterion(series: Series, expansion: Expansion) -> bool:
    return series.get_remaining_budget() + expansion.cost < 0

# A compound operation that we can implement generically, relying on an ExpanderOneBatch
def expand(batch: Batch, expander: ExpanderOneBatch, completion_criterion: Callable[[Series, Expansion], bool] = default_completion_criterion) -> ExpansionResultBatch:
    completed_series: list[Series] = []
    current_batch = batch
    while len(current_batch.items) > 0:
        print(f"Expanding {len(current_batch.items)} series: {current_batch.items}")
        current_batch_items = []
        expanded = expander.expand(current_batch)
        for item in expanded.items:
            if len(item.expansions) == 0:
                completed_series.append(item.series)
            else:
                new_series, completed = compute_new_series(item, completion_criterion)
                completed_series.extend(completed)
                current_batch_items.extend(new_series)
        current_batch = Batch(items=current_batch_items)
    return compute_expansions(batch.items, completed_series)