File size: 9,017 Bytes
e202b16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import json
import os
import readline  # type: ignore # noqa
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, Optional, Tuple, Union

import model as fast
import mp_utils
import sample_utils
import torch
from stats import Stats
from tokenizer import Tokenizer

from xformers.ops.fmha.attn_bias import (
    BlockDiagonalCausalWithOffsetPaddedKeysMask as AttnBias,
)


@dataclass
class GenArgs:
    gen_length: int = 1000

    use_sampling: bool = True
    temperature: float = 0.6
    top_p: float = 0.9


class FastGen:
    GRAPH_WARMUPS: int = 3
    tokenizer: Tokenizer

    @staticmethod
    def build(
        ckpt_dir: str,
        gen_args: GenArgs,
        device: Union[torch.device, str],
        tokenizer_path: Optional[str] = None,
    ) -> "FastGen":
        """
        Load a Llama or Code Llama checkpoint and return a new
        generator for this model.
        """
        start_time = time.time()
        world_size = mp_utils.get_world_size()

        checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))

        assert len(checkpoints) > 0, f"no checkpoint files in {ckpt_dir}"
        assert world_size == len(checkpoints), (
            f"checkpoint for model parallelism {len(checkpoints)}"
            f" but world size is {world_size}"
        )

        ckpt_path = checkpoints[mp_utils.get_rank()]
        with open(Path(ckpt_dir) / "params.json", "r") as f:
            params = json.loads(f.read())
        model_args = fast.ModelArgs(**params)

        if tokenizer_path is None:
            tokenizer_path = str(Path(ckpt_dir) / "tokenizer.model")
        if not os.path.isfile(tokenizer_path):
            tokenizer_path = str(Path(ckpt_dir) / ".." / "tokenizer.model")
        if not os.path.isfile(tokenizer_path):
            raise RuntimeError("could not find the tokenizer model")
        tokenizer = Tokenizer(model_path=tokenizer_path)
        model_args.vocab_size = tokenizer.n_words

        torch.set_default_device(device)
        torch.set_default_dtype(torch.bfloat16)

        model = fast.Transformer(model_args)
        checkpoint = torch.load(ckpt_path, map_location="cpu")
        model.load_state_dict(checkpoint, strict=False)
        print(f"loaded model in {time.time() - start_time:.2f} seconds")

        return FastGen(gen_args, model_args, model, tokenizer)

    def __init__(
        self,
        args: GenArgs,
        model_args: fast.ModelArgs,
        model: fast.Transformer,
        tokenizer: Tokenizer,
    ):
        self.gen_args = args
        self.model_args = model_args
        self.model = model
        self.tokenizer = tokenizer

    @torch.inference_mode()
    def generate_all(
        self, prompts: list[list[int]], use_cuda_graphs: bool
    ) -> Tuple[Stats, list[list[int]]]:
        bs = len(prompts)
        prompt_lens = [len(p) for p in prompts]
        max_prompt_length = max(prompt_lens)
        gen_length = self.gen_args.gen_length
        max_seq_length = max_prompt_length + gen_length

        cache = fast.make_cache(
            args=self.model_args,
            length=bs * max_seq_length,
        )

        bias = AttnBias.from_seqlens(
            q_seqlen=prompt_lens,
            kv_seqlen=prompt_lens,
            kv_padding=max_seq_length,
        )
        bias.q_seqinfo.to("cuda")
        bias.k_seqinfo.to("cuda")

        graph = torch.cuda.CUDAGraph()

        # Input tensors to the cuda graph
        q_seqstart = bias.q_seqinfo.seqstart
        kv_seqlen = bias.k_seqinfo.seqlen
        tokens = torch.IntTensor(sum(prompts, [])).cuda()
        out_tokens = torch.zeros((max_seq_length, bs), dtype=torch.int)

        stats = Stats()
        stats.phase("warmup" if use_cuda_graphs else "total")

        for niter in range(gen_length):
            if niter <= self.GRAPH_WARMUPS or not use_cuda_graphs:
                # Keep the first iteration out of the
                # warmup, it processes prompts while all
                # other iterations process sequences of 0
                # or 1 token only
                output = self.model.forward_with_attn_bias(
                    token_values=tokens,
                    attn_bias=bias,
                    cache=cache,
                )
            elif niter == self.GRAPH_WARMUPS + 1:
                recording_kwargs = {}
                if "capture_error_mode" in torch.cuda.graph.__init__.__annotations__:
                    # In PyTorch 2.1+ and nightlies from late Aug 2023,
                    # we can do this to maybe avoid watchdog-related crashes
                    recording_kwargs["capture_error_mode"] = "thread_local"
                with torch.cuda.graph(graph, **recording_kwargs):
                    output = self.model.forward_with_attn_bias(
                        token_values=tokens,
                        attn_bias=bias,
                        cache=cache,
                    )
                graph.replay()
                # synchronize to get accurate timings
                torch.cuda.synchronize()
                stats.phase("graph", tokens=(niter + 1) * bs)
            else:
                graph.replay()

            # output: (sum(token_lengths), vocab_size)
            logits = output.view(bs, self.model_args.vocab_size)

            if self.gen_args.use_sampling:
                temp = self.gen_args.temperature
                top_p = self.gen_args.top_p
                probs = torch.softmax(logits / temp, dim=-1)
                next_token = sample_utils.top_p(probs, top_p)
            else:
                next_token = torch.argmax(logits, dim=-1)

            next_token = next_token.reshape(bs)
            out_tokens[niter, :] = next_token

            # Update attention bias state for decoding rounds
            if niter == 0:
                q_seqstart.copy_(torch.arange(bs + 1, dtype=torch.int))
                bias.q_seqinfo.min_seqlen = 1
                bias.q_seqinfo.max_seqlen = 1
                bias.q_seqinfo.seqstart_py = q_seqstart.tolist()
                tokens = tokens[:bs]

            kv_seqlen.add_(kv_seqlen < max_seq_length)

            tokens.copy_(next_token)

        stats.end_phase(tokens=gen_length * bs)

        def trim_answer(prompt, tokens):
            """Trim the answer to end it on an eos token."""
            tokens = tokens[: max_seq_length - len(prompt)]
            eos_id = self.tokenizer.eos_id
            if eos_id in tokens:
                return tokens[: tokens.index(eos_id) + 1]
            else:
                return tokens

        answers = [
            trim_answer(prompt, answer)
            for prompt, answer in zip(prompts, out_tokens.t().tolist())
        ]
        return stats, answers


