par-meta commited on
Commit
b79eb3e
·
unverified ·
1 Parent(s): 2dcf48b

Get generation working for BLT (#86)

Browse files

Summary:

Create a script for simple generation from BLT

Test Plan:

```
python -m bytelatent.generate_blt config=../internal-blt/configs/eval_blt.yaml
```

bytelatent/args.py CHANGED
@@ -7,7 +7,7 @@ import numpy as np
7
  import yaml
8
  from pydantic import BaseModel, ConfigDict
9
 
10
- from bytelatent.checkpoint import CheckpointArgs
11
  from bytelatent.data.data_types import Batch
12
  from bytelatent.data.file_util import get_fs
13
  from bytelatent.data.iterators.abstract_iterator import StatefulIterator
@@ -270,8 +270,11 @@ class EvalArgs(BaseModel):
270
  model_config = ConfigDict(extra="forbid")
271
  dump_dir: str | None = None
272
  ckpt_dir: str | None = None
 
273
  metric_log_dir: str | None = None
274
 
 
 
275
  run_ppl: bool = True
276
  run_tasks: bool = False
277
 
@@ -284,6 +287,8 @@ class EvalArgs(BaseModel):
284
 
285
  global_step: int | None = None # for in-training evaluation
286
  s3_profile: str | None = None
 
 
287
 
288
 
289
  class TrainArgs(BaseModel):
 
7
  import yaml
8
  from pydantic import BaseModel, ConfigDict
9
 
10
+ from bytelatent.checkpoint import CONSOLIDATE_FOLDER, CheckpointArgs
11
  from bytelatent.data.data_types import Batch
12
  from bytelatent.data.file_util import get_fs
13
  from bytelatent.data.iterators.abstract_iterator import StatefulIterator
 
270
  model_config = ConfigDict(extra="forbid")
271
  dump_dir: str | None = None
272
  ckpt_dir: str | None = None
273
+ entropy_ckpt_dir: str | None = None
274
  metric_log_dir: str | None = None
275
 
276
+ prompts: list[str] | None = None
277
+
278
  run_ppl: bool = True
279
  run_tasks: bool = False
280
 
 
287
 
288
  global_step: int | None = None # for in-training evaluation
289
  s3_profile: str | None = None
290
+ consolidate_if_needed: bool = False
291
+ consolidate_folder: str = CONSOLIDATE_FOLDER
292
 
293
 
294
  class TrainArgs(BaseModel):
bytelatent/data/patcher.py CHANGED
@@ -1,5 +1,6 @@
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
2
  import math
 
3
  import time
4
  from collections import defaultdict
5
  from contextlib import nullcontext
@@ -476,7 +477,11 @@ class Patcher:
476
  patcher_args.entropy_model_checkpoint_dir is not None
477
  ), "Cannot require realtime patching without an entropy model checkpoint"
478
  entropy_model = load_entropy_model(
479
- patcher_args.entropy_model_checkpoint_dir
 
 
 
 
480
  )
481
  entropy_model, _ = to_device(entropy_model, patcher_args.patching_device)
482
  self.entropy_model = entropy_model
 
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
2
  import math
3
+ import os
4
  import time
5
  from collections import defaultdict
6
  from contextlib import nullcontext
 
477
  patcher_args.entropy_model_checkpoint_dir is not None
478
  ), "Cannot require realtime patching without an entropy model checkpoint"
479
  entropy_model = load_entropy_model(
480
+ patcher_args.entropy_model_checkpoint_dir,
481
+ os.path.join(
482
+ patcher_args.entropy_model_checkpoint_dir,
483
+ "consolidated/consolidated.pth",
484
+ ),
485
  )
486
  entropy_model, _ = to_device(entropy_model, patcher_args.patching_device)
487
  self.entropy_model = entropy_model
bytelatent/distributed.py CHANGED
@@ -162,6 +162,12 @@ def dist_max(x: Union[int, float], mesh: DeviceMesh = None):
162
  return tensor
