File size: 4,402 Bytes
bcc039b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

import contextlib
import logging
import os
from pathlib import Path

import torch.distributed
import wandb
import xformers.profiler
from pydantic import BaseModel
from torch.profiler.profiler import profile
from xformers.profiler import MemSnapshotsProfiler, PyTorchProfiler

from bytelatent.distributed import get_is_master


class ProfilerArgs(BaseModel):
    run: bool = False
    trace_folder: str = "profiling"
    mem_warmup: int = 100
    mem_steps: int = 2
    profile_warmup: int = 102
    profile_steps: int = 2


logger = logging.getLogger()


def perfetto_to_html(json_file, html_file):
    import gzip
    import string

    import viztracer

    root = os.path.dirname(viztracer.__file__)
    sub = {}
    json_file = gzip.open(json_file) if ".gz" in str(json_file) else open(json_file)
    with open(
        os.path.join(root, "html/trace_viewer_embedder.html"), encoding="utf-8"
    ) as f:
        tmpl = f.read()
    with open(os.path.join(root, "html/trace_viewer_full.html"), encoding="utf-8") as f:
        sub["trace_viewer_full"] = f.read()
    with json_file as j:
        content = j.read()
        if isinstance(content, bytes):
            content = content.decode("utf-8")
        sub["json_data"] = content.replace("</script>", "<\\/script>")  # type: ignore
    with open(html_file, "w+", encoding="utf-8") as output_file:
        output_file.write(string.Template(tmpl).substitute(sub))


class PyTorchProfilerWandb(PyTorchProfiler):
    def __init__(self, main_profiler) -> None:
        self.main_profiler = main_profiler
        self.num_steps = 0
        self.pytorch_profiler = torch.profiler.profile(
            on_trace_ready=self._on_trace,
            profile_memory=True,
            record_shapes=True,
            # With stack gives huge profile traces
            # and bugs out because of some non ascii
            # character somewhere in pytorch
            with_stack=False,
            with_flops=True,
            activities=self.ACTIVITIES,
        )

    def _analyze_trace(self, prof: profile):
        logger.info("Begin analyze trace")
        super()._analyze_trace(prof)
        logger.info("End analyze trace")

    def _on_trace(self, prof: torch.profiler.profiler.profile) -> None:
        super()._on_trace(prof)
        if get_is_master() and wandb.run is not None:
            filename = list(
                Path(self.main_profiler.output_dir).glob(
                    "profile_CPU_CUDA*/*.pt.trace.json*"
                )
            )[0]
            html_path = str(filename).replace(".json", ".html")
            perfetto_to_html(filename, html_path)
            wandb.log({"profile_trace": wandb.Html(html_path)})


class MemSnapshotsProfilerWandb(MemSnapshotsProfiler):
    def __exit__(self, exc_type, exc_val, exc_tb):
        super().__exit__(exc_type, exc_val, exc_tb)
        if get_is_master() and wandb.run is not None:
            filename = list(
                Path(self.main_profiler.output_dir).glob("memory_trace_plot/*.html")
            )[0]
            wandb.log({"memory_trace": wandb.Html(open(filename), inject=False)})


@contextlib.contextmanager
def maybe_run_profiler(dump_dir, module, config: ProfilerArgs):
    # get user defined profiler settings

    if config.run:
        trace_dir = os.path.join(dump_dir, config.trace_folder)

        logger.info(f"Profiling active.  Traces will be saved at {trace_dir}")

        if get_is_master() and not os.path.exists(trace_dir):
            os.makedirs(trace_dir)
        if torch.distributed.is_initialized():
            torch.distributed.barrier()

        with xformers.profiler.profile(
            output_dir=trace_dir,
            module=module,
            schedule=[
                (
                    MemSnapshotsProfilerWandb,
                    config.mem_warmup,
                    config.mem_warmup + config.mem_steps,
                ),
                (
                    PyTorchProfilerWandb,
                    config.profile_warmup,
                    config.profile_warmup + config.profile_steps,
                ),
            ],
        ) as profiler:
            yield profiler

    else:
        torch_profiler = contextlib.nullcontext()
        yield None