FlexTheAi's picture
Upload folder using huggingface_hub
e202b16 verified
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import json
import os
import readline # type: ignore # noqa
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, Optional, Tuple, Union
import model as fast
import mp_utils
import sample_utils
import torch
from stats import Stats
from tokenizer import Tokenizer
from xformers.ops.fmha.attn_bias import (
BlockDiagonalCausalWithOffsetPaddedKeysMask as AttnBias,
)
@dataclass
class GenArgs:
gen_length: int = 1000
use_sampling: bool = True
temperature: float = 0.6
top_p: float = 0.9
class FastGen:
GRAPH_WARMUPS: int = 3
tokenizer: Tokenizer
@staticmethod
def build(
ckpt_dir: str,
gen_args: GenArgs,
device: Union[torch.device, str],
tokenizer_path: Optional[str] = None,
) -> "FastGen":
"""
Load a Llama or Code Llama checkpoint and return a new
generator for this model.
"""
start_time = time.time()
world_size = mp_utils.get_world_size()
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(checkpoints) > 0, f"no checkpoint files in {ckpt_dir}"
assert world_size == len(checkpoints), (
f"checkpoint for model parallelism {len(checkpoints)}"
f" but world size is {world_size}"
)
ckpt_path = checkpoints[mp_utils.get_rank()]
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
model_args = fast.ModelArgs(**params)
if tokenizer_path is None:
tokenizer_path = str(Path(ckpt_dir) / "tokenizer.model")
if not os.path.isfile(tokenizer_path):
tokenizer_path = str(Path(ckpt_dir) / ".." / "tokenizer.model")
if not os.path.isfile(tokenizer_path):
raise RuntimeError("could not find the tokenizer model")
tokenizer = Tokenizer(model_path=tokenizer_path)
model_args.vocab_size = tokenizer.n_words
torch.set_default_device(device)
torch.set_default_dtype(torch.bfloat16)
model = fast.Transformer(model_args)
checkpoint = torch.load(ckpt_path, map_location="cpu")
model.load_state_dict(checkpoint, strict=False)
print(f"loaded model in {time.time() - start_time:.2f} seconds")
return FastGen(gen_args, model_args, model, tokenizer)
def __init__(
self,
args: GenArgs,
model_args: fast.ModelArgs,
model: fast.Transformer,
tokenizer: Tokenizer,
):
self.gen_args = args
self.model_args = model_args
self.model = model
self.tokenizer = tokenizer
@torch.inference_mode()
def generate_all(
self, prompts: list[list[int]], use_cuda_graphs: bool
) -> Tuple[Stats, list[list[int]]]:
bs = len(prompts)
prompt_lens = [len(p) for p in prompts]
max_prompt_length = max(prompt_lens)
gen_length = self.gen_args.gen_length
max_seq_length = max_prompt_length + gen_length
cache = fast.make_cache(
args=self.model_args,
length=bs * max_seq_length,
)
bias = AttnBias.from_seqlens(
q_seqlen=prompt_lens,
kv_seqlen=prompt_lens,
kv_padding=max_seq_length,
)
bias.q_seqinfo.to("cuda")
bias.k_seqinfo.to("cuda")
graph = torch.cuda.CUDAGraph()
# Input tensors to the cuda graph
q_seqstart = bias.q_seqinfo.seqstart
kv_seqlen = bias.k_seqinfo.seqlen
tokens = torch.IntTensor(sum(prompts, [])).cuda()
out_tokens = torch.zeros((max_seq_length, bs), dtype=torch.int)
stats = Stats()
stats.phase("warmup" if use_cuda_graphs else "total")
for niter in range(gen_length):
if niter <= self.GRAPH_WARMUPS or not use_cuda_graphs:
# Keep the first iteration out of the
# warmup, it processes prompts while all
# other iterations process sequences of 0
# or 1 token only
output = self.model.forward_with_attn_bias(
token_values=tokens,
attn_bias=bias,
cache=cache,
)
elif niter == self.GRAPH_WARMUPS + 1:
recording_kwargs = {}
if "capture_error_mode" in torch.cuda.graph.__init__.__annotations__:
# In PyTorch 2.1+ and nightlies from late Aug 2023,
# we can do this to maybe avoid watchdog-related crashes
recording_kwargs["capture_error_mode"] = "thread_local"
with torch.cuda.graph(graph, **recording_kwargs):
output = self.model.forward_with_attn_bias(
token_values=tokens,
attn_bias=bias,
cache=cache,
)
graph.replay()
# synchronize to get accurate timings
torch.cuda.synchronize()
stats.phase("graph", tokens=(niter + 1) * bs)
else:
graph.replay()
# output: (sum(token_lengths), vocab_size)
logits = output.view(bs, self.model_args.vocab_size)
if self.gen_args.use_sampling:
temp = self.gen_args.temperature
top_p = self.gen_args.top_p
probs = torch.softmax(logits / temp, dim=-1)
next_token = sample_utils.top_p(probs, top_p)
else:
next_token = torch.argmax(logits, dim=-1)
next_token = next_token.reshape(bs)
out_tokens[niter, :] = next_token
# Update attention bias state for decoding rounds
if niter == 0:
q_seqstart.copy_(torch.arange(bs + 1, dtype=torch.int))
bias.q_seqinfo.min_seqlen = 1
bias.q_seqinfo.max_seqlen = 1
bias.q_seqinfo.seqstart_py = q_seqstart.tolist()
tokens = tokens[:bs]
kv_seqlen.add_(kv_seqlen < max_seq_length)
tokens.copy_(next_token)
stats.end_phase(tokens=gen_length * bs)
def trim_answer(prompt, tokens):
"""Trim the answer to end it on an eos token."""
tokens = tokens[: max_seq_length - len(prompt)]
eos_id = self.tokenizer.eos_id
if eos_id in tokens:
return tokens[: tokens.index(eos_id) + 1]
else:
return tokens
answers = [
trim_answer(prompt, answer)
for prompt, answer in zip(prompts, out_tokens.t().tolist())
]
return stats, answers
def get_prompts(interactive: bool) -> Iterable[list[str]]:
if interactive:
while True:
try:
prompts = input("enter prompt: ").split("\n")
except EOFError:
print("exiting")
sys.exit(0)
yield prompts
else:
yield [
"abc",
"can you write a hello world program in C#",
"peux tu resoudre le probleme des tours de Hanoi en ocaml",
]
def main(ckpt_dir: str, interactive: bool, add_instruction_tags: bool):
if "WORLD_SIZE" in os.environ:
mp_size = int(os.environ["WORLD_SIZE"])
local_rank = int(os.environ["LOCAL_RANK"])
else:
mp_size = 1
local_rank = 0
device = mp_utils.initialize(mp_size, local_rank)
g = FastGen.build(ckpt_dir, GenArgs(), device)
for prompts in get_prompts(interactive):
if add_instruction_tags:
prompts = [f"[INST]{prompt}[/INST]" for prompt in prompts]
tokens = [g.tokenizer.encode(x) for x in prompts]
stats, out_tokens = g.generate_all(
tokens, use_cuda_graphs="NO_CUDA_GRAPHS" not in os.environ
)
if mp_utils.get_rank() == 0:
for i, prompt in enumerate(prompts):
print(f"> {prompt}")
answer = g.tokenizer.decode(out_tokens[i])
print(answer)
print("---------------")
for phase_stats in stats.phases:
print(phase_stats.show())
if __name__ == "__main__":
parser = argparse.ArgumentParser("Llama inference")
parser.add_argument("ckpt_dir")
parser.add_argument(
"-i", "--interactive", action="store_true", help="ask for prompts"
)
parser.add_argument(
"--no-instruction-tags", action="store_true", help="do not add instruction tags"
)
args = parser.parse_args()
main(
ckpt_dir=args.ckpt_dir,
interactive=args.interactive,
add_instruction_tags=not args.no_instruction_tags,
)