Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,809 Bytes
a3e05e8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
import torch
import torch.nn as nn
from functools import lru_cache
import copy
@lru_cache(maxsize=1)
def get_cached_zeros(numel, device="cpu", dtype=torch.float32):
return torch.zeros(numel, device=device, dtype=dtype)
class StreamingODEWrapperForPrefix(nn.Module):
def __init__(self, net, x_mask, x_cond, use_cfg=False, use_cfg_rescale=True, cfg_init=1.0, cfg_scale=4.0, cfg_schedule="linear", cfg_token_id=0):
super(StreamingODEWrapperForPrefix, self).__init__()
self.net = net
self.x_mask = x_mask
self.x_cond = x_cond
assert use_cfg == False, "cfg is not supported in streaming detokenizer"
self.use_cfg = use_cfg
self.use_cfg_rescale = use_cfg_rescale
self.cfg_init = cfg_init
self.cfg_scale = cfg_scale
self.cfg_token_id = cfg_token_id
self.cfg_schedule = cfg_schedule
self.position_ids = None
self.seq_len = None
self.incremental_state = {}
self.kv_cache_tokens = 0
self.cu_seqlens = None
self.cu_maxlen = None
self.cu_seqlens_k = None
self.cu_maxlen_k = None
self.previous_seqlen = None
def clear_all_states(self):
self.incremental_state = {}
self.kv_cache_tokens = 0
self.cu_seqlens = None
self.cu_maxlen = None
self.cu_seqlens_k = None
self.cu_maxlen_k = None
self.previous_seqlen = None
def state_dict(self):
return {
"incremental_state": copy.deepcopy(self.incremental_state),
"kv_cache_tokens": copy.deepcopy(self.kv_cache_tokens),
"cu_seqlens": copy.deepcopy(self.cu_seqlens),
"cu_maxlen": copy.deepcopy(self.cu_maxlen),
"cu_seqlens_k": copy.deepcopy(self.cu_seqlens_k),
"cu_maxlen_k": copy.deepcopy(self.cu_maxlen_k),
"previous_seqlen": copy.deepcopy(self.previous_seqlen)
}
def load_state_dict(self, state_dict):
self.incremental_state = state_dict["incremental_state"]
self.kv_cache_tokens = state_dict["kv_cache_tokens"]
self.cu_seqlens = state_dict["cu_seqlens"]
self.cu_maxlen = state_dict["cu_maxlen"]
self.cu_seqlens_k = state_dict["cu_seqlens_k"]
self.cu_maxlen_k = state_dict["cu_maxlen_k"]
self.previous_seqlen = state_dict["previous_seqlen"]
def set_conditions(self, x_mask, x_cond, start_position_id, cache={}):
if not self.use_cfg:
self.x_mask = x_mask
self.x_cond = x_cond
else:
self.x_cond = torch.cat((x_cond, x_cond), dim=0)
self.x_mask = torch.cat((x_mask, x_mask), dim=0)
position_ids_cur = [i for i in range(start_position_id, self.x_cond.shape[1] + start_position_id)]
position_ids = torch.tensor([position_ids_cur])
if not self.use_cfg:
self.position_ids = position_ids.to(self.x_cond.device).long()
self.seq_len = torch.Tensor([position_ids.shape[1]]).to(self.x_cond.device).long()
else:
self.position_ids = torch.cat((position_ids, position_ids), dim=0).to(self.x_cond.device).long()
self.seq_len = torch.Tensor([position_ids.shape[1], position_ids.shape[1]]).to(self.x_cond.device).long()
cu_seqlens = torch.cumsum(self.seq_len, dim=0)
self.cu_seqlens = torch.cat([torch.Tensor([0]).to(cu_seqlens.device), cu_seqlens], dim=0).int()
self.cu_maxlen = self.seq_len.cpu().max()
if self.cu_seqlens_k is None:
self.cu_seqlens_k = self.cu_seqlens
self.cu_maxlen_k = self.cu_maxlen
previous_seqlen = self.seq_len
else:
previous_seqlen_old = cache["previous_seqlen"]
previous_seqlen = previous_seqlen_old + self.seq_len
# calculate cu_seqlens_k
cu_seqlens_k = torch.cumsum(previous_seqlen, dim=0)
self.cu_seqlens_k = torch.cat([torch.Tensor([0]).to(cu_seqlens_k.device), cu_seqlens_k], dim=0).int()
self.cu_maxlen_k = previous_seqlen.cpu().max()
self.previous_seqlen = previous_seqlen
ret_cache = {
"previous_seqlen": previous_seqlen
}
return ret_cache
def update_incremental_state(self, reserve_kv_cache_tokens=0, max_kv_cache_tokens=900, condition_cache={"previous_seqlen"}):
assert reserve_kv_cache_tokens <= max_kv_cache_tokens, "reserve_kv_cache_tokens should be less than or equal to max_kv_cache_tokens"
for layer_idx, layer_cache in self.incremental_state.items():
# update attention kv cache
layer_cache["attn_kvcache"]["prev_k"] = layer_cache["attn_kvcache"]["cur_k"]
layer_cache["attn_kvcache"]["prev_v"] = layer_cache["attn_kvcache"]["cur_v"]
self.kv_cache_tokens = layer_cache["attn_kvcache"]["prev_k"].shape[1]
if self.kv_cache_tokens > max_kv_cache_tokens:
# drop old tokens from reserve kv cache tokens to max_kv_cache_tokens
reserve_tokens_excludeprompt = max_kv_cache_tokens - reserve_kv_cache_tokens
if reserve_kv_cache_tokens == 0:
layer_cache["attn_kvcache"]["prev_k"] = layer_cache["attn_kvcache"]["prev_k"][:, -reserve_tokens_excludeprompt:]
layer_cache["attn_kvcache"]["prev_v"] = layer_cache["attn_kvcache"]["prev_v"][:, -reserve_tokens_excludeprompt:]
elif reserve_tokens_excludeprompt == 0:
layer_cache["attn_kvcache"]["prev_k"] = layer_cache["attn_kvcache"]["prev_k"][:, :reserve_kv_cache_tokens]
layer_cache["attn_kvcache"]["prev_v"] = layer_cache["attn_kvcache"]["prev_v"][:, :reserve_kv_cache_tokens]
else:
layer_cache["attn_kvcache"]["prev_k"] = torch.cat([
layer_cache["attn_kvcache"]["prev_k"][:, :reserve_kv_cache_tokens],
layer_cache["attn_kvcache"]["prev_k"][:, -reserve_tokens_excludeprompt:]
], dim=1)
layer_cache["attn_kvcache"]["prev_v"] = torch.cat([
layer_cache["attn_kvcache"]["prev_v"][:, :reserve_kv_cache_tokens],
layer_cache["attn_kvcache"]["prev_v"][:, -reserve_tokens_excludeprompt:]
], dim=1)
bsz = layer_cache["attn_kvcache"]["prev_k"].shape[0]
self.previous_seqlen = torch.Tensor([layer_cache["attn_kvcache"]["prev_k"].shape[1] for i in range(bsz)]).to(layer_cache["attn_kvcache"]["prev_k"].device).long()
condition_cache["previous_seqlen"] = self.previous_seqlen
self.kv_cache_tokens = layer_cache["attn_kvcache"]["prev_k"].shape[1]
# clear current cache
layer_cache["attn_kvcache"].pop("cur_k")
layer_cache["attn_kvcache"].pop("cur_v")
def forward(self, t, x, args=None):
# t = torch.tensor([t * 1000] * x.shape[0], device=x.device, dtype=x.dtype).long()
t = get_cached_zeros(x.shape[0], device=x.device, dtype=torch.long) + (t * 1000).long()
if self.use_cfg:
raise NotImplementedError("cfg is not supported in streaming detokenizer.")
else:
pred_noise = self.net(x=x, condition=self.x_cond, t=t, position_ids=self.position_ids,
cu_seqlens=self.cu_seqlens, cu_maxlen=self.cu_maxlen,
cu_seqlens_k=self.cu_seqlens_k, cu_maxlen_k=self.cu_maxlen_k,
incremental_state=self.incremental_state, nopadding=True,
mask=None, seq_len=None
)
return pred_noise
|