Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
import json | |
import logging | |
import multiprocessing as mp | |
from enum import Enum | |
from multiprocessing.synchronize import Event as EventClass | |
from queue import Empty, Full | |
import numpy as np | |
from pydantic import ConfigDict | |
from bytelatent.data.data_types import Batch | |
from bytelatent.data.iterators.abstract_iterator import ( | |
IteratorState, | |
PydanticIteratorState, | |
StatefulIterator, | |
) | |
from bytelatent.data.iterators.packing_iterator import PackingIteratorState | |
logger = logging.getLogger() | |
class PersistType(str, Enum): | |
EXACT = "exact" | |
APPROXIMATE = "approximate" | |
class MultiprocessIteratorState(PydanticIteratorState): | |
model_config = ConfigDict(extra="forbid") | |
base_iterator_state: PackingIteratorState | |
n_batches_to_prefetch: int | |
serialized_prefetch_buffer: str | |
persist_type: PersistType | |
def build(self): | |
base_iterator = self.base_iterator_state.build() | |
data = json.loads(self.serialized_prefetch_buffer) | |
prefetch_buffer = [Batch.from_python_dict(item) for item in data] | |
return MultiprocessIterator( | |
base_iterator, | |
n_batches_to_prefetch=self.n_batches_to_prefetch, | |
prefetch_buffer=prefetch_buffer, | |
persist_type=self.persist_type, | |
) | |
def start_work_from_state( | |
batch_queue: mp.Queue, | |
state_queue: mp.Queue, | |
approximate_state_queue: mp.Queue, | |
stop_event: EventClass, | |
state_dumped_event: EventClass, | |
trigger_approximate_send_state_event: EventClass, | |
sent_approximate_state_event: EventClass, | |
received_approximate_state_event: EventClass, | |
state: IteratorState, | |
): | |
logging.info("Worker thread: Starting base_iterator work") | |
stateful_iterator = state.build() | |
iterator = stateful_iterator.create_iter() | |
for item in iterator: | |
while not stop_event.is_set(): | |
try: | |
if trigger_approximate_send_state_event.is_set(): | |
logger.info("WT: trigger_approximate_send ack") | |
# Since this can be triggered again (but only after the state is received on mp), | |
# we should cleanup as soon as possible. | |
trigger_approximate_send_state_event.clear() | |
logging.info("WT: Computing approximate state") | |
approximate_state = stateful_iterator.get_state() | |
# At this state, there should always be exactly 1 slot. | |
# Blocking here would be a bug. | |
logger.info("WT: Attempting to send approximate state") | |
approximate_state_queue.put( | |
approximate_state, block=True, timeout=None | |
) | |
sent_approximate_state_event.set() | |
logger.info("WT: Approximate state sent") | |
# Same here, clear events as we no longer need them. | |
received_approximate_state_event.wait() | |
received_approximate_state_event.clear() | |
logger.info("WT: State received by MT, resuming batch iteration") | |
# Attempt to put on queue or timeout to try again (maybe main thread is busy) | |
batch_queue.put(item, timeout=0.1) | |
# On success, stop trying | |
break | |
except Full: | |
pass | |
if stop_event.is_set(): | |
# Signal the end of output, this ensures that even if the queue takes a while to | |
# buffer, that the main thread receives everything (and tosses this fake batch) | |
logging.info( | |
"Worker thread: Stop event detected, outputting is_final=True batch" | |
) | |
logging.info("Worker thread: batch_queue full=%s", batch_queue.full()) | |
batch_queue.put( | |
Batch( | |
x=np.zeros((1, 1)), | |
y=np.zeros((1, 1)), | |
is_final=True, | |
mask=None, | |
patch_lengths=None, | |
ngram_ids=None, | |
) | |
) | |
logging.info( | |
"Worker thread: is_final=True batch put in queue, breaking from loop." | |
) | |
break | |
try: | |
logging.info("Worker thread: outputting state") | |
state_queue.put(stateful_iterator.get_state(), timeout=1) | |
logging.info("Worker thread: state dump complete") | |
state_dumped_event.set() | |
logging.info("Worker thread: set state_dump_event") | |
except Full: | |
raise ValueError( | |
"Attempted to dump state into the state queue, but it was full" | |
) | |
FETCH_STATE_TIMEOUT = 120 | |
class MultiprocessIterator(StatefulIterator): | |
""" | |
Design sketch of the multiprocess iterator: | |
Given the base_iterator, the only thing we do with this is call get_state() | |
so that we can pass that through to the background worker process. | |
The background process will receive this, rebuild the iterator, then start yielding from it. | |
However, in order to implement MultiprocessIterator.get_state(), we need to be able to accurately get | |
(1) the state of the iterator in the worker process | |
(2) the currently buffered items in the Queue | |
To do this, we use: | |
- batch_queue: This is the prefetch buffer the worker yields to and the main loop yields from | |
- state_queue: This size 1 queue will be how the worker sends the iterator state once it has halted iterating. | |
It must hold the state in addition to the last batch, if the queue was full at the time the stop event is sent. | |
- stop_iterating_event: Once this is issued from the main loop, the worker will stop iterating and enter cleanup. | |
During cleanup, the iterator will send the state of the current iterator to the main loop, | |
in addition to possibly the last batch if the batch_queue was full at the time | |
- state_dumped_event: When the main loop issues the stop_iterating_event, it will wait until the state_dumped_event to attempt | |
to get state from the state_queue. It must do this since the worker may take some time to create and send the state. | |
Once received by the main loop, the main loop can safely store the Queue (plus maybe the last batch) as the prefetch buffer, | |
get the worker iterator's state, and terminate the background process + delete associated objects. | |
At this point, calling create_iter() again will bootstrap everything from the stored state and the old iterator will throw an error | |
since it will not iterate anymore (so the caller must call create_iter() again to get a python iterator). | |
""" | |
def __init__( | |
self, | |
base_iterator: StatefulIterator, | |
*, | |
n_batches_to_prefetch: int, | |
prefetch_buffer: list | None = None, | |
persist_type: PersistType = PersistType.EXACT, | |
): | |
self.base_iterator = base_iterator | |
self.n_batches_to_prefetch = n_batches_to_prefetch | |
self.persist_type = persist_type | |
if prefetch_buffer is None: | |
prefetch_buffer = [] | |
self.prefetch_buffer = prefetch_buffer | |
self.batch_queue = None | |
self.state_queue = None | |
self.approximate_state_queue = None | |
self.producer = None | |
self.stop_iterating_event = None | |
self.state_dumped_event = None | |
self.trigger_approximate_send_state_event = None | |
self.sent_approximate_state_event = None | |
self.received_approximate_state_event = None | |
self.force_shutdown = False | |
def shutdown(self): | |
if self.producer is not None: | |
# This properly shuts things down | |
self.producer.kill() | |
self.force_shutdown = True | |
def _get_state_exact(self): | |
logging.info("Main thread: Sending stop iteration event") | |
self.stop_iterating_event.set() | |
logging.info( | |
"Main thread: Emptying the batch_queue until batch.is_final=True is found." | |
) | |
if self.prefetch_buffer is not None and len(self.prefetch_buffer) > 0: | |
buffer = self.prefetch_buffer | |
else: | |
buffer = [] | |
self.prefetch_buffer = buffer | |
final_batch_received = False | |
while True: | |
try: | |
batch = self.batch_queue.get(timeout=1) | |
if batch.is_final: | |
logging.info( | |
"Main thread: is_final=True batch found, stopping fetch from batch_queue" | |
) | |
final_batch_received = True | |
break | |
self.prefetch_buffer.append(batch) | |
except Empty: | |
logging.warning("Main thread: batch_queue is abnormally empty") | |
assert final_batch_received | |
logging.info("Main thread: Waiting for state_dumped event") | |
self.state_dumped_event.wait() | |
try: | |
logging.info( | |
"Main thread: state_dumped_event received, waiting for state from queue" | |
) | |
base_iterator_state = self.state_queue.get(timeout=FETCH_STATE_TIMEOUT) | |
logging.info("Main thread: received state from queue") | |
assert isinstance(base_iterator_state, IteratorState) | |
except Empty: | |
raise ValueError( | |
"Attempted to get the state, but it was unexpectantly missing" | |
) | |
self.base_iterator = base_iterator_state.build() | |
self.producer.close() | |
self.producer = None | |
self.batch_queue = None | |
self.state_queue = None | |
self.approximate_state_queue = None | |
self.stop_iterating_event = None | |
self.state_dumped_event = None | |
self.trigger_approximate_send_state_event = None | |
self.sent_approximate_state_event = None | |
self.received_approximate_state_event = None | |
return MultiprocessIteratorState( | |
base_iterator_state=self.base_iterator.get_state(), | |
n_batches_to_prefetch=self.n_batches_to_prefetch, | |
serialized_prefetch_buffer=json.dumps( | |
[b.to_python_dict() for b in self.prefetch_buffer] | |
), | |
persist_type=self.persist_type, | |
) | |
def _get_state_approximate(self): | |
logging.info("MT: Sending approximate get_state request") | |
self.trigger_approximate_send_state_event.set() | |
logging.info("MT: Waiting for sent_approximate_state_event") | |
self.sent_approximate_state_event.wait() | |
logging.info("MT: sent_approximate_state_event ack") | |
try: | |
logging.info("MT: waiting for approximate state in queue") | |
base_iterator_state = self.approximate_state_queue.get( | |
timeout=FETCH_STATE_TIMEOUT | |
) | |
logging.info("MT: approximate state received") | |
assert isinstance(base_iterator_state, IteratorState) | |
assert self.approximate_state_queue.empty() | |
except Empty: | |
raise ValueError( | |
"Attempted to get approximate state, but queue was erroniously empty." | |
) | |
self.received_approximate_state_event.set() | |
if self.prefetch_buffer is not None and len(self.prefetch_buffer) > 0: | |
buffer = [b.to_python_dict() for b in self.prefetch_buffer] | |
else: | |
buffer = [] | |
return MultiprocessIteratorState( | |
base_iterator_state=base_iterator_state, | |
n_batches_to_prefetch=self.n_batches_to_prefetch, | |
serialized_prefetch_buffer=json.dumps(buffer), | |
persist_type=self.persist_type, | |
) | |
def get_state(self) -> MultiprocessIteratorState: | |
""" | |
This is slightly unusual in effectively destroying the current iterator, its necessary | |
to halt the background process and allow it to write the state to the main loop | |
in order to not lose data | |
""" | |
if self.force_shutdown: | |
raise ValueError( | |
"State will be invalid if shutdown was forced before state persisted." | |
) | |
if self.producer is None: | |
if self.prefetch_buffer is not None and len(self.prefetch_buffer) > 0: | |
serialized_prefetch_buffer = json.dumps( | |
[b.to_python_dict() for b in self.prefetch_buffer] | |
) | |
else: | |
serialized_prefetch_buffer = json.dumps([]) | |
return MultiprocessIteratorState( | |
base_iterator_state=self.base_iterator.get_state(), | |
n_batches_to_prefetch=self.n_batches_to_prefetch, | |
serialized_prefetch_buffer=serialized_prefetch_buffer, | |
persist_type=self.persist_type, | |
) | |
else: | |
if self.persist_type == PersistType.EXACT: | |
return self._get_state_exact() | |
elif self.persist_type == PersistType.APPROXIMATE: | |
return self._get_state_approximate() | |
else: | |
raise ValueError("invalid persist_type") | |
def create_iter(self): | |
if self.force_shutdown: | |
raise ValueError( | |
"Iterator may be invalid if shutdown was forced before state persisted." | |
) | |
logging.info("Main thread: Creating MP iterator") | |
assert ( | |
self.producer is None | |
), "Cannot create two parallel iterators at once, call get_state() then remake to have two." | |
# using mp context manager avoids excessive CPU loading | |
ctx = mp.get_context("forkserver") | |
self.batch_queue = ctx.Manager().Queue(maxsize=self.n_batches_to_prefetch) | |
# We should only ever one state, which is output at the detection of a stop event | |
self.state_queue = ctx.Manager().Queue(maxsize=1) | |
# Similarly, there should only ever be one state in flight due to event signals | |
self.approximate_state_queue = ctx.Manager().Queue(maxsize=1) | |
self.stop_iterating_event = ctx.Event() | |
self.state_dumped_event = ctx.Event() | |
self.trigger_approximate_send_state_event = ctx.Event() | |
self.sent_approximate_state_event = ctx.Event() | |
self.received_approximate_state_event = ctx.Event() | |
self.producer = mp.Process( | |
name="blt_data_loader", | |
target=start_work_from_state, | |
args=( | |
self.batch_queue, | |
self.state_queue, | |
self.approximate_state_queue, | |
self.stop_iterating_event, | |
self.state_dumped_event, | |
self.trigger_approximate_send_state_event, | |
self.sent_approximate_state_event, | |
self.received_approximate_state_event, | |
self.base_iterator.get_state(), | |
), | |
) | |
logger.info("Async dataloader started") | |
self.producer.start() | |
# First yield from the stored prefetch buffer. | |
if self.prefetch_buffer is not None: | |
while len(self.prefetch_buffer) > 0: | |
item = self.prefetch_buffer.pop(0) | |
yield item | |
self.prefetch_buffer = None | |
while True: | |
if self.producer.exitcode is not None: | |
raise RuntimeError( | |
"Data loader quit unexpectedly, real error has been raised previously" | |
) | |
try: | |
batch = self.batch_queue.get(timeout=0.1) | |
assert isinstance(batch, Batch) | |
assert ( | |
not batch.is_final | |
), "is_final should only be used during get_state() being called" | |
yield batch | |
except Empty: | |
pass | |
if self.producer is None: | |
raise ValueError( | |
"Attempted to call this iterator after calling get_state(). You must call create_iter() to make a new iterator instead." | |
) | |