listen2you003 commited on
Commit
36de41f
·
1 Parent(s): 53df0d6

init commit

Browse files
.gitattributes CHANGED
@@ -1,35 +1,35 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,14 +1,14 @@
1
- ---
2
- title: Step1X Edit
3
- emoji: 💻
4
- colorFrom: blue
5
- colorTo: purple
6
- sdk: gradio
7
- sdk_version: 5.27.0
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- short_description: Edit an image based on the given instruction.
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ ---
2
+ title: Step1X Edit
3
+ emoji: 💻
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 5.27.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ short_description: Edit an image based on the given instruction.
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import json
4
+ import itertools
5
+ import math
6
+ import os
7
+ import spaces
8
+ import time
9
+ from pathlib import Path
10
+
11
+
12
+ import gradio as gr
13
+ import numpy as np
14
+ import torch
15
+ from einops import rearrange, repeat
16
+ from huggingface_hub import snapshot_download
17
+ from PIL import Image, ImageOps
18
+ from safetensors.torch import load_file
19
+ from torchvision.transforms import functional as F
20
+ from tqdm import tqdm
21
+
22
+ import sampling
23
+ from modules.autoencoder import AutoEncoder
24
+ from modules.conditioner import Qwen25VL_7b_Embedder as Qwen2VLEmbedder
25
+ from modules.model_edit import Step1XParams, Step1XEdit
26
+
27
+ print("TORCH_CUDA", torch.cuda.is_available())
28
+
29
+ def load_state_dict(model, ckpt_path, device="cuda", strict=False, assign=True):
30
+ if Path(ckpt_path).suffix == ".safetensors":
31
+ state_dict = load_file(ckpt_path, device)
32
+ else:
33
+ state_dict = torch.load(ckpt_path, map_location="cpu")
34
+
35
+ missing, unexpected = model.load_state_dict(
36
+ state_dict, strict=strict, assign=assign
37
+ )
38
+ if len(missing) > 0 and len(unexpected) > 0:
39
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
40
+ print("\n" + "-" * 79 + "\n")
41
+ print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
42
+ elif len(missing) > 0:
43
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
44
+ elif len(unexpected) > 0:
45
+ print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
46
+ return model
47
+
48
+
49
+ def load_models(
50
+ dit_path=None,
51
+ ae_path=None,
52
+ qwen2vl_model_path=None,
53
+ device="cuda",
54
+ max_length=256,
55
+ dtype=torch.bfloat16,
56
+ ):
57
+ qwen2vl_encoder = Qwen2VLEmbedder(
58
+ qwen2vl_model_path,
59
+ device=device,
60
+ max_length=max_length,
61
+ dtype=dtype,
62
+ )
63
+
64
+ with torch.device("meta"):
65
+ ae = AutoEncoder(
66
+ resolution=256,
67
+ in_channels=3,
68
+ ch=128,
69
+ out_ch=3,
70
+ ch_mult=[1, 2, 4, 4],
71
+ num_res_blocks=2,
72
+ z_channels=16,
73
+ scale_factor=0.3611,
74
+ shift_factor=0.1159,
75
+ )
76
+
77
+ step1x_params = Step1XParams(
78
+ in_channels=64,
79
+ out_channels=64,
80
+ vec_in_dim=768,
81
+ context_in_dim=4096,
82
+ hidden_size=3072,
83
+ mlp_ratio=4.0,
84
+ num_heads=24,
85
+ depth=19,
86
+ depth_single_blocks=38,
87
+ axes_dim=[16, 56, 56],
88
+ theta=10_000,
89
+ qkv_bias=True,
90
+ )
91
+ dit = Step1XEdit(step1x_params)
92
+
93
+ ae = load_state_dict(ae, ae_path)
94
+ dit = load_state_dict(
95
+ dit, dit_path
96
+ )
97
+
98
+ dit = dit.to(device=device, dtype=dtype)
99
+ ae = ae.to(device=device, dtype=torch.float32)
100
+
101
+ return ae, dit, qwen2vl_encoder
102
+
103
+
104
+ class ImageGenerator:
105
+ def __init__(
106
+ self,
107
+ dit_path=None,
108
+ ae_path=None,
109
+ qwen2vl_model_path=None,
110
+ device="cuda",
111
+ max_length=640,
112
+ dtype=torch.bfloat16,
113
+ ) -> None:
114
+ self.device = torch.device(device)
115
+ self.ae, self.dit, self.llm_encoder = load_models(
116
+ dit_path=dit_path,
117
+ ae_path=ae_path,
118
+ qwen2vl_model_path=qwen2vl_model_path,
119
+ max_length=max_length,
120
+ dtype=dtype,
121
+ )
122
+
123
+ def prepare(self, prompt, img, ref_image, ref_image_raw):
124
+ bs, _, h, w = img.shape
125
+ bs, _, ref_h, ref_w = ref_image.shape
126
+
127
+ assert h == ref_h and w == ref_w
128
+
129
+ if bs == 1 and not isinstance(prompt, str):
130
+ bs = len(prompt)
131
+ elif bs >= 1 and isinstance(prompt, str):
132
+ prompt = [prompt] * bs
133
+
134
+ img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
135
+ ref_img = rearrange(ref_image, "b c (ref_h ph) (ref_w pw) -> b (ref_h ref_w) (c ph pw)", ph=2, pw=2)
136
+ if img.shape[0] == 1 and bs > 1:
137
+ img = repeat(img, "1 ... -> bs ...", bs=bs)
138
+ ref_img = repeat(ref_img, "1 ... -> bs ...", bs=bs)
139
+
140
+ img_ids = torch.zeros(h // 2, w // 2, 3)
141
+
142
+ img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
143
+ img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
144
+ img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
145
+
146
+ ref_img_ids = torch.zeros(ref_h // 2, ref_w // 2, 3)
147
+
148
+ ref_img_ids[..., 1] = ref_img_ids[..., 1] + torch.arange(ref_h // 2)[:, None]
149
+ ref_img_ids[..., 2] = ref_img_ids[..., 2] + torch.arange(ref_w // 2)[None, :]
150
+ ref_img_ids = repeat(ref_img_ids, "ref_h ref_w c -> b (ref_h ref_w) c", b=bs)
151
+
152
+ if isinstance(prompt, str):
153
+ prompt = [prompt]
154
+
155
+ txt, mask = self.llm_encoder(prompt, ref_image_raw)
156
+
157
+ txt_ids = torch.zeros(bs, txt.shape[1], 3)
158
+
159
+ img = torch.cat([img, ref_img.to(device=img.device, dtype=img.dtype)], dim=-2)
160
+ img_ids = torch.cat([img_ids, ref_img_ids], dim=-2)
161
+
162
+
163
+ return {
164
+ "img": img,
165
+ "mask": mask,
166
+ "img_ids": img_ids.to(img.device),
167
+ "llm_embedding": txt.to(img.device),
168
+ "txt_ids": txt_ids.to(img.device),
169
+ }
170
+
171
+ @staticmethod
172
+ def process_diff_norm(diff_norm, k):
173
+ pow_result = torch.pow(diff_norm, k)
174
+
175
+ result = torch.where(
176
+ diff_norm > 1.0,
177
+ pow_result,
178
+ torch.where(diff_norm < 1.0, torch.ones_like(diff_norm), diff_norm),
179
+ )
180
+ return result
181
+
182
+ def denoise(
183
+ self,
184
+ img: torch.Tensor,
185
+ img_ids: torch.Tensor,
186
+ llm_embedding: torch.Tensor,
187
+ txt_ids: torch.Tensor,
188
+ timesteps: list[float],
189
+ cfg_guidance: float = 4.5,
190
+ mask=None,
191
+ show_progress=False,
192
+ timesteps_truncate=1.0,
193
+ ):
194
+ if show_progress:
195
+ pbar = tqdm(itertools.pairwise(timesteps), desc='denoising...')
196
+ else:
197
+ pbar = itertools.pairwise(timesteps)
198
+ for t_curr, t_prev in pbar:
199
+ if img.shape[0] == 1 and cfg_guidance != -1:
200
+ img = torch.cat([img, img], dim=0)
201
+ t_vec = torch.full(
202
+ (img.shape[0],), t_curr, dtype=img.dtype, device=img.device
203
+ )
204
+
205
+ txt, vec = self.dit.connector(llm_embedding, t_vec, mask)
206
+
207
+
208
+ pred = self.dit(
209
+ img=img,
210
+ img_ids=img_ids,
211
+ txt=txt,
212
+ txt_ids=txt_ids,
213
+ y=vec,
214
+ timesteps=t_vec,
215
+ )
216
+
217
+ if cfg_guidance != -1:
218
+ cond, uncond = (
219
+ pred[0 : pred.shape[0] // 2, :],
220
+ pred[pred.shape[0] // 2 :, :],
221
+ )
222
+ if t_curr > timesteps_truncate:
223
+ diff = cond - uncond
224
+ diff_norm = torch.norm(diff, dim=(2), keepdim=True)
225
+ pred = uncond + cfg_guidance * (
226
+ cond - uncond
227
+ ) / self.process_diff_norm(diff_norm, k=0.4)
228
+ else:
229
+ pred = uncond + cfg_guidance * (cond - uncond)
230
+ tem_img = img[0 : img.shape[0] // 2, :] + (t_prev - t_curr) * pred
231
+ img_input_length = img.shape[1] // 2
232
+ img = torch.cat(
233
+ [
234
+ tem_img[:, :img_input_length],
235
+ img[ : img.shape[0] // 2, img_input_length:],
236
+ ], dim=1
237
+ )
238
+
239
+ return img[:, :img.shape[1] // 2]
240
+
241
+ @staticmethod
242
+ def unpack(x: torch.Tensor, height: int, width: int) -> torch.Tensor:
243
+ return rearrange(
244
+ x,
245
+ "b (h w) (c ph pw) -> b c (h ph) (w pw)",
246
+ h=math.ceil(height / 16),
247
+ w=math.ceil(width / 16),
248
+ ph=2,
249
+ pw=2,
250
+ )
251
+
252
+ @staticmethod
253
+ def load_image(image):
254
+ from PIL import Image
255
+
256
+ if isinstance(image, np.ndarray):
257
+ image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
258
+ image = image.unsqueeze(0)
259
+ return image
260
+ elif isinstance(image, Image.Image):
261
+ image = F.to_tensor(image.convert("RGB"))
262
+ image = image.unsqueeze(0)
263
+ return image
264
+ elif isinstance(image, torch.Tensor):
265
+ return image
266
+ elif isinstance(image, str):
267
+ image = F.to_tensor(Image.open(image).convert("RGB"))
268
+ image = image.unsqueeze(0)
269
+ return image
270
+ else:
271
+ raise ValueError(f"Unsupported image type: {type(image)}")
272
+
273
+ def output_process_image(self, resize_img, image_size):
274
+ res_image = resize_img.resize(image_size)
275
+ return res_image
276
+
277
+ def input_process_image(self, img, img_size=512):
278
+ # 1. 打开图片
279
+ w, h = img.size
280
+ r = w / h
281
+
282
+ if w > h:
283
+ w_new = math.ceil(math.sqrt(img_size * img_size * r))
284
+ h_new = math.ceil(w_new / r)
285
+ else:
286
+ h_new = math.ceil(math.sqrt(img_size * img_size / r))
287
+ w_new = math.ceil(h_new * r)
288
+ h_new = math.ceil(h_new) // 16 * 16
289
+ w_new = math.ceil(w_new) // 16 * 16
290
+
291
+ img_resized = img.resize((w_new, h_new))
292
+ return img_resized, img.size
293
+
294
+ @torch.inference_mode()
295
+ def generate_image(
296
+ self,
297
+ prompt,
298
+ negative_prompt,
299
+ ref_images,
300
+ num_steps,
301
+ cfg_guidance,
302
+ seed,
303
+ num_samples=1,
304
+ init_image=None,
305
+ image2image_strength=0.0,
306
+ show_progress=False,
307
+ size_level=512,
308
+ ):
309
+ assert num_samples == 1, "num_samples > 1 is not supported yet."
310
+ ref_images_raw, img_info = self.input_process_image(ref_images, img_size=size_level)
311
+
312
+ width, height = ref_images_raw.width, ref_images_raw.height
313
+
314
+
315
+ ref_images_raw = self.load_image(ref_images_raw)
316
+ ref_images_raw = ref_images_raw.to(self.device)
317
+ ref_images = self.ae.encode(ref_images_raw.to(self.device) * 2 - 1)
318
+
319
+ seed = int(seed)
320
+ seed = torch.Generator(device="cpu").seed() if seed < 0 else seed
321
+
322
+ t0 = time.perf_counter()
323
+
324
+ if init_image is not None:
325
+ init_image = self.load_image(init_image)
326
+ init_image = init_image.to(self.device)
327
+ init_image = torch.nn.functional.interpolate(init_image, (height, width))
328
+ init_image = self.ae.encode(init_image.to() * 2 - 1)
329
+
330
+ x = torch.randn(
331
+ num_samples,
332
+ 16,
333
+ height // 8,
334
+ width // 8,
335
+ device=self.device,
336
+ dtype=torch.bfloat16,
337
+ generator=torch.Generator(device=self.device).manual_seed(seed),
338
+ )
339
+
340
+ timesteps = sampling.get_schedule(
341
+ num_steps, x.shape[-1] * x.shape[-2] // 4, shift=True
342
+ )
343
+
344
+ if init_image is not None:
345
+ t_idx = int((1 - image2image_strength) * num_steps)
346
+ t = timesteps[t_idx]
347
+ timesteps = timesteps[t_idx:]
348
+ x = t * x + (1.0 - t) * init_image.to(x.dtype)
349
+
350
+ x = torch.cat([x, x], dim=0)
351
+ ref_images = torch.cat([ref_images, ref_images], dim=0)
352
+ ref_images_raw = torch.cat([ref_images_raw, ref_images_raw], dim=0)
353
+ inputs = self.prepare([prompt, negative_prompt], x, ref_image=ref_images, ref_image_raw=ref_images_raw)
354
+
355
+ x = self.denoise(
356
+ **inputs,
357
+ cfg_guidance=cfg_guidance,
358
+ timesteps=timesteps,
359
+ show_progress=show_progress,
360
+ timesteps_truncate=1.0,
361
+ )
362
+ x = self.unpack(x.float(), height, width)
363
+ with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16):
364
+ x = self.ae.decode(x)
365
+ x = x.clamp(-1, 1)
366
+ x = x.mul(0.5).add(0.5)
367
+
368
+ t1 = time.perf_counter()
369
+ print(f"Done in {t1 - t0:.1f}s.")
370
+ images_list = []
371
+ for img in x.float():
372
+ images_list.append(self.output_process_image(F.to_pil_image(img), img_info))
373
+ return images_list
374
+
375
+
376
+ def prepare_infer_func():
377
+ # 模型仓库ID(如:"bert-base-uncased")
378
+ model_repo = "stepfun-ai/Step1X-Edit"
379
+ # 本地保存路径
380
+ model_path = "./model_weights"
381
+ os.makedirs(model_path, exist_ok=True)
382
+
383
+
384
+ # 下载模型(包括所有文件)
385
+ snapshot_download(
386
+ repo_id=model_repo,
387
+ local_dir=model_path,
388
+ local_dir_use_symlinks=False # 避免使用符号链接
389
+ )
390
+
391
+
392
+ image_edit = ImageGenerator(
393
+ ae_path=os.path.join(model_path, 'vae.safetensors'),
394
+ dit_path=os.path.join(model_path, "step1x-edit-i1258.safetensors"),
395
+ qwen2vl_model_path='Qwen/Qwen2.5-VL-7B-Instruct',
396
+ max_length=640,
397
+ )
398
+
399
+ return image_edit.generate_image
400
+
401
+ @spaces.GPU
402
+ def inference(prompt, ref_images, seed, size_level, infer_func=None):
403
+ start_time = time.time()
404
+
405
+ if seed == -1:
406
+ import random
407
+ random_seed = random.randint(0, 2**32 - 1)
408
+ else:
409
+ random_seed = seed
410
+
411
+ image = infer_func(
412
+ prompt,
413
+ negative_prompt="",
414
+ ref_images=ref_images.convert('RGB'),
415
+ num_samples=1,
416
+ num_steps=28,
417
+ cfg_guidance=6.0,
418
+ seed=random_seed,
419
+ show_progress=True,
420
+ size_level=size_level,
421
+ )[0]
422
+
423
+ print(f"Time taken: {time.time() - start_time:.2f} seconds")
424
+ return image, random_seed
425
+
426
+
427
+ def create_demo():
428
+ inference_func = prepare_infer_func()
429
+ with gr.Blocks() as demo:
430
+ gr.Markdown(
431
+ """
432
+ # Step1X-Edit
433
+ """
434
+ )
435
+ with gr.Row():
436
+ with gr.Column():
437
+ prompt = gr.Textbox(
438
+ label="编辑指令",
439
+ value='Remove the person from the image.',
440
+ )
441
+ init_image = gr.Image(label="Input Image", type='pil')
442
+
443
+ random_seed = gr.Number(label="Random Seed", value=-1, minimum=-1)
444
+
445
+ size_level = gr.Number(label="size level (recommend 512, 768, 1024, min 512)", value=512, minimum=512)
446
+
447
+ generate_btn = gr.Button("Generate")
448
+
449
+ with gr.Column():
450
+ output_image = gr.Image(label="Generated Image",type='pil',image_mode='RGB')
451
+ output_random_seed = gr.Textbox(label="Used Seed", lines=5)
452
+ from functools import partial
453
+ generate_btn.click(
454
+ fn=partial(inference, infer_func=inference_func),
455
+ inputs=[
456
+ prompt,
457
+ init_image,
458
+ random_seed,
459
+ size_level,
460
+ ],
461
+ outputs=[output_image, output_random_seed],
462
+ )
463
+
464
+ return demo
465
+
466
+
467
+ if __name__ == "__main__":
468
+ demo = create_demo()
469
+ demo.launch(server_name='0.0.0.0',server_port=7860)
modules/__init__.py ADDED
File without changes
modules/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (128 Bytes). View file
 
modules/__pycache__/attention.cpython-310.pyc ADDED
Binary file (3.13 kB). View file
 
modules/__pycache__/autoencoder.cpython-310.pyc ADDED
Binary file (8.78 kB). View file
 
modules/__pycache__/conditioner.cpython-310.pyc ADDED
Binary file (4.94 kB). View file
 
modules/__pycache__/connector_edit.cpython-310.pyc ADDED
Binary file (11.8 kB). View file
 
modules/__pycache__/layers.cpython-310.pyc ADDED
Binary file (19.1 kB). View file
 
modules/__pycache__/model_edit.cpython-310.pyc ADDED
Binary file (4.21 kB). View file
 
modules/attention.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+
7
+ try:
8
+ import flash_attn
9
+ from flash_attn.flash_attn_interface import (
10
+ _flash_attn_forward,
11
+ flash_attn_func,
12
+ flash_attn_varlen_func,
13
+ )
14
+ except ImportError:
15
+ flash_attn = None
16
+ flash_attn_varlen_func = None
17
+ _flash_attn_forward = None
18
+ flash_attn_func = None
19
+
20
+ MEMORY_LAYOUT = {
21
+ # flash模式:
22
+ # 预处理: 输入 [batch_size, seq_len, num_heads, head_dim]
23
+ # 后处理: 保持形状不变
24
+ "flash": (
25
+ lambda x: x, # 保持形状
26
+ lambda x: x, # 保持形状
27
+ ),
28
+ # torch/vanilla模式:
29
+ # 预处理: 交换序列和注意力头的维度 [B,S,A,D] -> [B,A,S,D]
30
+ # 后处理: 交换回原始维度 [B,A,S,D] -> [B,S,A,D]
31
+ "torch": (
32
+ lambda x: x.transpose(1, 2), # (B,S,A,D) -> (B,A,S,D)
33
+ lambda x: x.transpose(1, 2), # (B,A,S,D) -> (B,S,A,D)
34
+ ),
35
+ "vanilla": (
36
+ lambda x: x.transpose(1, 2),
37
+ lambda x: x.transpose(1, 2),
38
+ ),
39
+ }
40
+
41
+
42
+ def attention(
43
+ q,
44
+ k,
45
+ v,
46
+ mode="torch",
47
+ drop_rate=0,
48
+ attn_mask=None,
49
+ causal=False,
50
+ ):
51
+ """
52
+ 执行QKV自注意力计算
53
+
54
+ Args:
55
+ q (torch.Tensor): 查询张量,形状 [batch_size, seq_len, num_heads, head_dim]
56
+ k (torch.Tensor): 键张量,形状 [batch_size, seq_len_kv, num_heads, head_dim]
57
+ v (torch.Tensor): 值张量,形状 [batch_size, seq_len_kv, num_heads, head_dim]
58
+ mode (str): 注意力模式,可选 'flash', 'torch', 'vanilla'
59
+ drop_rate (float): 注意力矩阵的dropout概率
60
+ attn_mask (torch.Tensor): 注意力掩码,形状根据模式不同而变化
61
+ causal (bool): 是否使用因果注意力(仅关注前面位置)
62
+
63
+ Returns:
64
+ torch.Tensor: 注意力输出,形状 [batch_size, seq_len, num_heads * head_dim]
65
+ """
66
+ # 获取预处理和后处理函数
67
+ pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
68
+
69
+ # 应用预处理变换
70
+ q = pre_attn_layout(q) # 形状根据模式变化
71
+ k = pre_attn_layout(k)
72
+ v = pre_attn_layout(v)
73
+
74
+ if mode == "torch":
75
+ # 使用PyTorch原生的scaled_dot_product_attention
76
+ if attn_mask is not None and attn_mask.dtype != torch.bool:
77
+ attn_mask = attn_mask.to(q.dtype)
78
+ x = F.scaled_dot_product_attention(
79
+ q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal
80
+ )
81
+ elif mode == "flash":
82
+ assert flash_attn_func is not None, "flash_attn_func未定义"
83
+ assert attn_mask is None, "不支持的注意力掩码"
84
+ x: torch.Tensor = flash_attn_func(
85
+ q, k, v, dropout_p=drop_rate, causal=causal, softmax_scale=None
86
+ ) # type: ignore
87
+ elif mode == "vanilla":
88
+ # 手动实现注意力机制
89
+ scale_factor = 1 / math.sqrt(q.size(-1)) # 缩放因子 1/sqrt(d_k)
90
+
91
+ b, a, s, _ = q.shape # 获取形状参数
92
+ s1 = k.size(2) # 键值序列长度
93
+
94
+ # 初始化注意力偏置
95
+ attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
96
+
97
+ # 处理因果掩码
98
+ if causal:
99
+ assert attn_mask is None, "因果掩码和注意力掩码不能同时使用"
100
+ # 生成下三角因果掩码
101
+ temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(
102
+ diagonal=0
103
+ )
104
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
105
+ attn_bias = attn_bias.to(q.dtype)
106
+
107
+ # 处理自定义注意力掩码
108
+ if attn_mask is not None:
109
+ if attn_mask.dtype == torch.bool:
110
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
111
+ else:
112
+ attn_bias += attn_mask # 允许类似ALiBi的位置偏置
113
+
114
+ # 计算注意力矩阵
115
+ attn = (q @ k.transpose(-2, -1)) * scale_factor # [B,A,S,S1]
116
+ attn += attn_bias
117
+
118
+ # softmax和dropout
119
+ attn = attn.softmax(dim=-1)
120
+ attn = torch.dropout(attn, p=drop_rate, train=True)
121
+
122
+ # 计算输出
123
+ x = attn @ v # [B,A,S,D]
124
+ else:
125
+ raise NotImplementedError(f"不支持的注意力模式: {mode}")
126
+
127
+ # 应用后处理变换
128
+ x = post_attn_layout(x) # 恢复原始维度顺序
129
+
130
+ # 合并注意力头维度
131
+ b, s, a, d = x.shape
132
+ out = x.reshape(b, s, -1) # [B,S,A*D]
133
+ return out
modules/autoencoder.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from Flux
2
+ #
3
+ # Copyright 2024 Black Forest Labs
4
+
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+ # This source code is licensed under the license found in the
18
+ # LICENSE file in the root directory of this source tree.
19
+ import torch
20
+ from einops import rearrange
21
+ from torch import Tensor, nn
22
+
23
+
24
+ def swish(x: Tensor) -> Tensor:
25
+ return x * torch.sigmoid(x)
26
+
27
+
28
+ class AttnBlock(nn.Module):
29
+ def __init__(self, in_channels: int):
30
+ super().__init__()
31
+ self.in_channels = in_channels
32
+
33
+ self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
34
+
35
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
36
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
37
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
38
+ self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
39
+
40
+ def attention(self, h_: Tensor) -> Tensor:
41
+ h_ = self.norm(h_)
42
+ q = self.q(h_)
43
+ k = self.k(h_)
44
+ v = self.v(h_)
45
+
46
+ b, c, h, w = q.shape
47
+ q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
48
+ k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
49
+ v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
50
+ h_ = nn.functional.scaled_dot_product_attention(q, k, v)
51
+
52
+ return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
53
+
54
+ def forward(self, x: Tensor) -> Tensor:
55
+ return x + self.proj_out(self.attention(x))
56
+
57
+
58
+ class ResnetBlock(nn.Module):
59
+ def __init__(self, in_channels: int, out_channels: int):
60
+ super().__init__()
61
+ self.in_channels = in_channels
62
+ out_channels = in_channels if out_channels is None else out_channels
63
+ self.out_channels = out_channels
64
+
65
+ self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
66
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
67
+ self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
68
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
69
+ if self.in_channels != self.out_channels:
70
+ self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
71
+
72
+ def forward(self, x):
73
+ h = x
74
+ h = self.norm1(h)
75
+ h = swish(h)
76
+ h = self.conv1(h)
77
+
78
+ h = self.norm2(h)
79
+ h = swish(h)
80
+ h = self.conv2(h)
81
+
82
+ if self.in_channels != self.out_channels:
83
+ x = self.nin_shortcut(x)
84
+
85
+ return x + h
86
+
87
+
88
+ class Downsample(nn.Module):
89
+ def __init__(self, in_channels: int):
90
+ super().__init__()
91
+ # no asymmetric padding in torch conv, must do it ourselves
92
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
93
+
94
+ def forward(self, x: Tensor):
95
+ pad = (0, 1, 0, 1)
96
+ x = nn.functional.pad(x, pad, mode="constant", value=0)
97
+ x = self.conv(x)
98
+ return x
99
+
100
+
101
+ class Upsample(nn.Module):
102
+ def __init__(self, in_channels: int):
103
+ super().__init__()
104
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
105
+
106
+ def forward(self, x: Tensor):
107
+ x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
108
+ x = self.conv(x)
109
+ return x
110
+
111
+
112
+ class Encoder(nn.Module):
113
+ def __init__(
114
+ self,
115
+ resolution: int,
116
+ in_channels: int,
117
+ ch: int,
118
+ ch_mult: list[int],
119
+ num_res_blocks: int,
120
+ z_channels: int,
121
+ ):
122
+ super().__init__()
123
+ self.ch = ch
124
+ self.num_resolutions = len(ch_mult)
125
+ self.num_res_blocks = num_res_blocks
126
+ self.resolution = resolution
127
+ self.in_channels = in_channels
128
+ # downsampling
129
+ self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
130
+
131
+ curr_res = resolution
132
+ in_ch_mult = (1, *tuple(ch_mult))
133
+ self.in_ch_mult = in_ch_mult
134
+ self.down = nn.ModuleList()
135
+ block_in = self.ch
136
+ for i_level in range(self.num_resolutions):
137
+ block = nn.ModuleList()
138
+ attn = nn.ModuleList()
139
+ block_in = ch * in_ch_mult[i_level]
140
+ block_out = ch * ch_mult[i_level]
141
+ for _ in range(self.num_res_blocks):
142
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
143
+ block_in = block_out
144
+ down = nn.Module()
145
+ down.block = block
146
+ down.attn = attn
147
+ if i_level != self.num_resolutions - 1:
148
+ down.downsample = Downsample(block_in)
149
+ curr_res = curr_res // 2
150
+ self.down.append(down)
151
+
152
+ # middle
153
+ self.mid = nn.Module()
154
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
155
+ self.mid.attn_1 = AttnBlock(block_in)
156
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
157
+
158
+ # end
159
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
160
+ self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
161
+
162
+ def forward(self, x: Tensor) -> Tensor:
163
+ # downsampling
164
+ hs = [self.conv_in(x)]
165
+ for i_level in range(self.num_resolutions):
166
+ for i_block in range(self.num_res_blocks):
167
+ h = self.down[i_level].block[i_block](hs[-1])
168
+ if len(self.down[i_level].attn) > 0:
169
+ h = self.down[i_level].attn[i_block](h)
170
+ hs.append(h)
171
+ if i_level != self.num_resolutions - 1:
172
+ hs.append(self.down[i_level].downsample(hs[-1]))
173
+
174
+ # middle
175
+ h = hs[-1]
176
+ h = self.mid.block_1(h)
177
+ h = self.mid.attn_1(h)
178
+ h = self.mid.block_2(h)
179
+ # end
180
+ h = self.norm_out(h)
181
+ h = swish(h)
182
+ h = self.conv_out(h)
183
+ return h
184
+
185
+
186
+ class Decoder(nn.Module):
187
+ def __init__(
188
+ self,
189
+ ch: int,
190
+ out_ch: int,
191
+ ch_mult: list[int],
192
+ num_res_blocks: int,
193
+ in_channels: int,
194
+ resolution: int,
195
+ z_channels: int,
196
+ ):
197
+ super().__init__()
198
+ self.ch = ch
199
+ self.num_resolutions = len(ch_mult)
200
+ self.num_res_blocks = num_res_blocks
201
+ self.resolution = resolution
202
+ self.in_channels = in_channels
203
+ self.ffactor = 2 ** (self.num_resolutions - 1)
204
+
205
+ # compute in_ch_mult, block_in and curr_res at lowest res
206
+ block_in = ch * ch_mult[self.num_resolutions - 1]
207
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
208
+ self.z_shape = (1, z_channels, curr_res, curr_res)
209
+
210
+ # z to block_in
211
+ self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
212
+
213
+ # middle
214
+ self.mid = nn.Module()
215
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
216
+ self.mid.attn_1 = AttnBlock(block_in)
217
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
218
+
219
+ # upsampling
220
+ self.up = nn.ModuleList()
221
+ for i_level in reversed(range(self.num_resolutions)):
222
+ block = nn.ModuleList()
223
+ attn = nn.ModuleList()
224
+ block_out = ch * ch_mult[i_level]
225
+ for _ in range(self.num_res_blocks + 1):
226
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
227
+ block_in = block_out
228
+ up = nn.Module()
229
+ up.block = block
230
+ up.attn = attn
231
+ if i_level != 0:
232
+ up.upsample = Upsample(block_in)
233
+ curr_res = curr_res * 2
234
+ self.up.insert(0, up) # prepend to get consistent order
235
+
236
+ # end
237
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
238
+ self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
239
+
240
+ def forward(self, z: Tensor) -> Tensor:
241
+ # z to block_in
242
+ h = self.conv_in(z)
243
+
244
+ # middle
245
+ h = self.mid.block_1(h)
246
+ h = self.mid.attn_1(h)
247
+ h = self.mid.block_2(h)
248
+
249
+ # upsampling
250
+ for i_level in reversed(range(self.num_resolutions)):
251
+ for i_block in range(self.num_res_blocks + 1):
252
+ h = self.up[i_level].block[i_block](h)
253
+ if len(self.up[i_level].attn) > 0:
254
+ h = self.up[i_level].attn[i_block](h)
255
+ if i_level != 0:
256
+ h = self.up[i_level].upsample(h)
257
+
258
+ # end
259
+ h = self.norm_out(h)
260
+ h = swish(h)
261
+ h = self.conv_out(h)
262
+ return h
263
+
264
+
265
+ class DiagonalGaussian(nn.Module):
266
+ def __init__(self, sample: bool = True, chunk_dim: int = 1):
267
+ super().__init__()
268
+ self.sample = sample
269
+ self.chunk_dim = chunk_dim
270
+
271
+ def forward(self, z: Tensor) -> Tensor:
272
+ mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
273
+ if self.sample:
274
+ std = torch.exp(0.5 * logvar)
275
+ return mean + std * torch.randn_like(mean)
276
+ else:
277
+ return mean
278
+
279
+
280
+ class AutoEncoder(nn.Module):
281
+ def __init__(
282
+ self,
283
+ resolution: int,
284
+ in_channels: int,
285
+ ch: int,
286
+ out_ch: int,
287
+ ch_mult: list[int],
288
+ num_res_blocks: int,
289
+ z_channels: int,
290
+ scale_factor: float,
291
+ shift_factor: float,
292
+ ):
293
+ super().__init__()
294
+ self.encoder = Encoder(
295
+ resolution=resolution,
296
+ in_channels=in_channels,
297
+ ch=ch,
298
+ ch_mult=ch_mult,
299
+ num_res_blocks=num_res_blocks,
300
+ z_channels=z_channels,
301
+ )
302
+ self.decoder = Decoder(
303
+ resolution=resolution,
304
+ in_channels=in_channels,
305
+ ch=ch,
306
+ out_ch=out_ch,
307
+ ch_mult=ch_mult,
308
+ num_res_blocks=num_res_blocks,
309
+ z_channels=z_channels,
310
+ )
311
+ self.reg = DiagonalGaussian()
312
+
313
+ self.scale_factor = scale_factor
314
+ self.shift_factor = shift_factor
315
+
316
+ def encode(self, x: Tensor) -> Tensor:
317
+ z = self.reg(self.encoder(x))
318
+ z = self.scale_factor * (z - self.shift_factor)
319
+ return z
320
+
321
+ def decode(self, z: Tensor) -> Tensor:
322
+ z = z / self.scale_factor + self.shift_factor
323
+ return self.decoder(z)
324
+
325
+ def forward(self, x: Tensor) -> Tensor:
326
+ return self.decode(self.encode(x))
modules/conditioner.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from qwen_vl_utils import process_vision_info
3
+ from transformers import (
4
+ AutoProcessor,
5
+ Qwen2VLForConditionalGeneration,
6
+ Qwen2_5_VLForConditionalGeneration,
7
+ )
8
+ from torchvision.transforms import ToPILImage
9
+
10
+ to_pil = ToPILImage()
11
+
12
+ Qwen25VL_7b_PREFIX = '''Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:
13
+ - If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.
14
+ - If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.\n
15
+ Here are examples of how to transform or refine prompts:
16
+ - User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.
17
+ - User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.\n
18
+ Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:
19
+ User Prompt:'''
20
+
21
+
22
+ def split_string(s):
23
+ # 将中文引号替换为英文引号
24
+ s = s.replace("“", '"').replace("”", '"') # use english quotes
25
+ result = []
26
+ # 标记是否在引号内
27
+ in_quotes = False
28
+ temp = ""
29
+
30
+ # 遍历字符串中的每个字符及其索引
31
+ for idx, char in enumerate(s):
32
+ # 如果字符是引号且索引大于 155
33
+ if char == '"' and idx > 155:
34
+ # 将引号添加到临时字符串
35
+ temp += char
36
+ # 如果不在引号内
37
+ if not in_quotes:
38
+ # 将临时字符串添加到结果列表
39
+ result.append(temp)
40
+ # 清空临时字符串
41
+ temp = ""
42
+
43
+ # 切换引号状态
44
+ in_quotes = not in_quotes
45
+ continue
46
+ # 如果在引号内
47
+ if in_quotes:
48
+ # 如果字符是空格
49
+ if char.isspace():
50
+ pass # have space token
51
+
52
+ # 将字符用中文引号包裹后添加到结果列表
53
+ result.append("“" + char + "”")
54
+ else:
55
+ # 将字符添加到临时字符串
56
+ temp += char
57
+
58
+ # 如果临时字符串不为空
59
+ if temp:
60
+ # 将临时字符串添加到结果列表
61
+ result.append(temp)
62
+
63
+ return result
64
+
65
+
66
+ class Qwen25VL_7b_Embedder(torch.nn.Module):
67
+ def __init__(self, model_path, max_length=640, dtype=torch.bfloat16, device="cuda"):
68
+ super(Qwen25VL_7b_Embedder, self).__init__()
69
+ self.max_length = max_length
70
+ self.dtype = dtype
71
+ self.device = device
72
+
73
+ self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
74
+ model_path,
75
+ torch_dtype=dtype,
76
+ attn_implementation="eager",
77
+ ).to(torch.cuda.current_device())
78
+
79
+ self.model.requires_grad_(False)
80
+ self.processor = AutoProcessor.from_pretrained(
81
+ model_path, min_pixels=256 * 28 * 28, max_pixels=324 * 28 * 28
82
+ )
83
+
84
+ self.prefix = Qwen25VL_7b_PREFIX
85
+
86
+ def forward(self, caption, ref_images):
87
+ text_list = caption
88
+ embs = torch.zeros(
89
+ len(text_list),
90
+ self.max_length,
91
+ self.model.config.hidden_size,
92
+ dtype=torch.bfloat16,
93
+ device=torch.cuda.current_device(),
94
+ )
95
+ hidden_states = torch.zeros(
96
+ len(text_list),
97
+ self.max_length,
98
+ self.model.config.hidden_size,
99
+ dtype=torch.bfloat16,
100
+ device=torch.cuda.current_device(),
101
+ )
102
+ masks = torch.zeros(
103
+ len(text_list),
104
+ self.max_length,
105
+ dtype=torch.long,
106
+ device=torch.cuda.current_device(),
107
+ )
108
+ input_ids_list = []
109
+ attention_mask_list = []
110
+ emb_list = []
111
+
112
+ def split_string(s):
113
+ s = s.replace("“", '"').replace("”", '"').replace("'", '''"''') # use english quotes
114
+ result = []
115
+ in_quotes = False
116
+ temp = ""
117
+
118
+ for idx,char in enumerate(s):
119
+ if char == '"' and idx>155:
120
+ temp += char
121
+ if not in_quotes:
122
+ result.append(temp)
123
+ temp = ""
124
+
125
+ in_quotes = not in_quotes
126
+ continue
127
+ if in_quotes:
128
+ if char.isspace():
129
+ pass # have space token
130
+
131
+ result.append("“" + char + "”")
132
+ else:
133
+ temp += char
134
+
135
+ if temp:
136
+ result.append(temp)
137
+
138
+ return result
139
+
140
+ for idx, (txt, imgs) in enumerate(zip(text_list, ref_images)):
141
+
142
+ messages = [{"role": "user", "content": []}]
143
+
144
+ messages[0]["content"].append({"type": "text", "text": f"{self.prefix}"})
145
+
146
+ messages[0]["content"].append({"type": "image", "image": to_pil(imgs)})
147
+
148
+ # 再添加 text
149
+ messages[0]["content"].append({"type": "text", "text": f"{txt}"})
150
+
151
+ # Preparation for inference
152
+ text = self.processor.apply_chat_template(
153
+ messages, tokenize=False, add_generation_prompt=True, add_vision_id=True
154
+ )
155
+
156
+ image_inputs, video_inputs = process_vision_info(messages)
157
+
158
+ inputs = self.processor(
159
+ text=[text],
160
+ images=image_inputs,
161
+ padding=True,
162
+ return_tensors="pt",
163
+ )
164
+
165
+ old_inputs_ids = inputs.input_ids
166
+ text_split_list = split_string(text)
167
+
168
+ token_list = []
169
+ for text_each in text_split_list:
170
+ txt_inputs = self.processor(
171
+ text=text_each,
172
+ images=None,
173
+ videos=None,
174
+ padding=True,
175
+ return_tensors="pt",
176
+ )
177
+ token_each = txt_inputs.input_ids
178
+ if token_each[0][0] == 2073 and token_each[0][-1] == 854:
179
+ token_each = token_each[:, 1:-1]
180
+ token_list.append(token_each)
181
+ else:
182
+ token_list.append(token_each)
183
+
184
+ new_txt_ids = torch.cat(token_list, dim=1).to("cuda")
185
+
186
+ new_txt_ids = new_txt_ids.to(old_inputs_ids.device)
187
+
188
+ idx1 = (old_inputs_ids == 151653).nonzero(as_tuple=True)[1][0]
189
+ idx2 = (new_txt_ids == 151653).nonzero(as_tuple=True)[1][0]
190
+ inputs.input_ids = (
191
+ torch.cat([old_inputs_ids[0, :idx1], new_txt_ids[0, idx2:]], dim=0)
192
+ .unsqueeze(0)
193
+ .to("cuda")
194
+ )
195
+ inputs.attention_mask = (inputs.input_ids > 0).long().to("cuda")
196
+ outputs = self.model(
197
+ input_ids=inputs.input_ids,
198
+ attention_mask=inputs.attention_mask,
199
+ pixel_values=inputs.pixel_values.to("cuda"),
200
+ image_grid_thw=inputs.image_grid_thw.to("cuda"),
201
+ output_hidden_states=True,
202
+ )
203
+
204
+ emb = outputs["hidden_states"][-1]
205
+
206
+ embs[idx, : min(self.max_length, emb.shape[1] - 217)] = emb[0, 217:][
207
+ : self.max_length
208
+ ]
209
+
210
+ masks[idx, : min(self.max_length, emb.shape[1] - 217)] = torch.ones(
211
+ (min(self.max_length, emb.shape[1] - 217)),
212
+ dtype=torch.long,
213
+ device=torch.cuda.current_device(),
214
+ )
215
+
216
+ return embs, masks
modules/connector_edit.py ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.nn
5
+ from einops import rearrange
6
+ from torch import nn
7
+
8
+ from .layers import MLP, TextProjection, TimestepEmbedder, apply_gate, attention
9
+
10
+
11
+ class RMSNorm(nn.Module):
12
+ def __init__(
13
+ self,
14
+ dim: int,
15
+ elementwise_affine=True,
16
+ eps: float = 1e-6,
17
+ device=None,
18
+ dtype=None,
19
+ ):
20
+ """
21
+ Initialize the RMSNorm normalization layer.
22
+
23
+ Args:
24
+ dim (int): The dimension of the input tensor.
25
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
26
+
27
+ Attributes:
28
+ eps (float): A small value added to the denominator for numerical stability.
29
+ weight (nn.Parameter): Learnable scaling parameter.
30
+
31
+ """
32
+ factory_kwargs = {"device": device, "dtype": dtype}
33
+ super().__init__()
34
+ self.eps = eps
35
+ if elementwise_affine:
36
+ self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
37
+
38
+ def _norm(self, x):
39
+ """
40
+ Apply the RMSNorm normalization to the input tensor.
41
+
42
+ Args:
43
+ x (torch.Tensor): The input tensor.
44
+
45
+ Returns:
46
+ torch.Tensor: The normalized tensor.
47
+
48
+ """
49
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
50
+
51
+ def forward(self, x):
52
+ """
53
+ Forward pass through the RMSNorm layer.
54
+
55
+ Args:
56
+ x (torch.Tensor): The input tensor.
57
+
58
+ Returns:
59
+ torch.Tensor: The output tensor after applying RMSNorm.
60
+
61
+ """
62
+ output = self._norm(x.float()).type_as(x)
63
+ if hasattr(self, "weight"):
64
+ output = output * self.weight
65
+ return output
66
+
67
+
68
+ def get_norm_layer(norm_layer):
69
+ """
70
+ Get the normalization layer.
71
+
72
+ Args:
73
+ norm_layer (str): The type of normalization layer.
74
+
75
+ Returns:
76
+ norm_layer (nn.Module): The normalization layer.
77
+ """
78
+ if norm_layer == "layer":
79
+ return nn.LayerNorm
80
+ elif norm_layer == "rms":
81
+ return RMSNorm
82
+ else:
83
+ raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
84
+
85
+
86
+ def get_activation_layer(act_type):
87
+ """get activation layer
88
+
89
+ Args:
90
+ act_type (str): the activation type
91
+
92
+ Returns:
93
+ torch.nn.functional: the activation layer
94
+ """
95
+ if act_type == "gelu":
96
+ return lambda: nn.GELU()
97
+ elif act_type == "gelu_tanh":
98
+ return lambda: nn.GELU(approximate="tanh")
99
+ elif act_type == "relu":
100
+ return nn.ReLU
101
+ elif act_type == "silu":
102
+ return nn.SiLU
103
+ else:
104
+ raise ValueError(f"Unknown activation type: {act_type}")
105
+
106
+ class IndividualTokenRefinerBlock(torch.nn.Module):
107
+ def __init__(
108
+ self,
109
+ hidden_size,
110
+ heads_num,
111
+ mlp_width_ratio: str = 4.0,
112
+ mlp_drop_rate: float = 0.0,
113
+ act_type: str = "silu",
114
+ qk_norm: bool = False,
115
+ qk_norm_type: str = "layer",
116
+ qkv_bias: bool = True,
117
+ need_CA: bool = False,
118
+ dtype: Optional[torch.dtype] = None,
119
+ device: Optional[torch.device] = None,
120
+ ):
121
+ factory_kwargs = {"device": device, "dtype": dtype}
122
+ super().__init__()
123
+ self.need_CA = need_CA
124
+ self.heads_num = heads_num
125
+ head_dim = hidden_size // heads_num
126
+ mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
127
+
128
+ self.norm1 = nn.LayerNorm(
129
+ hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
130
+ )
131
+ self.self_attn_qkv = nn.Linear(
132
+ hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs
133
+ )
134
+ qk_norm_layer = get_norm_layer(qk_norm_type)
135
+ self.self_attn_q_norm = (
136
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
137
+ if qk_norm
138
+ else nn.Identity()
139
+ )
140
+ self.self_attn_k_norm = (
141
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
142
+ if qk_norm
143
+ else nn.Identity()
144
+ )
145
+ self.self_attn_proj = nn.Linear(
146
+ hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
147
+ )
148
+
149
+ self.norm2 = nn.LayerNorm(
150
+ hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
151
+ )
152
+ act_layer = get_activation_layer(act_type)
153
+ self.mlp = MLP(
154
+ in_channels=hidden_size,
155
+ hidden_channels=mlp_hidden_dim,
156
+ act_layer=act_layer,
157
+ drop=mlp_drop_rate,
158
+ **factory_kwargs,
159
+ )
160
+
161
+ self.adaLN_modulation = nn.Sequential(
162
+ act_layer(),
163
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
164
+ )
165
+
166
+ if self.need_CA:
167
+ self.cross_attnblock=CrossAttnBlock(hidden_size=hidden_size,
168
+ heads_num=heads_num,
169
+ mlp_width_ratio=mlp_width_ratio,
170
+ mlp_drop_rate=mlp_drop_rate,
171
+ act_type=act_type,
172
+ qk_norm=qk_norm,
173
+ qk_norm_type=qk_norm_type,
174
+ qkv_bias=qkv_bias,
175
+ **factory_kwargs,)
176
+ # Zero-initialize the modulation
177
+ nn.init.zeros_(self.adaLN_modulation[1].weight)
178
+ nn.init.zeros_(self.adaLN_modulation[1].bias)
179
+
180
+ def forward(
181
+ self,
182
+ x: torch.Tensor,
183
+ c: torch.Tensor, # timestep_aware_representations + context_aware_representations
184
+ attn_mask: torch.Tensor = None,
185
+ y: torch.Tensor = None,
186
+ ):
187
+ gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
188
+
189
+ norm_x = self.norm1(x)
190
+ qkv = self.self_attn_qkv(norm_x)
191
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
192
+ # Apply QK-Norm if needed
193
+ q = self.self_attn_q_norm(q).to(v)
194
+ k = self.self_attn_k_norm(k).to(v)
195
+
196
+ # Self-Attention
197
+ attn = attention(q, k, v, mode="torch", attn_mask=attn_mask)
198
+
199
+ x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
200
+
201
+ if self.need_CA:
202
+ x = self.cross_attnblock(x, c, attn_mask, y)
203
+
204
+ # FFN Layer
205
+ x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp)
206
+
207
+ return x
208
+
209
+
210
+
211
+
212
+ class CrossAttnBlock(torch.nn.Module):
213
+ def __init__(
214
+ self,
215
+ hidden_size,
216
+ heads_num,
217
+ mlp_width_ratio: str = 4.0,
218
+ mlp_drop_rate: float = 0.0,
219
+ act_type: str = "silu",
220
+ qk_norm: bool = False,
221
+ qk_norm_type: str = "layer",
222
+ qkv_bias: bool = True,
223
+ dtype: Optional[torch.dtype] = None,
224
+ device: Optional[torch.device] = None,
225
+ ):
226
+ factory_kwargs = {"device": device, "dtype": dtype}
227
+ super().__init__()
228
+ self.heads_num = heads_num
229
+ head_dim = hidden_size // heads_num
230
+
231
+ self.norm1 = nn.LayerNorm(
232
+ hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
233
+ )
234
+ self.norm1_2 = nn.LayerNorm(
235
+ hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
236
+ )
237
+ self.self_attn_q = nn.Linear(
238
+ hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
239
+ )
240
+ self.self_attn_kv = nn.Linear(
241
+ hidden_size, hidden_size*2, bias=qkv_bias, **factory_kwargs
242
+ )
243
+ qk_norm_layer = get_norm_layer(qk_norm_type)
244
+ self.self_attn_q_norm = (
245
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
246
+ if qk_norm
247
+ else nn.Identity()
248
+ )
249
+ self.self_attn_k_norm = (
250
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
251
+ if qk_norm
252
+ else nn.Identity()
253
+ )
254
+ self.self_attn_proj = nn.Linear(
255
+ hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
256
+ )
257
+
258
+ self.norm2 = nn.LayerNorm(
259
+ hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
260
+ )
261
+ act_layer = get_activation_layer(act_type)
262
+
263
+ self.adaLN_modulation = nn.Sequential(
264
+ act_layer(),
265
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
266
+ )
267
+ # Zero-initialize the modulation
268
+ nn.init.zeros_(self.adaLN_modulation[1].weight)
269
+ nn.init.zeros_(self.adaLN_modulation[1].bias)
270
+
271
+ def forward(
272
+ self,
273
+ x: torch.Tensor,
274
+ c: torch.Tensor, # timestep_aware_representations + context_aware_representations
275
+ attn_mask: torch.Tensor = None,
276
+ y: torch.Tensor=None,
277
+
278
+ ):
279
+ gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
280
+
281
+ norm_x = self.norm1(x)
282
+ norm_y = self.norm1_2(y)
283
+ q = self.self_attn_q(norm_x)
284
+ q = rearrange(q, "B L (H D) -> B L H D", H=self.heads_num)
285
+ kv = self.self_attn_kv(norm_y)
286
+ k, v = rearrange(kv, "B L (K H D) -> K B L H D", K=2, H=self.heads_num)
287
+ # Apply QK-Norm if needed
288
+ q = self.self_attn_q_norm(q).to(v)
289
+ k = self.self_attn_k_norm(k).to(v)
290
+
291
+ # Self-Attention
292
+ attn = attention(q, k, v, mode="torch", attn_mask=attn_mask)
293
+
294
+ x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
295
+
296
+ return x
297
+
298
+
299
+
300
+ class IndividualTokenRefiner(torch.nn.Module):
301
+ def __init__(
302
+ self,
303
+ hidden_size,
304
+ heads_num,
305
+ depth,
306
+ mlp_width_ratio: float = 4.0,
307
+ mlp_drop_rate: float = 0.0,
308
+ act_type: str = "silu",
309
+ qk_norm: bool = False,
310
+ qk_norm_type: str = "layer",
311
+ qkv_bias: bool = True,
312
+ need_CA:bool=False,
313
+ dtype: Optional[torch.dtype] = None,
314
+ device: Optional[torch.device] = None,
315
+ ):
316
+
317
+ factory_kwargs = {"device": device, "dtype": dtype}
318
+ super().__init__()
319
+ self.need_CA = need_CA
320
+ self.blocks = nn.ModuleList(
321
+ [
322
+ IndividualTokenRefinerBlock(
323
+ hidden_size=hidden_size,
324
+ heads_num=heads_num,
325
+ mlp_width_ratio=mlp_width_ratio,
326
+ mlp_drop_rate=mlp_drop_rate,
327
+ act_type=act_type,
328
+ qk_norm=qk_norm,
329
+ qk_norm_type=qk_norm_type,
330
+ qkv_bias=qkv_bias,
331
+ need_CA=self.need_CA,
332
+ **factory_kwargs,
333
+ )
334
+ for _ in range(depth)
335
+ ]
336
+ )
337
+
338
+
339
+ def forward(
340
+ self,
341
+ x: torch.Tensor,
342
+ c: torch.LongTensor,
343
+ mask: Optional[torch.Tensor] = None,
344
+ y:torch.Tensor=None,
345
+ ):
346
+ self_attn_mask = None
347
+ if mask is not None:
348
+ batch_size = mask.shape[0]
349
+ seq_len = mask.shape[1]
350
+ mask = mask.to(x.device)
351
+ # batch_size x 1 x seq_len x seq_len
352
+ self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat(
353
+ 1, 1, seq_len, 1
354
+ )
355
+ # batch_size x 1 x seq_len x seq_len
356
+ self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
357
+ # batch_size x 1 x seq_len x seq_len, 1 for broadcasting of heads_num
358
+ self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
359
+ # avoids self-attention weight being NaN for padding tokens
360
+ self_attn_mask[:, :, :, 0] = True
361
+
362
+
363
+ for block in self.blocks:
364
+ x = block(x, c, self_attn_mask,y)
365
+
366
+ return x
367
+
368
+
369
+ class SingleTokenRefiner(torch.nn.Module):
370
+ """
371
+ A single token refiner block for llm text embedding refine.
372
+ """
373
+ def __init__(
374
+ self,
375
+ in_channels,
376
+ hidden_size,
377
+ heads_num,
378
+ depth,
379
+ mlp_width_ratio: float = 4.0,
380
+ mlp_drop_rate: float = 0.0,
381
+ act_type: str = "silu",
382
+ qk_norm: bool = False,
383
+ qk_norm_type: str = "layer",
384
+ qkv_bias: bool = True,
385
+ need_CA:bool=False,
386
+ attn_mode: str = "torch",
387
+ dtype: Optional[torch.dtype] = None,
388
+ device: Optional[torch.device] = None,
389
+ ):
390
+ factory_kwargs = {"device": device, "dtype": dtype}
391
+ super().__init__()
392
+ self.attn_mode = attn_mode
393
+ self.need_CA = need_CA
394
+ assert self.attn_mode == "torch", "Only support 'torch' mode for token refiner."
395
+
396
+ self.input_embedder = nn.Linear(
397
+ in_channels, hidden_size, bias=True, **factory_kwargs
398
+ )
399
+ if self.need_CA:
400
+ self.input_embedder_CA = nn.Linear(
401
+ in_channels, hidden_size, bias=True, **factory_kwargs
402
+ )
403
+
404
+ act_layer = get_activation_layer(act_type)
405
+ # Build timestep embedding layer
406
+ self.t_embedder = TimestepEmbedder(hidden_size, act_layer, **factory_kwargs)
407
+ # Build context embedding layer
408
+ self.c_embedder = TextProjection(
409
+ in_channels, hidden_size, act_layer, **factory_kwargs
410
+ )
411
+
412
+ self.individual_token_refiner = IndividualTokenRefiner(
413
+ hidden_size=hidden_size,
414
+ heads_num=heads_num,
415
+ depth=depth,
416
+ mlp_width_ratio=mlp_width_ratio,
417
+ mlp_drop_rate=mlp_drop_rate,
418
+ act_type=act_type,
419
+ qk_norm=qk_norm,
420
+ qk_norm_type=qk_norm_type,
421
+ qkv_bias=qkv_bias,
422
+ need_CA=need_CA,
423
+ **factory_kwargs,
424
+ )
425
+
426
+ def forward(
427
+ self,
428
+ x: torch.Tensor,
429
+ t: torch.LongTensor,
430
+ mask: Optional[torch.LongTensor] = None,
431
+ y: torch.LongTensor=None,
432
+ ):
433
+ timestep_aware_representations = self.t_embedder(t)
434
+
435
+ if mask is None:
436
+ context_aware_representations = x.mean(dim=1)
437
+ else:
438
+ mask_float = mask.unsqueeze(-1) # [b, s1, 1]
439
+ context_aware_representations = (x * mask_float).sum(
440
+ dim=1
441
+ ) / mask_float.sum(dim=1)
442
+ context_aware_representations = self.c_embedder(context_aware_representations)
443
+ c = timestep_aware_representations + context_aware_representations
444
+
445
+ x = self.input_embedder(x)
446
+ if self.need_CA:
447
+ y = self.input_embedder_CA(y)
448
+ x = self.individual_token_refiner(x, c, mask, y)
449
+ else:
450
+ x = self.individual_token_refiner(x, c, mask)
451
+
452
+ return x
453
+
454
+
455
+
456
+ class Qwen2Connector(torch.nn.Module):
457
+ def __init__(
458
+ self,
459
+ # biclip_dim=1024,
460
+ in_channels=3584,
461
+ hidden_size=4096,
462
+ heads_num=32,
463
+ depth=2,
464
+ need_CA=False,
465
+ device=None,
466
+ dtype=torch.bfloat16,
467
+ ):
468
+ super().__init__()
469
+ factory_kwargs = {"device": device, "dtype":dtype}
470
+
471
+ self.S =SingleTokenRefiner(in_channels=in_channels,hidden_size=hidden_size,heads_num=heads_num,depth=depth,need_CA=need_CA,**factory_kwargs)
472
+ self.global_proj_out=nn.Linear(in_channels,768)
473
+
474
+ self.scale_factor = nn.Parameter(torch.zeros(1))
475
+ with torch.no_grad():
476
+ self.scale_factor.data += -(1 - 0.09)
477
+
478
+ def forward(self, x,t,mask):
479
+ mask_float = mask.unsqueeze(-1) # [b, s1, 1]
480
+ x_mean = (x * mask_float).sum(
481
+ dim=1
482
+ ) / mask_float.sum(dim=1) * (1 + self.scale_factor)
483
+
484
+ global_out=self.global_proj_out(x_mean)
485
+ encoder_hidden_states = self.S(x,t,mask)
486
+ return encoder_hidden_states,global_out
modules/layers.py ADDED
@@ -0,0 +1,640 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from Flux
2
+ #
3
+ # Copyright 2024 Black Forest Labs
4
+
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+ # This source code is licensed under the license found in the
18
+ # LICENSE file in the root directory of this source tree.
19
+
20
+ import math # noqa: I001
21
+ from dataclasses import dataclass
22
+ from functools import partial
23
+
24
+ import torch
25
+ import torch.nn.functional as F
26
+ from einops import rearrange
27
+ # from liger_kernel.ops.rms_norm import LigerRMSNormFunction
28
+ from torch import Tensor, nn
29
+
30
+
31
+ try:
32
+ import flash_attn
33
+ from flash_attn.flash_attn_interface import (
34
+ _flash_attn_forward,
35
+ flash_attn_varlen_func,
36
+ )
37
+ except ImportError:
38
+ flash_attn = None
39
+ flash_attn_varlen_func = None
40
+ _flash_attn_forward = None
41
+
42
+
43
+ MEMORY_LAYOUT = {
44
+ "flash": (
45
+ lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
46
+ lambda x: x,
47
+ ),
48
+ "torch": (
49
+ lambda x: x.transpose(1, 2),
50
+ lambda x: x.transpose(1, 2),
51
+ ),
52
+ "vanilla": (
53
+ lambda x: x.transpose(1, 2),
54
+ lambda x: x.transpose(1, 2),
55
+ ),
56
+ }
57
+
58
+
59
+ def attention(
60
+ q,
61
+ k,
62
+ v,
63
+ mode="torch",
64
+ drop_rate=0,
65
+ attn_mask=None,
66
+ causal=False,
67
+ cu_seqlens_q=None,
68
+ cu_seqlens_kv=None,
69
+ max_seqlen_q=None,
70
+ max_seqlen_kv=None,
71
+ batch_size=1,
72
+ ):
73
+ """
74
+ Perform QKV self attention.
75
+
76
+ Args:
77
+ q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
78
+ k (torch.Tensor): Key tensor with shape [b, s1, a, d]
79
+ v (torch.Tensor): Value tensor with shape [b, s1, a, d]
80
+ mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
81
+ drop_rate (float): Dropout rate in attention map. (default: 0)
82
+ attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
83
+ (default: None)
84
+ causal (bool): Whether to use causal attention. (default: False)
85
+ cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
86
+ used to index into q.
87
+ cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
88
+ used to index into kv.
89
+ max_seqlen_q (int): The maximum sequence length in the batch of q.
90
+ max_seqlen_kv (int): The maximum sequence length in the batch of k and v.
91
+
92
+ Returns:
93
+ torch.Tensor: Output tensor after self attention with shape [b, s, ad]
94
+ """
95
+ pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
96
+ q = pre_attn_layout(q)
97
+ k = pre_attn_layout(k)
98
+ v = pre_attn_layout(v)
99
+
100
+ if mode == "torch":
101
+ if attn_mask is not None and attn_mask.dtype != torch.bool:
102
+ attn_mask = attn_mask.to(q.dtype)
103
+ x = F.scaled_dot_product_attention(
104
+ q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal
105
+ )
106
+ elif mode == "flash":
107
+ assert flash_attn_varlen_func is not None
108
+ x: torch.Tensor = flash_attn_varlen_func(
109
+ q,
110
+ k,
111
+ v,
112
+ cu_seqlens_q,
113
+ cu_seqlens_kv,
114
+ max_seqlen_q,
115
+ max_seqlen_kv,
116
+ ) # type: ignore
117
+ # x with shape [(bxs), a, d]
118
+ x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # type: ignore # reshape x to [b, s, a, d]
119
+ elif mode == "vanilla":
120
+ scale_factor = 1 / math.sqrt(q.size(-1))
121
+
122
+ b, a, s, _ = q.shape
123
+ s1 = k.size(2)
124
+ attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
125
+ if causal:
126
+ # Only applied to self attention
127
+ assert attn_mask is None, (
128
+ "Causal mask and attn_mask cannot be used together"
129
+ )
130
+ temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(
131
+ diagonal=0
132
+ )
133
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
134
+ attn_bias.to(q.dtype)
135
+
136
+ if attn_mask is not None:
137
+ if attn_mask.dtype == torch.bool:
138
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
139
+ else:
140
+ attn_bias += attn_mask
141
+
142
+ # TODO: Maybe force q and k to be float32 to avoid numerical overflow
143
+ attn = (q @ k.transpose(-2, -1)) * scale_factor
144
+ attn += attn_bias
145
+ attn = attn.softmax(dim=-1)
146
+ attn = torch.dropout(attn, p=drop_rate, train=True)
147
+ x = attn @ v
148
+ else:
149
+ raise NotImplementedError(f"Unsupported attention mode: {mode}")
150
+
151
+ x = post_attn_layout(x)
152
+ b, s, a, d = x.shape
153
+ out = x.reshape(b, s, -1)
154
+ return out
155
+
156
+
157
+ def apply_gate(x, gate=None, tanh=False):
158
+ """AI is creating summary for apply_gate
159
+
160
+ Args:
161
+ x (torch.Tensor): input tensor.
162
+ gate (torch.Tensor, optional): gate tensor. Defaults to None.
163
+ tanh (bool, optional): whether to use tanh function. Defaults to False.
164
+
165
+ Returns:
166
+ torch.Tensor: the output tensor after apply gate.
167
+ """
168
+ if gate is None:
169
+ return x
170
+ if tanh:
171
+ return x * gate.unsqueeze(1).tanh()
172
+ else:
173
+ return x * gate.unsqueeze(1)
174
+
175
+
176
+ class MLP(nn.Module):
177
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
178
+
179
+ def __init__(
180
+ self,
181
+ in_channels,
182
+ hidden_channels=None,
183
+ out_features=None,
184
+ act_layer=nn.GELU,
185
+ norm_layer=None,
186
+ bias=True,
187
+ drop=0.0,
188
+ use_conv=False,
189
+ device=None,
190
+ dtype=None,
191
+ ):
192
+ super().__init__()
193
+ out_features = out_features or in_channels
194
+ hidden_channels = hidden_channels or in_channels
195
+ bias = (bias, bias)
196
+ drop_probs = (drop, drop)
197
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
198
+
199
+ self.fc1 = linear_layer(
200
+ in_channels, hidden_channels, bias=bias[0], device=device, dtype=dtype
201
+ )
202
+ self.act = act_layer()
203
+ self.drop1 = nn.Dropout(drop_probs[0])
204
+ self.norm = (
205
+ norm_layer(hidden_channels, device=device, dtype=dtype)
206
+ if norm_layer is not None
207
+ else nn.Identity()
208
+ )
209
+ self.fc2 = linear_layer(
210
+ hidden_channels, out_features, bias=bias[1], device=device, dtype=dtype
211
+ )
212
+ self.drop2 = nn.Dropout(drop_probs[1])
213
+
214
+ def forward(self, x):
215
+ x = self.fc1(x)
216
+ x = self.act(x)
217
+ x = self.drop1(x)
218
+ x = self.norm(x)
219
+ x = self.fc2(x)
220
+ x = self.drop2(x)
221
+ return x
222
+
223
+
224
+ class TextProjection(nn.Module):
225
+ """
226
+ Projects text embeddings. Also handles dropout for classifier-free guidance.
227
+
228
+ Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
229
+ """
230
+
231
+ def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None):
232
+ factory_kwargs = {"dtype": dtype, "device": device}
233
+ super().__init__()
234
+ self.linear_1 = nn.Linear(
235
+ in_features=in_channels,
236
+ out_features=hidden_size,
237
+ bias=True,
238
+ **factory_kwargs,
239
+ )
240
+ self.act_1 = act_layer()
241
+ self.linear_2 = nn.Linear(
242
+ in_features=hidden_size,
243
+ out_features=hidden_size,
244
+ bias=True,
245
+ **factory_kwargs,
246
+ )
247
+
248
+ def forward(self, caption):
249
+ hidden_states = self.linear_1(caption)
250
+ hidden_states = self.act_1(hidden_states)
251
+ hidden_states = self.linear_2(hidden_states)
252
+ return hidden_states
253
+
254
+
255
+ class TimestepEmbedder(nn.Module):
256
+ """
257
+ Embeds scalar timesteps into vector representations.
258
+ """
259
+
260
+ def __init__(
261
+ self,
262
+ hidden_size,
263
+ act_layer,
264
+ frequency_embedding_size=256,
265
+ max_period=10000,
266
+ out_size=None,
267
+ dtype=None,
268
+ device=None,
269
+ ):
270
+ factory_kwargs = {"dtype": dtype, "device": device}
271
+ super().__init__()
272
+ self.frequency_embedding_size = frequency_embedding_size
273
+ self.max_period = max_period
274
+ if out_size is None:
275
+ out_size = hidden_size
276
+
277
+ self.mlp = nn.Sequential(
278
+ nn.Linear(
279
+ frequency_embedding_size, hidden_size, bias=True, **factory_kwargs
280
+ ),
281
+ act_layer(),
282
+ nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
283
+ )
284
+ nn.init.normal_(self.mlp[0].weight, std=0.02) # type: ignore
285
+ nn.init.normal_(self.mlp[2].weight, std=0.02) # type: ignore
286
+
287
+ @staticmethod
288
+ def timestep_embedding(t, dim, max_period=10000):
289
+ """
290
+ Create sinusoidal timestep embeddings.
291
+
292
+ Args:
293
+ t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional.
294
+ dim (int): the dimension of the output.
295
+ max_period (int): controls the minimum frequency of the embeddings.
296
+
297
+ Returns:
298
+ embedding (torch.Tensor): An (N, D) Tensor of positional embeddings.
299
+
300
+ .. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
301
+ """
302
+ half = dim // 2
303
+ freqs = torch.exp(
304
+ -math.log(max_period)
305
+ * torch.arange(start=0, end=half, dtype=torch.float32)
306
+ / half
307
+ ).to(device=t.device)
308
+ args = t[:, None].float() * freqs[None]
309
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
310
+ if dim % 2:
311
+ embedding = torch.cat(
312
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
313
+ )
314
+ return embedding
315
+
316
+ def forward(self, t):
317
+ t_freq = self.timestep_embedding(
318
+ t, self.frequency_embedding_size, self.max_period
319
+ ).type(self.mlp[0].weight.dtype) # type: ignore
320
+ t_emb = self.mlp(t_freq)
321
+ return t_emb
322
+
323
+
324
+ class EmbedND(nn.Module):
325
+ def __init__(self, dim: int, theta: int, axes_dim: list[int]):
326
+ super().__init__()
327
+ self.dim = dim
328
+ self.theta = theta
329
+ self.axes_dim = axes_dim
330
+
331
+ def forward(self, ids: Tensor) -> Tensor:
332
+ n_axes = ids.shape[-1]
333
+ emb = torch.cat(
334
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
335
+ dim=-3,
336
+ )
337
+
338
+ return emb.unsqueeze(1)
339
+
340
+
341
+ class MLPEmbedder(nn.Module):
342
+ def __init__(self, in_dim: int, hidden_dim: int):
343
+ super().__init__()
344
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
345
+ self.silu = nn.SiLU()
346
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
347
+
348
+ def forward(self, x: Tensor) -> Tensor:
349
+ return self.out_layer(self.silu(self.in_layer(x)))
350
+
351
+
352
+ def rope(pos, dim: int, theta: int):
353
+ assert dim % 2 == 0
354
+ scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
355
+ omega = 1.0 / (theta**scale)
356
+ out = torch.einsum("...n,d->...nd", pos, omega)
357
+ out = torch.stack(
358
+ [torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1
359
+ )
360
+ out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
361
+ return out.float()
362
+
363
+
364
+ def attention_after_rope(q, k, v, pe):
365
+ q, k = apply_rope(q, k, pe)
366
+
367
+ from .attention import attention
368
+
369
+ x = attention(q, k, v, mode="torch")
370
+ return x
371
+
372
+
373
+ @torch.compile(mode="max-autotune-no-cudagraphs", dynamic=True)
374
+ def apply_rope(xq, xk, freqs_cis):
375
+ # 将 num_heads 和 seq_len 的维度交换回原函数的处理顺序
376
+ xq = xq.transpose(1, 2) # [batch, num_heads, seq_len, head_dim]
377
+ xk = xk.transpose(1, 2)
378
+
379
+ # 将 head_dim 拆分为复数部分(实部和虚部)
380
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
381
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
382
+
383
+ # 应用旋转位置编码(复数乘法)
384
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
385
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
386
+
387
+ # 恢复张量形状并转置回目标维度顺序
388
+ xq_out = xq_out.reshape(*xq.shape).type_as(xq).transpose(1, 2)
389
+ xk_out = xk_out.reshape(*xk.shape).type_as(xk).transpose(1, 2)
390
+
391
+ return xq_out, xk_out
392
+
393
+
394
+ @torch.compile(mode="max-autotune-no-cudagraphs", dynamic=True)
395
+ def scale_add_residual(
396
+ x: torch.Tensor, scale: torch.Tensor, residual: torch.Tensor
397
+ ) -> torch.Tensor:
398
+ return x * scale + residual
399
+
400
+
401
+ @torch.compile(mode="max-autotune-no-cudagraphs", dynamic=True)
402
+ def layernorm_and_scale_shift(
403
+ x: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor
404
+ ) -> torch.Tensor:
405
+ return torch.nn.functional.layer_norm(x, (x.size(-1),)) * (scale + 1) + shift
406
+
407
+
408
+ class SelfAttention(nn.Module):
409
+ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
410
+ super().__init__()
411
+ self.num_heads = num_heads
412
+ head_dim = dim // num_heads
413
+
414
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
415
+ self.norm = QKNorm(head_dim)
416
+ self.proj = nn.Linear(dim, dim)
417
+
418
+ def forward(self, x: Tensor, pe: Tensor) -> Tensor:
419
+ qkv = self.qkv(x)
420
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
421
+ q, k = self.norm(q, k, v)
422
+ x = attention_after_rope(q, k, v, pe=pe)
423
+ x = self.proj(x)
424
+ return x
425
+
426
+
427
+ @dataclass
428
+ class ModulationOut:
429
+ shift: Tensor
430
+ scale: Tensor
431
+ gate: Tensor
432
+
433
+
434
+ class RMSNorm(torch.nn.Module):
435
+ def __init__(self, dim: int):
436
+ super().__init__()
437
+ self.scale = nn.Parameter(torch.ones(dim))
438
+
439
+ # @staticmethod
440
+ # def rms_norm_fast(x, weight, eps):
441
+ # return LigerRMSNormFunction.apply(
442
+ # x,
443
+ # weight,
444
+ # eps,
445
+ # 0.0,
446
+ # "gemma",
447
+ # True,
448
+ # )
449
+
450
+ @staticmethod
451
+ def rms_norm(x, weight, eps):
452
+ x_dtype = x.dtype
453
+ x = x.float()
454
+ rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
455
+ return (x * rrms).to(dtype=x_dtype) * weight
456
+
457
+ def forward(self, x: Tensor):
458
+ # return self.rms_norm_fast(x, self.scale, 1e-6)
459
+ return self.rms_norm(x, self.scale, 1e-6)
460
+
461
+
462
+ class QKNorm(torch.nn.Module):
463
+ def __init__(self, dim: int):
464
+ super().__init__()
465
+ self.query_norm = RMSNorm(dim)
466
+ self.key_norm = RMSNorm(dim)
467
+
468
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
469
+ q = self.query_norm(q)
470
+ k = self.key_norm(k)
471
+ return q.to(v), k.to(v)
472
+
473
+
474
+ class Modulation(nn.Module):
475
+ def __init__(self, dim: int, double: bool):
476
+ super().__init__()
477
+ self.is_double = double
478
+ self.multiplier = 6 if double else 3
479
+ self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
480
+
481
+ def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
482
+ out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(
483
+ self.multiplier, dim=-1
484
+ )
485
+
486
+ return (
487
+ ModulationOut(*out[:3]),
488
+ ModulationOut(*out[3:]) if self.is_double else None,
489
+ )
490
+
491
+
492
+ class DoubleStreamBlock(nn.Module):
493
+ def __init__(
494
+ self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False
495
+ ):
496
+ super().__init__()
497
+
498
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
499
+ self.num_heads = num_heads
500
+ self.hidden_size = hidden_size
501
+ self.img_mod = Modulation(hidden_size, double=True)
502
+ self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
503
+ self.img_attn = SelfAttention(
504
+ dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
505
+ )
506
+
507
+ self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
508
+ self.img_mlp = nn.Sequential(
509
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
510
+ nn.GELU(approximate="tanh"),
511
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
512
+ )
513
+
514
+ self.txt_mod = Modulation(hidden_size, double=True)
515
+ self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
516
+ self.txt_attn = SelfAttention(
517
+ dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
518
+ )
519
+
520
+ self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
521
+ self.txt_mlp = nn.Sequential(
522
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
523
+ nn.GELU(approximate="tanh"),
524
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
525
+ )
526
+
527
+ def forward(
528
+ self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor
529
+ ) -> tuple[Tensor, Tensor]:
530
+ img_mod1, img_mod2 = self.img_mod(vec)
531
+ txt_mod1, txt_mod2 = self.txt_mod(vec)
532
+
533
+ # prepare image for attention
534
+ img_modulated = self.img_norm1(img)
535
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
536
+ img_qkv = self.img_attn.qkv(img_modulated)
537
+ img_q, img_k, img_v = rearrange(
538
+ img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads
539
+ )
540
+ img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
541
+
542
+ # prepare txt for attention
543
+ txt_modulated = self.txt_norm1(txt)
544
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
545
+ txt_qkv = self.txt_attn.qkv(txt_modulated)
546
+ txt_q, txt_k, txt_v = rearrange(
547
+ txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads
548
+ )
549
+ txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
550
+
551
+ # run actual attention
552
+ q = torch.cat((txt_q, img_q), dim=1)
553
+ k = torch.cat((txt_k, img_k), dim=1)
554
+ v = torch.cat((txt_v, img_v), dim=1)
555
+
556
+ attn = attention_after_rope(q, k, v, pe=pe)
557
+ txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
558
+
559
+ # calculate the img bloks
560
+ img = img + img_mod1.gate * self.img_attn.proj(img_attn)
561
+ img_mlp = self.img_mlp(
562
+ (1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
563
+ )
564
+ img = scale_add_residual(img_mlp, img_mod2.gate, img)
565
+
566
+ # calculate the txt bloks
567
+ txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
568
+ txt_mlp = self.txt_mlp(
569
+ (1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
570
+ )
571
+ txt = scale_add_residual(txt_mlp, txt_mod2.gate, txt)
572
+ return img, txt
573
+
574
+
575
+ class SingleStreamBlock(nn.Module):
576
+ """
577
+ A DiT block with parallel linear layers as described in
578
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
579
+ """
580
+
581
+ def __init__(
582
+ self,
583
+ hidden_size: int,
584
+ num_heads: int,
585
+ mlp_ratio: float = 4.0,
586
+ qk_scale: float | None = None,
587
+ ):
588
+ super().__init__()
589
+ self.hidden_dim = hidden_size
590
+ self.num_heads = num_heads
591
+ head_dim = hidden_size // num_heads
592
+ self.scale = qk_scale or head_dim**-0.5
593
+
594
+ self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
595
+ # qkv and mlp_in
596
+ self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
597
+ # proj and mlp_out
598
+ self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
599
+
600
+ self.norm = QKNorm(head_dim)
601
+
602
+ self.hidden_size = hidden_size
603
+ self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
604
+
605
+ self.mlp_act = nn.GELU(approximate="tanh")
606
+ self.modulation = Modulation(hidden_size, double=False)
607
+
608
+ def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
609
+ mod, _ = self.modulation(vec)
610
+ x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
611
+ qkv, mlp = torch.split(
612
+ self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1
613
+ )
614
+
615
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
616
+ q, k = self.norm(q, k, v)
617
+
618
+ # compute attention
619
+ attn = attention_after_rope(q, k, v, pe=pe)
620
+ # compute activation in mlp stream, cat again and run second linear layer
621
+ output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
622
+ return scale_add_residual(output, mod.gate, x)
623
+
624
+
625
+ class LastLayer(nn.Module):
626
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
627
+ super().__init__()
628
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
629
+ self.linear = nn.Linear(
630
+ hidden_size, patch_size * patch_size * out_channels, bias=True
631
+ )
632
+ self.adaLN_modulation = nn.Sequential(
633
+ nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
634
+ )
635
+
636
+ def forward(self, x: Tensor, vec: Tensor) -> Tensor:
637
+ shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
638
+ x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
639
+ x = self.linear(x)
640
+ return x
modules/model_edit.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+
4
+ import numpy as np
5
+ import torch
6
+ from torch import Tensor, nn
7
+
8
+ from .connector_edit import Qwen2Connector
9
+ from .layers import DoubleStreamBlock, EmbedND, LastLayer, MLPEmbedder, SingleStreamBlock
10
+
11
+
12
+ @dataclass
13
+ class Step1XParams:
14
+ in_channels: int
15
+ out_channels: int
16
+ vec_in_dim: int
17
+ context_in_dim: int
18
+ hidden_size: int
19
+ mlp_ratio: float
20
+ num_heads: int
21
+ depth: int
22
+ depth_single_blocks: int
23
+ axes_dim: list[int]
24
+ theta: int
25
+ qkv_bias: bool
26
+
27
+
28
+ class Step1XEdit(nn.Module):
29
+ """
30
+ Transformer model for flow matching on sequences.
31
+ """
32
+
33
+ def __init__(self, params: Step1XParams):
34
+ super().__init__()
35
+
36
+ self.params = params
37
+ self.in_channels = params.in_channels
38
+ self.out_channels = params.out_channels
39
+ if params.hidden_size % params.num_heads != 0:
40
+ raise ValueError(
41
+ f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
42
+ )
43
+ pe_dim = params.hidden_size // params.num_heads
44
+ if sum(params.axes_dim) != pe_dim:
45
+ raise ValueError(
46
+ f"Got {params.axes_dim} but expected positional dim {pe_dim}"
47
+ )
48
+ self.hidden_size = params.hidden_size
49
+ self.num_heads = params.num_heads
50
+ self.pe_embedder = EmbedND(
51
+ dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim
52
+ )
53
+ self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
54
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
55
+ self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
56
+ self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
57
+
58
+ self.double_blocks = nn.ModuleList(
59
+ [
60
+ DoubleStreamBlock(
61
+ self.hidden_size,
62
+ self.num_heads,
63
+ mlp_ratio=params.mlp_ratio,
64
+ qkv_bias=params.qkv_bias,
65
+ )
66
+ for _ in range(params.depth)
67
+ ]
68
+ )
69
+
70
+ self.single_blocks = nn.ModuleList(
71
+ [
72
+ SingleStreamBlock(
73
+ self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio
74
+ )
75
+ for _ in range(params.depth_single_blocks)
76
+ ]
77
+ )
78
+
79
+ self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
80
+
81
+ self.connector = Qwen2Connector()
82
+
83
+ @staticmethod
84
+ def timestep_embedding(
85
+ t: Tensor, dim, max_period=10000, time_factor: float = 1000.0
86
+ ):
87
+ """
88
+ Create sinusoidal timestep embeddings.
89
+ :param t: a 1-D Tensor of N indices, one per batch element.
90
+ These may be fractional.
91
+ :param dim: the dimension of the output.
92
+ :param max_period: controls the minimum frequency of the embeddings.
93
+ :return: an (N, D) Tensor of positional embeddings.
94
+ """
95
+ t = time_factor * t
96
+ half = dim // 2
97
+ freqs = torch.exp(
98
+ -math.log(max_period)
99
+ * torch.arange(start=0, end=half, dtype=torch.float32)
100
+ / half
101
+ ).to(t.device)
102
+
103
+ args = t[:, None].float() * freqs[None]
104
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
105
+ if dim % 2:
106
+ embedding = torch.cat(
107
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
108
+ )
109
+ if torch.is_floating_point(t):
110
+ embedding = embedding.to(t)
111
+ return embedding
112
+
113
+ def forward(
114
+ self,
115
+ img: Tensor,
116
+ img_ids: Tensor,
117
+ txt: Tensor,
118
+ txt_ids: Tensor,
119
+ timesteps: Tensor,
120
+ y: Tensor,
121
+ ) -> Tensor:
122
+ if img.ndim != 3 or txt.ndim != 3:
123
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
124
+
125
+ img = self.img_in(img)
126
+ vec = self.time_in(self.timestep_embedding(timesteps, 256))
127
+
128
+ vec = vec + self.vector_in(y)
129
+ txt = self.txt_in(txt)
130
+
131
+ ids = torch.cat((txt_ids, img_ids), dim=1)
132
+ pe = self.pe_embedder(ids)
133
+
134
+ for block in self.double_blocks:
135
+ img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
136
+
137
+ img = torch.cat((txt, img), 1)
138
+ for block in self.single_blocks:
139
+ img = block(img, vec=vec, pe=pe)
140
+ img = img[:, txt.shape[1] :, ...]
141
+
142
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
143
+ return img
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ einops
2
+ transformers==4.49.0
3
+ qwen_vl_utils==0.0.10
4
+ safetensors==0.4.5
5
+ pillow==11.1.0
6
+ huggingface_hub
7
+ transformers
8
+ diffusers
9
+ peft
10
+ opencv-python
11
+ sentencepiece
12
+ boto3
13
+ torchvision
sampling.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from collections.abc import Callable
3
+
4
+ import torch
5
+ from torch import Tensor
6
+
7
+
8
+ def get_noise(num_samples: int, height: int, width: int, device: torch.device, dtype: torch.dtype, seed: int):
9
+ return torch.randn(
10
+ num_samples,
11
+ 16,
12
+ # allow for packing
13
+ 2 * math.ceil(height / 16),
14
+ 2 * math.ceil(width / 16),
15
+ device=device,
16
+ dtype=dtype,
17
+ generator=torch.Generator(device=device).manual_seed(seed),
18
+ )
19
+
20
+
21
+ def time_shift(mu: float, sigma: float, t: Tensor):
22
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
23
+
24
+
25
+ def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
26
+ m = (y2 - y1) / (x2 - x1)
27
+ b = y1 - m * x1
28
+ return lambda x: m * x + b
29
+
30
+
31
+ def get_schedule(
32
+ num_steps: int,
33
+ image_seq_len: int,
34
+ base_shift: float = 0.5,
35
+ max_shift: float = 1.15,
36
+ shift: bool = True,
37
+ ) -> list[float]:
38
+ # extra step for zero
39
+ timesteps = torch.linspace(1, 0, num_steps + 1)
40
+
41
+ # shifting the schedule to favor high timesteps for higher signal images
42
+ if shift:
43
+ # estimate mu based on linear estimation between two points
44
+ mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
45
+ timesteps = time_shift(mu, 1.0, timesteps)
46
+
47
+ return timesteps.tolist()