File size: 3,123 Bytes
2fb63bf
6f4449d
 
 
 
 
 
 
2fb63bf
 
 
 
 
 
6f4449d
 
 
 
 
 
 
2fb63bf
 
 
 
 
 
 
 
6f4449d
2fb63bf
 
 
 
 
 
 
 
 
 
 
 
6f4449d
2fb63bf
 
 
 
 
 
 
 
6f4449d
 
 
 
 
 
2fb63bf
 
 
 
 
 
 
 
6f4449d
2fb63bf
6f4449d
 
2fb63bf
 
 
6f4449d
2fb63bf
 
 
 
 
 
 
 
bb48904
2fb63bf
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Protocol, Self

@dataclass
class Expansion:
    token: int
    cost: float

@dataclass
class Series:
    id: int
    tokens: list[int]
    budget: float
    expansions: list[Expansion] = field(default_factory=list)

    def get_all_tokens(self) -> list[int]:
        return self.tokens + [e.token for e in self.expansions]

    def get_remaining_budget(self) -> float:
        return self.budget + sum(e.cost for e in self.expansions)

@dataclass
class Batch:
    items: list[Series]

@dataclass
class ExpansionOneResult:
    series: Series
    expansions: list[Expansion]

@dataclass
class ExpansionOneResultBatch:
    items: list[ExpansionOneResult]

# A fundamental operation that we can implement both using an LLM and using a list of hardcoded sequences, for testing
class ExpanderOneBatch(Protocol):
    def expand(self, batch: Batch) -> ExpansionOneResultBatch: ...

@dataclass
class ExpansionResult:
    series: Series
    expansions: list[list[Expansion]]

@dataclass
class ExpansionResultBatch:
    items: list[ExpansionResult]

def compute_new_series(result: ExpansionOneResult) -> list[Series]:
    results = []
    for expansion in result.expansions:
        results.append(Series(
            id=result.series.id,
            tokens=result.series.tokens,
            expansions=result.series.expansions + [expansion],
            budget=result.series.budget
        ))
    return results

def compute_expansions(original_series: list[Series], expanded_series: list[Series]) -> ExpansionResultBatch:
    # check that ids in original_series are unique
    assert len(original_series) == len({s.id for s in original_series})
    # group original series by id
    original_series_by_id = {s.id: s for s in original_series}
    # group expanded series by id
    expanded_series_by_id: dict[int, list[list[Expansion]]] = defaultdict(list)
    for s in expanded_series:
        if len(s.expansions) != 0:
            expanded_series_by_id[s.id].append(s.expansions)
    results = []
    for id, s in original_series_by_id.items():
        expansions = expanded_series_by_id[id]
        expansion_result = ExpansionResult(series=s, expansions=expansions)
        results.append(expansion_result)
    return ExpansionResultBatch(items=results)

# A compound operation that we can implement generically, relying on an ExpanderOneBatch
def expand(batch: Batch, expander: ExpanderOneBatch) -> ExpansionResultBatch:
    completed_series: list[Series] = []
    current_batch = batch
    while len(current_batch.items) > 0:
        print(f"Expanding {len(current_batch.items)} series: {current_batch.items}")
        current_batch_items = []
        expanded = expander.expand(current_batch)
        for item in expanded.items:
            if len(item.expansions) == 0:
                completed_series.append(item.series)
            else:
                current_batch_items.extend(compute_new_series(item))
        current_batch = Batch(items=current_batch_items)
    return compute_expansions(batch.items, completed_series)