File size: 2,335 Bytes
bcc039b
 
 
 
fc3399e
bcc039b
fc3399e
 
 
 
bcc039b
 
 
fc3399e
bcc039b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# Copyright (c) Meta Platforms, Inc. and affiliates.
from typing import Any

import numpy as np
from pydantic import ConfigDict

from bytelatent.data.iterators.abstract_iterator import (
    PydanticIteratorState,
    StatefulIterator,
)
from bytelatent.data.iterators.sequence_iterator import SequenceIteratorState


class SamplingIteratorState(PydanticIteratorState):
    model_config = ConfigDict(extra="forbid")
    rng_state: dict[str, Any]
    source_to_weight: dict[str, float]
    source_to_iterator_state: dict[str, SequenceIteratorState]

    def build(self) -> "SamplingIterator":
        return SamplingIterator(
            rng_state=self.rng_state,
            source_to_weight=self.source_to_weight,
            source_to_iterator={
                source: state.build()
                for source, state in self.source_to_iterator_state.items()
            },
        )


class SamplingIterator(StatefulIterator):
    def __init__(
        self,
        *,
        rng_state: dict[str, Any],
        source_to_weight: dict[str, float],
        source_to_iterator: dict[str, StatefulIterator],
    ):
        self.rng = np.random.default_rng()
        self.rng.bit_generator.state = rng_state
        self.source_to_weight = source_to_weight
        self.source_to_iterator = source_to_iterator

    def get_state(self) -> SamplingIteratorState:
        return SamplingIteratorState(
            rng_state=self.rng.bit_generator.state,
            source_to_weight=self.source_to_weight,
            source_to_iterator_state={
                source: iterator.get_state()
                for source, iterator in self.source_to_iterator.items()
            },
        )

    def create_iter(self):
        n_sources = len(self.source_to_weight)
        possible_sources = []
        weights = []
        for source, w in self.source_to_weight.items():
            possible_sources.append(source)
            weights.append(w)

        source_to_python_iter = {
            source: self.source_to_iterator[source].create_iter()
            for source in possible_sources
        }
        while True:
            norm_weights = np.array(weights) / np.array(weights).sum()
            source_choice = possible_sources[self.rng.choice(n_sources, p=norm_weights)]
            yield next(source_to_python_iter[source_choice])