Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
from functools import lru_cache | |
import copy | |
def get_cached_zeros(numel, device="cpu", dtype=torch.float32): | |
return torch.zeros(numel, device=device, dtype=dtype) | |
class StreamingODEWrapperForPrefix(nn.Module): | |
def __init__(self, net, x_mask, x_cond, use_cfg=False, use_cfg_rescale=True, cfg_init=1.0, cfg_scale=4.0, cfg_schedule="linear", cfg_token_id=0): | |
super(StreamingODEWrapperForPrefix, self).__init__() | |
self.net = net | |
self.x_mask = x_mask | |
self.x_cond = x_cond | |
assert use_cfg == False, "cfg is not supported in streaming detokenizer" | |
self.use_cfg = use_cfg | |
self.use_cfg_rescale = use_cfg_rescale | |
self.cfg_init = cfg_init | |
self.cfg_scale = cfg_scale | |
self.cfg_token_id = cfg_token_id | |
self.cfg_schedule = cfg_schedule | |
self.position_ids = None | |
self.seq_len = None | |
self.incremental_state = {} | |
self.kv_cache_tokens = 0 | |
self.cu_seqlens = None | |
self.cu_maxlen = None | |
self.cu_seqlens_k = None | |
self.cu_maxlen_k = None | |
self.previous_seqlen = None | |
def clear_all_states(self): | |
self.incremental_state = {} | |
self.kv_cache_tokens = 0 | |
self.cu_seqlens = None | |
self.cu_maxlen = None | |
self.cu_seqlens_k = None | |
self.cu_maxlen_k = None | |
self.previous_seqlen = None | |
def state_dict(self): | |
return { | |
"incremental_state": copy.deepcopy(self.incremental_state), | |
"kv_cache_tokens": copy.deepcopy(self.kv_cache_tokens), | |
"cu_seqlens": copy.deepcopy(self.cu_seqlens), | |
"cu_maxlen": copy.deepcopy(self.cu_maxlen), | |
"cu_seqlens_k": copy.deepcopy(self.cu_seqlens_k), | |
"cu_maxlen_k": copy.deepcopy(self.cu_maxlen_k), | |
"previous_seqlen": copy.deepcopy(self.previous_seqlen) | |
} | |
def load_state_dict(self, state_dict): | |
self.incremental_state = state_dict["incremental_state"] | |
self.kv_cache_tokens = state_dict["kv_cache_tokens"] | |
self.cu_seqlens = state_dict["cu_seqlens"] | |
self.cu_maxlen = state_dict["cu_maxlen"] | |
self.cu_seqlens_k = state_dict["cu_seqlens_k"] | |
self.cu_maxlen_k = state_dict["cu_maxlen_k"] | |
self.previous_seqlen = state_dict["previous_seqlen"] | |
def set_conditions(self, x_mask, x_cond, start_position_id, cache={}): | |
if not self.use_cfg: | |
self.x_mask = x_mask | |
self.x_cond = x_cond | |
else: | |
self.x_cond = torch.cat((x_cond, x_cond), dim=0) | |
self.x_mask = torch.cat((x_mask, x_mask), dim=0) | |
position_ids_cur = [i for i in range(start_position_id, self.x_cond.shape[1] + start_position_id)] | |
position_ids = torch.tensor([position_ids_cur]) | |
if not self.use_cfg: | |
self.position_ids = position_ids.to(self.x_cond.device).long() | |
self.seq_len = torch.Tensor([position_ids.shape[1]]).to(self.x_cond.device).long() | |
else: | |
self.position_ids = torch.cat((position_ids, position_ids), dim=0).to(self.x_cond.device).long() | |
self.seq_len = torch.Tensor([position_ids.shape[1], position_ids.shape[1]]).to(self.x_cond.device).long() | |
cu_seqlens = torch.cumsum(self.seq_len, dim=0) | |
self.cu_seqlens = torch.cat([torch.Tensor([0]).to(cu_seqlens.device), cu_seqlens], dim=0).int() | |
self.cu_maxlen = self.seq_len.cpu().max() | |
if self.cu_seqlens_k is None: | |
self.cu_seqlens_k = self.cu_seqlens | |
self.cu_maxlen_k = self.cu_maxlen | |
previous_seqlen = self.seq_len | |
else: | |
previous_seqlen_old = cache["previous_seqlen"] | |
previous_seqlen = previous_seqlen_old + self.seq_len | |
# calculate cu_seqlens_k | |
cu_seqlens_k = torch.cumsum(previous_seqlen, dim=0) | |
self.cu_seqlens_k = torch.cat([torch.Tensor([0]).to(cu_seqlens_k.device), cu_seqlens_k], dim=0).int() | |
self.cu_maxlen_k = previous_seqlen.cpu().max() | |
self.previous_seqlen = previous_seqlen | |
ret_cache = { | |
"previous_seqlen": previous_seqlen | |
} | |
return ret_cache | |
def update_incremental_state(self, reserve_kv_cache_tokens=0, max_kv_cache_tokens=900, condition_cache={"previous_seqlen"}): | |
assert reserve_kv_cache_tokens <= max_kv_cache_tokens, "reserve_kv_cache_tokens should be less than or equal to max_kv_cache_tokens" | |
for layer_idx, layer_cache in self.incremental_state.items(): | |
# update attention kv cache | |
layer_cache["attn_kvcache"]["prev_k"] = layer_cache["attn_kvcache"]["cur_k"] | |
layer_cache["attn_kvcache"]["prev_v"] = layer_cache["attn_kvcache"]["cur_v"] | |
self.kv_cache_tokens = layer_cache["attn_kvcache"]["prev_k"].shape[1] | |
if self.kv_cache_tokens > max_kv_cache_tokens: | |
# drop old tokens from reserve kv cache tokens to max_kv_cache_tokens | |
reserve_tokens_excludeprompt = max_kv_cache_tokens - reserve_kv_cache_tokens | |
if reserve_kv_cache_tokens == 0: | |
layer_cache["attn_kvcache"]["prev_k"] = layer_cache["attn_kvcache"]["prev_k"][:, -reserve_tokens_excludeprompt:] | |
layer_cache["attn_kvcache"]["prev_v"] = layer_cache["attn_kvcache"]["prev_v"][:, -reserve_tokens_excludeprompt:] | |
elif reserve_tokens_excludeprompt == 0: | |
layer_cache["attn_kvcache"]["prev_k"] = layer_cache["attn_kvcache"]["prev_k"][:, :reserve_kv_cache_tokens] | |
layer_cache["attn_kvcache"]["prev_v"] = layer_cache["attn_kvcache"]["prev_v"][:, :reserve_kv_cache_tokens] | |
else: | |
layer_cache["attn_kvcache"]["prev_k"] = torch.cat([ | |
layer_cache["attn_kvcache"]["prev_k"][:, :reserve_kv_cache_tokens], | |
layer_cache["attn_kvcache"]["prev_k"][:, -reserve_tokens_excludeprompt:] | |
], dim=1) | |
layer_cache["attn_kvcache"]["prev_v"] = torch.cat([ | |
layer_cache["attn_kvcache"]["prev_v"][:, :reserve_kv_cache_tokens], | |
layer_cache["attn_kvcache"]["prev_v"][:, -reserve_tokens_excludeprompt:] | |
], dim=1) | |
bsz = layer_cache["attn_kvcache"]["prev_k"].shape[0] | |
self.previous_seqlen = torch.Tensor([layer_cache["attn_kvcache"]["prev_k"].shape[1] for i in range(bsz)]).to(layer_cache["attn_kvcache"]["prev_k"].device).long() | |
condition_cache["previous_seqlen"] = self.previous_seqlen | |
self.kv_cache_tokens = layer_cache["attn_kvcache"]["prev_k"].shape[1] | |
# clear current cache | |
layer_cache["attn_kvcache"].pop("cur_k") | |
layer_cache["attn_kvcache"].pop("cur_v") | |
def forward(self, t, x, args=None): | |
# t = torch.tensor([t * 1000] * x.shape[0], device=x.device, dtype=x.dtype).long() | |
t = get_cached_zeros(x.shape[0], device=x.device, dtype=torch.long) + (t * 1000).long() | |
if self.use_cfg: | |
raise NotImplementedError("cfg is not supported in streaming detokenizer.") | |
else: | |
pred_noise = self.net(x=x, condition=self.x_cond, t=t, position_ids=self.position_ids, | |
cu_seqlens=self.cu_seqlens, cu_maxlen=self.cu_maxlen, | |
cu_seqlens_k=self.cu_seqlens_k, cu_maxlen_k=self.cu_maxlen_k, | |
incremental_state=self.incremental_state, nopadding=True, | |
mask=None, seq_len=None | |
) | |
return pred_noise | |