def get_prompts(interactive: bool) -> Iterable[list[str]]:
    if interactive:
        while True:
            try:
                prompts = input("enter prompt: ").split("\n")
            except EOFError:
                print("exiting")
                sys.exit(0)
            yield prompts
    else:
        yield [
            "abc",
            "can you write a hello world program in C#",
            "peux tu resoudre le probleme des tours de Hanoi en ocaml",
        ]


def main(ckpt_dir: str, interactive: bool, add_instruction_tags: bool):
    if "WORLD_SIZE" in os.environ:
        mp_size = int(os.environ["WORLD_SIZE"])
        local_rank = int(os.environ["LOCAL_RANK"])
    else:
        mp_size = 1
        local_rank = 0

    device = mp_utils.initialize(mp_size, local_rank)

    g = FastGen.build(ckpt_dir, GenArgs(), device)

    for prompts in get_prompts(interactive):
        if add_instruction_tags:
            prompts = [f"[INST]{prompt}[/INST]" for prompt in prompts]

        tokens = [g.tokenizer.encode(x) for x in prompts]
        stats, out_tokens = g.generate_all(
            tokens, use_cuda_graphs="NO_CUDA_GRAPHS" not in os.environ
        )

        if mp_utils.get_rank() == 0:
            for i, prompt in enumerate(prompts):
                print(f"> {prompt}")
                answer = g.tokenizer.decode(out_tokens[i])
                print(answer)
                print("---------------")

            for phase_stats in stats.phases:
                print(phase_stats.show())


if __name__ == "__main__":
    parser = argparse.ArgumentParser("Llama inference")
    parser.add_argument("ckpt_dir")
    parser.add_argument(
        "-i", "--interactive", action="store_true", help="ask for prompts"
    )
    parser.add_argument(
        "--no-instruction-tags", action="store_true", help="do not add instruction tags"
    )

    args = parser.parse_args()
    main(
        ckpt_dir=args.ckpt_dir,
        interactive=args.interactive,
        add_instruction_tags=not args.no_instruction_tags,
    )