from functools import partial import sys import torch import torch.nn as nn import torch.nn.functional as F import weakref from typing import Any, Dict, List, Optional, Tuple, Union, TYPE_CHECKING from diffusers.models.transformers.transformer_flux import FluxTransformerBlock from transformers import AutoModel, AutoTokenizer, Qwen2Model, LlamaModel, Qwen2Tokenizer, LlamaTokenizer from toolkit import train_tools from toolkit.prompt_utils import PromptEmbeds from diffusers import Transformer2DModel from toolkit.dequantize import patch_dequantization_on_save if TYPE_CHECKING: from toolkit.stable_diffusion_model import StableDiffusion, PixArtSigmaPipeline from toolkit.custom_adapter import CustomAdapter LLM = Union[Qwen2Model, LlamaModel] LLMTokenizer = Union[Qwen2Tokenizer, LlamaTokenizer] def new_context_embedder_forward(self, x): if self._adapter_ref().is_active: x = self._context_embedder_ref()(x) else: x = self._orig_forward(x) return x def new_block_forward( self: FluxTransformerBlock, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: if self._adapter_ref().is_active: return self._new_block_ref()(hidden_states, encoder_hidden_states, temb, image_rotary_emb, joint_attention_kwargs) else: return self._orig_forward(hidden_states, encoder_hidden_states, temb, image_rotary_emb, joint_attention_kwargs) class LLMAdapter(torch.nn.Module): def __init__( self, adapter: 'CustomAdapter', sd: 'StableDiffusion', llm: LLM, tokenizer: LLMTokenizer, num_cloned_blocks: int = 0, ): super(LLMAdapter, self).__init__() self.adapter_ref: weakref.ref = weakref.ref(adapter) self.sd_ref: weakref.ref = weakref.ref(sd) self.llm_ref: weakref.ref = weakref.ref(llm) self.tokenizer_ref: weakref.ref = weakref.ref(tokenizer) self.num_cloned_blocks = num_cloned_blocks self.apply_embedding_mask = False # make sure we can pad if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # self.system_prompt = "" self.system_prompt = "You are an assistant designed to generate superior images with the superior degree of image-text alignment based on textual prompts or user prompts. " # determine length of system prompt sys_prompt_tokenized = tokenizer( [self.system_prompt], padding="longest", return_tensors="pt", ) sys_prompt_tokenized_ids = sys_prompt_tokenized.input_ids[0] self.system_prompt_length = sys_prompt_tokenized_ids.shape[0] print(f"System prompt length: {self.system_prompt_length}") self.hidden_size = llm.config.hidden_size blocks = [] if sd.is_flux: self.apply_embedding_mask = True self.context_embedder = nn.Linear( self.hidden_size, sd.unet.inner_dim) self.sequence_length = 512 sd.unet.context_embedder._orig_forward = sd.unet.context_embedder.forward sd.unet.context_embedder.forward = partial( new_context_embedder_forward, sd.unet.context_embedder) sd.unet.context_embedder._context_embedder_ref = weakref.ref(self.context_embedder) # add a is active property to the context embedder sd.unet.context_embedder._adapter_ref = self.adapter_ref for idx in range(self.num_cloned_blocks): block = FluxTransformerBlock( dim=sd.unet.inner_dim, num_attention_heads=24, attention_head_dim=128, ) # patch it in case it is quantized patch_dequantization_on_save(sd.unet.transformer_blocks[idx]) state_dict = sd.unet.transformer_blocks[idx].state_dict() for key, value in state_dict.items(): block.state_dict()[key].copy_(value) blocks.append(block) orig_block = sd.unet.transformer_blocks[idx] orig_block._orig_forward = orig_block.forward orig_block.forward = partial( new_block_forward, orig_block) orig_block._new_block_ref = weakref.ref(block) orig_block._adapter_ref = self.adapter_ref elif sd.is_lumina2: self.context_embedder = nn.Linear( self.hidden_size, sd.unet.hidden_size) self.sequence_length = 256 else: raise ValueError( "llm adapter currently only supports flux or lumina2") self.blocks = nn.ModuleList(blocks) def _get_prompt_embeds( self, prompt: Union[str, List[str]], max_sequence_length: int = 256, ) -> Tuple[torch.Tensor, torch.Tensor]: tokenizer = self.tokenizer_ref() text_encoder = self.llm_ref() device = text_encoder.device prompt = [prompt] if isinstance(prompt, str) else prompt text_inputs = tokenizer( prompt, padding="max_length", max_length=max_sequence_length + self.system_prompt_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids.to(device) prompt_attention_mask = text_inputs.attention_mask.to(device) # remove the system prompt from the input and attention mask prompt_embeds = text_encoder( text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True ) prompt_embeds = prompt_embeds.hidden_states[-1] prompt_embeds = prompt_embeds[:, self.system_prompt_length:] prompt_attention_mask = prompt_attention_mask[:, self.system_prompt_length:] dtype = text_encoder.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) return prompt_embeds, prompt_attention_mask # make a getter to see if is active @property def is_active(self): return self.adapter_ref().is_active def encode_text(self, prompt): prompt = prompt if isinstance(prompt, list) else [prompt] prompt = [self.system_prompt + p for p in prompt] # prompt = [self.system_prompt + p for p in prompt] prompt_embeds, prompt_attention_mask = self._get_prompt_embeds( prompt=prompt, max_sequence_length=self.sequence_length, ) prompt_embeds = PromptEmbeds( prompt_embeds, attention_mask=prompt_attention_mask, ).detach() return prompt_embeds def forward(self, input): return input