Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: | |
""" | |
Repeats key-value hidden states along the key-value head dimension. | |
Args: | |
hidden_states (torch.Tensor): Input tensor with shape either | |
(batch, num_key_value_heads, seqlen, head_dim) or | |
(num_layers, batch, num_key_value_heads, seqlen, head_dim). | |
n_rep (int): Number of repetitions for key-value heads. | |
Returns: | |
torch.Tensor: The repeated tensor with shape either | |
(batch, num_attention_heads, seqlen, head_dim) or | |
(num_layers, batch, num_attention_heads, seqlen, head_dim). | |
""" | |
if hidden_states.dim() == 4: # (batch, num_key_value_heads, seqlen, head_dim) | |
batch, num_key_value_heads, slen, head_dim = hidden_states.shape | |
if n_rep == 1: | |
return hidden_states | |
hidden_states = hidden_states.unsqueeze(2).expand(batch, num_key_value_heads, n_rep, slen, head_dim) | |
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) | |
elif hidden_states.dim() == 5: # (num_layers, batch, num_key_value_heads, seqlen, head_dim) | |
num_layers, batch, num_key_value_heads, slen, head_dim = hidden_states.shape | |
if n_rep == 1: | |
return hidden_states | |
hidden_states = hidden_states.unsqueeze(3).expand(num_layers, batch, num_key_value_heads, n_rep, slen, head_dim) | |
return hidden_states.reshape(num_layers, batch, num_key_value_heads * n_rep, slen, head_dim) | |
else: | |
raise ValueError("Input tensor must have 4 or 5 dimensions.") | |
import math | |
def calculate_tokens_suggest_compression_ratio(text, tokenizer, model): | |
""" | |
Tokenizes the text and returns: | |
- token_count: the number of tokens in the input text. | |
- suggestions: a list of 6 candidate compression ratios. | |
- tokenized: a dictionary containing 'input_ids' and 'attention_mask'. | |
The suggestions are chosen so that compressing the token count by these ratios | |
would (in the worst case) bring the count within the maximum allowed tokens (128k). | |
If the text already fits within the context (<= 128k tokens), | |
the default suggestions [1, 2, 4, 8, 16, 32] are returned. | |
If the text is too long, we generate six values in logarithmic space | |
between max(required_ratio, 1) and 32 (or a higher upper bound if needed). | |
""" | |
tokenized = tokenizer(text, return_tensors="pt", truncation=False) | |
token_ids = tokenized["input_ids"][0] | |
token_count = token_ids.size(0) | |
max_context = model.config.max_position_embeddings | |
if token_count <= max_context: | |
required_ratio = 1.0 | |
else: | |
required_ratio = token_count / max_context | |
if required_ratio <= 1.0: | |
suggestions = [1, 2, 4, 8, 16, 32] | |
else: | |
lower_bound = max(required_ratio, 1) | |
if required_ratio < 32: | |
upper_bound = 32 | |
else: | |
upper_bound = required_ratio * (32 / 1) | |
suggestions = [ | |
round(math.exp(math.log(lower_bound) + i * (math.log(upper_bound) - math.log(lower_bound)) / (6 - 1)), 2) | |
for i in range(6) | |
] | |
return token_count, suggestions, tokenized | |
def update_retrieval_context(token_count, compression_ratio): | |
retrieval_tokens = int(token_count / compression_ratio) | |
return f"Retrieval context tokens (after compression): {retrieval_tokens}" | |