File size: 1,107 Bytes
63913e4
 
 
 
 
 
 
 
 
 
f84ee63
 
 
 
 
 
 
63913e4
f84ee63
 
63913e4
 
 
 
 
 
 
 
 
f84ee63
 
 
 
 
 
63913e4
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import json

import pyarrow
import typer
from rich.progress import track

from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIteratorState
from bytelatent.logger import init_logger


def main(
    state_file: str,
    steps: int = 3_000,
    io_thread_count: int = 2,
    cpu_count: int = 2,
    log_freq: int = 100,
):
    init_logger()
    pyarrow.set_io_thread_count(io_thread_count)
    pyarrow.set_cpu_count(cpu_count)
    with open(state_file) as f:
        train_state = json.load(f)
        dl_state = MultiprocessIteratorState(**train_state["data_loader_state"])
        packing_iterator_state = dl_state.base_iterator_state
        print("building")
        packing_iterator = packing_iterator_state.build()
        print("iter")
        batch_iter = packing_iterator.create_iter()
        print("looping")
        for i in track(range(steps)):
            _ = next(batch_iter)
            if i % log_freq == 0:
                print(pyarrow.default_memory_pool())
        print(i)
        print(pyarrow.default_memory_pool())


if __name__ == "__main__":
    typer.run(main)