|
from typing import Optional |
|
import torch |
|
from .attention import HiDreamAttention |
|
|
|
try: |
|
from flash_attn_interface import flash_attn_func |
|
USE_FLASH_ATTN3 = True |
|
except: |
|
from flash_attn import flash_attn_func |
|
USE_FLASH_ATTN3 = False |
|
|
|
|
|
def apply_rope(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
|
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) |
|
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) |
|
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] |
|
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] |
|
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) |
|
|
|
def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor): |
|
if USE_FLASH_ATTN3: |
|
hidden_states = flash_attn_func(query, key, value, causal=False, deterministic=False)[0] |
|
else: |
|
hidden_states = flash_attn_func(query, key, value, dropout_p=0., causal=False) |
|
hidden_states = hidden_states.flatten(-2) |
|
hidden_states = hidden_states.to(query.dtype) |
|
return hidden_states |
|
|
|
class HiDreamAttnProcessor_flashattn: |
|
"""Attention processor used typically in processing the SD3-like self-attention projections.""" |
|
|
|
def __call__( |
|
self, |
|
attn: HiDreamAttention, |
|
image_tokens: torch.FloatTensor, |
|
image_tokens_masks: Optional[torch.FloatTensor] = None, |
|
text_tokens: Optional[torch.FloatTensor] = None, |
|
rope: torch.FloatTensor = None, |
|
*args, |
|
**kwargs, |
|
) -> torch.FloatTensor: |
|
dtype = image_tokens.dtype |
|
batch_size = image_tokens.shape[0] |
|
|
|
query_i = attn.q_rms_norm(attn.to_q(image_tokens)).to(dtype=dtype) |
|
key_i = attn.k_rms_norm(attn.to_k(image_tokens)).to(dtype=dtype) |
|
value_i = attn.to_v(image_tokens) |
|
|
|
inner_dim = key_i.shape[-1] |
|
head_dim = inner_dim // attn.heads |
|
|
|
query_i = query_i.view(batch_size, -1, attn.heads, head_dim) |
|
key_i = key_i.view(batch_size, -1, attn.heads, head_dim) |
|
value_i = value_i.view(batch_size, -1, attn.heads, head_dim) |
|
if image_tokens_masks is not None: |
|
key_i = key_i * image_tokens_masks.view(batch_size, -1, 1, 1) |
|
|
|
if not attn.single: |
|
query_t = attn.q_rms_norm_t(attn.to_q_t(text_tokens)).to(dtype=dtype) |
|
key_t = attn.k_rms_norm_t(attn.to_k_t(text_tokens)).to(dtype=dtype) |
|
value_t = attn.to_v_t(text_tokens) |
|
|
|
query_t = query_t.view(batch_size, -1, attn.heads, head_dim) |
|
key_t = key_t.view(batch_size, -1, attn.heads, head_dim) |
|
value_t = value_t.view(batch_size, -1, attn.heads, head_dim) |
|
|
|
num_image_tokens = query_i.shape[1] |
|
num_text_tokens = query_t.shape[1] |
|
query = torch.cat([query_i, query_t], dim=1) |
|
key = torch.cat([key_i, key_t], dim=1) |
|
value = torch.cat([value_i, value_t], dim=1) |
|
else: |
|
query = query_i |
|
key = key_i |
|
value = value_i |
|
|
|
if query.shape[-1] == rope.shape[-3] * 2: |
|
query, key = apply_rope(query, key, rope) |
|
else: |
|
query_1, query_2 = query.chunk(2, dim=-1) |
|
key_1, key_2 = key.chunk(2, dim=-1) |
|
query_1, key_1 = apply_rope(query_1, key_1, rope) |
|
query = torch.cat([query_1, query_2], dim=-1) |
|
key = torch.cat([key_1, key_2], dim=-1) |
|
|
|
hidden_states = attention(query, key, value) |
|
|
|
if not attn.single: |
|
hidden_states_i, hidden_states_t = torch.split(hidden_states, [num_image_tokens, num_text_tokens], dim=1) |
|
hidden_states_i = attn.to_out(hidden_states_i) |
|
hidden_states_t = attn.to_out_t(hidden_states_t) |
|
return hidden_states_i, hidden_states_t |
|
else: |
|
hidden_states = attn.to_out(hidden_states) |
|
return hidden_states |