File size: 7,094 Bytes
1c72248
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
from functools import partial
import sys

import torch
import torch.nn as nn
import torch.nn.functional as F
import weakref
from typing import Any, Dict, List, Optional, Tuple, Union, TYPE_CHECKING

from diffusers.models.transformers.transformer_flux import FluxTransformerBlock
from transformers import AutoModel, AutoTokenizer, Qwen2Model, LlamaModel, Qwen2Tokenizer, LlamaTokenizer

from toolkit import train_tools
from toolkit.prompt_utils import PromptEmbeds
from diffusers import Transformer2DModel
from toolkit.dequantize import patch_dequantization_on_save


if TYPE_CHECKING:
    from toolkit.stable_diffusion_model import StableDiffusion, PixArtSigmaPipeline
    from toolkit.custom_adapter import CustomAdapter

LLM = Union[Qwen2Model, LlamaModel]
LLMTokenizer = Union[Qwen2Tokenizer, LlamaTokenizer]


def new_context_embedder_forward(self, x):
    if self._adapter_ref().is_active:
        x = self._context_embedder_ref()(x)
    else:
        x = self._orig_forward(x)
    return x

def new_block_forward(
    self: FluxTransformerBlock,
    hidden_states: torch.Tensor,
    encoder_hidden_states: torch.Tensor,
    temb: torch.Tensor,
    image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
    joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    if self._adapter_ref().is_active:
        return self._new_block_ref()(hidden_states, encoder_hidden_states, temb, image_rotary_emb, joint_attention_kwargs)
    else:
        return self._orig_forward(hidden_states, encoder_hidden_states, temb, image_rotary_emb, joint_attention_kwargs)


class LLMAdapter(torch.nn.Module):
    def __init__(
            self,
            adapter: 'CustomAdapter',
            sd: 'StableDiffusion',
            llm: LLM,
            tokenizer: LLMTokenizer,
            num_cloned_blocks: int = 0,
    ):
        super(LLMAdapter, self).__init__()
        self.adapter_ref: weakref.ref = weakref.ref(adapter)
        self.sd_ref: weakref.ref = weakref.ref(sd)
        self.llm_ref: weakref.ref = weakref.ref(llm)
        self.tokenizer_ref: weakref.ref = weakref.ref(tokenizer)
        self.num_cloned_blocks = num_cloned_blocks
        self.apply_embedding_mask = False
        # make sure we can pad
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

        # self.system_prompt = ""
        self.system_prompt = "You are an assistant designed to generate superior images with the superior degree of image-text alignment based on textual prompts or user prompts. <Prompt Start> "
        
        # determine length of system prompt
        sys_prompt_tokenized = tokenizer(
            [self.system_prompt],
            padding="longest",
            return_tensors="pt",
        )

        sys_prompt_tokenized_ids = sys_prompt_tokenized.input_ids[0]
        
        self.system_prompt_length = sys_prompt_tokenized_ids.shape[0]
        
        print(f"System prompt length: {self.system_prompt_length}")

        self.hidden_size = llm.config.hidden_size
        
        blocks = []

        if sd.is_flux:
            self.apply_embedding_mask = True
            self.context_embedder = nn.Linear(
                self.hidden_size, sd.unet.inner_dim)
            self.sequence_length = 512
            sd.unet.context_embedder._orig_forward = sd.unet.context_embedder.forward
            sd.unet.context_embedder.forward = partial(
                new_context_embedder_forward, sd.unet.context_embedder)
            sd.unet.context_embedder._context_embedder_ref = weakref.ref(self.context_embedder)
            # add a is active property to the context embedder
            sd.unet.context_embedder._adapter_ref = self.adapter_ref
            
            for idx in range(self.num_cloned_blocks):
                block = FluxTransformerBlock(
                    dim=sd.unet.inner_dim,
                    num_attention_heads=24,
                    attention_head_dim=128,
                )
                # patch it in case it is quantized
                patch_dequantization_on_save(sd.unet.transformer_blocks[idx])
                state_dict = sd.unet.transformer_blocks[idx].state_dict()
                for key, value in state_dict.items():
                    block.state_dict()[key].copy_(value)
                blocks.append(block)
                orig_block = sd.unet.transformer_blocks[idx]
                orig_block._orig_forward = orig_block.forward
                orig_block.forward = partial(
                    new_block_forward, orig_block)
                orig_block._new_block_ref = weakref.ref(block)
                orig_block._adapter_ref = self.adapter_ref
            
        elif sd.is_lumina2:
            self.context_embedder = nn.Linear(
                self.hidden_size, sd.unet.hidden_size)
            self.sequence_length = 256
        else:
            raise ValueError(
                "llm adapter currently only supports flux or lumina2")
        
        self.blocks = nn.ModuleList(blocks)

    def _get_prompt_embeds(
        self,
        prompt: Union[str, List[str]],
        max_sequence_length: int = 256,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        tokenizer = self.tokenizer_ref()
        text_encoder = self.llm_ref()
        device = text_encoder.device
        prompt = [prompt] if isinstance(prompt, str) else prompt
        text_inputs = tokenizer(
            prompt,
            padding="max_length",
            max_length=max_sequence_length + self.system_prompt_length,
            truncation=True,
            return_tensors="pt",
        )

        text_input_ids = text_inputs.input_ids.to(device)
        prompt_attention_mask = text_inputs.attention_mask.to(device)
        
        # remove the system prompt from the input and attention mask
        
        prompt_embeds = text_encoder(
            text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True
        )
        prompt_embeds = prompt_embeds.hidden_states[-1]
        
        prompt_embeds = prompt_embeds[:, self.system_prompt_length:]
        prompt_attention_mask = prompt_attention_mask[:, self.system_prompt_length:]

        dtype = text_encoder.dtype

        prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)

        return prompt_embeds, prompt_attention_mask

    # make a getter to see if is active

    @property
    def is_active(self):
        return self.adapter_ref().is_active

    def encode_text(self, prompt):

        prompt = prompt if isinstance(prompt, list) else [prompt]

        prompt = [self.system_prompt + p for p in prompt]
        # prompt = [self.system_prompt + p for p in prompt]

        prompt_embeds, prompt_attention_mask = self._get_prompt_embeds(
            prompt=prompt,
            max_sequence_length=self.sequence_length,
        )

        prompt_embeds = PromptEmbeds(
            prompt_embeds,
            attention_mask=prompt_attention_mask,
        ).detach()

        return prompt_embeds

    def forward(self, input):
        return input