163
 
164
 
 
 
 
 
 
 
165
  def dist_sum(
166
  x: Union[int, float], mesh: DeviceMesh = None, reduce_dtype: torch.dtype = None
167
  ):
 
162
  return tensor
163
 
164
 
165
+ def dist_min(x: Union[int, float], mesh: DeviceMesh = None):
166
+ tensor = torch.tensor(x).cuda()
167
+ dist.all_reduce(tensor, op=ReduceOp.MIN, group=mesh.get_group() if mesh else None)
168
+ return tensor
169
+
170
+
171
  def dist_sum(
172
  x: Union[int, float], mesh: DeviceMesh = None, reduce_dtype: torch.dtype = None
173
  ):
bytelatent/eval.py CHANGED
@@ -243,9 +243,20 @@ def launch_eval(eval_args: EvalArgs):
243
  ):
244
  consolidate_path = eval_args.ckpt_dir
245
  else:
246
- consolidate_path = os.path.join(eval_args.ckpt_dir, CONSOLIDATE_FOLDER)
247
- if not fs.exists(consolidate_path) and get_global_rank() == 0:
248
- consolidate_path = consolidate_checkpoints(fs, eval_args.ckpt_dir)
 
 
 
 
 
 
 
 
 
 
 
249
 
250
  fs.mkdirs(eval_args.dump_dir, exist_ok=True)
251
  with fs.open(os.path.join(eval_args.dump_dir, "config.yaml"), "w") as f:
 
243
  ):
244
  consolidate_path = eval_args.ckpt_dir
245
  else:
246
+ if eval_args.consolidate_if_needed:
247
+ logger.info(
248
+ "Found a model checkpoint, but it has not been consolidated.... so consolidating the checkpoint"
249
+ )
250
+ consolidate_path = os.path.join(
251
+ eval_args.ckpt_dir, eval_args.consolidate_folder
252
+ )
253
+ if not fs.exists(consolidate_path) and get_global_rank() == 0:
254
+ consolidate_path = consolidate_checkpoints(fs, eval_args.ckpt_dir)
255
+ logger.info("Model consolidated to: %s", consolidate_path)
256
+ else:
257
+ raise ValueError(
258
+ "Did not find a consolidated checkpoint and consolidate_if_needed is False"
259
+ )
260
 
261
  fs.mkdirs(eval_args.dump_dir, exist_ok=True)
262
  with fs.open(os.path.join(eval_args.dump_dir, "config.yaml"), "w") as f:
bytelatent/generate.py CHANGED
@@ -10,7 +10,7 @@ from torch.nn import functional as F
10
  from torch.nn.attention.flex_attention import create_block_mask
11
  from tqdm import tqdm
12
 
