wangmengchao commited on
Commit
282b272
·
1 Parent(s): a893799
app.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from pathlib import Path
3
+ import argparse
4
+ from datetime import datetime
5
+ import librosa
6
+ from infer import load_models,main
7
+
8
+
9
+ pipe,fantasytalking,wav2vec_processor,wav2vec = None,None,None,None
10
+
11
+ def generate_video(
12
+ image_path,
13
+ audio_path,
14
+ prompt,
15
+ prompt_cfg_scale,
16
+ audio_cfg_scale,
17
+ audio_weight,
18
+ image_size,
19
+ max_num_frames,
20
+ inference_steps,
21
+ seed,
22
+ ):
23
+ # Create the temp directory if it doesn't exist
24
+ output_dir = Path("./output")
25
+ output_dir.mkdir(parents=True, exist_ok=True)
26
+
27
+ # Convert paths to absolute Path objects and normalize them
28
+ print(image_path)
29
+ image_path = Path(image_path).absolute().as_posix()
30
+ audio_path = Path(audio_path).absolute().as_posix()
31
+
32
+ # Parse the arguments
33
+
34
+ args = create_args(
35
+ image_path=image_path,
36
+ audio_path=audio_path,
37
+ prompt=prompt,
38
+ output_dir=str(output_dir),
39
+ audio_weight=audio_weight,
40
+ prompt_cfg_scale=prompt_cfg_scale,
41
+ audio_cfg_scale=audio_cfg_scale,
42
+ image_size=image_size,
43
+ max_num_frames=max_num_frames,
44
+ inference_steps=inference_steps,
45
+ seed=seed,
46
+ )
47
+
48
+ try:
49
+ global pipe, fantasytalking, wav2vec_processor, wav2vec
50
+ if pipe is None:
51
+ pipe,fantasytalking,wav2vec_processor,wav2vec = load_models(args)
52
+ output_path=main(
53
+ args,pipe,fantasytalking,wav2vec_processor,wav2vec
54
+ )
55
+ return output_path # Ensure the output path is returned
56
+ except Exception as e:
57
+ print(f"Error during processing: {str(e)}")
58
+ raise gr.Error(f"Error during processing: {str(e)}")
59
+
60
+
61
+ def create_args(
62
+ image_path: str,
63
+ audio_path: str,
64
+ prompt: str,
65
+ output_dir: str,
66
+ audio_weight: float,
67
+ prompt_cfg_scale: float,
68
+ audio_cfg_scale: float,
69
+ image_size: int,
70
+ max_num_frames: int,
71
+ inference_steps: int,
72
+ seed: int,
73
+ ) -> argparse.Namespace:
74
+ parser = argparse.ArgumentParser()
75
+ parser.add_argument(
76
+ "--wan_model_dir",
77
+ type=str,
78
+ default="./models/Wan2.1-I2V-14B-720P",
79
+ required=False,
80
+ help="The dir of the Wan I2V 14B model.",
81
+ )
82
+ parser.add_argument(
83
+ "--fantasytalking_model_path",
84
+ type=str,
85
+ default="./models/fantasytalking_model.ckpt",
86
+ required=False,
87
+ help="The .ckpt path of fantasytalking model.",
88
+ )
89
+ parser.add_argument(
90
+ "--wav2vec_model_dir",
91
+ type=str,
92
+ default="./models/wav2vec2-base-960h",
93
+ required=False,
94
+ help="The dir of wav2vec model.",
95
+ )
96
+ parser.add_argument(
97
+ "--image_path",
98
+ type=str,
99
+ default="./assets/images/woman.png",
100
+ required=False,
101
+ help="The path of the image.",
102
+ )
103
+ parser.add_argument(
104
+ "--audio_path",
105
+ type=str,
106
+ default="./assets/audios/woman.wav",
107
+ required=False,
108
+ help="The path of the audio.",
109
+ )
110
+ parser.add_argument(
111
+ "--prompt",
112
+ type=str,
113
+ default="A woman is talking.",
114
+ required=False,
115
+ help="prompt.",
116
+ )
117
+ parser.add_argument(
118
+ "--output_dir",
119
+ type=str,
120
+ default="./output",
121
+ help="Dir to save the video.",
122
+ )
123
+ parser.add_argument(
124
+ "--image_size",
125
+ type=int,
126
+ default=512,
127
+ help="The image will be resized proportionally to this size.",
128
+ )
129
+ parser.add_argument(
130
+ "--audio_scale",
131
+ type=float,
132
+ default=1.0,
133
+ help="Image width.",
134
+ )
135
+ parser.add_argument(
136
+ "--prompt_cfg_scale",
137
+ type=float,
138
+ default=5.0,
139
+ required=False,
140
+ help="prompt cfg scale",
141
+ )
142
+ parser.add_argument(
143
+ "--audio_cfg_scale",
144
+ type=float,
145
+ default=5.0,
146
+ required=False,
147
+ help="audio cfg scale",
148
+ )
149
+ parser.add_argument(
150
+ "--max_num_frames",
151
+ type=int,
152
+ default=81,
153
+ required=False,
154
+ help="The maximum frames for generating videos, the audio part exceeding max_num_frames/fps will be truncated.",
155
+ )
156
+ parser.add_argument(
157
+ "--inference_steps",
158
+ type=int,
159
+ default=20,
160
+ required=False,
161
+ )
162
+ parser.add_argument(
163
+ "--fps",
164
+ type=int,
165
+ default=23,
166
+ required=False,
167
+ )
168
+ parser.add_argument(
169
+ "--num_persistent_param_in_dit",
170
+ type=int,
171
+ default=None,
172
+ required=False,
173
+ help="Maximum parameter quantity retained in video memory, small number to reduce VRAM required"
174
+ )
175
+ parser.add_argument(
176
+ "--seed",
177
+ type=int,
178
+ default=1111,
179
+ required=False,
180
+ )
181
+ args = parser.parse_args(
182
+ [
183
+ "--image_path",
184
+ image_path,
185
+ "--audio_path",
186
+ audio_path,
187
+ "--prompt",
188
+ prompt,
189
+ "--output_dir",
190
+ output_dir,
191
+ "--image_size",
192
+ str(image_size),
193
+ "--audio_scale",
194
+ str(audio_weight),
195
+ "--prompt_cfg_scale",
196
+ str(prompt_cfg_scale),
197
+ "--audio_cfg_scale",
198
+ str(audio_cfg_scale),
199
+ "--max_num_frames",
200
+ str(max_num_frames),
201
+ "--inference_steps",
202
+ str(inference_steps),
203
+ "--seed",
204
+ str(seed),
205
+ ]
206
+ )
207
+ print(args)
208
+ return args
209
+
210
+
211
+ # Create Gradio interface
212
+ with gr.Blocks(title="FantasyTalking Video Generation") as demo:
213
+ gr.Markdown(
214
+ """
215
+ # FantasyTalking: Realistic Talking Portrait Generation via Coherent Motion Synthesis
216
+
217
+ <div align="center">
218
+ <strong> Mengchao Wang1* Qiang Wang1* Fan Jiang1†
219
+ Yaqi Fan2 Yunpeng Zhang1,2 YongGang Qi2‡
220
+ Kun Zhao1. Mu Xu1 </strong>
221
+ </div>
222
+
223
+ <div align="center">
224
+ <strong>1AMAP,Alibaba Group 2Beijing University of Posts and Telecommunications</strong>
225
+ </div>
226
+
227
+ <div style="display:flex;justify-content:center;column-gap:4px;">
228
+ <a href="https://github.com/Fantasy-AMAP/fantasy-talking">
229
+ <img src='https://img.shields.io/badge/GitHub-Repo-blue'>
230
+ </a>
231
+ <a href="https://arxiv.org/abs/2504.04842">
232
+ <img src='https://img.shields.io/badge/ArXiv-Paper-red'>
233
+ </a>
234
+ </div>
235
+ """
236
+ )
237
+
238
+ with gr.Row():
239
+ with gr.Column():
240
+ image_input = gr.Image(label="Input Image", type="filepath")
241
+ audio_input = gr.Audio(label="Input Audio", type="filepath")
242
+ prompt_input = gr.Text(label="Input Prompt")
243
+ with gr.Row():
244
+ prompt_cfg_scale = gr.Slider(
245
+ minimum=1.0,
246
+ maximum=9.0,
247
+ value=5.0,
248
+ step=0.5,
249
+ label="Prompt CFG Scale",
250
+ )
251
+ audio_cfg_scale = gr.Slider(
252
+ minimum=1.0,
253
+ maximum=9.0,
254
+ value=5.0,
255
+ step=0.5,
256
+ label="Audio CFG Scale",
257
+ )
258
+ audio_weight = gr.Slider(
259
+ minimum=0.1,
260
+ maximum=3.0,
261
+ value=1.0,
262
+ step=0.1,
263
+ label="Audio Weight",
264
+ )
265
+ with gr.Row():
266
+ image_size = gr.Number(
267
+ value=512, label="Width/Height Maxsize", precision=0
268
+ )
269
+ max_num_frames = gr.Number(
270
+ value=81, label="The Maximum Frames", precision=0
271
+ )
272
+ inference_steps = gr.Slider(
273
+ minimum=1, maximum=50, value=20, step=1, label="Inference Steps"
274
+ )
275
+
276
+ with gr.Row():
277
+ seed = gr.Number(value=1247, label="Random Seed", precision=0)
278
+
279
+ process_btn = gr.Button("Generate Video")
280
+
281
+ with gr.Column():
282
+ video_output = gr.Video(label="Output Video")
283
+
284
+ gr.Examples(
285
+ examples=[
286
+ [
287
+ "/home/wangmengchao.wmc/code/fantasytalking/assets/images/woman.png",
288
+ "/home/wangmengchao.wmc/code/fantasytalking/assets/audios/woman.wav",
289
+ ],
290
+ ],
291
+ inputs=[image_input, audio_input],
292
+ )
293
+
294
+ process_btn.click(
295
+ fn=generate_video,
296
+ inputs=[
297
+ image_input,
298
+ audio_input,
299
+ prompt_input,
300
+ prompt_cfg_scale,
301
+ audio_cfg_scale,
302
+ audio_weight,
303
+ image_size,
304
+ max_num_frames,
305
+ inference_steps,
306
+ seed,
307
+ ],
308
+ outputs=video_output,
309
+ )
310
+
311
+ if __name__ == "__main__":
312
+ demo.launch(inbrowser=True, share=True)
diffsynth/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .data import *
2
+ from .models import *
3
+ from .prompters import *
4
+ from .schedulers import *
5
+ from .pipelines import *
diffsynth/configs/__init__.py ADDED
File without changes
diffsynth/configs/model_config.py ADDED
@@ -0,0 +1,650 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing_extensions import Literal, TypeAlias
2
+
3
+ from ..models.wan_video_dit import WanModel
4
+ from ..models.wan_video_text_encoder import WanTextEncoder
5
+ from ..models.wan_video_image_encoder import WanImageEncoder
6
+ from ..models.wan_video_vae import WanVideoVAE
7
+
8
+
9
+ model_loader_configs = [
10
+ # These configs are provided for detecting model type automatically.
11
+ # The format is (state_dict_keys_hash, state_dict_keys_hash_with_shape, model_names, model_classes, model_resource)
12
+ (None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"),
13
+ (None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"),
14
+ (None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
15
+ (None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"),
16
+ (None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"),
17
+ (None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"),
18
+ (None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"),
19
+ ]
20
+ huggingface_model_loader_configs = [
21
+ # These configs are provided for detecting model type automatically.
22
+ # The format is (architecture_in_huggingface_config, huggingface_lib, model_name, redirected_architecture)
23
+ ("ChatGLMModel", "diffsynth.models.kolors_text_encoder", "kolors_text_encoder", None),
24
+ ("MarianMTModel", "transformers.models.marian.modeling_marian", "translator", None),
25
+ ("BloomForCausalLM", "transformers.models.bloom.modeling_bloom", "beautiful_prompt", None),
26
+ ("Qwen2ForCausalLM", "transformers.models.qwen2.modeling_qwen2", "qwen_prompt", None),
27
+ # ("LlamaForCausalLM", "transformers.models.llama.modeling_llama", "omost_prompt", None),
28
+ ("T5EncoderModel", "diffsynth.models.flux_text_encoder", "flux_text_encoder_2", "FluxTextEncoder2"),
29
+ ("CogVideoXTransformer3DModel", "diffsynth.models.cog_dit", "cog_dit", "CogDiT"),
30
+ ("SiglipModel", "transformers.models.siglip.modeling_siglip", "siglip_vision_model", "SiglipVisionModel"),
31
+ ("LlamaForCausalLM", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoLLMEncoder"),
32
+ ("Step1Model", "diffsynth.models.stepvideo_text_encoder", "stepvideo_text_encoder_2", "STEP1TextEncoder"),
33
+ ]
34
+ patch_model_loader_configs = [
35
+ # These configs are provided for detecting model type automatically.
36
+ # The format is (state_dict_keys_hash_with_shape, model_name, model_class, extra_kwargs)
37
+ # ("9a4ab6869ac9b7d6e31f9854e397c867", ["svd_unet"], [SVDUNet], {"add_positional_conv": 128}),
38
+ ]
39
+
40
+ preset_models_on_huggingface = {
41
+ "HunyuanDiT": [
42
+ ("Tencent-Hunyuan/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
43
+ ("Tencent-Hunyuan/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
44
+ ("Tencent-Hunyuan/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
45
+ ("Tencent-Hunyuan/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
46
+ ],
47
+ "stable-video-diffusion-img2vid-xt": [
48
+ ("stabilityai/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
49
+ ],
50
+ "ExVideo-SVD-128f-v1": [
51
+ ("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
52
+ ],
53
+ # Stable Diffusion
54
+ "StableDiffusion_v15": [
55
+ ("benjamin-paine/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
56
+ ],
57
+ "DreamShaper_8": [
58
+ ("Yntec/Dreamshaper8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
59
+ ],
60
+ # Textual Inversion
61
+ "TextualInversion_VeryBadImageNegative_v1.3": [
62
+ ("gemasai/verybadimagenegative_v1.3", "verybadimagenegative_v1.3.pt", "models/textual_inversion"),
63
+ ],
64
+ # Stable Diffusion XL
65
+ "StableDiffusionXL_v1": [
66
+ ("stabilityai/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors", "models/stable_diffusion_xl"),
67
+ ],
68
+ "BluePencilXL_v200": [
69
+ ("frankjoshua/bluePencilXL_v200", "bluePencilXL_v200.safetensors", "models/stable_diffusion_xl"),
70
+ ],
71
+ "StableDiffusionXL_Turbo": [
72
+ ("stabilityai/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
73
+ ],
74
+ # Stable Diffusion 3
75
+ "StableDiffusion3": [
76
+ ("stabilityai/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
77
+ ],
78
+ "StableDiffusion3_without_T5": [
79
+ ("stabilityai/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
80
+ ],
81
+ # ControlNet
82
+ "ControlNet_v11f1p_sd15_depth": [
83
+ ("lllyasviel/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
84
+ ("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
85
+ ],
86
+ "ControlNet_v11p_sd15_softedge": [
87
+ ("lllyasviel/ControlNet-v1-1", "control_v11p_sd15_softedge.pth", "models/ControlNet"),
88
+ ("lllyasviel/Annotators", "ControlNetHED.pth", "models/Annotators")
89
+ ],
90
+ "ControlNet_v11f1e_sd15_tile": [
91
+ ("lllyasviel/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")
92
+ ],
93
+ "ControlNet_v11p_sd15_lineart": [
94
+ ("lllyasviel/ControlNet-v1-1", "control_v11p_sd15_lineart.pth", "models/ControlNet"),
95
+ ("lllyasviel/Annotators", "sk_model.pth", "models/Annotators"),
96
+ ("lllyasviel/Annotators", "sk_model2.pth", "models/Annotators")
97
+ ],
98
+ "ControlNet_union_sdxl_promax": [
99
+ ("xinsir/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", "models/ControlNet/controlnet_union"),
100
+ ("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
101
+ ],
102
+ # AnimateDiff
103
+ "AnimateDiff_v2": [
104
+ ("guoyww/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
105
+ ],
106
+ "AnimateDiff_xl_beta": [
107
+ ("guoyww/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
108
+ ],
109
+
110
+ # Qwen Prompt
111
+ "QwenPrompt": [
112
+ ("Qwen/Qwen2-1.5B-Instruct", "config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
113
+ ("Qwen/Qwen2-1.5B-Instruct", "generation_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
114
+ ("Qwen/Qwen2-1.5B-Instruct", "model.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct"),
115
+ ("Qwen/Qwen2-1.5B-Instruct", "special_tokens_map.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
116
+ ("Qwen/Qwen2-1.5B-Instruct", "tokenizer.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
117
+ ("Qwen/Qwen2-1.5B-Instruct", "tokenizer_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
118
+ ("Qwen/Qwen2-1.5B-Instruct", "merges.txt", "models/QwenPrompt/qwen2-1.5b-instruct"),
119
+ ("Qwen/Qwen2-1.5B-Instruct", "vocab.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
120
+ ],
121
+ # Beautiful Prompt
122
+ "BeautifulPrompt": [
123
+ ("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
124
+ ("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
125
+ ("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
126
+ ("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
127
+ ("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
128
+ ("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
129
+ ],
130
+ # Omost prompt
131
+ "OmostPrompt":[
132
+ ("lllyasviel/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
133
+ ("lllyasviel/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
134
+ ("lllyasviel/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
135
+ ("lllyasviel/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
136
+ ("lllyasviel/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
137
+ ("lllyasviel/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
138
+ ("lllyasviel/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
139
+ ("lllyasviel/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
140
+ ],
141
+ # Translator
142
+ "opus-mt-zh-en": [
143
+ ("Helsinki-NLP/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
144
+ ("Helsinki-NLP/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
145
+ ("Helsinki-NLP/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
146
+ ("Helsinki-NLP/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
147
+ ("Helsinki-NLP/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
148
+ ("Helsinki-NLP/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
149
+ ("Helsinki-NLP/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
150
+ ("Helsinki-NLP/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
151
+ ],
152
+ # IP-Adapter
153
+ "IP-Adapter-SD": [
154
+ ("h94/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
155
+ ("h94/IP-Adapter", "models/ip-adapter_sd15.bin", "models/IpAdapter/stable_diffusion"),
156
+ ],
157
+ "IP-Adapter-SDXL": [
158
+ ("h94/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
159
+ ("h94/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
160
+ ],
161
+ "SDXL-vae-fp16-fix": [
162
+ ("madebyollin/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
163
+ ],
164
+ # Kolors
165
+ "Kolors": [
166
+ ("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
167
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
168
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
169
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
170
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
171
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
172
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
173
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
174
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
175
+ ("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
176
+ ("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
177
+ ],
178
+ # FLUX
179
+ "FLUX.1-dev": [
180
+ ("black-forest-labs/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
181
+ ("black-forest-labs/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
182
+ ("black-forest-labs/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
183
+ ("black-forest-labs/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
184
+ ("black-forest-labs/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
185
+ ("black-forest-labs/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
186
+ ("black-forest-labs/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
187
+ ],
188
+ "InstantX/FLUX.1-dev-IP-Adapter": {
189
+ "file_list": [
190
+ ("InstantX/FLUX.1-dev-IP-Adapter", "ip-adapter.bin", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter"),
191
+ ("google/siglip-so400m-patch14-384", "model.safetensors", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
192
+ ("google/siglip-so400m-patch14-384", "config.json", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
193
+ ],
194
+ "load_path": [
195
+ "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
196
+ "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
197
+ ],
198
+ },
199
+ # RIFE
200
+ "RIFE": [
201
+ ("AlexWortega/RIFE", "flownet.pkl", "models/RIFE"),
202
+ ],
203
+ # CogVideo
204
+ "CogVideoX-5B": [
205
+ ("THUDM/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
206
+ ("THUDM/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
207
+ ("THUDM/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
208
+ ("THUDM/CogVideoX-5b", "text_encoder/model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
209
+ ("THUDM/CogVideoX-5b", "transformer/config.json", "models/CogVideo/CogVideoX-5b/transformer"),
210
+ ("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/transformer"),
211
+ ("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
212
+ ("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
213
+ ("THUDM/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"),
214
+ ],
215
+ # Stable Diffusion 3.5
216
+ "StableDiffusion3.5-large": [
217
+ ("stabilityai/stable-diffusion-3.5-large", "sd3.5_large.safetensors", "models/stable_diffusion_3"),
218
+ ("stabilityai/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
219
+ ("stabilityai/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
220
+ ("stabilityai/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
221
+ ],
222
+ }
223
+ preset_models_on_modelscope = {
224
+ # Hunyuan DiT
225
+ "HunyuanDiT": [
226
+ ("modelscope/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
227
+ ("modelscope/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
228
+ ("modelscope/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
229
+ ("modelscope/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
230
+ ],
231
+ # Stable Video Diffusion
232
+ "stable-video-diffusion-img2vid-xt": [
233
+ ("AI-ModelScope/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
234
+ ],
235
+ # ExVideo
236
+ "ExVideo-SVD-128f-v1": [
237
+ ("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
238
+ ],
239
+ "ExVideo-CogVideoX-LoRA-129f-v1": [
240
+ ("ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1", "ExVideo-CogVideoX-LoRA-129f-v1.safetensors", "models/lora"),
241
+ ],
242
+ # Stable Diffusion
243
+ "StableDiffusion_v15": [
244
+ ("AI-ModelScope/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
245
+ ],
246
+ "DreamShaper_8": [
247
+ ("sd_lora/dreamshaper_8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
248
+ ],
249
+ "AingDiffusion_v12": [
250
+ ("sd_lora/aingdiffusion_v12", "aingdiffusion_v12.safetensors", "models/stable_diffusion"),
251
+ ],
252
+ "Flat2DAnimerge_v45Sharp": [
253
+ ("sd_lora/Flat-2D-Animerge", "flat2DAnimerge_v45Sharp.safetensors", "models/stable_diffusion"),
254
+ ],
255
+ # Textual Inversion
256
+ "TextualInversion_VeryBadImageNegative_v1.3": [
257
+ ("sd_lora/verybadimagenegative_v1.3", "verybadimagenegative_v1.3.pt", "models/textual_inversion"),
258
+ ],
259
+ # Stable Diffusion XL
260
+ "StableDiffusionXL_v1": [
261
+ ("AI-ModelScope/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors", "models/stable_diffusion_xl"),
262
+ ],
263
+ "BluePencilXL_v200": [
264
+ ("sd_lora/bluePencilXL_v200", "bluePencilXL_v200.safetensors", "models/stable_diffusion_xl"),
265
+ ],
266
+ "StableDiffusionXL_Turbo": [
267
+ ("AI-ModelScope/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
268
+ ],
269
+ "SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0": [
270
+ ("sd_lora/zyd232_ChineseInkStyle_SDXL_v1_0", "zyd232_ChineseInkStyle_SDXL_v1_0.safetensors", "models/lora"),
271
+ ],
272
+ # Stable Diffusion 3
273
+ "StableDiffusion3": [
274
+ ("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
275
+ ],
276
+ "StableDiffusion3_without_T5": [
277
+ ("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
278
+ ],
279
+ # ControlNet
280
+ "ControlNet_v11f1p_sd15_depth": [
281
+ ("AI-ModelScope/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
282
+ ("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
283
+ ],
284
+ "ControlNet_v11p_sd15_softedge": [
285
+ ("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_softedge.pth", "models/ControlNet"),
286
+ ("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators")
287
+ ],
288
+ "ControlNet_v11f1e_sd15_tile": [
289
+ ("AI-ModelScope/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")
290
+ ],
291
+ "ControlNet_v11p_sd15_lineart": [
292
+ ("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_lineart.pth", "models/ControlNet"),
293
+ ("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
294
+ ("sd_lora/Annotators", "sk_model2.pth", "models/Annotators")
295
+ ],
296
+ "ControlNet_union_sdxl_promax": [
297
+ ("AI-ModelScope/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", "models/ControlNet/controlnet_union"),
298
+ ("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
299
+ ],
300
+ "Annotators:Depth": [
301
+ ("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators"),
302
+ ],
303
+ "Annotators:Softedge": [
304
+ ("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators"),
305
+ ],
306
+ "Annotators:Lineart": [
307
+ ("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
308
+ ("sd_lora/Annotators", "sk_model2.pth", "models/Annotators"),
309
+ ],
310
+ "Annotators:Normal": [
311
+ ("sd_lora/Annotators", "scannet.pt", "models/Annotators"),
312
+ ],
313
+ "Annotators:Openpose": [
314
+ ("sd_lora/Annotators", "body_pose_model.pth", "models/Annotators"),
315
+ ("sd_lora/Annotators", "facenet.pth", "models/Annotators"),
316
+ ("sd_lora/Annotators", "hand_pose_model.pth", "models/Annotators"),
317
+ ],
318
+ # AnimateDiff
319
+ "AnimateDiff_v2": [
320
+ ("Shanghai_AI_Laboratory/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
321
+ ],
322
+ "AnimateDiff_xl_beta": [
323
+ ("Shanghai_AI_Laboratory/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
324
+ ],
325
+ # RIFE
326
+ "RIFE": [
327
+ ("Damo_XR_Lab/cv_rife_video-frame-interpolation", "flownet.pkl", "models/RIFE"),
328
+ ],
329
+ # Qwen Prompt
330
+ "QwenPrompt": {
331
+ "file_list": [
332
+ ("qwen/Qwen2-1.5B-Instruct", "config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
333
+ ("qwen/Qwen2-1.5B-Instruct", "generation_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
334
+ ("qwen/Qwen2-1.5B-Instruct", "model.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct"),
335
+ ("qwen/Qwen2-1.5B-Instruct", "special_tokens_map.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
336
+ ("qwen/Qwen2-1.5B-Instruct", "tokenizer.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
337
+ ("qwen/Qwen2-1.5B-Instruct", "tokenizer_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
338
+ ("qwen/Qwen2-1.5B-Instruct", "merges.txt", "models/QwenPrompt/qwen2-1.5b-instruct"),
339
+ ("qwen/Qwen2-1.5B-Instruct", "vocab.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
340
+ ],
341
+ "load_path": [
342
+ "models/QwenPrompt/qwen2-1.5b-instruct",
343
+ ],
344
+ },
345
+ # Beautiful Prompt
346
+ "BeautifulPrompt": {
347
+ "file_list": [
348
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
349
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
350
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
351
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
352
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
353
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
354
+ ],
355
+ "load_path": [
356
+ "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
357
+ ],
358
+ },
359
+ # Omost prompt
360
+ "OmostPrompt": {
361
+ "file_list": [
362
+ ("Omost/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
363
+ ("Omost/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
364
+ ("Omost/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
365
+ ("Omost/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
366
+ ("Omost/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
367
+ ("Omost/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
368
+ ("Omost/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
369
+ ("Omost/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
370
+ ],
371
+ "load_path": [
372
+ "models/OmostPrompt/omost-llama-3-8b-4bits",
373
+ ],
374
+ },
375
+ # Translator
376
+ "opus-mt-zh-en": {
377
+ "file_list": [
378
+ ("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
379
+ ("moxying/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
380
+ ("moxying/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
381
+ ("moxying/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
382
+ ("moxying/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
383
+ ("moxying/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
384
+ ("moxying/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
385
+ ("moxying/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
386
+ ],
387
+ "load_path": [
388
+ "models/translator/opus-mt-zh-en",
389
+ ],
390
+ },
391
+ # IP-Adapter
392
+ "IP-Adapter-SD": [
393
+ ("AI-ModelScope/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
394
+ ("AI-ModelScope/IP-Adapter", "models/ip-adapter_sd15.bin", "models/IpAdapter/stable_diffusion"),
395
+ ],
396
+ "IP-Adapter-SDXL": [
397
+ ("AI-ModelScope/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
398
+ ("AI-ModelScope/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
399
+ ],
400
+ # Kolors
401
+ "Kolors": {
402
+ "file_list": [
403
+ ("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
404
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
405
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
406
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
407
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
408
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
409
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
410
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
411
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
412
+ ("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
413
+ ("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
414
+ ],
415
+ "load_path": [
416
+ "models/kolors/Kolors/text_encoder",
417
+ "models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors",
418
+ "models/kolors/Kolors/vae/diffusion_pytorch_model.safetensors",
419
+ ],
420
+ },
421
+ "SDXL-vae-fp16-fix": [
422
+ ("AI-ModelScope/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
423
+ ],
424
+ # FLUX
425
+ "FLUX.1-dev": {
426
+ "file_list": [
427
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
428
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
429
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
430
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
431
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
432
+ ("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
433
+ ("AI-ModelScope/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
434
+ ],
435
+ "load_path": [
436
+ "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
437
+ "models/FLUX/FLUX.1-dev/text_encoder_2",
438
+ "models/FLUX/FLUX.1-dev/ae.safetensors",
439
+ "models/FLUX/FLUX.1-dev/flux1-dev.safetensors"
440
+ ],
441
+ },
442
+ "FLUX.1-schnell": {
443
+ "file_list": [
444
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
445
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
446
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
447
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
448
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
449
+ ("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
450
+ ("AI-ModelScope/FLUX.1-schnell", "flux1-schnell.safetensors", "models/FLUX/FLUX.1-schnell"),
451
+ ],
452
+ "load_path": [
453
+ "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
454
+ "models/FLUX/FLUX.1-dev/text_encoder_2",
455
+ "models/FLUX/FLUX.1-dev/ae.safetensors",
456
+ "models/FLUX/FLUX.1-schnell/flux1-schnell.safetensors"
457
+ ],
458
+ },
459
+ "InstantX/FLUX.1-dev-Controlnet-Union-alpha": [
460
+ ("InstantX/FLUX.1-dev-Controlnet-Union-alpha", "diffusion_pytorch_model.safetensors", "models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha"),
461
+ ],
462
+ "jasperai/Flux.1-dev-Controlnet-Depth": [
463
+ ("jasperai/Flux.1-dev-Controlnet-Depth", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Depth"),
464
+ ],
465
+ "jasperai/Flux.1-dev-Controlnet-Surface-Normals": [
466
+ ("jasperai/Flux.1-dev-Controlnet-Surface-Normals", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Surface-Normals"),
467
+ ],
468
+ "jasperai/Flux.1-dev-Controlnet-Upscaler": [
469
+ ("jasperai/Flux.1-dev-Controlnet-Upscaler", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler"),
470
+ ],
471
+ "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha": [
472
+ ("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha", "diffusion_pytorch_model.safetensors", "models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha"),
473
+ ],
474
+ "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta": [
475
+ ("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", "diffusion_pytorch_model.safetensors", "models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta"),
476
+ ],
477
+ "Shakker-Labs/FLUX.1-dev-ControlNet-Depth": [
478
+ ("Shakker-Labs/FLUX.1-dev-ControlNet-Depth", "diffusion_pytorch_model.safetensors", "models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Depth"),
479
+ ],
480
+ "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro": [
481
+ ("Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro", "diffusion_pytorch_model.safetensors", "models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro"),
482
+ ],
483
+ "InstantX/FLUX.1-dev-IP-Adapter": {
484
+ "file_list": [
485
+ ("InstantX/FLUX.1-dev-IP-Adapter", "ip-adapter.bin", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter"),
486
+ ("AI-ModelScope/siglip-so400m-patch14-384", "model.safetensors", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
487
+ ("AI-ModelScope/siglip-so400m-patch14-384", "config.json", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
488
+ ],
489
+ "load_path": [
490
+ "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
491
+ "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
492
+ ],
493
+ },
494
+ # ESRGAN
495
+ "ESRGAN_x4": [
496
+ ("AI-ModelScope/Real-ESRGAN", "RealESRGAN_x4.pth", "models/ESRGAN"),
497
+ ],
498
+ # RIFE
499
+ "RIFE": [
500
+ ("AI-ModelScope/RIFE", "flownet.pkl", "models/RIFE"),
501
+ ],
502
+ # Omnigen
503
+ "OmniGen-v1": {
504
+ "file_list": [
505
+ ("BAAI/OmniGen-v1", "vae/diffusion_pytorch_model.safetensors", "models/OmniGen/OmniGen-v1/vae"),
506
+ ("BAAI/OmniGen-v1", "model.safetensors", "models/OmniGen/OmniGen-v1"),
507
+ ("BAAI/OmniGen-v1", "config.json", "models/OmniGen/OmniGen-v1"),
508
+ ("BAAI/OmniGen-v1", "special_tokens_map.json", "models/OmniGen/OmniGen-v1"),
509
+ ("BAAI/OmniGen-v1", "tokenizer_config.json", "models/OmniGen/OmniGen-v1"),
510
+ ("BAAI/OmniGen-v1", "tokenizer.json", "models/OmniGen/OmniGen-v1"),
511
+ ],
512
+ "load_path": [
513
+ "models/OmniGen/OmniGen-v1/vae/diffusion_pytorch_model.safetensors",
514
+ "models/OmniGen/OmniGen-v1/model.safetensors",
515
+ ]
516
+ },
517
+ # CogVideo
518
+ "CogVideoX-5B": {
519
+ "file_list": [
520
+ ("ZhipuAI/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
521
+ ("ZhipuAI/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
522
+ ("ZhipuAI/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
523
+ ("ZhipuAI/CogVideoX-5b", "text_encoder/model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
524
+ ("ZhipuAI/CogVideoX-5b", "transformer/config.json", "models/CogVideo/CogVideoX-5b/transformer"),
525
+ ("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/transformer"),
526
+ ("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
527
+ ("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
528
+ ("ZhipuAI/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"),
529
+ ],
530
+ "load_path": [
531
+ "models/CogVideo/CogVideoX-5b/text_encoder",
532
+ "models/CogVideo/CogVideoX-5b/transformer",
533
+ "models/CogVideo/CogVideoX-5b/vae/diffusion_pytorch_model.safetensors",
534
+ ],
535
+ },
536
+ # Stable Diffusion 3.5
537
+ "StableDiffusion3.5-large": [
538
+ ("AI-ModelScope/stable-diffusion-3.5-large", "sd3.5_large.safetensors", "models/stable_diffusion_3"),
539
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
540
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
541
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
542
+ ],
543
+ "StableDiffusion3.5-medium": [
544
+ ("AI-ModelScope/stable-diffusion-3.5-medium", "sd3.5_medium.safetensors", "models/stable_diffusion_3"),
545
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
546
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
547
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
548
+ ],
549
+ "StableDiffusion3.5-large-turbo": [
550
+ ("AI-ModelScope/stable-diffusion-3.5-large-turbo", "sd3.5_large_turbo.safetensors", "models/stable_diffusion_3"),
551
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
552
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
553
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
554
+ ],
555
+ "HunyuanVideo":{
556
+ "file_list": [
557
+ ("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideo/text_encoder"),
558
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00001-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
559
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00002-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
560
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00003-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
561
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00004-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
562
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "config.json", "models/HunyuanVideo/text_encoder_2"),
563
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model.safetensors.index.json", "models/HunyuanVideo/text_encoder_2"),
564
+ ("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/vae/pytorch_model.pt", "models/HunyuanVideo/vae"),
565
+ ("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt", "models/HunyuanVideo/transformers")
566
+ ],
567
+ "load_path": [
568
+ "models/HunyuanVideo/text_encoder/model.safetensors",
569
+ "models/HunyuanVideo/text_encoder_2",
570
+ "models/HunyuanVideo/vae/pytorch_model.pt",
571
+ "models/HunyuanVideo/transformers/mp_rank_00_model_states.pt"
572
+ ],
573
+ },
574
+ "HunyuanVideo-fp8":{
575
+ "file_list": [
576
+ ("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideo/text_encoder"),
577
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00001-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
578
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00002-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
579
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00003-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
580
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00004-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
581
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "config.json", "models/HunyuanVideo/text_encoder_2"),
582
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model.safetensors.index.json", "models/HunyuanVideo/text_encoder_2"),
583
+ ("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/vae/pytorch_model.pt", "models/HunyuanVideo/vae"),
584
+ ("DiffSynth-Studio/HunyuanVideo-safetensors", "model.fp8.safetensors", "models/HunyuanVideo/transformers")
585
+ ],
586
+ "load_path": [
587
+ "models/HunyuanVideo/text_encoder/model.safetensors",
588
+ "models/HunyuanVideo/text_encoder_2",
589
+ "models/HunyuanVideo/vae/pytorch_model.pt",
590
+ "models/HunyuanVideo/transformers/model.fp8.safetensors"
591
+ ],
592
+ },
593
+ }
594
+ Preset_model_id: TypeAlias = Literal[
595
+ "HunyuanDiT",
596
+ "stable-video-diffusion-img2vid-xt",
597
+ "ExVideo-SVD-128f-v1",
598
+ "ExVideo-CogVideoX-LoRA-129f-v1",
599
+ "StableDiffusion_v15",
600
+ "DreamShaper_8",
601
+ "AingDiffusion_v12",
602
+ "Flat2DAnimerge_v45Sharp",
603
+ "TextualInversion_VeryBadImageNegative_v1.3",
604
+ "StableDiffusionXL_v1",
605
+ "BluePencilXL_v200",
606
+ "StableDiffusionXL_Turbo",
607
+ "ControlNet_v11f1p_sd15_depth",
608
+ "ControlNet_v11p_sd15_softedge",
609
+ "ControlNet_v11f1e_sd15_tile",
610
+ "ControlNet_v11p_sd15_lineart",
611
+ "AnimateDiff_v2",
612
+ "AnimateDiff_xl_beta",
613
+ "RIFE",
614
+ "BeautifulPrompt",
615
+ "opus-mt-zh-en",
616
+ "IP-Adapter-SD",
617
+ "IP-Adapter-SDXL",
618
+ "StableDiffusion3",
619
+ "StableDiffusion3_without_T5",
620
+ "Kolors",
621
+ "SDXL-vae-fp16-fix",
622
+ "ControlNet_union_sdxl_promax",
623
+ "FLUX.1-dev",
624
+ "FLUX.1-schnell",
625
+ "InstantX/FLUX.1-dev-Controlnet-Union-alpha",
626
+ "jasperai/Flux.1-dev-Controlnet-Depth",
627
+ "jasperai/Flux.1-dev-Controlnet-Surface-Normals",
628
+ "jasperai/Flux.1-dev-Controlnet-Upscaler",
629
+ "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha",
630
+ "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta",
631
+ "Shakker-Labs/FLUX.1-dev-ControlNet-Depth",
632
+ "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
633
+ "InstantX/FLUX.1-dev-IP-Adapter",
634
+ "SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0",
635
+ "QwenPrompt",
636
+ "OmostPrompt",
637
+ "ESRGAN_x4",
638
+ "RIFE",
639
+ "OmniGen-v1",
640
+ "CogVideoX-5B",
641
+ "Annotators:Depth",
642
+ "Annotators:Softedge",
643
+ "Annotators:Lineart",
644
+ "Annotators:Normal",
645
+ "Annotators:Openpose",
646
+ "StableDiffusion3.5-large",
647
+ "StableDiffusion3.5-medium",
648
+ "HunyuanVideo",
649
+ "HunyuanVideo-fp8",
650
+ ]
diffsynth/data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .video import VideoData, save_video, save_frames
diffsynth/data/video.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import imageio, os
2
+ import numpy as np
3
+ from PIL import Image
4
+ from tqdm import tqdm
5
+
6
+
7
+ class LowMemoryVideo:
8
+ def __init__(self, file_name):
9
+ self.reader = imageio.get_reader(file_name)
10
+
11
+ def __len__(self):
12
+ return self.reader.count_frames()
13
+
14
+ def __getitem__(self, item):
15
+ return Image.fromarray(np.array(self.reader.get_data(item))).convert("RGB")
16
+
17
+ def __del__(self):
18
+ self.reader.close()
19
+
20
+
21
+ def split_file_name(file_name):
22
+ result = []
23
+ number = -1
24
+ for i in file_name:
25
+ if ord(i)>=ord("0") and ord(i)<=ord("9"):
26
+ if number == -1:
27
+ number = 0
28
+ number = number*10 + ord(i) - ord("0")
29
+ else:
30
+ if number != -1:
31
+ result.append(number)
32
+ number = -1
33
+ result.append(i)
34
+ if number != -1:
35
+ result.append(number)
36
+ result = tuple(result)
37
+ return result
38
+
39
+
40
+ def search_for_images(folder):
41
+ file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")]
42
+ file_list = [(split_file_name(file_name), file_name) for file_name in file_list]
43
+ file_list = [i[1] for i in sorted(file_list)]
44
+ file_list = [os.path.join(folder, i) for i in file_list]
45
+ return file_list
46
+
47
+
48
+ class LowMemoryImageFolder:
49
+ def __init__(self, folder, file_list=None):
50
+ if file_list is None:
51
+ self.file_list = search_for_images(folder)
52
+ else:
53
+ self.file_list = [os.path.join(folder, file_name) for file_name in file_list]
54
+
55
+ def __len__(self):
56
+ return len(self.file_list)
57
+
58
+ def __getitem__(self, item):
59
+ return Image.open(self.file_list[item]).convert("RGB")
60
+
61
+ def __del__(self):
62
+ pass
63
+
64
+
65
+ def crop_and_resize(image, height, width):
66
+ image = np.array(image)
67
+ image_height, image_width, _ = image.shape
68
+ if image_height / image_width < height / width:
69
+ croped_width = int(image_height / height * width)
70
+ left = (image_width - croped_width) // 2
71
+ image = image[:, left: left+croped_width]
72
+ image = Image.fromarray(image).resize((width, height))
73
+ else:
74
+ croped_height = int(image_width / width * height)
75
+ left = (image_height - croped_height) // 2
76
+ image = image[left: left+croped_height, :]
77
+ image = Image.fromarray(image).resize((width, height))
78
+ return image
79
+
80
+
81
+ class VideoData:
82
+ def __init__(self, video_file=None, image_folder=None, height=None, width=None, **kwargs):
83
+ if video_file is not None:
84
+ self.data_type = "video"
85
+ self.data = LowMemoryVideo(video_file, **kwargs)
86
+ elif image_folder is not None:
87
+ self.data_type = "images"
88
+ self.data = LowMemoryImageFolder(image_folder, **kwargs)
89
+ else:
90
+ raise ValueError("Cannot open video or image folder")
91
+ self.length = None
92
+ self.set_shape(height, width)
93
+
94
+ def raw_data(self):
95
+ frames = []
96
+ for i in range(self.__len__()):
97
+ frames.append(self.__getitem__(i))
98
+ return frames
99
+
100
+ def set_length(self, length):
101
+ self.length = length
102
+
103
+ def set_shape(self, height, width):
104
+ self.height = height
105
+ self.width = width
106
+
107
+ def __len__(self):
108
+ if self.length is None:
109
+ return len(self.data)
110
+ else:
111
+ return self.length
112
+
113
+ def shape(self):
114
+ if self.height is not None and self.width is not None:
115
+ return self.height, self.width
116
+ else:
117
+ height, width, _ = self.__getitem__(0).shape
118
+ return height, width
119
+
120
+ def __getitem__(self, item):
121
+ frame = self.data.__getitem__(item)
122
+ width, height = frame.size
123
+ if self.height is not None and self.width is not None:
124
+ if self.height != height or self.width != width:
125
+ frame = crop_and_resize(frame, self.height, self.width)
126
+ return frame
127
+
128
+ def __del__(self):
129
+ pass
130
+
131
+ def save_images(self, folder):
132
+ os.makedirs(folder, exist_ok=True)
133
+ for i in tqdm(range(self.__len__()), desc="Saving images"):
134
+ frame = self.__getitem__(i)
135
+ frame.save(os.path.join(folder, f"{i}.png"))
136
+
137
+ def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None):
138
+ writer = imageio.get_writer(save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params)
139
+ for frame in tqdm(frames, desc="Saving video"):
140
+ frame = np.array(frame)
141
+ writer.append_data(frame)
142
+ writer.close()
143
+ # def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None):
144
+ # writer = imageio.get_writer(save_path, fps=fps, quality=quality, ffmpeg_params=["-crf", "0", "-preset", "veryslow"])
145
+ # for frame in tqdm(frames, desc="Saving video"):
146
+ # frame = np.array(frame)
147
+ # writer.append_data(frame)
148
+ # writer.close()
149
+
150
+ # def save_video_h264(frames, save_path, fps, ffmpeg_params=None):
151
+ # import imageio.v3 as iio
152
+ # from tqdm import tqdm
153
+ # import numpy as np
154
+
155
+ # if ffmpeg_params is None:
156
+ # ffmpeg_params = ["-crf", "0", "-preset", "ultrafast"] # 无损 H.264
157
+
158
+ # writer = iio.get_writer(save_path, fps=fps, codec="libx264", ffmpeg_params=ffmpeg_params)
159
+ # for frame in tqdm(frames, desc="Saving video"):
160
+ # writer.append_data(np.array(frame))
161
+ # writer.close()
162
+
163
+
164
+
165
+ def save_frames(frames, save_path):
166
+ os.makedirs(save_path, exist_ok=True)
167
+ for i, frame in enumerate(tqdm(frames, desc="Saving images")):
168
+ frame.save(os.path.join(save_path, f"{i}.png"))
169
+
170
+
171
+ if __name__=='__main__':
172
+ frames = [Image.fromarray(np.random.randint(0, 256, (512, 512, 3), dtype=np.uint8)) for i in range(81)]
173
+ save_video(frames,"haha.mp4",23,5)
diffsynth/pipelines/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .wan_video import WanVideoPipeline
diffsynth/pipelines/base.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from PIL import Image
4
+ from torchvision.transforms import GaussianBlur
5
+
6
+
7
+
8
+ class BasePipeline(torch.nn.Module):
9
+
10
+ def __init__(self, device="cuda", torch_dtype=torch.float16, height_division_factor=64, width_division_factor=64):
11
+ super().__init__()
12
+ self.device = device
13
+ self.torch_dtype = torch_dtype
14
+ self.height_division_factor = height_division_factor
15
+ self.width_division_factor = width_division_factor
16
+ self.cpu_offload = False
17
+ self.model_names = []
18
+
19
+
20
+ def check_resize_height_width(self, height, width):
21
+ if height % self.height_division_factor != 0:
22
+ height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor
23
+ print(f"The height cannot be evenly divided by {self.height_division_factor}. We round it up to {height}.")
24
+ if width % self.width_division_factor != 0:
25
+ width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor
26
+ print(f"The width cannot be evenly divided by {self.width_division_factor}. We round it up to {width}.")
27
+ return height, width
28
+
29
+
30
+ def preprocess_image(self, image):
31
+ image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
32
+ return image
33
+
34
+
35
+ def preprocess_images(self, images):
36
+ return [self.preprocess_image(image) for image in images]
37
+
38
+
39
+ def vae_output_to_image(self, vae_output):
40
+ image = vae_output[0].cpu().float().permute(1, 2, 0).numpy()
41
+ image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
42
+ return image
43
+
44
+
45
+ def vae_output_to_video(self, vae_output):
46
+ video = vae_output.cpu().permute(1, 2, 0).numpy()
47
+ video = [Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) for image in video]
48
+ return video
49
+
50
+
51
+ def merge_latents(self, value, latents, masks, scales, blur_kernel_size=33, blur_sigma=10.0):
52
+ if len(latents) > 0:
53
+ blur = GaussianBlur(kernel_size=blur_kernel_size, sigma=blur_sigma)
54
+ height, width = value.shape[-2:]
55
+ weight = torch.ones_like(value)
56
+ for latent, mask, scale in zip(latents, masks, scales):
57
+ mask = self.preprocess_image(mask.resize((width, height))).mean(dim=1, keepdim=True) > 0
58
+ mask = mask.repeat(1, latent.shape[1], 1, 1).to(dtype=latent.dtype, device=latent.device)
59
+ mask = blur(mask)
60
+ value += latent * mask * scale
61
+ weight += mask * scale
62
+ value /= weight
63
+ return value
64
+
65
+
66
+ def control_noise_via_local_prompts(self, prompt_emb_global, prompt_emb_locals, masks, mask_scales, inference_callback, special_kwargs=None, special_local_kwargs_list=None):
67
+ if special_kwargs is None:
68
+ noise_pred_global = inference_callback(prompt_emb_global)
69
+ else:
70
+ noise_pred_global = inference_callback(prompt_emb_global, special_kwargs)
71
+ if special_local_kwargs_list is None:
72
+ noise_pred_locals = [inference_callback(prompt_emb_local) for prompt_emb_local in prompt_emb_locals]
73
+ else:
74
+ noise_pred_locals = [inference_callback(prompt_emb_local, special_kwargs) for prompt_emb_local, special_kwargs in zip(prompt_emb_locals, special_local_kwargs_list)]
75
+ noise_pred = self.merge_latents(noise_pred_global, noise_pred_locals, masks, mask_scales)
76
+ return noise_pred
77
+
78
+
79
+ def extend_prompt(self, prompt, local_prompts, masks, mask_scales):
80
+ local_prompts = local_prompts or []
81
+ masks = masks or []
82
+ mask_scales = mask_scales or []
83
+ extended_prompt_dict = self.prompter.extend_prompt(prompt)
84
+ prompt = extended_prompt_dict.get("prompt", prompt)
85
+ local_prompts += extended_prompt_dict.get("prompts", [])
86
+ masks += extended_prompt_dict.get("masks", [])
87
+ mask_scales += [100.0] * len(extended_prompt_dict.get("masks", []))
88
+ return prompt, local_prompts, masks, mask_scales
89
+
90
+
91
+ def enable_cpu_offload(self):
92
+ self.cpu_offload = True
93
+
94
+
95
+ def load_models_to_device(self, loadmodel_names=[]):
96
+ # only load models to device if cpu_offload is enabled
97
+ if not self.cpu_offload:
98
+ return
99
+ # offload the unneeded models to cpu
100
+ for model_name in self.model_names:
101
+ if model_name not in loadmodel_names:
102
+ model = getattr(self, model_name)
103
+ if model is not None:
104
+ if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
105
+ for module in model.modules():
106
+ if hasattr(module, "offload"):
107
+ module.offload()
108
+ else:
109
+ model.cpu()
110
+ # load the needed models to device
111
+ for model_name in loadmodel_names:
112
+ model = getattr(self, model_name)
113
+ if model is not None:
114
+ if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
115
+ for module in model.modules():
116
+ if hasattr(module, "onload"):
117
+ module.onload()
118
+ else:
119
+ model.to(self.device)
120
+ # fresh the cuda cache
121
+ torch.cuda.empty_cache()
122
+
123
+
124
+ def generate_noise(self, shape, seed=None, device="cpu", dtype=torch.float16):
125
+ generator = None if seed is None else torch.Generator(device).manual_seed(seed)
126
+ noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
127
+ return noise
diffsynth/pipelines/wan_video.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..models import ModelManager
2
+ from ..models.wan_video_dit import WanModel
3
+ from ..models.wan_video_text_encoder import WanTextEncoder
4
+ from ..models.wan_video_vae import WanVideoVAE
5
+ from ..models.wan_video_image_encoder import WanImageEncoder
6
+ from ..schedulers.flow_match import FlowMatchScheduler
7
+ from .base import BasePipeline
8
+ from ..prompters import WanPrompter
9
+ import torch, os
10
+ from einops import rearrange
11
+ import numpy as np
12
+ from PIL import Image
13
+ from tqdm import tqdm
14
+
15
+ from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
16
+ from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm
17
+ from ..models.wan_video_dit import WanLayerNorm, WanRMSNorm
18
+ from ..models.wan_video_vae import RMS_norm, CausalConv3d, Upsample
19
+
20
+
21
+ class WanVideoPipeline(BasePipeline):
22
+
23
+ def __init__(self, device="cuda", torch_dtype=torch.float16, tokenizer_path=None):
24
+ super().__init__(device=device, torch_dtype=torch_dtype)
25
+ self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True)
26
+ self.prompter = WanPrompter(tokenizer_path=tokenizer_path)
27
+ self.text_encoder: WanTextEncoder = None
28
+ self.image_encoder: WanImageEncoder = None
29
+ self.dit: WanModel = None
30
+ self.vae: WanVideoVAE = None
31
+ self.model_names = ['text_encoder', 'dit', 'vae']
32
+ self.height_division_factor = 16
33
+ self.width_division_factor = 16
34
+
35
+
36
+ def enable_vram_management(self, num_persistent_param_in_dit=None):
37
+ dtype = next(iter(self.text_encoder.parameters())).dtype
38
+ enable_vram_management(
39
+ self.text_encoder,
40
+ module_map = {
41
+ torch.nn.Linear: AutoWrappedLinear,
42
+ torch.nn.Embedding: AutoWrappedModule,
43
+ T5RelativeEmbedding: AutoWrappedModule,
44
+ T5LayerNorm: AutoWrappedModule,
45
+ },
46
+ module_config = dict(
47
+ offload_dtype=dtype,
48
+ offload_device="cpu",
49
+ onload_dtype=dtype,
50
+ onload_device="cpu",
51
+ computation_dtype=self.torch_dtype,
52
+ computation_device=self.device,
53
+ ),
54
+ )
55
+ dtype = next(iter(self.dit.parameters())).dtype
56
+ enable_vram_management(
57
+ self.dit,
58
+ module_map = {
59
+ torch.nn.Linear: AutoWrappedLinear,
60
+ torch.nn.Conv3d: AutoWrappedModule,
61
+ torch.nn.LayerNorm: AutoWrappedModule,
62
+ WanLayerNorm: AutoWrappedModule,
63
+ WanRMSNorm: AutoWrappedModule,
64
+ },
65
+ module_config = dict(
66
+ offload_dtype=dtype,
67
+ offload_device="cpu",
68
+ onload_dtype=dtype,
69
+ onload_device=self.device,
70
+ computation_dtype=self.torch_dtype,
71
+ computation_device=self.device,
72
+ ),
73
+ max_num_param=num_persistent_param_in_dit,
74
+ overflow_module_config = dict(
75
+ offload_dtype=dtype,
76
+ offload_device="cpu",
77
+ onload_dtype=dtype,
78
+ onload_device="cpu",
79
+ computation_dtype=self.torch_dtype,
80
+ computation_device=self.device,
81
+ ),
82
+ )
83
+ dtype = next(iter(self.vae.parameters())).dtype
84
+ enable_vram_management(
85
+ self.vae,
86
+ module_map = {
87
+ torch.nn.Linear: AutoWrappedLinear,
88
+ torch.nn.Conv2d: AutoWrappedModule,
89
+ RMS_norm: AutoWrappedModule,
90
+ CausalConv3d: AutoWrappedModule,
91
+ Upsample: AutoWrappedModule,
92
+ torch.nn.SiLU: AutoWrappedModule,
93
+ torch.nn.Dropout: AutoWrappedModule,
94
+ },
95
+ module_config = dict(
96
+ offload_dtype=dtype,
97
+ offload_device="cpu",
98
+ onload_dtype=dtype,
99
+ onload_device=self.device,
100
+ computation_dtype=self.torch_dtype,
101
+ computation_device=self.device,
102
+ ),
103
+ )
104
+ if self.image_encoder is not None:
105
+ dtype = next(iter(self.image_encoder.parameters())).dtype
106
+ enable_vram_management(
107
+ self.image_encoder,
108
+ module_map = {
109
+ torch.nn.Linear: AutoWrappedLinear,
110
+ torch.nn.Conv2d: AutoWrappedModule,
111
+ torch.nn.LayerNorm: AutoWrappedModule,
112
+ },
113
+ module_config = dict(
114
+ offload_dtype=dtype,
115
+ offload_device="cpu",
116
+ onload_dtype=dtype,
117
+ onload_device="cpu",
118
+ computation_dtype=self.torch_dtype,
119
+ computation_device=self.device,
120
+ ),
121
+ )
122
+ self.enable_cpu_offload()
123
+
124
+
125
+ def fetch_models(self, model_manager: ModelManager):
126
+ text_encoder_model_and_path = model_manager.fetch_model("wan_video_text_encoder", require_model_path=True)
127
+ if text_encoder_model_and_path is not None:
128
+ self.text_encoder, tokenizer_path = text_encoder_model_and_path
129
+ self.prompter.fetch_models(self.text_encoder)
130
+ self.prompter.fetch_tokenizer(os.path.join(os.path.dirname(tokenizer_path), "google/umt5-xxl"))
131
+ self.dit = model_manager.fetch_model("wan_video_dit")
132
+ self.vae = model_manager.fetch_model("wan_video_vae")
133
+ self.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
134
+
135
+
136
+ @staticmethod
137
+ def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None):
138
+ if device is None: device = model_manager.device
139
+ if torch_dtype is None: torch_dtype = model_manager.torch_dtype
140
+ pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
141
+ pipe.fetch_models(model_manager)
142
+ return pipe
143
+
144
+
145
+ def denoising_model(self):
146
+ return self.dit
147
+
148
+
149
+ def encode_prompt(self, prompt, positive=True):
150
+ prompt_emb = self.prompter.encode_prompt(prompt, positive=positive)
151
+ return {"context": prompt_emb}
152
+
153
+
154
+ def encode_image(self, image, num_frames, height, width):
155
+ with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
156
+ image = self.preprocess_image(image.resize((width, height))).to(self.device)
157
+ clip_context = self.image_encoder.encode_image([image])
158
+ msk = torch.ones(1, num_frames, height//8, width//8, device=self.device)
159
+ msk[:, 1:] = 0
160
+ msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
161
+ msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
162
+ msk = msk.transpose(1, 2)[0]
163
+ y = self.vae.encode([torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)], device=self.device)[0]
164
+ y = torch.concat([msk, y])
165
+ return {"clip_fea": clip_context, "y": [y]}
166
+
167
+
168
+ def tensor2video(self, frames):
169
+ frames = rearrange(frames, "C T H W -> T H W C")
170
+ frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
171
+ frames = [Image.fromarray(frame) for frame in frames]
172
+ return frames
173
+
174
+
175
+ def prepare_extra_input(self, latents=None):
176
+ return {"seq_len": latents.shape[2] * latents.shape[3] * latents.shape[4] // 4}
177
+
178
+
179
+ def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
180
+ with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
181
+ latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
182
+ return latents
183
+
184
+
185
+ def decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
186
+ with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
187
+ frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
188
+ return frames
189
+
190
+ def set_ip(self, local_path):
191
+ pass
192
+ @torch.no_grad()
193
+ def __call__(
194
+ self,
195
+ prompt,
196
+ negative_prompt="",
197
+ input_image=None,
198
+ input_video=None,
199
+ denoising_strength=1.0,
200
+ seed=None,
201
+ rand_device="cpu",
202
+ height=480,
203
+ width=832,
204
+ num_frames=81,
205
+ cfg_scale=5.0,
206
+ audio_cfg_scale=None,
207
+ num_inference_steps=50,
208
+ sigma_shift=5.0,
209
+ tiled=True,
210
+ tile_size=(30, 52),
211
+ tile_stride=(15, 26),
212
+ progress_bar_cmd=tqdm,
213
+ progress_bar_st=None,
214
+ **kwargs,
215
+ ):
216
+ # Parameter check
217
+ height, width = self.check_resize_height_width(height, width)
218
+ if num_frames % 4 != 1:
219
+ num_frames = (num_frames + 2) // 4 * 4 + 1
220
+ print(f"Only `num_frames % 4 != 1` is acceptable. We round it up to {num_frames}.")
221
+
222
+ # Tiler parameters
223
+ tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
224
+
225
+ # Scheduler
226
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength, shift=sigma_shift)
227
+
228
+ # Initialize noise
229
+ noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=torch.float32).to(self.device)
230
+ if input_video is not None:
231
+ self.load_models_to_device(['vae'])
232
+ input_video = self.preprocess_images(input_video)
233
+ input_video = torch.stack(input_video, dim=2)
234
+ latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=noise.dtype, device=noise.device)
235
+ latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
236
+ else:
237
+ latents = noise
238
+
239
+ # Encode prompts
240
+ self.load_models_to_device(["text_encoder"])
241
+ prompt_emb_posi = self.encode_prompt(prompt, positive=True)
242
+ if cfg_scale != 1.0:
243
+ prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
244
+
245
+ # Encode image
246
+ if input_image is not None and self.image_encoder is not None:
247
+ self.load_models_to_device(["image_encoder", "vae"])
248
+ image_emb = self.encode_image(input_image, num_frames, height, width)
249
+ else:
250
+ image_emb = {}
251
+
252
+ # Extra input
253
+ extra_input = self.prepare_extra_input(latents)
254
+
255
+ # Denoise
256
+ self.load_models_to_device(["dit"])
257
+ with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
258
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
259
+ timestep = timestep.unsqueeze(0).to(dtype=torch.float32, device=self.device)
260
+
261
+ # Inference
262
+ noise_pred_posi = self.dit(latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input, **kwargs) # (zt,audio,prompt)
263
+ if audio_cfg_scale is not None:
264
+ audio_scale = kwargs['audio_scale']
265
+ kwargs['audio_scale'] = 0.0
266
+ noise_pred_noaudio = self.dit(latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input, **kwargs) #(zt,0,prompt)
267
+ # kwargs['ip_scale'] = ip_scale
268
+ if cfg_scale != 1.0: #prompt cfg
269
+ noise_pred_no_cond = self.dit(latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input, **kwargs) # (zt,0,0)
270
+ noise_pred = noise_pred_no_cond + cfg_scale * (noise_pred_noaudio - noise_pred_no_cond) + audio_cfg_scale * (noise_pred_posi - noise_pred_noaudio)
271
+ else:
272
+ noise_pred = noise_pred_noaudio + audio_cfg_scale * (noise_pred_posi - noise_pred_noaudio)
273
+ kwargs['audio_scale'] = audio_scale
274
+ else:
275
+ if cfg_scale != 1.0:
276
+ noise_pred_nega = self.dit(latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input, **kwargs) #(zt,audio,0)
277
+ noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
278
+ else:
279
+ noise_pred = noise_pred_posi
280
+
281
+ # Scheduler
282
+ latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
283
+
284
+ # Decode
285
+ self.load_models_to_device(['vae'])
286
+ frames = self.decode_video(latents, **tiler_kwargs)
287
+ self.load_models_to_device([])
288
+ frames = self.tensor2video(frames[0])
289
+
290
+ return frames
diffsynth/prompters/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .wan_prompter import WanPrompter
diffsynth/prompters/base_prompter.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..models.model_manager import ModelManager
2
+ import torch
3
+
4
+
5
+
6
+ def tokenize_long_prompt(tokenizer, prompt, max_length=None):
7
+ # Get model_max_length from self.tokenizer
8
+ length = tokenizer.model_max_length if max_length is None else max_length
9
+
10
+ # To avoid the warning. set self.tokenizer.model_max_length to +oo.
11
+ tokenizer.model_max_length = 99999999
12
+
13
+ # Tokenize it!
14
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
15
+
16
+ # Determine the real length.
17
+ max_length = (input_ids.shape[1] + length - 1) // length * length
18
+
19
+ # Restore tokenizer.model_max_length
20
+ tokenizer.model_max_length = length
21
+
22
+ # Tokenize it again with fixed length.
23
+ input_ids = tokenizer(
24
+ prompt,
25
+ return_tensors="pt",
26
+ padding="max_length",
27
+ max_length=max_length,
28
+ truncation=True
29
+ ).input_ids
30
+
31
+ # Reshape input_ids to fit the text encoder.
32
+ num_sentence = input_ids.shape[1] // length
33
+ input_ids = input_ids.reshape((num_sentence, length))
34
+
35
+ return input_ids
36
+
37
+
38
+
39
+ class BasePrompter:
40
+ def __init__(self):
41
+ self.refiners = []
42
+ self.extenders = []
43
+
44
+
45
+ def load_prompt_refiners(self, model_manager: ModelManager, refiner_classes=[]):
46
+ for refiner_class in refiner_classes:
47
+ refiner = refiner_class.from_model_manager(model_manager)
48
+ self.refiners.append(refiner)
49
+
50
+ def load_prompt_extenders(self,model_manager:ModelManager,extender_classes=[]):
51
+ for extender_class in extender_classes:
52
+ extender = extender_class.from_model_manager(model_manager)
53
+ self.extenders.append(extender)
54
+
55
+
56
+ @torch.no_grad()
57
+ def process_prompt(self, prompt, positive=True):
58
+ if isinstance(prompt, list):
59
+ prompt = [self.process_prompt(prompt_, positive=positive) for prompt_ in prompt]
60
+ else:
61
+ for refiner in self.refiners:
62
+ prompt = refiner(prompt, positive=positive)
63
+ return prompt
64
+
65
+ @torch.no_grad()
66
+ def extend_prompt(self, prompt:str, positive=True):
67
+ extended_prompt = dict(prompt=prompt)
68
+ for extender in self.extenders:
69
+ extended_prompt = extender(extended_prompt)
70
+ return extended_prompt
diffsynth/prompters/wan_prompter.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base_prompter import BasePrompter
2
+ from ..models.wan_video_text_encoder import WanTextEncoder
3
+ from transformers import AutoTokenizer
4
+ import os, torch
5
+ import ftfy
6
+ import html
7
+ import string
8
+ import regex as re
9
+
10
+
11
+ def basic_clean(text):
12
+ text = ftfy.fix_text(text)
13
+ text = html.unescape(html.unescape(text))
14
+ return text.strip()
15
+
16
+
17
+ def whitespace_clean(text):
18
+ text = re.sub(r'\s+', ' ', text)
19
+ text = text.strip()
20
+ return text
21
+
22
+
23
+ def canonicalize(text, keep_punctuation_exact_string=None):
24
+ text = text.replace('_', ' ')
25
+ if keep_punctuation_exact_string:
26
+ text = keep_punctuation_exact_string.join(
27
+ part.translate(str.maketrans('', '', string.punctuation))
28
+ for part in text.split(keep_punctuation_exact_string))
29
+ else:
30
+ text = text.translate(str.maketrans('', '', string.punctuation))
31
+ text = text.lower()
32
+ text = re.sub(r'\s+', ' ', text)
33
+ return text.strip()
34
+
35
+
36
+ class HuggingfaceTokenizer:
37
+
38
+ def __init__(self, name, seq_len=None, clean=None, **kwargs):
39
+ assert clean in (None, 'whitespace', 'lower', 'canonicalize')
40
+ self.name = name
41
+ self.seq_len = seq_len
42
+ self.clean = clean
43
+
44
+ # init tokenizer
45
+ self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
46
+ self.vocab_size = self.tokenizer.vocab_size
47
+
48
+ def __call__(self, sequence, **kwargs):
49
+ return_mask = kwargs.pop('return_mask', False)
50
+
51
+ # arguments
52
+ _kwargs = {'return_tensors': 'pt'}
53
+ if self.seq_len is not None:
54
+ _kwargs.update({
55
+ 'padding': 'max_length',
56
+ 'truncation': True,
57
+ 'max_length': self.seq_len
58
+ })
59
+ _kwargs.update(**kwargs)
60
+
61
+ # tokenization
62
+ if isinstance(sequence, str):
63
+ sequence = [sequence]
64
+ if self.clean:
65
+ sequence = [self._clean(u) for u in sequence]
66
+ ids = self.tokenizer(sequence, **_kwargs)
67
+
68
+ # output
69
+ if return_mask:
70
+ return ids.input_ids, ids.attention_mask
71
+ else:
72
+ return ids.input_ids
73
+
74
+ def _clean(self, text):
75
+ if self.clean == 'whitespace':
76
+ text = whitespace_clean(basic_clean(text))
77
+ elif self.clean == 'lower':
78
+ text = whitespace_clean(basic_clean(text)).lower()
79
+ elif self.clean == 'canonicalize':
80
+ text = canonicalize(basic_clean(text))
81
+ return text
82
+
83
+
84
+ class WanPrompter(BasePrompter):
85
+
86
+ def __init__(self, tokenizer_path=None, text_len=512):
87
+ super().__init__()
88
+ self.text_len = text_len
89
+ self.text_encoder = None
90
+ self.fetch_tokenizer(tokenizer_path)
91
+
92
+ def fetch_tokenizer(self, tokenizer_path=None):
93
+ if tokenizer_path is not None:
94
+ self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=self.text_len, clean='whitespace')
95
+
96
+ def fetch_models(self, text_encoder: WanTextEncoder = None):
97
+ self.text_encoder = text_encoder
98
+
99
+ def encode_prompt(self, prompt, positive=True, device="cuda"):
100
+ prompt = self.process_prompt(prompt, positive=positive)
101
+
102
+ ids, mask = self.tokenizer(prompt, return_mask=True, add_special_tokens=True)
103
+ ids = ids.to(device)
104
+ mask = mask.to(device)
105
+ seq_lens = mask.gt(0).sum(dim=1).long()
106
+ prompt_emb = self.text_encoder(ids, mask)
107
+ prompt_emb = [u[:v] for u, v in zip(prompt_emb, seq_lens)]
108
+ return prompt_emb
diffsynth/schedulers/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .ddim import EnhancedDDIMScheduler
2
+ from .continuous_ode import ContinuousODEScheduler
3
+ from .flow_match import FlowMatchScheduler
diffsynth/schedulers/continuous_ode.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class ContinuousODEScheduler():
5
+
6
+ def __init__(self, num_inference_steps=100, sigma_max=700.0, sigma_min=0.002, rho=7.0):
7
+ self.sigma_max = sigma_max
8
+ self.sigma_min = sigma_min
9
+ self.rho = rho
10
+ self.set_timesteps(num_inference_steps)
11
+
12
+
13
+ def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, **kwargs):
14
+ ramp = torch.linspace(1-denoising_strength, 1, num_inference_steps)
15
+ min_inv_rho = torch.pow(torch.tensor((self.sigma_min,)), (1 / self.rho))
16
+ max_inv_rho = torch.pow(torch.tensor((self.sigma_max,)), (1 / self.rho))
17
+ self.sigmas = torch.pow(max_inv_rho + ramp * (min_inv_rho - max_inv_rho), self.rho)
18
+ self.timesteps = torch.log(self.sigmas) * 0.25
19
+
20
+
21
+ def step(self, model_output, timestep, sample, to_final=False):
22
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
23
+ sigma = self.sigmas[timestep_id]
24
+ sample *= (sigma*sigma + 1).sqrt()
25
+ estimated_sample = -sigma / (sigma*sigma + 1).sqrt() * model_output + 1 / (sigma*sigma + 1) * sample
26
+ if to_final or timestep_id + 1 >= len(self.timesteps):
27
+ prev_sample = estimated_sample
28
+ else:
29
+ sigma_ = self.sigmas[timestep_id + 1]
30
+ derivative = 1 / sigma * (sample - estimated_sample)
31
+ prev_sample = sample + derivative * (sigma_ - sigma)
32
+ prev_sample /= (sigma_*sigma_ + 1).sqrt()
33
+ return prev_sample
34
+
35
+
36
+ def return_to_timestep(self, timestep, sample, sample_stablized):
37
+ # This scheduler doesn't support this function.
38
+ pass
39
+
40
+
41
+ def add_noise(self, original_samples, noise, timestep):
42
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
43
+ sigma = self.sigmas[timestep_id]
44
+ sample = (original_samples + noise * sigma) / (sigma*sigma + 1).sqrt()
45
+ return sample
46
+
47
+
48
+ def training_target(self, sample, noise, timestep):
49
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
50
+ sigma = self.sigmas[timestep_id]
51
+ target = (-(sigma*sigma + 1).sqrt() / sigma + 1 / (sigma*sigma + 1).sqrt() / sigma) * sample + 1 / (sigma*sigma + 1).sqrt() * noise
52
+ return target
53
+
54
+
55
+ def training_weight(self, timestep):
56
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
57
+ sigma = self.sigmas[timestep_id]
58
+ weight = (1 + sigma*sigma).sqrt() / sigma
59
+ return weight
diffsynth/schedulers/ddim.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, math
2
+
3
+
4
+ class EnhancedDDIMScheduler():
5
+
6
+ def __init__(self, num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", prediction_type="epsilon", rescale_zero_terminal_snr=False):
7
+ self.num_train_timesteps = num_train_timesteps
8
+ if beta_schedule == "scaled_linear":
9
+ betas = torch.square(torch.linspace(math.sqrt(beta_start), math.sqrt(beta_end), num_train_timesteps, dtype=torch.float32))
10
+ elif beta_schedule == "linear":
11
+ betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
12
+ else:
13
+ raise NotImplementedError(f"{beta_schedule} is not implemented")
14
+ self.alphas_cumprod = torch.cumprod(1.0 - betas, dim=0)
15
+ if rescale_zero_terminal_snr:
16
+ self.alphas_cumprod = self.rescale_zero_terminal_snr(self.alphas_cumprod)
17
+ self.alphas_cumprod = self.alphas_cumprod.tolist()
18
+ self.set_timesteps(10)
19
+ self.prediction_type = prediction_type
20
+
21
+
22
+ def rescale_zero_terminal_snr(self, alphas_cumprod):
23
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
24
+
25
+ # Store old values.
26
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
27
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
28
+
29
+ # Shift so the last timestep is zero.
30
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
31
+
32
+ # Scale so the first timestep is back to the old value.
33
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
34
+
35
+ # Convert alphas_bar_sqrt to betas
36
+ alphas_bar = alphas_bar_sqrt.square() # Revert sqrt
37
+
38
+ return alphas_bar
39
+
40
+
41
+ def set_timesteps(self, num_inference_steps, denoising_strength=1.0, **kwargs):
42
+ # The timesteps are aligned to 999...0, which is different from other implementations,
43
+ # but I think this implementation is more reasonable in theory.
44
+ max_timestep = max(round(self.num_train_timesteps * denoising_strength) - 1, 0)
45
+ num_inference_steps = min(num_inference_steps, max_timestep + 1)
46
+ if num_inference_steps == 1:
47
+ self.timesteps = torch.Tensor([max_timestep])
48
+ else:
49
+ step_length = max_timestep / (num_inference_steps - 1)
50
+ self.timesteps = torch.Tensor([round(max_timestep - i*step_length) for i in range(num_inference_steps)])
51
+
52
+
53
+ def denoise(self, model_output, sample, alpha_prod_t, alpha_prod_t_prev):
54
+ if self.prediction_type == "epsilon":
55
+ weight_e = math.sqrt(1 - alpha_prod_t_prev) - math.sqrt(alpha_prod_t_prev * (1 - alpha_prod_t) / alpha_prod_t)
56
+ weight_x = math.sqrt(alpha_prod_t_prev / alpha_prod_t)
57
+ prev_sample = sample * weight_x + model_output * weight_e
58
+ elif self.prediction_type == "v_prediction":
59
+ weight_e = -math.sqrt(alpha_prod_t_prev * (1 - alpha_prod_t)) + math.sqrt(alpha_prod_t * (1 - alpha_prod_t_prev))
60
+ weight_x = math.sqrt(alpha_prod_t * alpha_prod_t_prev) + math.sqrt((1 - alpha_prod_t) * (1 - alpha_prod_t_prev))
61
+ prev_sample = sample * weight_x + model_output * weight_e
62
+ else:
63
+ raise NotImplementedError(f"{self.prediction_type} is not implemented")
64
+ return prev_sample
65
+
66
+
67
+ def step(self, model_output, timestep, sample, to_final=False):
68
+ alpha_prod_t = self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
69
+ if isinstance(timestep, torch.Tensor):
70
+ timestep = timestep.cpu()
71
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
72
+ if to_final or timestep_id + 1 >= len(self.timesteps):
73
+ alpha_prod_t_prev = 1.0
74
+ else:
75
+ timestep_prev = int(self.timesteps[timestep_id + 1])
76
+ alpha_prod_t_prev = self.alphas_cumprod[timestep_prev]
77
+
78
+ return self.denoise(model_output, sample, alpha_prod_t, alpha_prod_t_prev)
79
+
80
+
81
+ def return_to_timestep(self, timestep, sample, sample_stablized):
82
+ alpha_prod_t = self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
83
+ noise_pred = (sample - math.sqrt(alpha_prod_t) * sample_stablized) / math.sqrt(1 - alpha_prod_t)
84
+ return noise_pred
85
+
86
+
87
+ def add_noise(self, original_samples, noise, timestep):
88
+ sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
89
+ sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
90
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
91
+ return noisy_samples
92
+
93
+
94
+ def training_target(self, sample, noise, timestep):
95
+ if self.prediction_type == "epsilon":
96
+ return noise
97
+ else:
98
+ sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
99
+ sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
100
+ target = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
101
+ return target
102
+
103
+
104
+ def training_weight(self, timestep):
105
+ return 1.0
diffsynth/schedulers/flow_match.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+
5
+ class FlowMatchScheduler():
6
+
7
+ def __init__(self, num_inference_steps=100, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003/1.002, inverse_timesteps=False, extra_one_step=False, reverse_sigmas=False):
8
+ self.num_train_timesteps = num_train_timesteps
9
+ self.shift = shift
10
+ self.sigma_max = sigma_max
11
+ self.sigma_min = sigma_min
12
+ self.inverse_timesteps = inverse_timesteps
13
+ self.extra_one_step = extra_one_step
14
+ self.reverse_sigmas = reverse_sigmas
15
+ self.set_timesteps(num_inference_steps)
16
+
17
+
18
+ def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, shift=None):
19
+ if shift is not None:
20
+ self.shift = shift
21
+ sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
22
+ if self.extra_one_step:
23
+ self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1]
24
+ else:
25
+ self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps)
26
+ if self.inverse_timesteps:
27
+ self.sigmas = torch.flip(self.sigmas, dims=[0])
28
+ self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)
29
+ if self.reverse_sigmas:
30
+ self.sigmas = 1 - self.sigmas
31
+ self.timesteps = self.sigmas * self.num_train_timesteps
32
+ if training:
33
+ x = self.timesteps
34
+ y = torch.exp(-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2)
35
+ y_shifted = y - y.min()
36
+ bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum())
37
+ self.linear_timesteps_weights = bsmntw_weighing
38
+
39
+
40
+ def step(self, model_output, timestep, sample, to_final=False):
41
+ if isinstance(timestep, torch.Tensor):
42
+ timestep = timestep.cpu()
43
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
44
+ sigma = self.sigmas[timestep_id]
45
+ if to_final or timestep_id + 1 >= len(self.timesteps):
46
+ sigma_ = 1 if (self.inverse_timesteps or self.reverse_sigmas) else 0
47
+ else:
48
+ sigma_ = self.sigmas[timestep_id + 1]
49
+ prev_sample = sample + model_output * (sigma_ - sigma)
50
+ return prev_sample
51
+
52
+
53
+ def return_to_timestep(self, timestep, sample, sample_stablized):
54
+ if isinstance(timestep, torch.Tensor):
55
+ timestep = timestep.cpu()
56
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
57
+ sigma = self.sigmas[timestep_id]
58
+ model_output = (sample - sample_stablized) / sigma
59
+ return model_output
60
+
61
+
62
+ def add_noise(self, original_samples, noise, timestep):
63
+ if isinstance(timestep, torch.Tensor):
64
+ timestep = timestep.cpu()
65
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
66
+ sigma = self.sigmas[timestep_id]
67
+ sample = (1 - sigma) * original_samples + sigma * noise
68
+ return sample
69
+
70
+
71
+ def training_target(self, sample, noise, timestep):
72
+ target = noise - sample
73
+ return target
74
+
75
+
76
+ def training_weight(self, timestep):
77
+ timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs())
78
+ weights = self.linear_timesteps_weights[timestep_id]
79
+ return weights
diffsynth/vram_management/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .layers import *
diffsynth/vram_management/layers.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, copy
2
+ from ..models.utils import init_weights_on_device
3
+
4
+
5
+ def cast_to(weight, dtype, device):
6
+ r = torch.empty_like(weight, dtype=dtype, device=device)
7
+ r.copy_(weight)
8
+ return r
9
+
10
+
11
+ class AutoWrappedModule(torch.nn.Module):
12
+ def __init__(self, module: torch.nn.Module, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device):
13
+ super().__init__()
14
+ self.module = module.to(dtype=offload_dtype, device=offload_device)
15
+ self.offload_dtype = offload_dtype
16
+ self.offload_device = offload_device
17
+ self.onload_dtype = onload_dtype
18
+ self.onload_device = onload_device
19
+ self.computation_dtype = computation_dtype
20
+ self.computation_device = computation_device
21
+ self.state = 0
22
+
23
+ def offload(self):
24
+ if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
25
+ self.module.to(dtype=self.offload_dtype, device=self.offload_device)
26
+ self.state = 0
27
+
28
+ def onload(self):
29
+ if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
30
+ self.module.to(dtype=self.onload_dtype, device=self.onload_device)
31
+ self.state = 1
32
+
33
+ def forward(self, *args, **kwargs):
34
+ if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
35
+ module = self.module
36
+ else:
37
+ module = copy.deepcopy(self.module).to(dtype=self.computation_dtype, device=self.computation_device)
38
+ return module(*args, **kwargs)
39
+
40
+
41
+ class AutoWrappedLinear(torch.nn.Linear):
42
+ def __init__(self, module: torch.nn.Linear, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device):
43
+ with init_weights_on_device(device=torch.device("meta")):
44
+ super().__init__(in_features=module.in_features, out_features=module.out_features, bias=module.bias is not None, dtype=offload_dtype, device=offload_device)
45
+ self.weight = module.weight
46
+ self.bias = module.bias
47
+ self.offload_dtype = offload_dtype
48
+ self.offload_device = offload_device
49
+ self.onload_dtype = onload_dtype
50
+ self.onload_device = onload_device
51
+ self.computation_dtype = computation_dtype
52
+ self.computation_device = computation_device
53
+ self.state = 0
54
+
55
+ def offload(self):
56
+ if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
57
+ self.to(dtype=self.offload_dtype, device=self.offload_device)
58
+ self.state = 0
59
+
60
+ def onload(self):
61
+ if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
62
+ self.to(dtype=self.onload_dtype, device=self.onload_device)
63
+ self.state = 1
64
+
65
+ def forward(self, x, *args, **kwargs):
66
+ if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
67
+ weight, bias = self.weight, self.bias
68
+ else:
69
+ weight = cast_to(self.weight, self.computation_dtype, self.computation_device)
70
+ bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
71
+ return torch.nn.functional.linear(x, weight, bias)
72
+
73
+
74
+ def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0):
75
+ for name, module in model.named_children():
76
+ for source_module, target_module in module_map.items():
77
+ if isinstance(module, source_module):
78
+ num_param = sum(p.numel() for p in module.parameters())
79
+ if max_num_param is not None and total_num_param + num_param > max_num_param:
80
+ module_config_ = overflow_module_config
81
+ else:
82
+ module_config_ = module_config
83
+ module_ = target_module(module, **module_config_)
84
+ setattr(model, name, module_)
85
+ total_num_param += num_param
86
+ break
87
+ else:
88
+ total_num_param = enable_vram_management_recursively(module, module_map, module_config, max_num_param, overflow_module_config, total_num_param)
89
+ return total_num_param
90
+
91
+
92
+ def enable_vram_management(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None):
93
+ enable_vram_management_recursively(model, module_map, module_config, max_num_param, overflow_module_config, total_num_param=0)
94
+ model.vram_management_enabled = True
95
+
infer.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffsynth import ModelManager, WanVideoPipeline
3
+ from PIL import Image
4
+ import argparse
5
+ from transformers import Wav2Vec2Processor, Wav2Vec2Model
6
+ import librosa
7
+ import os
8
+ import subprocess
9
+ import cv2
10
+ from model import FantasyTalkingAudioConditionModel
11
+ from utils import save_video, get_audio_features, resize_image_by_longest_edge
12
+ from pathlib import Path
13
+ from datetime import datetime
14
+
15
+ def parse_args():
16
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
17
+
18
+ parser.add_argument(
19
+ "--wan_model_dir",
20
+ type=str,
21
+ default="./models/Wan2.1-I2V-14B-720P",
22
+ required=False,
23
+ help="The dir of the Wan I2V 14B model.",
24
+ )
25
+ parser.add_argument(
26
+ "--fantasytalking_model_path",
27
+ type=str,
28
+ default="./models/fantasytalking_model.ckpt",
29
+ required=False,
30
+ help="The .ckpt path of fantasytalking model.",
31
+ )
32
+ parser.add_argument(
33
+ "--wav2vec_model_dir",
34
+ type=str,
35
+ default="./models/wav2vec2-base-960h",
36
+ required=False,
37
+ help="The dir of wav2vec model.",
38
+ )
39
+
40
+ parser.add_argument(
41
+ "--image_path",
42
+ type=str,
43
+ default="./assets/images/woman.png",
44
+ required=False,
45
+ help="The path of the image.",
46
+ )
47
+
48
+ parser.add_argument(
49
+ "--audio_path",
50
+ type=str,
51
+ default="./assets/audios/woman.wav",
52
+ required=False,
53
+ help="The path of the audio.",
54
+ )
55
+ parser.add_argument(
56
+ "--prompt",
57
+ type=str,
58
+ default="A woman is talking.",
59
+ required=False,
60
+ help="prompt.",
61
+ )
62
+ parser.add_argument(
63
+ "--output_dir",
64
+ type=str,
65
+ default="./output",
66
+ help="Dir to save the model.",
67
+ )
68
+ parser.add_argument(
69
+ "--image_size",
70
+ type=int,
71
+ default=512,
72
+ help="The image will be resized proportionally to this size.",
73
+ )
74
+ parser.add_argument(
75
+ "--audio_scale",
76
+ type=float,
77
+ default=1.0,
78
+ help="Audio condition injection weight",
79
+ )
80
+ parser.add_argument(
81
+ "--prompt_cfg_scale",
82
+ type=float,
83
+ default=5.0,
84
+ required=False,
85
+ help="Prompt cfg scale",
86
+ )
87
+ parser.add_argument(
88
+ "--audio_cfg_scale",
89
+ type=float,
90
+ default=5.0,
91
+ required=False,
92
+ help="Audio cfg scale",
93
+ )
94
+ parser.add_argument(
95
+ "--max_num_frames",
96
+ type=int,
97
+ default=81,
98
+ required=False,
99
+ help="The maximum frames for generating videos, the audio part exceeding max_num_frames/fps will be truncated."
100
+ )
101
+ parser.add_argument(
102
+ "--fps",
103
+ type=int,
104
+ default=23,
105
+ required=False,
106
+ )
107
+ parser.add_argument(
108
+ "--num_persistent_param_in_dit",
109
+ type=int,
110
+ default=None,
111
+ required=False,
112
+ help="Maximum parameter quantity retained in video memory, small number to reduce VRAM required"
113
+ )
114
+ parser.add_argument(
115
+ "--seed",
116
+ type=int,
117
+ default=1111,
118
+ required=False,
119
+ )
120
+ args = parser.parse_args()
121
+ return args
122
+
123
+ def load_models(args):
124
+ # Load Wan I2V models
125
+ model_manager = ModelManager(device="cpu")
126
+ model_manager.load_models(
127
+ [
128
+ [
129
+ f"{args.wan_model_dir}/diffusion_pytorch_model-00001-of-00007.safetensors",
130
+ f"{args.wan_model_dir}/diffusion_pytorch_model-00002-of-00007.safetensors",
131
+ f"{args.wan_model_dir}/diffusion_pytorch_model-00003-of-00007.safetensors",
132
+ f"{args.wan_model_dir}/diffusion_pytorch_model-00004-of-00007.safetensors",
133
+ f"{args.wan_model_dir}/diffusion_pytorch_model-00005-of-00007.safetensors",
134
+ f"{args.wan_model_dir}/diffusion_pytorch_model-00006-of-00007.safetensors",
135
+ f"{args.wan_model_dir}/diffusion_pytorch_model-00007-of-00007.safetensors",
136
+ ],
137
+ f"{args.wan_model_dir}/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
138
+ f"{args.wan_model_dir}/models_t5_umt5-xxl-enc-bf16.pth",
139
+ f"{args.wan_model_dir}/Wan2.1_VAE.pth",
140
+ ],
141
+ # torch_dtype=torch.float8_e4m3fn, # You can set `torch_dtype=torch.bfloat16` to disable FP8 quantization.
142
+ torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.bfloat16` to disable FP8 quantization.
143
+ )
144
+ pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
145
+
146
+ # Load FantasyTalking weights
147
+ fantasytalking = FantasyTalkingAudioConditionModel(pipe.dit, 768, 2048).to("cuda")
148
+ fantasytalking.load_audio_processor(args.fantasytalking_model_path, pipe.dit)
149
+
150
+ # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required.
151
+ pipe.enable_vram_management(num_persistent_param_in_dit=args.num_persistent_param_in_dit)
152
+
153
+ # Load wav2vec models
154
+ wav2vec_processor = Wav2Vec2Processor.from_pretrained(args.wav2vec_model_dir)
155
+ wav2vec = Wav2Vec2Model.from_pretrained(args.wav2vec_model_dir).to("cuda")
156
+
157
+ return pipe,fantasytalking,wav2vec_processor,wav2vec
158
+
159
+
160
+
161
+ def main(args,pipe,fantasytalking,wav2vec_processor,wav2vec):
162
+ os.makedirs(args.output_dir,exist_ok=True)
163
+
164
+ duration = librosa.get_duration(filename=args.audio_path)
165
+ num_frames = min(int(args.fps*duration//4)*4+5,args.max_num_frames)
166
+
167
+ audio_wav2vec_fea = get_audio_features(wav2vec,wav2vec_processor,args.audio_path,args.fps,num_frames)
168
+ image = resize_image_by_longest_edge(args.image_path,args.image_size)
169
+ width, height = image.size
170
+
171
+ audio_proj_fea = fantasytalking.get_proj_fea(audio_wav2vec_fea)
172
+ pos_idx_ranges = fantasytalking.split_audio_sequence(audio_proj_fea.size(1),num_frames=num_frames)
173
+ audio_proj_split,audio_context_lens = fantasytalking.split_tensor_with_padding(audio_proj_fea,pos_idx_ranges,expand_length=4) # [b,21,9+8,768]
174
+
175
+ # Image-to-video
176
+ video_audio = pipe(
177
+ prompt=args.prompt,
178
+ negative_prompt="人物静止不动,静止,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
179
+ input_image=image,
180
+ width=width,
181
+ height=height,
182
+ num_frames=num_frames,
183
+ num_inference_steps=30,
184
+ seed=args.seed, tiled=True,
185
+ audio_scale=args.audio_scale,
186
+ cfg_scale = args.prompt_cfg_scale,
187
+ audio_cfg_scale=args.audio_cfg_scale,
188
+ audio_proj=audio_proj_split,
189
+ audio_context_lens=audio_context_lens,
190
+ latents_num_frames=(num_frames-1)//4+1
191
+ )
192
+ current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
193
+ save_path_tmp = f"{args.output_dir}/tmp_{Path(args.image_path).stem}_{Path(args.audio_path).stem}_{current_time}.mp4"
194
+ save_video(video_audio, save_path_tmp, fps=args.fps, quality=5)
195
+
196
+ save_path = f"{args.output_dir}/{Path(args.image_path).stem}_{Path(args.audio_path).stem}_{current_time}.mp4"
197
+ final_command = [
198
+ "ffmpeg", "-y",
199
+ "-i", save_path_tmp,
200
+ "-i", args.audio_path,
201
+ "-c:v", "libx264",
202
+ "-c:a", "aac",
203
+ "-shortest",
204
+ save_path
205
+ ]
206
+ subprocess.run(final_command, check=True)
207
+ os.remove(save_path_tmp)
208
+ return save_path
209
+
210
+ if __name__ == "__main__":
211
+ args = parse_args()
212
+ pipe,fantasytalking,wav2vec_processor,wav2vec = load_models(args)
213
+
214
+ main(args,pipe,fantasytalking,wav2vec_processor,wav2vec)
model.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffsynth.models.wan_video_dit import flash_attention, WanModel
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+ import torch
5
+ import os
6
+ from safetensors import safe_open
7
+
8
+
9
+ class AudioProjModel(nn.Module):
10
+ def __init__(self, audio_in_dim=1024, cross_attention_dim=1024):
11
+ super().__init__()
12
+ self.cross_attention_dim = cross_attention_dim
13
+ self.proj = torch.nn.Linear(audio_in_dim, cross_attention_dim, bias=False)
14
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
15
+
16
+ def forward(self, audio_embeds):
17
+ context_tokens = self.proj(audio_embeds)
18
+ context_tokens = self.norm(context_tokens)
19
+ return context_tokens # [B,L,C]
20
+
21
+
22
+ class WanCrossAttentionProcessor(nn.Module):
23
+ def __init__(self, context_dim, hidden_dim):
24
+ super().__init__()
25
+
26
+ self.context_dim = context_dim
27
+ self.hidden_dim = hidden_dim
28
+
29
+ self.k_proj = nn.Linear(context_dim, hidden_dim, bias=False)
30
+ self.v_proj = nn.Linear(context_dim, hidden_dim, bias=False)
31
+
32
+ nn.init.zeros_(self.k_proj.weight)
33
+ nn.init.zeros_(self.v_proj.weight)
34
+
35
+ def __call__(
36
+ self,
37
+ attn: nn.Module,
38
+ x: torch.Tensor,
39
+ context: torch.Tensor,
40
+ context_lens: torch.Tensor,
41
+ audio_proj: torch.Tensor,
42
+ audio_context_lens: torch.Tensor,
43
+ latents_num_frames: int = 21,
44
+ audio_scale: float = 1.0,
45
+ ) -> torch.Tensor:
46
+ """
47
+ x: [B, L1, C].
48
+ context: [B, L2, C].
49
+ context_lens: [B].
50
+ audio_proj: [B, 21, L3, C]
51
+ audio_context_lens: [B*21].
52
+ """
53
+ context_img = context[:, :257]
54
+ context = context[:, 257:]
55
+ b, n, d = x.size(0), attn.num_heads, attn.head_dim
56
+
57
+ # compute query, key, value
58
+ q = attn.norm_q(attn.q(x)).view(b, -1, n, d)
59
+ k = attn.norm_k(attn.k(context)).view(b, -1, n, d)
60
+ v = attn.v(context).view(b, -1, n, d)
61
+ k_img = attn.norm_k_img(attn.k_img(context_img)).view(b, -1, n, d)
62
+ v_img = attn.v_img(context_img).view(b, -1, n, d)
63
+ img_x = flash_attention(q, k_img, v_img, k_lens=None)
64
+ # compute attention
65
+ x = flash_attention(q, k, v, k_lens=context_lens)
66
+ x = x.flatten(2)
67
+ img_x = img_x.flatten(2)
68
+
69
+ if len(audio_proj.shape) == 4:
70
+ audio_q = q.view(b * latents_num_frames, -1, n, d) # [b, 21, l1, n, d]
71
+ ip_key = self.k_proj(audio_proj).view(b * latents_num_frames, -1, n, d)
72
+ ip_value = self.v_proj(audio_proj).view(b * latents_num_frames, -1, n, d)
73
+ audio_x = flash_attention(
74
+ audio_q, ip_key, ip_value, k_lens=audio_context_lens
75
+ )
76
+ audio_x = audio_x.view(b, q.size(1), n, d)
77
+ audio_x = audio_x.flatten(2)
78
+ elif len(audio_proj.shape) == 3:
79
+ ip_key = self.k_proj(audio_proj).view(b, -1, n, d)
80
+ ip_value = self.v_proj(audio_proj).view(b, -1, n, d)
81
+ audio_x = flash_attention(q, ip_key, ip_value, k_lens=audio_context_lens)
82
+ audio_x = audio_x.flatten(2)
83
+ # output
84
+ x = x + img_x + audio_x * audio_scale
85
+ x = attn.o(x)
86
+ return x
87
+
88
+
89
+ class FantasyTalkingAudioConditionModel(nn.Module):
90
+ def __init__(self, wan_dit: WanModel, audio_in_dim: int, audio_proj_dim: int):
91
+ super().__init__()
92
+
93
+ self.audio_in_dim = audio_in_dim
94
+ self.audio_proj_dim = audio_proj_dim
95
+
96
+ # audio proj model
97
+ self.proj_model = self.init_proj(self.audio_proj_dim)
98
+ self.set_audio_processor(wan_dit)
99
+
100
+ def init_proj(self, cross_attention_dim=5120):
101
+ proj_model = AudioProjModel(
102
+ audio_in_dim=self.audio_in_dim, cross_attention_dim=cross_attention_dim
103
+ )
104
+ return proj_model
105
+
106
+ def set_audio_processor(self, wan_dit):
107
+ attn_procs = {}
108
+ for name in wan_dit.attn_processors.keys():
109
+ attn_procs[name] = WanCrossAttentionProcessor(
110
+ context_dim=self.audio_proj_dim, hidden_dim=wan_dit.dim
111
+ )
112
+ wan_dit.set_attn_processor(attn_procs)
113
+
114
+ def load_audio_processor(self, ip_ckpt: str, wan_dit):
115
+ if os.path.splitext(ip_ckpt)[-1] == ".safetensors":
116
+ state_dict = {"proj_model": {}, "audio_processor": {}}
117
+ with safe_open(ip_ckpt, framework="pt", device="cpu") as f:
118
+ for key in f.keys():
119
+ if key.startswith("proj_model."):
120
+ state_dict["proj_model"][key.replace("proj_model.", "")] = (
121
+ f.get_tensor(key)
122
+ )
123
+ elif key.startswith("audio_processor."):
124
+ state_dict["audio_processor"][
125
+ key.replace("audio_processor.", "")
126
+ ] = f.get_tensor(key)
127
+ else:
128
+ state_dict = torch.load(ip_ckpt, map_location="cpu")
129
+ self.proj_model.load_state_dict(state_dict["proj_model"])
130
+ wan_dit.load_state_dict(state_dict["audio_processor"], strict=False)
131
+
132
+ def get_proj_fea(self, audio_fea=None):
133
+
134
+ return self.proj_model(audio_fea) if audio_fea is not None else None
135
+
136
+ def split_audio_sequence(self, audio_proj_length, num_frames=81):
137
+ """
138
+ Map the audio feature sequence to corresponding latent frame slices.
139
+
140
+ Args:
141
+ audio_proj_length (int): The total length of the audio feature sequence
142
+ (e.g., 173 in audio_proj[1, 173, 768]).
143
+ num_frames (int): The number of video frames in the training data (default: 81).
144
+
145
+ Returns:
146
+ list: A list of [start_idx, end_idx] pairs. Each pair represents the index range
147
+ (within the audio feature sequence) corresponding to a latent frame.
148
+ """
149
+ # Average number of tokens per original video frame
150
+ tokens_per_frame = audio_proj_length / num_frames
151
+
152
+ # Each latent frame covers 4 video frames, and we want the center
153
+ tokens_per_latent_frame = tokens_per_frame * 4
154
+ half_tokens = int(tokens_per_latent_frame / 2)
155
+
156
+ pos_indices = []
157
+ for i in range(int((num_frames - 1) / 4) + 1):
158
+ if i == 0:
159
+ pos_indices.append(0)
160
+ else:
161
+ start_token = tokens_per_frame * ((i - 1) * 4 + 1)
162
+ end_token = tokens_per_frame * (i * 4 + 1)
163
+ center_token = int((start_token + end_token) / 2) - 1
164
+ pos_indices.append(center_token)
165
+
166
+ # Build index ranges centered around each position
167
+ pos_idx_ranges = [[idx - half_tokens, idx + half_tokens] for idx in pos_indices]
168
+
169
+ # Adjust the first range to avoid negative start index
170
+ pos_idx_ranges[0] = [
171
+ -(half_tokens * 2 - pos_idx_ranges[1][0]),
172
+ pos_idx_ranges[1][0],
173
+ ]
174
+
175
+ return pos_idx_ranges
176
+
177
+ def split_tensor_with_padding(self, input_tensor, pos_idx_ranges, expand_length=0):
178
+ """
179
+ Split the input tensor into subsequences based on index ranges, and apply right-side zero-padding
180
+ if the range exceeds the input boundaries.
181
+
182
+ Args:
183
+ input_tensor (Tensor): Input audio tensor of shape [1, L, 768].
184
+ pos_idx_ranges (list): A list of index ranges, e.g. [[-7, 1], [1, 9], ..., [165, 173]].
185
+ expand_length (int): Number of tokens to expand on both sides of each subsequence.
186
+
187
+ Returns:
188
+ sub_sequences (Tensor): A tensor of shape [1, F, L, 768], where L is the length after padding.
189
+ Each element is a padded subsequence.
190
+ k_lens (Tensor): A tensor of shape [F], representing the actual (unpadded) length of each subsequence.
191
+ Useful for ignoring padding tokens in attention masks.
192
+ """
193
+ pos_idx_ranges = [
194
+ [idx[0] - expand_length, idx[1] + expand_length] for idx in pos_idx_ranges
195
+ ]
196
+ sub_sequences = []
197
+ seq_len = input_tensor.size(1) # 173
198
+ max_valid_idx = seq_len - 1 # 172
199
+ k_lens_list = []
200
+ for start, end in pos_idx_ranges:
201
+ # Calculate the fill amount
202
+ pad_front = max(-start, 0)
203
+ pad_back = max(end - max_valid_idx, 0)
204
+
205
+ # Calculate the start and end indices of the valid part
206
+ valid_start = max(start, 0)
207
+ valid_end = min(end, max_valid_idx)
208
+
209
+ # Extract the valid part
210
+ if valid_start <= valid_end:
211
+ valid_part = input_tensor[:, valid_start : valid_end + 1, :]
212
+ else:
213
+ valid_part = input_tensor.new_zeros(
214
+ (1, 0, input_tensor.size(2))
215
+ )
216
+
217
+ # In the sequence dimension (the 1st dimension) perform padding
218
+ padded_subseq = F.pad(
219
+ valid_part,
220
+ (0, 0, 0, pad_back + pad_front, 0, 0),
221
+ mode="constant",
222
+ value=0,
223
+ )
224
+ k_lens_list.append(padded_subseq.size(-2) - pad_back - pad_front)
225
+
226
+ sub_sequences.append(padded_subseq)
227
+ return torch.stack(sub_sequences, dim=1), torch.tensor(
228
+ k_lens_list, dtype=torch.long
229
+ )
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchvision
3
+ cupy-cuda12x
4
+ transformers==4.46.2
5
+ controlnet-aux==0.0.7
6
+ imageio
7
+ imageio[ffmpeg]
8
+ safetensors
9
+ einops
10
+ sentencepiece
11
+ protobuf
12
+ modelscope
13
+ ftfy
14
+ librosa
utils.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import imageio, librosa
2
+ import torch
3
+ from PIL import Image
4
+ from tqdm import tqdm
5
+ import numpy as np
6
+
7
+
8
+ def resize_image_by_longest_edge(image_path, target_size):
9
+ image = Image.open(image_path).convert("RGB")
10
+ width, height = image.size
11
+ scale = target_size / max(width, height)
12
+ new_size = (int(width * scale), int(height * scale))
13
+ return image.resize(new_size, Image.LANCZOS)
14
+
15
+
16
+ def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None):
17
+ writer = imageio.get_writer(
18
+ save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params
19
+ )
20
+ for frame in tqdm(frames, desc="Saving video"):
21
+ frame = np.array(frame)
22
+ writer.append_data(frame)
23
+ writer.close()
24
+
25
+
26
+ def get_audio_features(wav2vec, audio_processor, audio_path, fps, num_frames):
27
+ sr = 16000
28
+ audio_input, sample_rate = librosa.load(audio_path, sr=sr) # 采样率为 16kHz
29
+
30
+ start_time = 0
31
+ # end_time = (0 + (num_frames - 1) * 1) / fps
32
+ end_time = num_frames / fps
33
+
34
+ start_sample = int(start_time * sr)
35
+ end_sample = int(end_time * sr)
36
+
37
+ try:
38
+ audio_segment = audio_input[start_sample:end_sample]
39
+ except:
40
+ audio_segment = audio_input
41
+
42
+ input_values = audio_processor(
43
+ audio_segment, sampling_rate=sample_rate, return_tensors="pt"
44
+ ).input_values.to("cuda")
45
+
46
+ with torch.no_grad():
47
+ fea = wav2vec(input_values).last_hidden_state
48
+
49
+ return fea