prithivMLmods commited on
Commit
335625e
·
verified ·
1 Parent(s): a8f636c

Delete controlnet_flux.py

Browse files
Files changed (1) hide show
  1. controlnet_flux.py +0 -418
controlnet_flux.py DELETED
@@ -1,418 +0,0 @@
1
- from dataclasses import dataclass
2
- from typing import Any, Dict, List, Optional, Tuple, Union
3
-
4
- import torch
5
- import torch.nn as nn
6
-
7
- from diffusers.configuration_utils import ConfigMixin, register_to_config
8
- from diffusers.loaders import PeftAdapterMixin
9
- from diffusers.models.modeling_utils import ModelMixin
10
- from diffusers.models.attention_processor import AttentionProcessor
11
- from diffusers.utils import (
12
- USE_PEFT_BACKEND,
13
- is_torch_version,
14
- logging,
15
- scale_lora_layers,
16
- unscale_lora_layers,
17
- )
18
- from diffusers.models.controlnet import BaseOutput, zero_module
19
- from diffusers.models.embeddings import (
20
- CombinedTimestepGuidanceTextProjEmbeddings,
21
- CombinedTimestepTextProjEmbeddings,
22
- )
23
- from diffusers.models.modeling_outputs import Transformer2DModelOutput
24
- from transformer_flux import (
25
- EmbedND,
26
- FluxSingleTransformerBlock,
27
- FluxTransformerBlock,
28
- )
29
-
30
-
31
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
32
-
33
-
34
- @dataclass
35
- class FluxControlNetOutput(BaseOutput):
36
- controlnet_block_samples: Tuple[torch.Tensor]
37
- controlnet_single_block_samples: Tuple[torch.Tensor]
38
-
39
-
40
- class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
41
- _supports_gradient_checkpointing = True
42
-
43
- @register_to_config
44
- def __init__(
45
- self,
46
- patch_size: int = 1,
47
- in_channels: int = 64,
48
- num_layers: int = 19,
49
- num_single_layers: int = 38,
50
- attention_head_dim: int = 128,
51
- num_attention_heads: int = 24,
52
- joint_attention_dim: int = 4096,
53
- pooled_projection_dim: int = 768,
54
- guidance_embeds: bool = False,
55
- axes_dims_rope: List[int] = [16, 56, 56],
56
- extra_condition_channels: int = 1 * 4,
57
- ):
58
- super().__init__()
59
- self.out_channels = in_channels
60
- self.inner_dim = num_attention_heads * attention_head_dim
61
-
62
- self.pos_embed = EmbedND(
63
- dim=self.inner_dim, theta=10000, axes_dim=axes_dims_rope
64
- )
65
- text_time_guidance_cls = (
66
- CombinedTimestepGuidanceTextProjEmbeddings
67
- if guidance_embeds
68
- else CombinedTimestepTextProjEmbeddings
69
- )
70
- self.time_text_embed = text_time_guidance_cls(
71
- embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
72
- )
73
-
74
- self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
75
- self.x_embedder = nn.Linear(in_channels, self.inner_dim)
76
-
77
- self.transformer_blocks = nn.ModuleList(
78
- [
79
- FluxTransformerBlock(
80
- dim=self.inner_dim,
81
- num_attention_heads=num_attention_heads,
82
- attention_head_dim=attention_head_dim,
83
- )
84
- for _ in range(num_layers)
85
- ]
86
- )
87
-
88
- self.single_transformer_blocks = nn.ModuleList(
89
- [
90
- FluxSingleTransformerBlock(
91
- dim=self.inner_dim,
92
- num_attention_heads=num_attention_heads,
93
- attention_head_dim=attention_head_dim,
94
- )
95
- for _ in range(num_single_layers)
96
- ]
97
- )
98
-
99
- # controlnet_blocks
100
- self.controlnet_blocks = nn.ModuleList([])
101
- for _ in range(len(self.transformer_blocks)):
102
- self.controlnet_blocks.append(
103
- zero_module(nn.Linear(self.inner_dim, self.inner_dim))
104
- )
105
-
106
- self.controlnet_single_blocks = nn.ModuleList([])
107
- for _ in range(len(self.single_transformer_blocks)):
108
- self.controlnet_single_blocks.append(
109
- zero_module(nn.Linear(self.inner_dim, self.inner_dim))
110
- )
111
-
112
- self.controlnet_x_embedder = zero_module(
113
- torch.nn.Linear(in_channels + extra_condition_channels, self.inner_dim)
114
- )
115
-
116
- self.gradient_checkpointing = False
117
-
118
- @property
119
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
120
- def attn_processors(self):
121
- r"""
122
- Returns:
123
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
124
- indexed by its weight name.
125
- """
126
- # set recursively
127
- processors = {}
128
-
129
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
130
- if hasattr(module, "get_processor"):
131
- processors[f"{name}.processor"] = module.get_processor()
132
-
133
- for sub_name, child in module.named_children():
134
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
135
-
136
- return processors
137
-
138
- for name, module in self.named_children():
139
- fn_recursive_add_processors(name, module, processors)
140
-
141
- return processors
142
-
143
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
144
- def set_attn_processor(self, processor):
145
- r"""
146
- Sets the attention processor to use to compute attention.
147
-
148
- Parameters:
149
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
150
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
151
- for **all** `Attention` layers.
152
-
153
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
154
- processor. This is strongly recommended when setting trainable attention processors.
155
-
156
- """
157
- count = len(self.attn_processors.keys())
158
-
159
- if isinstance(processor, dict) and len(processor) != count:
160
- raise ValueError(
161
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
162
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
163
- )
164
-
165
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
166
- if hasattr(module, "set_processor"):
167
- if not isinstance(processor, dict):
168
- module.set_processor(processor)
169
- else:
170
- module.set_processor(processor.pop(f"{name}.processor"))
171
-
172
- for sub_name, child in module.named_children():
173
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
174
-
175
- for name, module in self.named_children():
176
- fn_recursive_attn_processor(name, module, processor)
177
-
178
- def _set_gradient_checkpointing(self, module, value=False):
179
- if hasattr(module, "gradient_checkpointing"):
180
- module.gradient_checkpointing = value
181
-
182
- @classmethod
183
- def from_transformer(
184
- cls,
185
- transformer,
186
- num_layers: int = 4,
187
- num_single_layers: int = 10,
188
- attention_head_dim: int = 128,
189
- num_attention_heads: int = 24,
190
- load_weights_from_transformer=True,
191
- ):
192
- config = transformer.config
193
- config["num_layers"] = num_layers
194
- config["num_single_layers"] = num_single_layers
195
- config["attention_head_dim"] = attention_head_dim
196
- config["num_attention_heads"] = num_attention_heads
197
-
198
- controlnet = cls(**config)
199
-
200
- if load_weights_from_transformer:
201
- controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
202
- controlnet.time_text_embed.load_state_dict(
203
- transformer.time_text_embed.state_dict()
204
- )
205
- controlnet.context_embedder.load_state_dict(
206
- transformer.context_embedder.state_dict()
207
- )
208
- controlnet.x_embedder.load_state_dict(transformer.x_embedder.state_dict())
209
- controlnet.transformer_blocks.load_state_dict(
210
- transformer.transformer_blocks.state_dict(), strict=False
211
- )
212
- controlnet.single_transformer_blocks.load_state_dict(
213
- transformer.single_transformer_blocks.state_dict(), strict=False
214
- )
215
-
216
- controlnet.controlnet_x_embedder = zero_module(
217
- controlnet.controlnet_x_embedder
218
- )
219
-
220
- return controlnet
221
-
222
- def forward(
223
- self,
224
- hidden_states: torch.Tensor,
225
- controlnet_cond: torch.Tensor,
226
- conditioning_scale: float = 1.0,
227
- encoder_hidden_states: torch.Tensor = None,
228
- pooled_projections: torch.Tensor = None,
229
- timestep: torch.LongTensor = None,
230
- img_ids: torch.Tensor = None,
231
- txt_ids: torch.Tensor = None,
232
- guidance: torch.Tensor = None,
233
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
234
- return_dict: bool = True,
235
- ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
236
- """
237
- The [`FluxTransformer2DModel`] forward method.
238
-
239
- Args:
240
- hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
241
- Input `hidden_states`.
242
- encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
243
- Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
244
- pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
245
- from the embeddings of input conditions.
246
- timestep ( `torch.LongTensor`):
247
- Used to indicate denoising step.
248
- block_controlnet_hidden_states: (`list` of `torch.Tensor`):
249
- A list of tensors that if specified are added to the residuals of transformer blocks.
250
- joint_attention_kwargs (`dict`, *optional*):
251
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
252
- `self.processor` in
253
- [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
254
- return_dict (`bool`, *optional*, defaults to `True`):
255
- Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
256
- tuple.
257
-
258
- Returns:
259
- If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
260
- `tuple` where the first element is the sample tensor.
261
- """
262
- if joint_attention_kwargs is not None:
263
- joint_attention_kwargs = joint_attention_kwargs.copy()
264
- lora_scale = joint_attention_kwargs.pop("scale", 1.0)
265
- else:
266
- lora_scale = 1.0
267
-
268
- if USE_PEFT_BACKEND:
269
- # weight the lora layers by setting `lora_scale` for each PEFT layer
270
- scale_lora_layers(self, lora_scale)
271
- else:
272
- if (
273
- joint_attention_kwargs is not None
274
- and joint_attention_kwargs.get("scale", None) is not None
275
- ):
276
- logger.warning(
277
- "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
278
- )
279
- hidden_states = self.x_embedder(hidden_states)
280
-
281
- # add condition
282
- hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond)
283
-
284
- timestep = timestep.to(hidden_states.dtype) * 1000
285
- if guidance is not None:
286
- guidance = guidance.to(hidden_states.dtype) * 1000
287
- else:
288
- guidance = None
289
- temb = (
290
- self.time_text_embed(timestep, pooled_projections)
291
- if guidance is None
292
- else self.time_text_embed(timestep, guidance, pooled_projections)
293
- )
294
- encoder_hidden_states = self.context_embedder(encoder_hidden_states)
295
-
296
- txt_ids = txt_ids.expand(img_ids.size(0), -1, -1)
297
- ids = torch.cat((txt_ids, img_ids), dim=1)
298
- image_rotary_emb = self.pos_embed(ids)
299
-
300
- block_samples = ()
301
- for _, block in enumerate(self.transformer_blocks):
302
- if self.training and self.gradient_checkpointing:
303
-
304
- def create_custom_forward(module, return_dict=None):
305
- def custom_forward(*inputs):
306
- if return_dict is not None:
307
- return module(*inputs, return_dict=return_dict)
308
- else:
309
- return module(*inputs)
310
-
311
- return custom_forward
312
-
313
- ckpt_kwargs: Dict[str, Any] = (
314
- {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
315
- )
316
- (
317
- encoder_hidden_states,
318
- hidden_states,
319
- ) = torch.utils.checkpoint.checkpoint(
320
- create_custom_forward(block),
321
- hidden_states,
322
- encoder_hidden_states,
323
- temb,
324
- image_rotary_emb,
325
- **ckpt_kwargs,
326
- )
327
-
328
- else:
329
- encoder_hidden_states, hidden_states = block(
330
- hidden_states=hidden_states,
331
- encoder_hidden_states=encoder_hidden_states,
332
- temb=temb,
333
- image_rotary_emb=image_rotary_emb,
334
- )
335
- block_samples = block_samples + (hidden_states,)
336
-
337
- hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
338
-
339
- single_block_samples = ()
340
- for _, block in enumerate(self.single_transformer_blocks):
341
- if self.training and self.gradient_checkpointing:
342
-
343
- def create_custom_forward(module, return_dict=None):
344
- def custom_forward(*inputs):
345
- if return_dict is not None:
346
- return module(*inputs, return_dict=return_dict)
347
- else:
348
- return module(*inputs)
349
-
350
- return custom_forward
351
-
352
- ckpt_kwargs: Dict[str, Any] = (
353
- {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
354
- )
355
- hidden_states = torch.utils.checkpoint.checkpoint(
356
- create_custom_forward(block),
357
- hidden_states,
358
- temb,
359
- image_rotary_emb,
360
- **ckpt_kwargs,
361
- )
362
-
363
- else:
364
- hidden_states = block(
365
- hidden_states=hidden_states,
366
- temb=temb,
367
- image_rotary_emb=image_rotary_emb,
368
- )
369
- single_block_samples = single_block_samples + (
370
- hidden_states[:, encoder_hidden_states.shape[1] :],
371
- )
372
-
373
- # controlnet block
374
- controlnet_block_samples = ()
375
- for block_sample, controlnet_block in zip(
376
- block_samples, self.controlnet_blocks
377
- ):
378
- block_sample = controlnet_block(block_sample)
379
- controlnet_block_samples = controlnet_block_samples + (block_sample,)
380
-
381
- controlnet_single_block_samples = ()
382
- for single_block_sample, controlnet_block in zip(
383
- single_block_samples, self.controlnet_single_blocks
384
- ):
385
- single_block_sample = controlnet_block(single_block_sample)
386
- controlnet_single_block_samples = controlnet_single_block_samples + (
387
- single_block_sample,
388
- )
389
-
390
- # scaling
391
- controlnet_block_samples = [
392
- sample * conditioning_scale for sample in controlnet_block_samples
393
- ]
394
- controlnet_single_block_samples = [
395
- sample * conditioning_scale for sample in controlnet_single_block_samples
396
- ]
397
-
398
- #
399
- controlnet_block_samples = (
400
- None if len(controlnet_block_samples) == 0 else controlnet_block_samples
401
- )
402
- controlnet_single_block_samples = (
403
- None
404
- if len(controlnet_single_block_samples) == 0
405
- else controlnet_single_block_samples
406
- )
407
-
408
- if USE_PEFT_BACKEND:
409
- # remove `lora_scale` from each PEFT layer
410
- unscale_lora_layers(self, lora_scale)
411
-
412
- if not return_dict:
413
- return (controlnet_block_samples, controlnet_single_block_samples)
414
-
415
- return FluxControlNetOutput(
416
- controlnet_block_samples=controlnet_block_samples,
417
- controlnet_single_block_samples=controlnet_single_block_samples,
418
- )