13
- from bytelatent.args import PackedCausalTransformerGeneratorArgs, TrainArgs
14
  from bytelatent.base_transformer import (
15
  Attention,
16
  causal_mask,
@@ -18,8 +18,14 @@ from bytelatent.base_transformer import (
18
  lengths_to_local_ids,
19
  lengths_to_start_ids,
20
  )
21
- from bytelatent.checkpoint import CONSOLIDATE_NAME
 
 
 
 
 
22
  from bytelatent.data.file_util import get_fs
 
23
  from bytelatent.model.blt import ByteLatentTransformer
24
  from bytelatent.tokenizers.abstract_tokenizer import Tokenizer
25
  from bytelatent.transformer import LMTransformer
@@ -411,15 +417,25 @@ def load_consolidated_model_and_tokenizer(
411
 
412
  def main():
413
  # Load CLI arguments (overrides) and combine with a YAML config
414
- cfg = OmegaConf.from_cli()
415
- gen_cfg = dataclass_from_dict(
416
- PackedCausalTransformerGeneratorArgs, cfg, strict=False
417
- )
418
- print(cfg)
 
 
 
 
 
 
 
 
419
 
420
- model, tokenizer, _ = load_consolidated_model_and_tokenizer(cfg.ckpt)
 
 
421
 
422
- generator = PackedCausalTransformerGenerator(gen_cfg, model, tokenizer)
423
 
424
  # Allow multiple prompts
425
  prompts = []
 
10
  from torch.nn.attention.flex_attention import create_block_mask
11
  from tqdm import tqdm
12
 
13
+ from bytelatent.args import EvalArgs, PackedCausalTransformerGeneratorArgs, TrainArgs
14
  from bytelatent.base_transformer import (
15
  Attention,
16
  causal_mask,
 
18
  lengths_to_local_ids,
19
  lengths_to_start_ids,
20
  )
21
+ from bytelatent.checkpoint import (
22
+ CONSOLIDATE_FOLDER,
23
+ CONSOLIDATE_NAME,
24
+ consolidate_checkpoints,
25
+ )
26
+ from bytelatent.config_parser import parse_args_to_pydantic_model
27
  from bytelatent.data.file_util import get_fs
28
+ from bytelatent.distributed import get_global_rank
29
  from bytelatent.model.blt import ByteLatentTransformer
30
  from bytelatent.tokenizers.abstract_tokenizer import Tokenizer
31
  from bytelatent.transformer import LMTransformer
 
417
 
418
  def main():
419
  # Load CLI arguments (overrides) and combine with a YAML config
420
+ eval_args = parse_args_to_pydantic_model(EvalArgs)
421
+
422
+ fs = get_fs(eval_args.ckpt_dir, s3_profile=eval_args.s3_profile)
423
+ if (
424
+ fs.exists(eval_args.ckpt_dir)
425
+ and fs.exists(os.path.join(eval_args.ckpt_dir, "params.json"))
426
+ and len(fs.glob(os.path.join(eval_args.ckpt_dir, "*.pth"))) != 0
427
+ ):
428
+ consolidate_path = eval_args.ckpt_dir
429
+ else:
430
+ consolidate_path = os.path.join(eval_args.ckpt_dir, CONSOLIDATE_FOLDER)
431
+ if not fs.exists(consolidate_path) and get_global_rank() == 0:
432
+ consolidate_path = consolidate_checkpoints(fs, eval_args.ckpt_dir)
433
 
434
+ model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(
435
+ consolidate_path
436
+ )
437
 
438
+ generator = PackedCausalTransformerGenerator(eval_args.generator, model, tokenizer)
439
 
440
  # Allow multiple prompts
441
  prompts = []
bytelatent/generate_blt.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+
4
+ import torch
5
+
6
+ from bytelatent.args import EvalArgs
7
+ from bytelatent.config_parser import parse_args_to_pydantic_model
8
+ from bytelatent.data.file_util import get_fs
9
+ from bytelatent.data.patcher import Patcher
10
+ from bytelatent.distributed import (
11
+ DistributedArgs,
12
+ dist_max,
13
+ dist_min,
14
+ dist_sum,
15
+ get_device_mesh,
16
+ setup_torch_distributed,
17
+ )
18
+ from bytelatent.generate import load_consolidated_model_and_tokenizer
19
+ from bytelatent.model.blt import ByteLatentTransformer
20
+ from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
21
+
22
+ logger = logging.getLogger()
23
+
24
+
25
+ def get_max_length(input_tokens: list[list[int]] | None) -> int:
26
+ # reduce max length prompt over all processes to have an equal number of call on each process with fsdp
27
+ if input_tokens is None:
28
+ max_length = 0
29
+ else:
30
+ max_length = max([len(t) for t in input_tokens])
31
+ if torch.distributed.is_initialized():
32
+ max_length = int(dist_max(max_length))
33
+ return max_length
34
+
35
+
36
+ def get_min_length(input_tokens: list[list[int]] | None) -> int:
37
+ # reduce min length prompt over all processes to have an equal number of call on each process with fsdp
38
+ if input_tokens is None:
39
+ # TODO: Double check this change from int(1e9) is correct
40
+ min_length = 0
41
+ else:
42
+ min_length = min([len(t) for t in input_tokens])
43
+ if torch.distributed.is_initialized():
44
+ min_length = int(dist_min(min_length))
45
+ return min_length
46
+
47
+
48
+ def get_generation_range(
49
+ prompt_tokens: list[list[int]] | None, max_gen_len: int
50
+ ) -> tuple[int, int]:
51
+ batch_min_prompt_length = get_min_length(prompt_tokens)
52
+ batch_max_prompt_length = get_max_length(prompt_tokens)
53
+ return batch_min_prompt_length, batch_max_prompt_length + max_gen_len
54
+
55
+
56
+ def sample_top_k(probs, k):
57
+ topk_value, _ = torch.topk(probs, k) # batch_sz x topk
58
+ min_value_top_k = topk_value[:, [-1]]
59
+ probs[probs < min_value_top_k] = 0.0
60
+ probs.div_(probs.sum(dim=-1, keepdim=True))
61
+ next_token = torch.multinomial(probs, num_samples=1)
62
+ return next_token
63
+
64
+
65
+ def sample_top_p(probs, p):
66
+ probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
67
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
68
+ mask = probs_sum - probs_sort > p
69
+ probs_sort[mask] = 0.0
70
+ probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
71
+ next_token = torch.multinomial(probs_sort, num_samples=1)
72
+ next_token = torch.gather(probs_idx, -1, next_token)
73
+ return next_token
74
+
75
+
76
+ @torch.inference_mode()
77
+ def generate_nocache(
78
+ prompts: list[str] | None,
79
+ *,
80
+ model: ByteLatentTransformer,
81
+ tokenizer: BltTokenizer,
82
+ patcher: Patcher,
83
+ max_prompt_len: int = 256,
84
+ max_gen_len: int = 256,
85
+ use_sampling: bool = False,
86
+ temp: float = 1.0,
87
+ top_k: int = 0,
88
+ top_p: float = 0.0,
89
+ remove_prompts: bool = True,
90
+ ) -> list[list[int]]:
91
+ assert (
92
+ patcher.realtime_patching
93
+ ), "generate_nocache requires patcher.realtime_patching=True"
94
+ model.eval()
95
+ if prompts is None:
96
+ prompt_tokens = None
97
+ n_truncated_prompts = 0
98
+ total_truncated_prompts = 0
99
+ else:
100
+ prompt_tokens = [tokenizer.encode(t, add_eos=False) for t in prompts]
101
+ n_truncated_prompts = sum([max_prompt_len < len(t) for t in prompt_tokens])
102
+ total_truncated_prompts = dist_sum(n_truncated_prompts)
103
+
104
+ # Truncation
105
+ prompt_tokens = [
106
+ t if len(t) < max_prompt_len else t[len(t) - max_prompt_len :]
107
+ for t in prompt_tokens
108
+ ]
109
+
110
+ if total_truncated_prompts > 0:
111
+ logger.info(
112
+ f"There are {total_truncated_prompts} prompts that are truncated on the left, "
113
+ f"length greater than max_prompt_len = {max_prompt_len}, "
114
+ f"maximum prompt length = {get_max_length(prompt_tokens)} across all gpus."
115
+ )
116
+
117
+ if prompt_tokens is None:
118
+ prompt_tokens = [[tokenizer.bos_id] for _ in range(end_pos)]
119
+
120
+ start_pos, end_pos = get_generation_range(prompt_tokens, max_gen_len)
121
+ batch_size = len(prompt_tokens)
122
+ tokens = torch.full((batch_size, end_pos), tokenizer.pad_id).cuda().long()
123
+
124
+ # Copy inputs to tensor for generated tokens
125
+ for i, row_tokens in enumerate(prompt_tokens):
126
+ tokens[i, : len(row_tokens)] = torch.tensor(row_tokens).long()
127
+ input_text_mask = tokens != tokenizer.pad_id
128
+
129
+ for i, curr_pos in enumerate(range(start_pos, end_pos)):
130
+ current_tokens = tokens[:, :curr_pos]
131
+ patch_lengths, _ = patcher.patch(current_tokens, include_next_token=True)
132
+ logits = model(current_tokens, patch_lengths=patch_lengths)[:, -1]
133
+
134
+ if use_sampling:
135
+ probs = torch.softmax(logits / temp, dim=-1)
136
+ if top_p > 0.0:
137
+ next_token = sample_top_p(probs, top_p)
138
+ elif top_k > 0:
139
+ next_token = sample_top_k(probs, top_k)
140
+ else:
141
+ next_token = torch.multinomial(probs, num_samples=1)
142
+ else:
143
+ next_token = torch.argmax(logits, dim=-1)
144
+
145
+ next_token = torch.where(
146
+ input_text_mask[:, curr_pos], tokens[:, curr_pos], next_token
147
+ )
148
+ tokens[:, curr_pos] = next_token
149
+
150
+ if remove_prompts:
151
+ generated_tokens = [
152
+ t[len(prompt_tokens[i]) : len(prompt_tokens[i]) + max_gen_len].tolist()
153
+ for i, t in enumerate(tokens)
154
+ ]
155
+ else:
156
+ generated_tokens = [
157
+ t[: len(prompt_tokens[i]) + max_gen_len].tolist()
158
+ for i, t in enumerate(tokens)
159
+ ]
160
+ return generated_tokens
161
+
162
+
163
+ def launch_generate(eval_args: EvalArgs):
164
+ assert eval_args.dump_dir is not None
165
+ assert eval_args.ckpt_dir is not None
166
+ distributed_args = DistributedArgs()
167
+ distributed_args.configure_world()
168
+ if not torch.distributed.is_initialized():
169
+ setup_torch_distributed(distributed_args)
170
+
171
+ world_mesh = get_device_mesh(distributed_args)
172
+ dp_mesh = world_mesh["dp_replicate"]
173
+ assert distributed_args.dp_shard == 1
174
+ world_size = dp_mesh.size()
175
+ world_rank = dp_mesh.get_local_rank()
176
+
177
+ fs = get_fs(eval_args.ckpt_dir, s3_profile=eval_args.s3_profile)
178
+ if (
179
+ fs.exists(eval_args.ckpt_dir)
180
+ and fs.exists(os.path.join(eval_args.ckpt_dir, "params.json"))
181
+ and len(fs.glob(os.path.join(eval_args.ckpt_dir, "*.pth"))) != 0
182
+ ):
183
+ consolidate_path = eval_args.ckpt_dir
184
+ else:
185
+ raise ValueError("Did not find a consolidated checkpoint in the ckpt_dir")
186
+
187
+ model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(
188
+ consolidate_path,
189
+ )
190
+ patcher_args = train_cfg.data.patcher_args.model_copy(deep=True)
191
+ patcher_args.realtime_patching = True
192
+ patcher_args.entropy_model_checkpoint_dir = eval_args.entropy_ckpt_dir
193
+ patcher = patcher_args.build()
194
+ outputs = generate_nocache(
195
+ eval_args.prompts, model=model, tokenizer=tokenizer, patcher=patcher
196
+ )
197
+ text_outputs = [tokenizer.decode(t) for t in outputs]
198
+ for p, t in zip(eval_args.prompts, text_outputs):
199
+ print(f'Prompt: "{p}" Completion: "{t}"')
200
+ print()
201
+
202
+
203
+ def main():
204
+ eval_args = parse_args_to_pydantic_model(EvalArgs)
205
+ launch_generate(eval_args)
206
+
207
+
208
+ if __name__ == "__main__":
209
+ main()