Spaces:
Running
on
Zero
Running
on
Zero
from transformers import DynamicCache | |
import torch | |
import os | |
class FinchCache(DynamicCache): | |
def __init__(self) -> None: | |
super().__init__() | |
self.key_cache = [] | |
self.value_cache = [] | |
def _rotate_half(x): | |
x1 = x[..., : x.shape[-1] // 2] | |
x2 = x[..., x.shape[-1] // 2 :] | |
return torch.cat((-x2, x1), dim=-1) | |
def _apply_key_rotary_pos_emb(self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: | |
return (key_states * cos) + (self._rotate_half(key_states) * sin) | |
def _rerotate_cos_sin(x, inv_freq, important_pos_batch): | |
B, H, L = important_pos_batch.shape | |
device = important_pos_batch.device | |
device_type = x.device.type | |
dtype = x.dtype | |
idx = torch.arange(0, L, device=device) | |
idx = idx.unsqueeze(0) | |
inv_freq = inv_freq[None, None, :, None].float().expand(B, H, -1, 1) # (B, H, M, 1) | |
idx = idx[:, None, :].float().expand(B, H, L) # (B, H, L) | |
delta_pos = idx - important_pos_batch | |
delta_pos = delta_pos.unsqueeze(2) # (B, H, 1, L) | |
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" | |
with torch.autocast(device_type=device_type, enabled=False): | |
freqs = delta_pos.float() * inv_freq.float() | |
freqs = freqs.transpose(2, 3) | |
emb = torch.cat((freqs, freqs), dim=-1) | |
cos = emb.cos().contiguous() | |
sin = emb.sin().contiguous() | |
return cos.to(dtype=dtype), sin.to(dtype=dtype) | |
def gather_important_tokens(states, indices): | |
return torch.gather(states, 2, indices.unsqueeze(-1).expand(-1, -1, -1, states.size(3))).contiguous() | |
def compress_cache(self, layer_index, important_pos, inv_freq): | |
new_length = important_pos.size(2) | |
new_cos, new_sin = self._rerotate_cos_sin(self.key_cache[layer_index], inv_freq, important_pos) | |
gathered_keys = self.gather_important_tokens(self.key_cache[layer_index], important_pos).clone() | |
self.key_cache[layer_index] = self._apply_key_rotary_pos_emb(gathered_keys, new_cos, new_sin) | |
gathered_values = self.gather_important_tokens(self.value_cache[layer_index], important_pos).clone() | |
self.value_cache[layer_index] = gathered_values | |
self._seen_tokens = new_length | |
def save(self, path: str): | |
"""Save the cache to disk, moving tensors to CPU.""" | |
try: | |
os.makedirs(os.path.dirname(path), exist_ok=True) | |
torch.save( | |
{"key_cache": [k.cpu() for k in self.key_cache], "value_cache": [v.cpu() for v in self.value_cache]}, | |
path, | |
) | |
except Exception as e: | |
print(f"Error occurred while saving: {e}") | |
def load(cls, path: str, device: str = "cpu") -> "FinchCache": | |
"""Load the cache from disk and move tensors to the specified device.""" | |
data = torch.load(path, map_location=device) | |
cache = cls() | |
cache.key_cache = [k.to(device) for k in data["key_cache"]] | |
cache.value_cache = [v.to(device) for v in data["value_cache"]] | |
cache._seen_tokens = cache.value_cache[0].size(2) if cache.value_cache else 0 | |
return cache |