Spaces:
Running
Running
import json | |
import pyarrow | |
import typer | |
from rich.progress import track | |
from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIteratorState | |
from bytelatent.logger import init_logger | |
def main( | |
state_file: str, | |
steps: int = 3_000, | |
io_thread_count: int = 2, | |
cpu_count: int = 2, | |
log_freq: int = 100, | |
): | |
init_logger() | |
pyarrow.set_io_thread_count(io_thread_count) | |
pyarrow.set_cpu_count(cpu_count) | |
with open(state_file) as f: | |
train_state = json.load(f) | |
dl_state = MultiprocessIteratorState(**train_state["data_loader_state"]) | |
packing_iterator_state = dl_state.base_iterator_state | |
print("building") | |
packing_iterator = packing_iterator_state.build() | |
print("iter") | |
batch_iter = packing_iterator.create_iter() | |
print("looping") | |
for i in track(range(steps)): | |
_ = next(batch_iter) | |
if i % log_freq == 0: | |
print(pyarrow.default_memory_pool()) | |
print(i) | |
print(pyarrow.default_memory_pool()) | |
if __name__ == "__main__": | |
typer.run(main) | |