Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
41ea791
1
Parent(s):
1b67cbe
Visualisation working on CPU via CUDA_VISIBLE_DEVICE=-1 python demo_patcher.py 'Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin.'
Browse files
bytelatent/entropy_model.py
CHANGED
@@ -27,8 +27,8 @@ def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cp
|
|
27 |
max_seqlen=model_params["max_seqlen"],
|
28 |
ffn_dim_multiplier=model_params["ffn_dim_multiplier"],
|
29 |
vocab_size=model_params["vocab_size"],
|
30 |
-
attn_bias_type="
|
31 |
-
attn_impl="
|
32 |
sliding_window=512,
|
33 |
)
|
34 |
)
|
|
|
27 |
max_seqlen=model_params["max_seqlen"],
|
28 |
ffn_dim_multiplier=model_params["ffn_dim_multiplier"],
|
29 |
vocab_size=model_params["vocab_size"],
|
30 |
+
attn_bias_type="causal",
|
31 |
+
attn_impl="sdpa",
|
32 |
sliding_window=512,
|
33 |
)
|
34 |
)
|
bytelatent/generate_patcher.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
from typing import Tuple
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from bytelatent.args import EvalArgs
|
8 |
+
from bytelatent.config_parser import parse_args_to_pydantic_model
|
9 |
+
from bytelatent.data.file_util import get_fs
|
10 |
+
from bytelatent.data.patcher import Patcher
|
11 |
+
from bytelatent.distributed import (
|
12 |
+
DistributedArgs,
|
13 |
+
dist_max,
|
14 |
+
dist_min,
|
15 |
+
dist_sum,
|
16 |
+
get_device_mesh,
|
17 |
+
setup_torch_distributed,
|
18 |
+
)
|
19 |
+
from bytelatent.generate import load_consolidated_model_and_tokenizer
|
20 |
+
from bytelatent.model.blt import ByteLatentTransformer
|
21 |
+
from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
|
22 |
+
|
23 |
+
logger = logging.getLogger()
|
24 |
+
|
25 |
+
|
26 |
+
def get_max_length(input_tokens: list[list[int]] | None) -> int:
|
27 |
+
# reduce max length prompt over all processes to have an equal number of call on each process with fsdp
|
28 |
+
if input_tokens is None:
|
29 |
+
max_length = 0
|
30 |
+
else:
|
31 |
+
max_length = max([len(t) for t in input_tokens])
|
32 |
+
if torch.distributed.is_initialized():
|
33 |
+
max_length = int(dist_max(max_length))
|
34 |
+
return max_length
|
35 |
+
|
36 |
+
|
37 |
+
def get_min_length(input_tokens: list[list[int]] | None) -> int:
|
38 |
+
# reduce min length prompt over all processes to have an equal number of call on each process with fsdp
|
39 |
+
if input_tokens is None:
|
40 |
+
# TODO: Double check this change from int(1e9) is correct
|
41 |
+
min_length = 0
|
42 |
+
else:
|
43 |
+
min_length = min([len(t) for t in input_tokens])
|
44 |
+
if torch.distributed.is_initialized():
|
45 |
+
min_length = int(dist_min(min_length))
|
46 |
+
return min_length
|
47 |
+
|
48 |
+
|
49 |
+
def get_generation_range(
|
50 |
+
prompt_tokens: list[list[int]] | None, max_gen_len: int
|
51 |
+
) -> tuple[int, int]:
|
52 |
+
batch_min_prompt_length = get_min_length(prompt_tokens)
|
53 |
+
batch_max_prompt_length = get_max_length(prompt_tokens)
|
54 |
+
return batch_min_prompt_length, batch_max_prompt_length + max_gen_len
|
55 |
+
|
56 |
+
|
57 |
+
def sample_top_k(probs, k):
|
58 |
+
topk_value, _ = torch.topk(probs, k) # batch_sz x topk
|
59 |
+
min_value_top_k = topk_value[:, [-1]]
|
60 |
+
probs[probs < min_value_top_k] = 0.0
|
61 |
+
probs.div_(probs.sum(dim=-1, keepdim=True))
|
62 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
63 |
+
return next_token
|
64 |
+
|
65 |
+
|
66 |
+
def sample_top_p(probs, p):
|
67 |
+
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
68 |
+
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
69 |
+
mask = probs_sum - probs_sort > p
|
70 |
+
probs_sort[mask] = 0.0
|
71 |
+
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
72 |
+
next_token = torch.multinomial(probs_sort, num_samples=1)
|
73 |
+
next_token = torch.gather(probs_idx, -1, next_token)
|
74 |
+
return next_token
|
75 |
+
|
76 |
+
|
77 |
+
@torch.inference_mode()
|
78 |
+
def patcher_nocache(
|
79 |
+
prompts: list[str] | None,
|
80 |
+
*,
|
81 |
+
tokenizer: BltTokenizer,
|
82 |
+
patcher: Patcher,
|
83 |
+
max_prompt_len: int = 256,
|
84 |
+
max_gen_len: int = 256,
|
85 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None:
|
86 |
+
assert (
|
87 |
+
patcher.realtime_patching
|
88 |
+
), "generate_nocache requires patcher.realtime_patching=True"
|
89 |
+
if prompts is None:
|
90 |
+
prompt_tokens = None
|
91 |
+
n_truncated_prompts = 0
|
92 |
+
total_truncated_prompts = 0
|
93 |
+
else:
|
94 |
+
prompt_tokens = [tokenizer.encode(t, add_eos=False) for t in prompts]
|
95 |
+
n_truncated_prompts = sum([max_prompt_len < len(t) for t in prompt_tokens])
|
96 |
+
if torch.distributed.is_initialized():
|
97 |
+
total_truncated_prompts = dist_sum(n_truncated_prompts)
|
98 |
+
else:
|
99 |
+
total_truncated_prompts = n_truncated_prompts
|
100 |
+
|
101 |
+
# Truncation
|
102 |
+
prompt_tokens = [
|
103 |
+
t if len(t) < max_prompt_len else t[len(t) - max_prompt_len :]
|
104 |
+
for t in prompt_tokens
|
105 |
+
]
|
106 |
+
|
107 |
+
if total_truncated_prompts > 0:
|
108 |
+
logger.info(
|
109 |
+
f"There are {total_truncated_prompts} prompts that are truncated on the left, "
|
110 |
+
f"length greater than max_prompt_len = {max_prompt_len}, "
|
111 |
+
f"maximum prompt length = {get_max_length(prompt_tokens)} across all gpus."
|
112 |
+
)
|
113 |
+
|
114 |
+
if prompt_tokens is None:
|
115 |
+
prompt_tokens = [[tokenizer.bos_id] for _ in range(end_pos)]
|
116 |
+
|
117 |
+
start_pos, end_pos = get_generation_range(prompt_tokens, max_gen_len)
|
118 |
+
batch_size = len(prompt_tokens)
|
119 |
+
tokens = torch.full((batch_size, end_pos), tokenizer.pad_id).to(patcher.device).long()
|
120 |
+
|
121 |
+
# Copy inputs to tensor for generated tokens
|
122 |
+
for i, row_tokens in enumerate(prompt_tokens):
|
123 |
+
tokens[i, : len(row_tokens)] = torch.tensor(row_tokens).long()
|
124 |
+
|
125 |
+
for i, curr_pos in enumerate(range(start_pos, end_pos)):
|
126 |
+
current_tokens = tokens[:, :curr_pos]
|
127 |
+
patch_lengths, scores = patcher.patch(current_tokens, include_next_token=False)
|
128 |
+
# insta return since not generating t+1
|
129 |
+
return patch_lengths, scores, current_tokens
|
130 |
+
return None
|
bytelatent/plotting/entropy_figure_via_matplot_lib.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
def plot_entropies(patch_lengths: torch.Tensor, scores: torch.Tensor, chars: str, threshold: float):
|
8 |
+
patch_lengths_np = patch_lengths.cpu().numpy().flatten()
|
9 |
+
scores_np = scores.cpu().float().numpy().flatten()
|
10 |
+
chars = chars.replace(" ", "_")
|
11 |
+
tokens_np = np.array([char for char in "<"+chars])
|
12 |
+
|
13 |
+
if len(scores_np) != len(tokens_np):
|
14 |
+
raise ValueError("Length of scores and tokens tensors must be the same.")
|
15 |
+
if patch_lengths_np.sum() != len(tokens_np):
|
16 |
+
raise ValueError(f"Sum of patch_lengths ({patch_lengths_np.sum()}) "
|
17 |
+
f"does not match the length of tokens/scores ({len(tokens_np)}).")
|
18 |
+
|
19 |
+
|
20 |
+
x_indices = np.arange(len(tokens_np))
|
21 |
+
|
22 |
+
# Calculate cumulative sums of patch lengths for vertical line positions
|
23 |
+
# These indicate the *end* index of each patch
|
24 |
+
patch_boundaries = np.cumsum(patch_lengths_np)
|
25 |
+
|
26 |
+
# --- Plotting ---
|
27 |
+
fig, ax = plt.subplots(figsize=(15, 5)) # Adjust figure size as needed
|
28 |
+
|
29 |
+
# Plot the scores as a blue line with markers
|
30 |
+
ax.plot(x_indices, scores_np, marker='.', linestyle='-', color='steelblue', label='Scores')
|
31 |
+
|
32 |
+
# Plot the vertical dotted lines at the patch boundaries
|
33 |
+
# We plot a line *after* each patch, so at index `boundary - 1 + 0.5`
|
34 |
+
# We skip the last boundary as it's the end of the data
|
35 |
+
for boundary in patch_boundaries[:-1]:
|
36 |
+
ax.axvline(x=boundary, color='grey', linestyle='--', linewidth=1)
|
37 |
+
|
38 |
+
ax.axhline(y=threshold, color='red', linestyle='--', linewidth=1)
|
39 |
+
|
40 |
+
# Set x-axis ticks and labels
|
41 |
+
ax.set_xticks(x_indices)
|
42 |
+
ax.set_xticklabels(tokens_np, rotation=0, fontsize=8) # Rotate labels for better readability
|
43 |
+
|
44 |
+
# Set labels for axes
|
45 |
+
# Using the Y-axis label from the example image
|
46 |
+
ax.set_ylabel("Entropy of Next Byte", fontsize=12)
|
47 |
+
ax.set_xlabel("Tokens", fontsize=12)
|
48 |
+
|
49 |
+
# Set y-axis limits (optional, but often good practice)
|
50 |
+
ax.set_ylim(bottom=0) # Start y-axis at 0 like the example
|
51 |
+
ax.set_xlim(left = x_indices[0]-1.0, right = x_indices[-1]+1.0) # Add padding to x-axis
|
52 |
+
|
53 |
+
# Add grid lines (optional)
|
54 |
+
# ax.grid(True, axis='y', linestyle=':', color='lightgrey')
|
55 |
+
|
56 |
+
# Remove the top and right spines for cleaner look (optional)
|
57 |
+
ax.spines['top'].set_visible(False)
|
58 |
+
ax.spines['right'].set_visible(False)
|
59 |
+
|
60 |
+
# Adjust layout and display the plot
|
61 |
+
plt.tight_layout()
|
62 |
+
output_filename = "token_score_plot.png"
|
63 |
+
fig.savefig(output_filename, dpi=300, bbox_inches='tight') # Save the figure
|
64 |
+
print(f"Plot saved to {os.path.abspath(output_filename)}") # Print confirmation with full path
|
65 |
+
|
66 |
+
# Close the plot figure to free memory (good practice)
|
67 |
+
plt.close(fig)
|
demo_patcher.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import typer
|
5 |
+
|
6 |
+
from bytelatent.data.file_util import get_fs
|
7 |
+
from bytelatent.distributed import DistributedArgs, setup_torch_distributed
|
8 |
+
from bytelatent.generate_patcher import patcher_nocache
|
9 |
+
from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
|
10 |
+
from bytelatent.plotting.entropy_figure_via_matplot_lib import plot_entropies
|
11 |
+
|
12 |
+
|
13 |
+
def main(prompt: str, model_name: str = "blt-1b"):
|
14 |
+
from bytelatent.args import TrainArgs
|
15 |
+
consolidated_path = os.path.join("hf-weights", model_name)
|
16 |
+
train_args_path = os.path.join(consolidated_path, "params.json")
|
17 |
+
fs = get_fs(train_args_path)
|
18 |
+
train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path))
|
19 |
+
|
20 |
+
tokenizer = train_args.data.tokenizer_args.build()
|
21 |
+
assert isinstance(tokenizer, BltTokenizer)
|
22 |
+
patcher_args = train_args.data.patcher_args.model_copy(deep=True)
|
23 |
+
patcher_args.realtime_patching = True
|
24 |
+
# NOTE: CPU currently unsupported due to reliance of xformers
|
25 |
+
patcher_args.patching_device = "cpu"
|
26 |
+
patcher_args.device = "cpu"
|
27 |
+
print("Loading entropy model and patcher")
|
28 |
+
patcher_args.entropy_model_checkpoint_dir = os.path.join(
|
29 |
+
consolidated_path, "entropy_model"
|
30 |
+
)
|
31 |
+
patcher = patcher_args.build()
|
32 |
+
prompts = [prompt]
|
33 |
+
results = patcher_nocache(
|
34 |
+
prompts, tokenizer=tokenizer, patcher=patcher
|
35 |
+
)
|
36 |
+
if not results:
|
37 |
+
raise Exception("Ruh roh")
|
38 |
+
batch_patch_lengths, batch_scores, batch_tokens = results
|
39 |
+
decoded_chars = [tokenizer.decode(row_tokens.tolist()) for row_tokens in batch_tokens]
|
40 |
+
plot_entropies(
|
41 |
+
batch_patch_lengths[0],
|
42 |
+
batch_scores[0],
|
43 |
+
decoded_chars[0],
|
44 |
+
threshold=patcher.threshold
|
45 |
+
)
|
46 |
+
|
47 |
+
|
48 |
+
if __name__ == "__main__":
|
49 |
+
typer.run(main)
|