File size: 994 Bytes
bcc039b
 
 
 
fc3399e
 
bcc039b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c61ab5
 
fc3399e
 
 
 
8c61ab5
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# Copyright (c) Meta Platforms, Inc. and affiliates.
import abc
from typing import Any, Generator, Generic, TypeVar

import pydantic

T = TypeVar("T")
C = TypeVar("C")


class StatefulIterator(Generic[T, C], abc.ABC):

    @abc.abstractmethod
    def get_state(self) -> C:
        pass

    @abc.abstractmethod
    def create_iter(self) -> Generator[T, Any, None]:
        pass


class IteratorState(Generic[C]):
    @abc.abstractmethod
    def build(self) -> StatefulIterator[T, C]:
        pass


class PydanticIteratorState(pydantic.BaseModel, IteratorState):
    model_config = pydantic.ConfigDict(extra="forbid")


def get_state_and_refresh(iterator: StatefulIterator):
    # Re-init dataloader and iterator is necessary since get_state()
    # on mp iterator shuts down MP to correctly persist state and it needs
    # to be restarted.
    state = iterator.get_state()
    data_loader = state.build()
    py_iterator = data_loader.create_iter()
    return state, data_loader, py_iterator