Spaces:
Running
on
Zero
Running
on
Zero
import math | |
import torch | |
from cache import FinchCache | |
from utils import repeat_kv | |
from transformers.models.llama.modeling_llama import rotate_half | |
import spaces | |
def get_compressed_kv_cache(model, sink_tokens, step_size, target_token_size, context_ids, context_attention_mask, question_ids, question_attention_mask): | |
device = model.device | |
dtype = model.dtype | |
sink_tokens = sink_tokens | |
num_chunks = step_size | |
context_ids = context_ids.to(device) | |
context_attention_mask = context_attention_mask.to(device) | |
question_ids = question_ids.to(device) | |
question_attention_mask = question_attention_mask.to(device) | |
question_len = question_ids.size(1) | |
total_len = context_ids.size(1) | |
max_context_tokens_allowed = model.config.max_position_embeddings - question_len | |
if total_len > max_context_tokens_allowed: | |
num_chunks = max(step_size, math.ceil(total_len / max_context_tokens_allowed)) | |
if total_len <= sink_tokens or num_chunks == 1: | |
# If the context is too short or only one chunk is desired, use the entire context. | |
context_ids_list = [context_ids] | |
context_attention_mask_list = [context_attention_mask] | |
else: | |
# Calculate how many tokens remain after the sink tokens. | |
remainder_len = total_len - sink_tokens | |
# Compute the base tokens per chunk and any leftover. | |
base = remainder_len // num_chunks | |
leftover = remainder_len % num_chunks | |
# Build a list of chunk sizes. | |
# First chunk gets the sink tokens plus base tokens. | |
chunk_sizes = [sink_tokens + base] | |
# Chunks 2 to num_chunks-1 get base tokens each. | |
for _ in range(num_chunks - 2): | |
chunk_sizes.append(base) | |
# The last chunk gets the remaining tokens (base + leftover). | |
if num_chunks > 1: | |
chunk_sizes.append(base + leftover) | |
# Now slice the context using the calculated sizes. | |
context_ids_list = [] | |
context_attention_mask_list = [] | |
offset = 0 | |
for size in chunk_sizes: | |
end = offset + size | |
context_ids_list.append(context_ids[:, offset:end]) | |
context_attention_mask_list.append(context_attention_mask[:, offset:end]) | |
offset = end | |
# (Optional) Continue with the rest of your processing… | |
len_rest = max(total_len - sink_tokens, 1) | |
compression_factor = len_rest // target_token_size | |
if compression_factor < 1: | |
compression_factor = 1 | |
tokenized_doc_chunks = [] | |
for ids_chunk, mask_chunk in zip(context_ids_list, context_attention_mask_list): | |
tokenized_doc_chunks.append({"input_ids": ids_chunk, "attention_mask": mask_chunk}) | |
print("Number of chunks: ", len(tokenized_doc_chunks)) | |
rotary_emb = model.model.rotary_emb.to(device) | |
inv_freq = rotary_emb.inv_freq | |
batch_size = question_ids.size(0) | |
ones_mask = torch.ones(batch_size, 1, dtype=question_attention_mask.dtype, device=device) | |
cache = FinchCache() | |
past_cache_len = 0 | |
past_attention_mask = torch.zeros(batch_size, 0, dtype=question_attention_mask.dtype, device=device) | |
num_chunks = len(tokenized_doc_chunks) | |
# Prepare a shared dictionary for hook outputs. | |
query_context_matrices = {} | |
# Define a hook function that uses a per-chunk offset stored on self. | |
def query_hook_fn(module, input, output): | |
layer_idx = getattr(module, "layer_idx", None) | |
if layer_idx is not None: | |
query_states = output.detach() | |
bsz, seq_len, hidden_dim = query_states.size() | |
num_query_heads = module.num_query_heads | |
head_dim = hidden_dim // num_query_heads | |
query_states = ( | |
query_states.view(bsz, seq_len, num_query_heads, head_dim) | |
.transpose(1, 2) | |
.contiguous() | |
) | |
# Use self._current_chunk_offset to select only the new tokens. | |
query_context_matrices[layer_idx] = query_states[:, :, _current_chunk_offset:, :].clone() | |
# Pre-register hooks for all layers only once. | |
hooks = [] | |
for i, layer in enumerate(model.model.layers): | |
layer.self_attn.q_proj.layer_idx = i # For tracking. | |
layer.self_attn.q_proj.num_query_heads = layer.self_attn.config.num_attention_heads | |
hook = layer.self_attn.q_proj.register_forward_hook(query_hook_fn) | |
hooks.append(hook) | |
# Process each document chunk sequentially. | |
for j, tokenized_doc_chunk in enumerate(tokenized_doc_chunks): | |
current_seq_length = tokenized_doc_chunk["input_ids"].size(1) | |
# Save the offset in an attribute the hook can access. | |
_current_chunk_offset = current_seq_length | |
# Clear the dictionary from any previous chunk. | |
query_context_matrices.clear() | |
# These chunks are already on the device. | |
chunk_input_ids = tokenized_doc_chunk["input_ids"].contiguous() | |
chunk_attention_mask = tokenized_doc_chunk["attention_mask"].contiguous() | |
segment_attention_mask = torch.cat( | |
[past_attention_mask, chunk_attention_mask, ones_mask], dim=-1 | |
).contiguous() | |
current_input_ids = torch.cat([chunk_input_ids, question_ids], dim=-1).contiguous() | |
current_attention_mask = torch.cat([segment_attention_mask, question_attention_mask], dim=-1).contiguous() | |
past_seen_tokens = cache.get_seq_length() if cache is not None else 0 | |
cache_position = torch.arange( | |
past_seen_tokens + chunk_input_ids.shape[1], | |
past_seen_tokens + current_input_ids.shape[1], | |
device=device | |
) | |
causal_mask = model.model._prepare_4d_causal_attention_mask_with_cache_position( | |
current_attention_mask, | |
sequence_length=question_ids.size(1), | |
target_length=current_attention_mask.size(-1), | |
dtype=dtype, | |
device=device, | |
cache_position=cache_position, | |
batch_size=current_input_ids.size(0), | |
).contiguous() | |
with torch.no_grad(): | |
outputs = model.model( | |
input_ids=current_input_ids, | |
use_cache=True, | |
past_key_values=cache, | |
) | |
cache = outputs.past_key_values | |
len_question = question_ids.size(1) | |
# Now, for each transformer layer, update the cache using the query/key attention. | |
for layer_idx in range(len(model.model.layers)): | |
key_matrix = cache.key_cache[layer_idx] | |
query_matrix = query_context_matrices[layer_idx] | |
layer_cache_pos = torch.arange( | |
past_cache_len + current_seq_length, | |
past_cache_len + current_seq_length + len_question, | |
device=device | |
) | |
position_ids = layer_cache_pos.unsqueeze(0) | |
cos, sin = rotary_emb(query_matrix, position_ids) | |
cos = cos.unsqueeze(1) | |
sin = sin.unsqueeze(1) | |
query_matrix = (query_matrix * cos) + (rotate_half(query_matrix) * sin) | |
num_repeats = model.config.num_attention_heads // model.config.num_key_value_heads | |
key_matrix = repeat_kv(key_matrix, num_repeats) | |
scaling = math.sqrt(model.config.head_dim) | |
attention_matrix = torch.matmul(query_matrix, key_matrix.transpose(2, 3)) / scaling | |
causal_mask_sliced = causal_mask[:, :, :, : key_matrix.shape[-2]] | |
attention_matrix = attention_matrix + causal_mask_sliced | |
attention_matrix = torch.nn.functional.softmax(attention_matrix, dim=-1, dtype=torch.float32).to(query_matrix.dtype) | |
# Normalization | |
tol = 1e-8 | |
binary_mask = (torch.abs(causal_mask_sliced.to(torch.float32)) < tol).to(torch.float32) | |
non_zero_counts = binary_mask.sum(dim=3, keepdim=True) | |
non_zero_counts = torch.clamp_min(non_zero_counts, 1.0).to(attention_matrix.dtype) | |
attention_matrix = attention_matrix / non_zero_counts | |
if j != num_chunks - 1: | |
attention_matrix = attention_matrix[:, :, :, : past_cache_len + current_seq_length].clone().contiguous() | |
else: | |
attention_matrix = attention_matrix[:, :, :, : past_cache_len + current_seq_length + len_question].clone().contiguous() | |
attention_matrix = torch.sum(attention_matrix, dim=-2) | |
attention_matrix = attention_matrix.view( | |
attention_matrix.size(0), model.config.num_key_value_heads, num_repeats, -1 | |
).sum(dim=2) | |
full_context_size = attention_matrix.size(-1) | |
attention_matrix[..., :sink_tokens] = float("inf") | |
if j == num_chunks - 1: | |
attention_matrix[..., -len_question:] = float("inf") | |
if j == 0: | |
k = int(sink_tokens + (max(0, current_seq_length - sink_tokens) // compression_factor)) | |
k = min(k + past_cache_len, full_context_size) | |
elif j < num_chunks - 1: | |
to_keep_new = int(current_seq_length // compression_factor) | |
k = min(past_cache_len + to_keep_new, full_context_size) | |
else: | |
desired_final = sink_tokens + target_token_size + len_question# TODO remember to include the question tokens | |
k = desired_final if full_context_size >= desired_final else full_context_size | |
k = max(k, sink_tokens) | |
selected_indices = torch.topk(attention_matrix, k, dim=-1).indices | |
selected_indices, _ = torch.sort(selected_indices, dim=-1) | |
cache.compress_cache(layer_idx, selected_indices, inv_freq) | |
past_cache_len = cache._seen_tokens | |
past_attention_mask = torch.ones(1, past_cache_len, device=device) | |
# Remove the hooks once after all chunks are processed. | |
for hook in hooks: | |
hook.remove() | |
return cache | |