luca-peric commited on
Commit
41ea791
·
1 Parent(s): 1b67cbe

Visualisation working on CPU via CUDA_VISIBLE_DEVICE=-1 python demo_patcher.py 'Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin.'

Browse files
bytelatent/entropy_model.py CHANGED
@@ -27,8 +27,8 @@ def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cp
27
  max_seqlen=model_params["max_seqlen"],
28
  ffn_dim_multiplier=model_params["ffn_dim_multiplier"],
29
  vocab_size=model_params["vocab_size"],
30
- attn_bias_type="local_block_causal",
31
- attn_impl="xformers",
32
  sliding_window=512,
33
  )
34
  )
 
27
  max_seqlen=model_params["max_seqlen"],
28
  ffn_dim_multiplier=model_params["ffn_dim_multiplier"],
29
  vocab_size=model_params["vocab_size"],
30
+ attn_bias_type="causal",
31
+ attn_impl="sdpa",
32
  sliding_window=512,
33
  )
34
  )
bytelatent/generate_patcher.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from typing import Tuple
4
+
5
+ import torch
6
+
7
+ from bytelatent.args import EvalArgs
8
+ from bytelatent.config_parser import parse_args_to_pydantic_model
9
+ from bytelatent.data.file_util import get_fs
10
+ from bytelatent.data.patcher import Patcher
11
+ from bytelatent.distributed import (
12
+ DistributedArgs,
13
+ dist_max,
14
+ dist_min,
15
+ dist_sum,
16
+ get_device_mesh,
17
+ setup_torch_distributed,
18
+ )
19
+ from bytelatent.generate import load_consolidated_model_and_tokenizer
20
+ from bytelatent.model.blt import ByteLatentTransformer
21
+ from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
22
+
23
+ logger = logging.getLogger()
24
+
25
+
26
+ def get_max_length(input_tokens: list[list[int]] | None) -> int:
27
+ # reduce max length prompt over all processes to have an equal number of call on each process with fsdp
28
+ if input_tokens is None:
29
+ max_length = 0
30
+ else:
31
+ max_length = max([len(t) for t in input_tokens])
32
+ if torch.distributed.is_initialized():
33
+ max_length = int(dist_max(max_length))
34
+ return max_length
35
+
36
+
37
+ def get_min_length(input_tokens: list[list[int]] | None) -> int:
38
+ # reduce min length prompt over all processes to have an equal number of call on each process with fsdp
39
+ if input_tokens is None:
40
+ # TODO: Double check this change from int(1e9) is correct
41
+ min_length = 0
42
+ else:
43
+ min_length = min([len(t) for t in input_tokens])
44
+ if torch.distributed.is_initialized():
45
+ min_length = int(dist_min(min_length))
46
+ return min_length
47
+
48
+
49
+ def get_generation_range(
50
+ prompt_tokens: list[list[int]] | None, max_gen_len: int
51
+ ) -> tuple[int, int]:
52
+ batch_min_prompt_length = get_min_length(prompt_tokens)
53
+ batch_max_prompt_length = get_max_length(prompt_tokens)
54
+ return batch_min_prompt_length, batch_max_prompt_length + max_gen_len
55
+
56
+
57
+ def sample_top_k(probs, k):
58
+ topk_value, _ = torch.topk(probs, k) # batch_sz x topk
59
+ min_value_top_k = topk_value[:, [-1]]
60
+ probs[probs < min_value_top_k] = 0.0
61
+ probs.div_(probs.sum(dim=-1, keepdim=True))
62
+ next_token = torch.multinomial(probs, num_samples=1)
63
+ return next_token
64
+
65
+
66
+ def sample_top_p(probs, p):
67
+ probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
68
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
69
+ mask = probs_sum - probs_sort > p
70
+ probs_sort[mask] = 0.0
71
+ probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
72
+ next_token = torch.multinomial(probs_sort, num_samples=1)
73
+ next_token = torch.gather(probs_idx, -1, next_token)
74
+ return next_token
75
+
76
+
77
+ @torch.inference_mode()
78
+ def patcher_nocache(
79
+ prompts: list[str] | None,
80
+ *,
81
+ tokenizer: BltTokenizer,
82
+ patcher: Patcher,
83
+ max_prompt_len: int = 256,
84
+ max_gen_len: int = 256,
85
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None:
86
+ assert (
87
+ patcher.realtime_patching
88
+ ), "generate_nocache requires patcher.realtime_patching=True"
89
+ if prompts is None:
90
+ prompt_tokens = None
91
+ n_truncated_prompts = 0
92
+ total_truncated_prompts = 0
93
+ else:
94
+ prompt_tokens = [tokenizer.encode(t, add_eos=False) for t in prompts]
95
+ n_truncated_prompts = sum([max_prompt_len < len(t) for t in prompt_tokens])
96
+ if torch.distributed.is_initialized():
97
+ total_truncated_prompts = dist_sum(n_truncated_prompts)
98
+ else:
99
+ total_truncated_prompts = n_truncated_prompts
100
+
101
+ # Truncation
102
+ prompt_tokens = [
103
+ t if len(t) < max_prompt_len else t[len(t) - max_prompt_len :]
104
+ for t in prompt_tokens
105
+ ]
106
+
107
+ if total_truncated_prompts > 0:
108
+ logger.info(
109
+ f"There are {total_truncated_prompts} prompts that are truncated on the left, "
110
+ f"length greater than max_prompt_len = {max_prompt_len}, "
111
+ f"maximum prompt length = {get_max_length(prompt_tokens)} across all gpus."
112
+ )
113
+
114
+ if prompt_tokens is None:
115
+ prompt_tokens = [[tokenizer.bos_id] for _ in range(end_pos)]
116
+
117
+ start_pos, end_pos = get_generation_range(prompt_tokens, max_gen_len)
118
+ batch_size = len(prompt_tokens)
119
+ tokens = torch.full((batch_size, end_pos), tokenizer.pad_id).to(patcher.device).long()
120
+
121
+ # Copy inputs to tensor for generated tokens
122
+ for i, row_tokens in enumerate(prompt_tokens):
123
+ tokens[i, : len(row_tokens)] = torch.tensor(row_tokens).long()
124
+
125
+ for i, curr_pos in enumerate(range(start_pos, end_pos)):
126
+ current_tokens = tokens[:, :curr_pos]
127
+ patch_lengths, scores = patcher.patch(current_tokens, include_next_token=False)
128
+ # insta return since not generating t+1
129
+ return patch_lengths, scores, current_tokens
130
+ return None
bytelatent/plotting/entropy_figure_via_matplot_lib.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+
6
+
7
+ def plot_entropies(patch_lengths: torch.Tensor, scores: torch.Tensor, chars: str, threshold: float):
8
+ patch_lengths_np = patch_lengths.cpu().numpy().flatten()
9
+ scores_np = scores.cpu().float().numpy().flatten()
10
+ chars = chars.replace(" ", "_")
11
+ tokens_np = np.array([char for char in "<"+chars])
12
+
13
+ if len(scores_np) != len(tokens_np):
14
+ raise ValueError("Length of scores and tokens tensors must be the same.")
15
+ if patch_lengths_np.sum() != len(tokens_np):
16
+ raise ValueError(f"Sum of patch_lengths ({patch_lengths_np.sum()}) "
17
+ f"does not match the length of tokens/scores ({len(tokens_np)}).")
18
+
19
+
20
+ x_indices = np.arange(len(tokens_np))
21
+
22
+ # Calculate cumulative sums of patch lengths for vertical line positions
23
+ # These indicate the *end* index of each patch
24
+ patch_boundaries = np.cumsum(patch_lengths_np)
25
+
26
+ # --- Plotting ---
27
+ fig, ax = plt.subplots(figsize=(15, 5)) # Adjust figure size as needed
28
+
29
+ # Plot the scores as a blue line with markers
30
+ ax.plot(x_indices, scores_np, marker='.', linestyle='-', color='steelblue', label='Scores')
31
+
32
+ # Plot the vertical dotted lines at the patch boundaries
33
+ # We plot a line *after* each patch, so at index `boundary - 1 + 0.5`
34
+ # We skip the last boundary as it's the end of the data
35
+ for boundary in patch_boundaries[:-1]:
36
+ ax.axvline(x=boundary, color='grey', linestyle='--', linewidth=1)
37
+
38
+ ax.axhline(y=threshold, color='red', linestyle='--', linewidth=1)
39
+
40
+ # Set x-axis ticks and labels
41
+ ax.set_xticks(x_indices)
42
+ ax.set_xticklabels(tokens_np, rotation=0, fontsize=8) # Rotate labels for better readability
43
+
44
+ # Set labels for axes
45
+ # Using the Y-axis label from the example image
46
+ ax.set_ylabel("Entropy of Next Byte", fontsize=12)
47
+ ax.set_xlabel("Tokens", fontsize=12)
48
+
49
+ # Set y-axis limits (optional, but often good practice)
50
+ ax.set_ylim(bottom=0) # Start y-axis at 0 like the example
51
+ ax.set_xlim(left = x_indices[0]-1.0, right = x_indices[-1]+1.0) # Add padding to x-axis
52
+
53
+ # Add grid lines (optional)
54
+ # ax.grid(True, axis='y', linestyle=':', color='lightgrey')
55
+
56
+ # Remove the top and right spines for cleaner look (optional)
57
+ ax.spines['top'].set_visible(False)
58
+ ax.spines['right'].set_visible(False)
59
+
60
+ # Adjust layout and display the plot
61
+ plt.tight_layout()
62
+ output_filename = "token_score_plot.png"
63
+ fig.savefig(output_filename, dpi=300, bbox_inches='tight') # Save the figure
64
+ print(f"Plot saved to {os.path.abspath(output_filename)}") # Print confirmation with full path
65
+
66
+ # Close the plot figure to free memory (good practice)
67
+ plt.close(fig)
demo_patcher.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import typer
5
+
6
+ from bytelatent.data.file_util import get_fs
7
+ from bytelatent.distributed import DistributedArgs, setup_torch_distributed
8
+ from bytelatent.generate_patcher import patcher_nocache
9
+ from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
10
+ from bytelatent.plotting.entropy_figure_via_matplot_lib import plot_entropies
11
+
12
+
13
+ def main(prompt: str, model_name: str = "blt-1b"):
14
+ from bytelatent.args import TrainArgs
15
+ consolidated_path = os.path.join("hf-weights", model_name)
16
+ train_args_path = os.path.join(consolidated_path, "params.json")
17
+ fs = get_fs(train_args_path)
18
+ train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path))
19
+
20
+ tokenizer = train_args.data.tokenizer_args.build()
21
+ assert isinstance(tokenizer, BltTokenizer)
22
+ patcher_args = train_args.data.patcher_args.model_copy(deep=True)
23
+ patcher_args.realtime_patching = True
24
+ # NOTE: CPU currently unsupported due to reliance of xformers
25
+ patcher_args.patching_device = "cpu"
26
+ patcher_args.device = "cpu"
27
+ print("Loading entropy model and patcher")
28
+ patcher_args.entropy_model_checkpoint_dir = os.path.join(
29
+ consolidated_path, "entropy_model"
30
+ )
31
+ patcher = patcher_args.build()
32
+ prompts = [prompt]
33
+ results = patcher_nocache(
34
+ prompts, tokenizer=tokenizer, patcher=patcher
35
+ )
36
+ if not results:
37
+ raise Exception("Ruh roh")
38
+ batch_patch_lengths, batch_scores, batch_tokens = results
39
+ decoded_chars = [tokenizer.decode(row_tokens.tolist()) for row_tokens in batch_tokens]
40
+ plot_entropies(
41
+ batch_patch_lengths[0],
42
+ batch_scores[0],
43
+ decoded_chars[0],
44
+ threshold=patcher.threshold
45
+ )
46
+
47
+
48
+ if __name__ == "__main__":
49
+ typer.run(main)