Spaces:
Running
on
Zero
Running
on
Zero
wangmengchao
commited on
Commit
·
282b272
1
Parent(s):
a893799
init
Browse files- app.py +312 -0
- diffsynth/__init__.py +5 -0
- diffsynth/configs/__init__.py +0 -0
- diffsynth/configs/model_config.py +650 -0
- diffsynth/data/__init__.py +1 -0
- diffsynth/data/video.py +173 -0
- diffsynth/pipelines/__init__.py +1 -0
- diffsynth/pipelines/base.py +127 -0
- diffsynth/pipelines/wan_video.py +290 -0
- diffsynth/prompters/__init__.py +1 -0
- diffsynth/prompters/base_prompter.py +70 -0
- diffsynth/prompters/wan_prompter.py +108 -0
- diffsynth/schedulers/__init__.py +3 -0
- diffsynth/schedulers/continuous_ode.py +59 -0
- diffsynth/schedulers/ddim.py +105 -0
- diffsynth/schedulers/flow_match.py +79 -0
- diffsynth/vram_management/__init__.py +1 -0
- diffsynth/vram_management/layers.py +95 -0
- infer.py +214 -0
- model.py +229 -0
- requirements.txt +14 -0
- utils.py +49 -0
app.py
ADDED
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from pathlib import Path
|
3 |
+
import argparse
|
4 |
+
from datetime import datetime
|
5 |
+
import librosa
|
6 |
+
from infer import load_models,main
|
7 |
+
|
8 |
+
|
9 |
+
pipe,fantasytalking,wav2vec_processor,wav2vec = None,None,None,None
|
10 |
+
|
11 |
+
def generate_video(
|
12 |
+
image_path,
|
13 |
+
audio_path,
|
14 |
+
prompt,
|
15 |
+
prompt_cfg_scale,
|
16 |
+
audio_cfg_scale,
|
17 |
+
audio_weight,
|
18 |
+
image_size,
|
19 |
+
max_num_frames,
|
20 |
+
inference_steps,
|
21 |
+
seed,
|
22 |
+
):
|
23 |
+
# Create the temp directory if it doesn't exist
|
24 |
+
output_dir = Path("./output")
|
25 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
26 |
+
|
27 |
+
# Convert paths to absolute Path objects and normalize them
|
28 |
+
print(image_path)
|
29 |
+
image_path = Path(image_path).absolute().as_posix()
|
30 |
+
audio_path = Path(audio_path).absolute().as_posix()
|
31 |
+
|
32 |
+
# Parse the arguments
|
33 |
+
|
34 |
+
args = create_args(
|
35 |
+
image_path=image_path,
|
36 |
+
audio_path=audio_path,
|
37 |
+
prompt=prompt,
|
38 |
+
output_dir=str(output_dir),
|
39 |
+
audio_weight=audio_weight,
|
40 |
+
prompt_cfg_scale=prompt_cfg_scale,
|
41 |
+
audio_cfg_scale=audio_cfg_scale,
|
42 |
+
image_size=image_size,
|
43 |
+
max_num_frames=max_num_frames,
|
44 |
+
inference_steps=inference_steps,
|
45 |
+
seed=seed,
|
46 |
+
)
|
47 |
+
|
48 |
+
try:
|
49 |
+
global pipe, fantasytalking, wav2vec_processor, wav2vec
|
50 |
+
if pipe is None:
|
51 |
+
pipe,fantasytalking,wav2vec_processor,wav2vec = load_models(args)
|
52 |
+
output_path=main(
|
53 |
+
args,pipe,fantasytalking,wav2vec_processor,wav2vec
|
54 |
+
)
|
55 |
+
return output_path # Ensure the output path is returned
|
56 |
+
except Exception as e:
|
57 |
+
print(f"Error during processing: {str(e)}")
|
58 |
+
raise gr.Error(f"Error during processing: {str(e)}")
|
59 |
+
|
60 |
+
|
61 |
+
def create_args(
|
62 |
+
image_path: str,
|
63 |
+
audio_path: str,
|
64 |
+
prompt: str,
|
65 |
+
output_dir: str,
|
66 |
+
audio_weight: float,
|
67 |
+
prompt_cfg_scale: float,
|
68 |
+
audio_cfg_scale: float,
|
69 |
+
image_size: int,
|
70 |
+
max_num_frames: int,
|
71 |
+
inference_steps: int,
|
72 |
+
seed: int,
|
73 |
+
) -> argparse.Namespace:
|
74 |
+
parser = argparse.ArgumentParser()
|
75 |
+
parser.add_argument(
|
76 |
+
"--wan_model_dir",
|
77 |
+
type=str,
|
78 |
+
default="./models/Wan2.1-I2V-14B-720P",
|
79 |
+
required=False,
|
80 |
+
help="The dir of the Wan I2V 14B model.",
|
81 |
+
)
|
82 |
+
parser.add_argument(
|
83 |
+
"--fantasytalking_model_path",
|
84 |
+
type=str,
|
85 |
+
default="./models/fantasytalking_model.ckpt",
|
86 |
+
required=False,
|
87 |
+
help="The .ckpt path of fantasytalking model.",
|
88 |
+
)
|
89 |
+
parser.add_argument(
|
90 |
+
"--wav2vec_model_dir",
|
91 |
+
type=str,
|
92 |
+
default="./models/wav2vec2-base-960h",
|
93 |
+
required=False,
|
94 |
+
help="The dir of wav2vec model.",
|
95 |
+
)
|
96 |
+
parser.add_argument(
|
97 |
+
"--image_path",
|
98 |
+
type=str,
|
99 |
+
default="./assets/images/woman.png",
|
100 |
+
required=False,
|
101 |
+
help="The path of the image.",
|
102 |
+
)
|
103 |
+
parser.add_argument(
|
104 |
+
"--audio_path",
|
105 |
+
type=str,
|
106 |
+
default="./assets/audios/woman.wav",
|
107 |
+
required=False,
|
108 |
+
help="The path of the audio.",
|
109 |
+
)
|
110 |
+
parser.add_argument(
|
111 |
+
"--prompt",
|
112 |
+
type=str,
|
113 |
+
default="A woman is talking.",
|
114 |
+
required=False,
|
115 |
+
help="prompt.",
|
116 |
+
)
|
117 |
+
parser.add_argument(
|
118 |
+
"--output_dir",
|
119 |
+
type=str,
|
120 |
+
default="./output",
|
121 |
+
help="Dir to save the video.",
|
122 |
+
)
|
123 |
+
parser.add_argument(
|
124 |
+
"--image_size",
|
125 |
+
type=int,
|
126 |
+
default=512,
|
127 |
+
help="The image will be resized proportionally to this size.",
|
128 |
+
)
|
129 |
+
parser.add_argument(
|
130 |
+
"--audio_scale",
|
131 |
+
type=float,
|
132 |
+
default=1.0,
|
133 |
+
help="Image width.",
|
134 |
+
)
|
135 |
+
parser.add_argument(
|
136 |
+
"--prompt_cfg_scale",
|
137 |
+
type=float,
|
138 |
+
default=5.0,
|
139 |
+
required=False,
|
140 |
+
help="prompt cfg scale",
|
141 |
+
)
|
142 |
+
parser.add_argument(
|
143 |
+
"--audio_cfg_scale",
|
144 |
+
type=float,
|
145 |
+
default=5.0,
|
146 |
+
required=False,
|
147 |
+
help="audio cfg scale",
|
148 |
+
)
|
149 |
+
parser.add_argument(
|
150 |
+
"--max_num_frames",
|
151 |
+
type=int,
|
152 |
+
default=81,
|
153 |
+
required=False,
|
154 |
+
help="The maximum frames for generating videos, the audio part exceeding max_num_frames/fps will be truncated.",
|
155 |
+
)
|
156 |
+
parser.add_argument(
|
157 |
+
"--inference_steps",
|
158 |
+
type=int,
|
159 |
+
default=20,
|
160 |
+
required=False,
|
161 |
+
)
|
162 |
+
parser.add_argument(
|
163 |
+
"--fps",
|
164 |
+
type=int,
|
165 |
+
default=23,
|
166 |
+
required=False,
|
167 |
+
)
|
168 |
+
parser.add_argument(
|
169 |
+
"--num_persistent_param_in_dit",
|
170 |
+
type=int,
|
171 |
+
default=None,
|
172 |
+
required=False,
|
173 |
+
help="Maximum parameter quantity retained in video memory, small number to reduce VRAM required"
|
174 |
+
)
|
175 |
+
parser.add_argument(
|
176 |
+
"--seed",
|
177 |
+
type=int,
|
178 |
+
default=1111,
|
179 |
+
required=False,
|
180 |
+
)
|
181 |
+
args = parser.parse_args(
|
182 |
+
[
|
183 |
+
"--image_path",
|
184 |
+
image_path,
|
185 |
+
"--audio_path",
|
186 |
+
audio_path,
|
187 |
+
"--prompt",
|
188 |
+
prompt,
|
189 |
+
"--output_dir",
|
190 |
+
output_dir,
|
191 |
+
"--image_size",
|
192 |
+
str(image_size),
|
193 |
+
"--audio_scale",
|
194 |
+
str(audio_weight),
|
195 |
+
"--prompt_cfg_scale",
|
196 |
+
str(prompt_cfg_scale),
|
197 |
+
"--audio_cfg_scale",
|
198 |
+
str(audio_cfg_scale),
|
199 |
+
"--max_num_frames",
|
200 |
+
str(max_num_frames),
|
201 |
+
"--inference_steps",
|
202 |
+
str(inference_steps),
|
203 |
+
"--seed",
|
204 |
+
str(seed),
|
205 |
+
]
|
206 |
+
)
|
207 |
+
print(args)
|
208 |
+
return args
|
209 |
+
|
210 |
+
|
211 |
+
# Create Gradio interface
|
212 |
+
with gr.Blocks(title="FantasyTalking Video Generation") as demo:
|
213 |
+
gr.Markdown(
|
214 |
+
"""
|
215 |
+
# FantasyTalking: Realistic Talking Portrait Generation via Coherent Motion Synthesis
|
216 |
+
|
217 |
+
<div align="center">
|
218 |
+
<strong> Mengchao Wang1* Qiang Wang1* Fan Jiang1†
|
219 |
+
Yaqi Fan2 Yunpeng Zhang1,2 YongGang Qi2‡
|
220 |
+
Kun Zhao1. Mu Xu1 </strong>
|
221 |
+
</div>
|
222 |
+
|
223 |
+
<div align="center">
|
224 |
+
<strong>1AMAP,Alibaba Group 2Beijing University of Posts and Telecommunications</strong>
|
225 |
+
</div>
|
226 |
+
|
227 |
+
<div style="display:flex;justify-content:center;column-gap:4px;">
|
228 |
+
<a href="https://github.com/Fantasy-AMAP/fantasy-talking">
|
229 |
+
<img src='https://img.shields.io/badge/GitHub-Repo-blue'>
|
230 |
+
</a>
|
231 |
+
<a href="https://arxiv.org/abs/2504.04842">
|
232 |
+
<img src='https://img.shields.io/badge/ArXiv-Paper-red'>
|
233 |
+
</a>
|
234 |
+
</div>
|
235 |
+
"""
|
236 |
+
)
|
237 |
+
|
238 |
+
with gr.Row():
|
239 |
+
with gr.Column():
|
240 |
+
image_input = gr.Image(label="Input Image", type="filepath")
|
241 |
+
audio_input = gr.Audio(label="Input Audio", type="filepath")
|
242 |
+
prompt_input = gr.Text(label="Input Prompt")
|
243 |
+
with gr.Row():
|
244 |
+
prompt_cfg_scale = gr.Slider(
|
245 |
+
minimum=1.0,
|
246 |
+
maximum=9.0,
|
247 |
+
value=5.0,
|
248 |
+
step=0.5,
|
249 |
+
label="Prompt CFG Scale",
|
250 |
+
)
|
251 |
+
audio_cfg_scale = gr.Slider(
|
252 |
+
minimum=1.0,
|
253 |
+
maximum=9.0,
|
254 |
+
value=5.0,
|
255 |
+
step=0.5,
|
256 |
+
label="Audio CFG Scale",
|
257 |
+
)
|
258 |
+
audio_weight = gr.Slider(
|
259 |
+
minimum=0.1,
|
260 |
+
maximum=3.0,
|
261 |
+
value=1.0,
|
262 |
+
step=0.1,
|
263 |
+
label="Audio Weight",
|
264 |
+
)
|
265 |
+
with gr.Row():
|
266 |
+
image_size = gr.Number(
|
267 |
+
value=512, label="Width/Height Maxsize", precision=0
|
268 |
+
)
|
269 |
+
max_num_frames = gr.Number(
|
270 |
+
value=81, label="The Maximum Frames", precision=0
|
271 |
+
)
|
272 |
+
inference_steps = gr.Slider(
|
273 |
+
minimum=1, maximum=50, value=20, step=1, label="Inference Steps"
|
274 |
+
)
|
275 |
+
|
276 |
+
with gr.Row():
|
277 |
+
seed = gr.Number(value=1247, label="Random Seed", precision=0)
|
278 |
+
|
279 |
+
process_btn = gr.Button("Generate Video")
|
280 |
+
|
281 |
+
with gr.Column():
|
282 |
+
video_output = gr.Video(label="Output Video")
|
283 |
+
|
284 |
+
gr.Examples(
|
285 |
+
examples=[
|
286 |
+
[
|
287 |
+
"/home/wangmengchao.wmc/code/fantasytalking/assets/images/woman.png",
|
288 |
+
"/home/wangmengchao.wmc/code/fantasytalking/assets/audios/woman.wav",
|
289 |
+
],
|
290 |
+
],
|
291 |
+
inputs=[image_input, audio_input],
|
292 |
+
)
|
293 |
+
|
294 |
+
process_btn.click(
|
295 |
+
fn=generate_video,
|
296 |
+
inputs=[
|
297 |
+
image_input,
|
298 |
+
audio_input,
|
299 |
+
prompt_input,
|
300 |
+
prompt_cfg_scale,
|
301 |
+
audio_cfg_scale,
|
302 |
+
audio_weight,
|
303 |
+
image_size,
|
304 |
+
max_num_frames,
|
305 |
+
inference_steps,
|
306 |
+
seed,
|
307 |
+
],
|
308 |
+
outputs=video_output,
|
309 |
+
)
|
310 |
+
|
311 |
+
if __name__ == "__main__":
|
312 |
+
demo.launch(inbrowser=True, share=True)
|
diffsynth/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .data import *
|
2 |
+
from .models import *
|
3 |
+
from .prompters import *
|
4 |
+
from .schedulers import *
|
5 |
+
from .pipelines import *
|
diffsynth/configs/__init__.py
ADDED
File without changes
|
diffsynth/configs/model_config.py
ADDED
@@ -0,0 +1,650 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing_extensions import Literal, TypeAlias
|
2 |
+
|
3 |
+
from ..models.wan_video_dit import WanModel
|
4 |
+
from ..models.wan_video_text_encoder import WanTextEncoder
|
5 |
+
from ..models.wan_video_image_encoder import WanImageEncoder
|
6 |
+
from ..models.wan_video_vae import WanVideoVAE
|
7 |
+
|
8 |
+
|
9 |
+
model_loader_configs = [
|
10 |
+
# These configs are provided for detecting model type automatically.
|
11 |
+
# The format is (state_dict_keys_hash, state_dict_keys_hash_with_shape, model_names, model_classes, model_resource)
|
12 |
+
(None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"),
|
13 |
+
(None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"),
|
14 |
+
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
|
15 |
+
(None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"),
|
16 |
+
(None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"),
|
17 |
+
(None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"),
|
18 |
+
(None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"),
|
19 |
+
]
|
20 |
+
huggingface_model_loader_configs = [
|
21 |
+
# These configs are provided for detecting model type automatically.
|
22 |
+
# The format is (architecture_in_huggingface_config, huggingface_lib, model_name, redirected_architecture)
|
23 |
+
("ChatGLMModel", "diffsynth.models.kolors_text_encoder", "kolors_text_encoder", None),
|
24 |
+
("MarianMTModel", "transformers.models.marian.modeling_marian", "translator", None),
|
25 |
+
("BloomForCausalLM", "transformers.models.bloom.modeling_bloom", "beautiful_prompt", None),
|
26 |
+
("Qwen2ForCausalLM", "transformers.models.qwen2.modeling_qwen2", "qwen_prompt", None),
|
27 |
+
# ("LlamaForCausalLM", "transformers.models.llama.modeling_llama", "omost_prompt", None),
|
28 |
+
("T5EncoderModel", "diffsynth.models.flux_text_encoder", "flux_text_encoder_2", "FluxTextEncoder2"),
|
29 |
+
("CogVideoXTransformer3DModel", "diffsynth.models.cog_dit", "cog_dit", "CogDiT"),
|
30 |
+
("SiglipModel", "transformers.models.siglip.modeling_siglip", "siglip_vision_model", "SiglipVisionModel"),
|
31 |
+
("LlamaForCausalLM", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoLLMEncoder"),
|
32 |
+
("Step1Model", "diffsynth.models.stepvideo_text_encoder", "stepvideo_text_encoder_2", "STEP1TextEncoder"),
|
33 |
+
]
|
34 |
+
patch_model_loader_configs = [
|
35 |
+
# These configs are provided for detecting model type automatically.
|
36 |
+
# The format is (state_dict_keys_hash_with_shape, model_name, model_class, extra_kwargs)
|
37 |
+
# ("9a4ab6869ac9b7d6e31f9854e397c867", ["svd_unet"], [SVDUNet], {"add_positional_conv": 128}),
|
38 |
+
]
|
39 |
+
|
40 |
+
preset_models_on_huggingface = {
|
41 |
+
"HunyuanDiT": [
|
42 |
+
("Tencent-Hunyuan/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
|
43 |
+
("Tencent-Hunyuan/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
|
44 |
+
("Tencent-Hunyuan/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
|
45 |
+
("Tencent-Hunyuan/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
|
46 |
+
],
|
47 |
+
"stable-video-diffusion-img2vid-xt": [
|
48 |
+
("stabilityai/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
|
49 |
+
],
|
50 |
+
"ExVideo-SVD-128f-v1": [
|
51 |
+
("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
|
52 |
+
],
|
53 |
+
# Stable Diffusion
|
54 |
+
"StableDiffusion_v15": [
|
55 |
+
("benjamin-paine/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
|
56 |
+
],
|
57 |
+
"DreamShaper_8": [
|
58 |
+
("Yntec/Dreamshaper8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
|
59 |
+
],
|
60 |
+
# Textual Inversion
|
61 |
+
"TextualInversion_VeryBadImageNegative_v1.3": [
|
62 |
+
("gemasai/verybadimagenegative_v1.3", "verybadimagenegative_v1.3.pt", "models/textual_inversion"),
|
63 |
+
],
|
64 |
+
# Stable Diffusion XL
|
65 |
+
"StableDiffusionXL_v1": [
|
66 |
+
("stabilityai/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors", "models/stable_diffusion_xl"),
|
67 |
+
],
|
68 |
+
"BluePencilXL_v200": [
|
69 |
+
("frankjoshua/bluePencilXL_v200", "bluePencilXL_v200.safetensors", "models/stable_diffusion_xl"),
|
70 |
+
],
|
71 |
+
"StableDiffusionXL_Turbo": [
|
72 |
+
("stabilityai/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
|
73 |
+
],
|
74 |
+
# Stable Diffusion 3
|
75 |
+
"StableDiffusion3": [
|
76 |
+
("stabilityai/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
|
77 |
+
],
|
78 |
+
"StableDiffusion3_without_T5": [
|
79 |
+
("stabilityai/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
|
80 |
+
],
|
81 |
+
# ControlNet
|
82 |
+
"ControlNet_v11f1p_sd15_depth": [
|
83 |
+
("lllyasviel/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
|
84 |
+
("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
|
85 |
+
],
|
86 |
+
"ControlNet_v11p_sd15_softedge": [
|
87 |
+
("lllyasviel/ControlNet-v1-1", "control_v11p_sd15_softedge.pth", "models/ControlNet"),
|
88 |
+
("lllyasviel/Annotators", "ControlNetHED.pth", "models/Annotators")
|
89 |
+
],
|
90 |
+
"ControlNet_v11f1e_sd15_tile": [
|
91 |
+
("lllyasviel/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")
|
92 |
+
],
|
93 |
+
"ControlNet_v11p_sd15_lineart": [
|
94 |
+
("lllyasviel/ControlNet-v1-1", "control_v11p_sd15_lineart.pth", "models/ControlNet"),
|
95 |
+
("lllyasviel/Annotators", "sk_model.pth", "models/Annotators"),
|
96 |
+
("lllyasviel/Annotators", "sk_model2.pth", "models/Annotators")
|
97 |
+
],
|
98 |
+
"ControlNet_union_sdxl_promax": [
|
99 |
+
("xinsir/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", "models/ControlNet/controlnet_union"),
|
100 |
+
("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
|
101 |
+
],
|
102 |
+
# AnimateDiff
|
103 |
+
"AnimateDiff_v2": [
|
104 |
+
("guoyww/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
|
105 |
+
],
|
106 |
+
"AnimateDiff_xl_beta": [
|
107 |
+
("guoyww/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
|
108 |
+
],
|
109 |
+
|
110 |
+
# Qwen Prompt
|
111 |
+
"QwenPrompt": [
|
112 |
+
("Qwen/Qwen2-1.5B-Instruct", "config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
113 |
+
("Qwen/Qwen2-1.5B-Instruct", "generation_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
114 |
+
("Qwen/Qwen2-1.5B-Instruct", "model.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
115 |
+
("Qwen/Qwen2-1.5B-Instruct", "special_tokens_map.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
116 |
+
("Qwen/Qwen2-1.5B-Instruct", "tokenizer.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
117 |
+
("Qwen/Qwen2-1.5B-Instruct", "tokenizer_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
118 |
+
("Qwen/Qwen2-1.5B-Instruct", "merges.txt", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
119 |
+
("Qwen/Qwen2-1.5B-Instruct", "vocab.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
120 |
+
],
|
121 |
+
# Beautiful Prompt
|
122 |
+
"BeautifulPrompt": [
|
123 |
+
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
124 |
+
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
125 |
+
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
126 |
+
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
127 |
+
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
128 |
+
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
129 |
+
],
|
130 |
+
# Omost prompt
|
131 |
+
"OmostPrompt":[
|
132 |
+
("lllyasviel/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
133 |
+
("lllyasviel/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
134 |
+
("lllyasviel/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
135 |
+
("lllyasviel/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
136 |
+
("lllyasviel/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
137 |
+
("lllyasviel/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
138 |
+
("lllyasviel/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
139 |
+
("lllyasviel/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
140 |
+
],
|
141 |
+
# Translator
|
142 |
+
"opus-mt-zh-en": [
|
143 |
+
("Helsinki-NLP/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
|
144 |
+
("Helsinki-NLP/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
|
145 |
+
("Helsinki-NLP/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
|
146 |
+
("Helsinki-NLP/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
|
147 |
+
("Helsinki-NLP/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
|
148 |
+
("Helsinki-NLP/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
|
149 |
+
("Helsinki-NLP/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
|
150 |
+
("Helsinki-NLP/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
|
151 |
+
],
|
152 |
+
# IP-Adapter
|
153 |
+
"IP-Adapter-SD": [
|
154 |
+
("h94/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
|
155 |
+
("h94/IP-Adapter", "models/ip-adapter_sd15.bin", "models/IpAdapter/stable_diffusion"),
|
156 |
+
],
|
157 |
+
"IP-Adapter-SDXL": [
|
158 |
+
("h94/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
|
159 |
+
("h94/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
|
160 |
+
],
|
161 |
+
"SDXL-vae-fp16-fix": [
|
162 |
+
("madebyollin/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
|
163 |
+
],
|
164 |
+
# Kolors
|
165 |
+
"Kolors": [
|
166 |
+
("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
|
167 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
|
168 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
169 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
170 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
171 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
172 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
173 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
174 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
175 |
+
("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
|
176 |
+
("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
|
177 |
+
],
|
178 |
+
# FLUX
|
179 |
+
"FLUX.1-dev": [
|
180 |
+
("black-forest-labs/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
|
181 |
+
("black-forest-labs/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
182 |
+
("black-forest-labs/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
183 |
+
("black-forest-labs/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
184 |
+
("black-forest-labs/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
185 |
+
("black-forest-labs/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
|
186 |
+
("black-forest-labs/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
|
187 |
+
],
|
188 |
+
"InstantX/FLUX.1-dev-IP-Adapter": {
|
189 |
+
"file_list": [
|
190 |
+
("InstantX/FLUX.1-dev-IP-Adapter", "ip-adapter.bin", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter"),
|
191 |
+
("google/siglip-so400m-patch14-384", "model.safetensors", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
|
192 |
+
("google/siglip-so400m-patch14-384", "config.json", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
|
193 |
+
],
|
194 |
+
"load_path": [
|
195 |
+
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
|
196 |
+
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
|
197 |
+
],
|
198 |
+
},
|
199 |
+
# RIFE
|
200 |
+
"RIFE": [
|
201 |
+
("AlexWortega/RIFE", "flownet.pkl", "models/RIFE"),
|
202 |
+
],
|
203 |
+
# CogVideo
|
204 |
+
"CogVideoX-5B": [
|
205 |
+
("THUDM/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
206 |
+
("THUDM/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
207 |
+
("THUDM/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
208 |
+
("THUDM/CogVideoX-5b", "text_encoder/model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
209 |
+
("THUDM/CogVideoX-5b", "transformer/config.json", "models/CogVideo/CogVideoX-5b/transformer"),
|
210 |
+
("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/transformer"),
|
211 |
+
("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
|
212 |
+
("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
|
213 |
+
("THUDM/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"),
|
214 |
+
],
|
215 |
+
# Stable Diffusion 3.5
|
216 |
+
"StableDiffusion3.5-large": [
|
217 |
+
("stabilityai/stable-diffusion-3.5-large", "sd3.5_large.safetensors", "models/stable_diffusion_3"),
|
218 |
+
("stabilityai/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
|
219 |
+
("stabilityai/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
|
220 |
+
("stabilityai/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
|
221 |
+
],
|
222 |
+
}
|
223 |
+
preset_models_on_modelscope = {
|
224 |
+
# Hunyuan DiT
|
225 |
+
"HunyuanDiT": [
|
226 |
+
("modelscope/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
|
227 |
+
("modelscope/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
|
228 |
+
("modelscope/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
|
229 |
+
("modelscope/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
|
230 |
+
],
|
231 |
+
# Stable Video Diffusion
|
232 |
+
"stable-video-diffusion-img2vid-xt": [
|
233 |
+
("AI-ModelScope/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
|
234 |
+
],
|
235 |
+
# ExVideo
|
236 |
+
"ExVideo-SVD-128f-v1": [
|
237 |
+
("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
|
238 |
+
],
|
239 |
+
"ExVideo-CogVideoX-LoRA-129f-v1": [
|
240 |
+
("ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1", "ExVideo-CogVideoX-LoRA-129f-v1.safetensors", "models/lora"),
|
241 |
+
],
|
242 |
+
# Stable Diffusion
|
243 |
+
"StableDiffusion_v15": [
|
244 |
+
("AI-ModelScope/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
|
245 |
+
],
|
246 |
+
"DreamShaper_8": [
|
247 |
+
("sd_lora/dreamshaper_8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
|
248 |
+
],
|
249 |
+
"AingDiffusion_v12": [
|
250 |
+
("sd_lora/aingdiffusion_v12", "aingdiffusion_v12.safetensors", "models/stable_diffusion"),
|
251 |
+
],
|
252 |
+
"Flat2DAnimerge_v45Sharp": [
|
253 |
+
("sd_lora/Flat-2D-Animerge", "flat2DAnimerge_v45Sharp.safetensors", "models/stable_diffusion"),
|
254 |
+
],
|
255 |
+
# Textual Inversion
|
256 |
+
"TextualInversion_VeryBadImageNegative_v1.3": [
|
257 |
+
("sd_lora/verybadimagenegative_v1.3", "verybadimagenegative_v1.3.pt", "models/textual_inversion"),
|
258 |
+
],
|
259 |
+
# Stable Diffusion XL
|
260 |
+
"StableDiffusionXL_v1": [
|
261 |
+
("AI-ModelScope/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors", "models/stable_diffusion_xl"),
|
262 |
+
],
|
263 |
+
"BluePencilXL_v200": [
|
264 |
+
("sd_lora/bluePencilXL_v200", "bluePencilXL_v200.safetensors", "models/stable_diffusion_xl"),
|
265 |
+
],
|
266 |
+
"StableDiffusionXL_Turbo": [
|
267 |
+
("AI-ModelScope/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
|
268 |
+
],
|
269 |
+
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0": [
|
270 |
+
("sd_lora/zyd232_ChineseInkStyle_SDXL_v1_0", "zyd232_ChineseInkStyle_SDXL_v1_0.safetensors", "models/lora"),
|
271 |
+
],
|
272 |
+
# Stable Diffusion 3
|
273 |
+
"StableDiffusion3": [
|
274 |
+
("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
|
275 |
+
],
|
276 |
+
"StableDiffusion3_without_T5": [
|
277 |
+
("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
|
278 |
+
],
|
279 |
+
# ControlNet
|
280 |
+
"ControlNet_v11f1p_sd15_depth": [
|
281 |
+
("AI-ModelScope/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
|
282 |
+
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
|
283 |
+
],
|
284 |
+
"ControlNet_v11p_sd15_softedge": [
|
285 |
+
("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_softedge.pth", "models/ControlNet"),
|
286 |
+
("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators")
|
287 |
+
],
|
288 |
+
"ControlNet_v11f1e_sd15_tile": [
|
289 |
+
("AI-ModelScope/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")
|
290 |
+
],
|
291 |
+
"ControlNet_v11p_sd15_lineart": [
|
292 |
+
("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_lineart.pth", "models/ControlNet"),
|
293 |
+
("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
|
294 |
+
("sd_lora/Annotators", "sk_model2.pth", "models/Annotators")
|
295 |
+
],
|
296 |
+
"ControlNet_union_sdxl_promax": [
|
297 |
+
("AI-ModelScope/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", "models/ControlNet/controlnet_union"),
|
298 |
+
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
|
299 |
+
],
|
300 |
+
"Annotators:Depth": [
|
301 |
+
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators"),
|
302 |
+
],
|
303 |
+
"Annotators:Softedge": [
|
304 |
+
("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators"),
|
305 |
+
],
|
306 |
+
"Annotators:Lineart": [
|
307 |
+
("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
|
308 |
+
("sd_lora/Annotators", "sk_model2.pth", "models/Annotators"),
|
309 |
+
],
|
310 |
+
"Annotators:Normal": [
|
311 |
+
("sd_lora/Annotators", "scannet.pt", "models/Annotators"),
|
312 |
+
],
|
313 |
+
"Annotators:Openpose": [
|
314 |
+
("sd_lora/Annotators", "body_pose_model.pth", "models/Annotators"),
|
315 |
+
("sd_lora/Annotators", "facenet.pth", "models/Annotators"),
|
316 |
+
("sd_lora/Annotators", "hand_pose_model.pth", "models/Annotators"),
|
317 |
+
],
|
318 |
+
# AnimateDiff
|
319 |
+
"AnimateDiff_v2": [
|
320 |
+
("Shanghai_AI_Laboratory/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
|
321 |
+
],
|
322 |
+
"AnimateDiff_xl_beta": [
|
323 |
+
("Shanghai_AI_Laboratory/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
|
324 |
+
],
|
325 |
+
# RIFE
|
326 |
+
"RIFE": [
|
327 |
+
("Damo_XR_Lab/cv_rife_video-frame-interpolation", "flownet.pkl", "models/RIFE"),
|
328 |
+
],
|
329 |
+
# Qwen Prompt
|
330 |
+
"QwenPrompt": {
|
331 |
+
"file_list": [
|
332 |
+
("qwen/Qwen2-1.5B-Instruct", "config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
333 |
+
("qwen/Qwen2-1.5B-Instruct", "generation_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
334 |
+
("qwen/Qwen2-1.5B-Instruct", "model.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
335 |
+
("qwen/Qwen2-1.5B-Instruct", "special_tokens_map.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
336 |
+
("qwen/Qwen2-1.5B-Instruct", "tokenizer.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
337 |
+
("qwen/Qwen2-1.5B-Instruct", "tokenizer_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
338 |
+
("qwen/Qwen2-1.5B-Instruct", "merges.txt", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
339 |
+
("qwen/Qwen2-1.5B-Instruct", "vocab.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
340 |
+
],
|
341 |
+
"load_path": [
|
342 |
+
"models/QwenPrompt/qwen2-1.5b-instruct",
|
343 |
+
],
|
344 |
+
},
|
345 |
+
# Beautiful Prompt
|
346 |
+
"BeautifulPrompt": {
|
347 |
+
"file_list": [
|
348 |
+
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
349 |
+
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
350 |
+
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
351 |
+
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
352 |
+
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
353 |
+
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
354 |
+
],
|
355 |
+
"load_path": [
|
356 |
+
"models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
|
357 |
+
],
|
358 |
+
},
|
359 |
+
# Omost prompt
|
360 |
+
"OmostPrompt": {
|
361 |
+
"file_list": [
|
362 |
+
("Omost/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
363 |
+
("Omost/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
364 |
+
("Omost/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
365 |
+
("Omost/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
366 |
+
("Omost/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
367 |
+
("Omost/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
368 |
+
("Omost/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
369 |
+
("Omost/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
370 |
+
],
|
371 |
+
"load_path": [
|
372 |
+
"models/OmostPrompt/omost-llama-3-8b-4bits",
|
373 |
+
],
|
374 |
+
},
|
375 |
+
# Translator
|
376 |
+
"opus-mt-zh-en": {
|
377 |
+
"file_list": [
|
378 |
+
("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
|
379 |
+
("moxying/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
|
380 |
+
("moxying/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
|
381 |
+
("moxying/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
|
382 |
+
("moxying/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
|
383 |
+
("moxying/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
|
384 |
+
("moxying/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
|
385 |
+
("moxying/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
|
386 |
+
],
|
387 |
+
"load_path": [
|
388 |
+
"models/translator/opus-mt-zh-en",
|
389 |
+
],
|
390 |
+
},
|
391 |
+
# IP-Adapter
|
392 |
+
"IP-Adapter-SD": [
|
393 |
+
("AI-ModelScope/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
|
394 |
+
("AI-ModelScope/IP-Adapter", "models/ip-adapter_sd15.bin", "models/IpAdapter/stable_diffusion"),
|
395 |
+
],
|
396 |
+
"IP-Adapter-SDXL": [
|
397 |
+
("AI-ModelScope/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
|
398 |
+
("AI-ModelScope/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
|
399 |
+
],
|
400 |
+
# Kolors
|
401 |
+
"Kolors": {
|
402 |
+
"file_list": [
|
403 |
+
("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
|
404 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
|
405 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
406 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
407 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
408 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
409 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
410 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
411 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
412 |
+
("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
|
413 |
+
("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
|
414 |
+
],
|
415 |
+
"load_path": [
|
416 |
+
"models/kolors/Kolors/text_encoder",
|
417 |
+
"models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors",
|
418 |
+
"models/kolors/Kolors/vae/diffusion_pytorch_model.safetensors",
|
419 |
+
],
|
420 |
+
},
|
421 |
+
"SDXL-vae-fp16-fix": [
|
422 |
+
("AI-ModelScope/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
|
423 |
+
],
|
424 |
+
# FLUX
|
425 |
+
"FLUX.1-dev": {
|
426 |
+
"file_list": [
|
427 |
+
("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
|
428 |
+
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
429 |
+
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
430 |
+
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
431 |
+
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
432 |
+
("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
|
433 |
+
("AI-ModelScope/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
|
434 |
+
],
|
435 |
+
"load_path": [
|
436 |
+
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
|
437 |
+
"models/FLUX/FLUX.1-dev/text_encoder_2",
|
438 |
+
"models/FLUX/FLUX.1-dev/ae.safetensors",
|
439 |
+
"models/FLUX/FLUX.1-dev/flux1-dev.safetensors"
|
440 |
+
],
|
441 |
+
},
|
442 |
+
"FLUX.1-schnell": {
|
443 |
+
"file_list": [
|
444 |
+
("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
|
445 |
+
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
446 |
+
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
447 |
+
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
448 |
+
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
449 |
+
("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
|
450 |
+
("AI-ModelScope/FLUX.1-schnell", "flux1-schnell.safetensors", "models/FLUX/FLUX.1-schnell"),
|
451 |
+
],
|
452 |
+
"load_path": [
|
453 |
+
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
|
454 |
+
"models/FLUX/FLUX.1-dev/text_encoder_2",
|
455 |
+
"models/FLUX/FLUX.1-dev/ae.safetensors",
|
456 |
+
"models/FLUX/FLUX.1-schnell/flux1-schnell.safetensors"
|
457 |
+
],
|
458 |
+
},
|
459 |
+
"InstantX/FLUX.1-dev-Controlnet-Union-alpha": [
|
460 |
+
("InstantX/FLUX.1-dev-Controlnet-Union-alpha", "diffusion_pytorch_model.safetensors", "models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha"),
|
461 |
+
],
|
462 |
+
"jasperai/Flux.1-dev-Controlnet-Depth": [
|
463 |
+
("jasperai/Flux.1-dev-Controlnet-Depth", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Depth"),
|
464 |
+
],
|
465 |
+
"jasperai/Flux.1-dev-Controlnet-Surface-Normals": [
|
466 |
+
("jasperai/Flux.1-dev-Controlnet-Surface-Normals", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Surface-Normals"),
|
467 |
+
],
|
468 |
+
"jasperai/Flux.1-dev-Controlnet-Upscaler": [
|
469 |
+
("jasperai/Flux.1-dev-Controlnet-Upscaler", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler"),
|
470 |
+
],
|
471 |
+
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha": [
|
472 |
+
("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha", "diffusion_pytorch_model.safetensors", "models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha"),
|
473 |
+
],
|
474 |
+
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta": [
|
475 |
+
("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", "diffusion_pytorch_model.safetensors", "models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta"),
|
476 |
+
],
|
477 |
+
"Shakker-Labs/FLUX.1-dev-ControlNet-Depth": [
|
478 |
+
("Shakker-Labs/FLUX.1-dev-ControlNet-Depth", "diffusion_pytorch_model.safetensors", "models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Depth"),
|
479 |
+
],
|
480 |
+
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro": [
|
481 |
+
("Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro", "diffusion_pytorch_model.safetensors", "models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro"),
|
482 |
+
],
|
483 |
+
"InstantX/FLUX.1-dev-IP-Adapter": {
|
484 |
+
"file_list": [
|
485 |
+
("InstantX/FLUX.1-dev-IP-Adapter", "ip-adapter.bin", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter"),
|
486 |
+
("AI-ModelScope/siglip-so400m-patch14-384", "model.safetensors", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
|
487 |
+
("AI-ModelScope/siglip-so400m-patch14-384", "config.json", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
|
488 |
+
],
|
489 |
+
"load_path": [
|
490 |
+
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
|
491 |
+
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
|
492 |
+
],
|
493 |
+
},
|
494 |
+
# ESRGAN
|
495 |
+
"ESRGAN_x4": [
|
496 |
+
("AI-ModelScope/Real-ESRGAN", "RealESRGAN_x4.pth", "models/ESRGAN"),
|
497 |
+
],
|
498 |
+
# RIFE
|
499 |
+
"RIFE": [
|
500 |
+
("AI-ModelScope/RIFE", "flownet.pkl", "models/RIFE"),
|
501 |
+
],
|
502 |
+
# Omnigen
|
503 |
+
"OmniGen-v1": {
|
504 |
+
"file_list": [
|
505 |
+
("BAAI/OmniGen-v1", "vae/diffusion_pytorch_model.safetensors", "models/OmniGen/OmniGen-v1/vae"),
|
506 |
+
("BAAI/OmniGen-v1", "model.safetensors", "models/OmniGen/OmniGen-v1"),
|
507 |
+
("BAAI/OmniGen-v1", "config.json", "models/OmniGen/OmniGen-v1"),
|
508 |
+
("BAAI/OmniGen-v1", "special_tokens_map.json", "models/OmniGen/OmniGen-v1"),
|
509 |
+
("BAAI/OmniGen-v1", "tokenizer_config.json", "models/OmniGen/OmniGen-v1"),
|
510 |
+
("BAAI/OmniGen-v1", "tokenizer.json", "models/OmniGen/OmniGen-v1"),
|
511 |
+
],
|
512 |
+
"load_path": [
|
513 |
+
"models/OmniGen/OmniGen-v1/vae/diffusion_pytorch_model.safetensors",
|
514 |
+
"models/OmniGen/OmniGen-v1/model.safetensors",
|
515 |
+
]
|
516 |
+
},
|
517 |
+
# CogVideo
|
518 |
+
"CogVideoX-5B": {
|
519 |
+
"file_list": [
|
520 |
+
("ZhipuAI/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
521 |
+
("ZhipuAI/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
522 |
+
("ZhipuAI/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
523 |
+
("ZhipuAI/CogVideoX-5b", "text_encoder/model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
524 |
+
("ZhipuAI/CogVideoX-5b", "transformer/config.json", "models/CogVideo/CogVideoX-5b/transformer"),
|
525 |
+
("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/transformer"),
|
526 |
+
("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
|
527 |
+
("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
|
528 |
+
("ZhipuAI/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"),
|
529 |
+
],
|
530 |
+
"load_path": [
|
531 |
+
"models/CogVideo/CogVideoX-5b/text_encoder",
|
532 |
+
"models/CogVideo/CogVideoX-5b/transformer",
|
533 |
+
"models/CogVideo/CogVideoX-5b/vae/diffusion_pytorch_model.safetensors",
|
534 |
+
],
|
535 |
+
},
|
536 |
+
# Stable Diffusion 3.5
|
537 |
+
"StableDiffusion3.5-large": [
|
538 |
+
("AI-ModelScope/stable-diffusion-3.5-large", "sd3.5_large.safetensors", "models/stable_diffusion_3"),
|
539 |
+
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
|
540 |
+
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
|
541 |
+
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
|
542 |
+
],
|
543 |
+
"StableDiffusion3.5-medium": [
|
544 |
+
("AI-ModelScope/stable-diffusion-3.5-medium", "sd3.5_medium.safetensors", "models/stable_diffusion_3"),
|
545 |
+
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
|
546 |
+
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
|
547 |
+
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
|
548 |
+
],
|
549 |
+
"StableDiffusion3.5-large-turbo": [
|
550 |
+
("AI-ModelScope/stable-diffusion-3.5-large-turbo", "sd3.5_large_turbo.safetensors", "models/stable_diffusion_3"),
|
551 |
+
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
|
552 |
+
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
|
553 |
+
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
|
554 |
+
],
|
555 |
+
"HunyuanVideo":{
|
556 |
+
"file_list": [
|
557 |
+
("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideo/text_encoder"),
|
558 |
+
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00001-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
559 |
+
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00002-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
560 |
+
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00003-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
561 |
+
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00004-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
562 |
+
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "config.json", "models/HunyuanVideo/text_encoder_2"),
|
563 |
+
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model.safetensors.index.json", "models/HunyuanVideo/text_encoder_2"),
|
564 |
+
("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/vae/pytorch_model.pt", "models/HunyuanVideo/vae"),
|
565 |
+
("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt", "models/HunyuanVideo/transformers")
|
566 |
+
],
|
567 |
+
"load_path": [
|
568 |
+
"models/HunyuanVideo/text_encoder/model.safetensors",
|
569 |
+
"models/HunyuanVideo/text_encoder_2",
|
570 |
+
"models/HunyuanVideo/vae/pytorch_model.pt",
|
571 |
+
"models/HunyuanVideo/transformers/mp_rank_00_model_states.pt"
|
572 |
+
],
|
573 |
+
},
|
574 |
+
"HunyuanVideo-fp8":{
|
575 |
+
"file_list": [
|
576 |
+
("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideo/text_encoder"),
|
577 |
+
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00001-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
578 |
+
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00002-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
579 |
+
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00003-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
580 |
+
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00004-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
581 |
+
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "config.json", "models/HunyuanVideo/text_encoder_2"),
|
582 |
+
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model.safetensors.index.json", "models/HunyuanVideo/text_encoder_2"),
|
583 |
+
("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/vae/pytorch_model.pt", "models/HunyuanVideo/vae"),
|
584 |
+
("DiffSynth-Studio/HunyuanVideo-safetensors", "model.fp8.safetensors", "models/HunyuanVideo/transformers")
|
585 |
+
],
|
586 |
+
"load_path": [
|
587 |
+
"models/HunyuanVideo/text_encoder/model.safetensors",
|
588 |
+
"models/HunyuanVideo/text_encoder_2",
|
589 |
+
"models/HunyuanVideo/vae/pytorch_model.pt",
|
590 |
+
"models/HunyuanVideo/transformers/model.fp8.safetensors"
|
591 |
+
],
|
592 |
+
},
|
593 |
+
}
|
594 |
+
Preset_model_id: TypeAlias = Literal[
|
595 |
+
"HunyuanDiT",
|
596 |
+
"stable-video-diffusion-img2vid-xt",
|
597 |
+
"ExVideo-SVD-128f-v1",
|
598 |
+
"ExVideo-CogVideoX-LoRA-129f-v1",
|
599 |
+
"StableDiffusion_v15",
|
600 |
+
"DreamShaper_8",
|
601 |
+
"AingDiffusion_v12",
|
602 |
+
"Flat2DAnimerge_v45Sharp",
|
603 |
+
"TextualInversion_VeryBadImageNegative_v1.3",
|
604 |
+
"StableDiffusionXL_v1",
|
605 |
+
"BluePencilXL_v200",
|
606 |
+
"StableDiffusionXL_Turbo",
|
607 |
+
"ControlNet_v11f1p_sd15_depth",
|
608 |
+
"ControlNet_v11p_sd15_softedge",
|
609 |
+
"ControlNet_v11f1e_sd15_tile",
|
610 |
+
"ControlNet_v11p_sd15_lineart",
|
611 |
+
"AnimateDiff_v2",
|
612 |
+
"AnimateDiff_xl_beta",
|
613 |
+
"RIFE",
|
614 |
+
"BeautifulPrompt",
|
615 |
+
"opus-mt-zh-en",
|
616 |
+
"IP-Adapter-SD",
|
617 |
+
"IP-Adapter-SDXL",
|
618 |
+
"StableDiffusion3",
|
619 |
+
"StableDiffusion3_without_T5",
|
620 |
+
"Kolors",
|
621 |
+
"SDXL-vae-fp16-fix",
|
622 |
+
"ControlNet_union_sdxl_promax",
|
623 |
+
"FLUX.1-dev",
|
624 |
+
"FLUX.1-schnell",
|
625 |
+
"InstantX/FLUX.1-dev-Controlnet-Union-alpha",
|
626 |
+
"jasperai/Flux.1-dev-Controlnet-Depth",
|
627 |
+
"jasperai/Flux.1-dev-Controlnet-Surface-Normals",
|
628 |
+
"jasperai/Flux.1-dev-Controlnet-Upscaler",
|
629 |
+
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha",
|
630 |
+
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta",
|
631 |
+
"Shakker-Labs/FLUX.1-dev-ControlNet-Depth",
|
632 |
+
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
|
633 |
+
"InstantX/FLUX.1-dev-IP-Adapter",
|
634 |
+
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0",
|
635 |
+
"QwenPrompt",
|
636 |
+
"OmostPrompt",
|
637 |
+
"ESRGAN_x4",
|
638 |
+
"RIFE",
|
639 |
+
"OmniGen-v1",
|
640 |
+
"CogVideoX-5B",
|
641 |
+
"Annotators:Depth",
|
642 |
+
"Annotators:Softedge",
|
643 |
+
"Annotators:Lineart",
|
644 |
+
"Annotators:Normal",
|
645 |
+
"Annotators:Openpose",
|
646 |
+
"StableDiffusion3.5-large",
|
647 |
+
"StableDiffusion3.5-medium",
|
648 |
+
"HunyuanVideo",
|
649 |
+
"HunyuanVideo-fp8",
|
650 |
+
]
|
diffsynth/data/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .video import VideoData, save_video, save_frames
|
diffsynth/data/video.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import imageio, os
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
from tqdm import tqdm
|
5 |
+
|
6 |
+
|
7 |
+
class LowMemoryVideo:
|
8 |
+
def __init__(self, file_name):
|
9 |
+
self.reader = imageio.get_reader(file_name)
|
10 |
+
|
11 |
+
def __len__(self):
|
12 |
+
return self.reader.count_frames()
|
13 |
+
|
14 |
+
def __getitem__(self, item):
|
15 |
+
return Image.fromarray(np.array(self.reader.get_data(item))).convert("RGB")
|
16 |
+
|
17 |
+
def __del__(self):
|
18 |
+
self.reader.close()
|
19 |
+
|
20 |
+
|
21 |
+
def split_file_name(file_name):
|
22 |
+
result = []
|
23 |
+
number = -1
|
24 |
+
for i in file_name:
|
25 |
+
if ord(i)>=ord("0") and ord(i)<=ord("9"):
|
26 |
+
if number == -1:
|
27 |
+
number = 0
|
28 |
+
number = number*10 + ord(i) - ord("0")
|
29 |
+
else:
|
30 |
+
if number != -1:
|
31 |
+
result.append(number)
|
32 |
+
number = -1
|
33 |
+
result.append(i)
|
34 |
+
if number != -1:
|
35 |
+
result.append(number)
|
36 |
+
result = tuple(result)
|
37 |
+
return result
|
38 |
+
|
39 |
+
|
40 |
+
def search_for_images(folder):
|
41 |
+
file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")]
|
42 |
+
file_list = [(split_file_name(file_name), file_name) for file_name in file_list]
|
43 |
+
file_list = [i[1] for i in sorted(file_list)]
|
44 |
+
file_list = [os.path.join(folder, i) for i in file_list]
|
45 |
+
return file_list
|
46 |
+
|
47 |
+
|
48 |
+
class LowMemoryImageFolder:
|
49 |
+
def __init__(self, folder, file_list=None):
|
50 |
+
if file_list is None:
|
51 |
+
self.file_list = search_for_images(folder)
|
52 |
+
else:
|
53 |
+
self.file_list = [os.path.join(folder, file_name) for file_name in file_list]
|
54 |
+
|
55 |
+
def __len__(self):
|
56 |
+
return len(self.file_list)
|
57 |
+
|
58 |
+
def __getitem__(self, item):
|
59 |
+
return Image.open(self.file_list[item]).convert("RGB")
|
60 |
+
|
61 |
+
def __del__(self):
|
62 |
+
pass
|
63 |
+
|
64 |
+
|
65 |
+
def crop_and_resize(image, height, width):
|
66 |
+
image = np.array(image)
|
67 |
+
image_height, image_width, _ = image.shape
|
68 |
+
if image_height / image_width < height / width:
|
69 |
+
croped_width = int(image_height / height * width)
|
70 |
+
left = (image_width - croped_width) // 2
|
71 |
+
image = image[:, left: left+croped_width]
|
72 |
+
image = Image.fromarray(image).resize((width, height))
|
73 |
+
else:
|
74 |
+
croped_height = int(image_width / width * height)
|
75 |
+
left = (image_height - croped_height) // 2
|
76 |
+
image = image[left: left+croped_height, :]
|
77 |
+
image = Image.fromarray(image).resize((width, height))
|
78 |
+
return image
|
79 |
+
|
80 |
+
|
81 |
+
class VideoData:
|
82 |
+
def __init__(self, video_file=None, image_folder=None, height=None, width=None, **kwargs):
|
83 |
+
if video_file is not None:
|
84 |
+
self.data_type = "video"
|
85 |
+
self.data = LowMemoryVideo(video_file, **kwargs)
|
86 |
+
elif image_folder is not None:
|
87 |
+
self.data_type = "images"
|
88 |
+
self.data = LowMemoryImageFolder(image_folder, **kwargs)
|
89 |
+
else:
|
90 |
+
raise ValueError("Cannot open video or image folder")
|
91 |
+
self.length = None
|
92 |
+
self.set_shape(height, width)
|
93 |
+
|
94 |
+
def raw_data(self):
|
95 |
+
frames = []
|
96 |
+
for i in range(self.__len__()):
|
97 |
+
frames.append(self.__getitem__(i))
|
98 |
+
return frames
|
99 |
+
|
100 |
+
def set_length(self, length):
|
101 |
+
self.length = length
|
102 |
+
|
103 |
+
def set_shape(self, height, width):
|
104 |
+
self.height = height
|
105 |
+
self.width = width
|
106 |
+
|
107 |
+
def __len__(self):
|
108 |
+
if self.length is None:
|
109 |
+
return len(self.data)
|
110 |
+
else:
|
111 |
+
return self.length
|
112 |
+
|
113 |
+
def shape(self):
|
114 |
+
if self.height is not None and self.width is not None:
|
115 |
+
return self.height, self.width
|
116 |
+
else:
|
117 |
+
height, width, _ = self.__getitem__(0).shape
|
118 |
+
return height, width
|
119 |
+
|
120 |
+
def __getitem__(self, item):
|
121 |
+
frame = self.data.__getitem__(item)
|
122 |
+
width, height = frame.size
|
123 |
+
if self.height is not None and self.width is not None:
|
124 |
+
if self.height != height or self.width != width:
|
125 |
+
frame = crop_and_resize(frame, self.height, self.width)
|
126 |
+
return frame
|
127 |
+
|
128 |
+
def __del__(self):
|
129 |
+
pass
|
130 |
+
|
131 |
+
def save_images(self, folder):
|
132 |
+
os.makedirs(folder, exist_ok=True)
|
133 |
+
for i in tqdm(range(self.__len__()), desc="Saving images"):
|
134 |
+
frame = self.__getitem__(i)
|
135 |
+
frame.save(os.path.join(folder, f"{i}.png"))
|
136 |
+
|
137 |
+
def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None):
|
138 |
+
writer = imageio.get_writer(save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params)
|
139 |
+
for frame in tqdm(frames, desc="Saving video"):
|
140 |
+
frame = np.array(frame)
|
141 |
+
writer.append_data(frame)
|
142 |
+
writer.close()
|
143 |
+
# def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None):
|
144 |
+
# writer = imageio.get_writer(save_path, fps=fps, quality=quality, ffmpeg_params=["-crf", "0", "-preset", "veryslow"])
|
145 |
+
# for frame in tqdm(frames, desc="Saving video"):
|
146 |
+
# frame = np.array(frame)
|
147 |
+
# writer.append_data(frame)
|
148 |
+
# writer.close()
|
149 |
+
|
150 |
+
# def save_video_h264(frames, save_path, fps, ffmpeg_params=None):
|
151 |
+
# import imageio.v3 as iio
|
152 |
+
# from tqdm import tqdm
|
153 |
+
# import numpy as np
|
154 |
+
|
155 |
+
# if ffmpeg_params is None:
|
156 |
+
# ffmpeg_params = ["-crf", "0", "-preset", "ultrafast"] # 无损 H.264
|
157 |
+
|
158 |
+
# writer = iio.get_writer(save_path, fps=fps, codec="libx264", ffmpeg_params=ffmpeg_params)
|
159 |
+
# for frame in tqdm(frames, desc="Saving video"):
|
160 |
+
# writer.append_data(np.array(frame))
|
161 |
+
# writer.close()
|
162 |
+
|
163 |
+
|
164 |
+
|
165 |
+
def save_frames(frames, save_path):
|
166 |
+
os.makedirs(save_path, exist_ok=True)
|
167 |
+
for i, frame in enumerate(tqdm(frames, desc="Saving images")):
|
168 |
+
frame.save(os.path.join(save_path, f"{i}.png"))
|
169 |
+
|
170 |
+
|
171 |
+
if __name__=='__main__':
|
172 |
+
frames = [Image.fromarray(np.random.randint(0, 256, (512, 512, 3), dtype=np.uint8)) for i in range(81)]
|
173 |
+
save_video(frames,"haha.mp4",23,5)
|
diffsynth/pipelines/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .wan_video import WanVideoPipeline
|
diffsynth/pipelines/base.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
from torchvision.transforms import GaussianBlur
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
class BasePipeline(torch.nn.Module):
|
9 |
+
|
10 |
+
def __init__(self, device="cuda", torch_dtype=torch.float16, height_division_factor=64, width_division_factor=64):
|
11 |
+
super().__init__()
|
12 |
+
self.device = device
|
13 |
+
self.torch_dtype = torch_dtype
|
14 |
+
self.height_division_factor = height_division_factor
|
15 |
+
self.width_division_factor = width_division_factor
|
16 |
+
self.cpu_offload = False
|
17 |
+
self.model_names = []
|
18 |
+
|
19 |
+
|
20 |
+
def check_resize_height_width(self, height, width):
|
21 |
+
if height % self.height_division_factor != 0:
|
22 |
+
height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor
|
23 |
+
print(f"The height cannot be evenly divided by {self.height_division_factor}. We round it up to {height}.")
|
24 |
+
if width % self.width_division_factor != 0:
|
25 |
+
width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor
|
26 |
+
print(f"The width cannot be evenly divided by {self.width_division_factor}. We round it up to {width}.")
|
27 |
+
return height, width
|
28 |
+
|
29 |
+
|
30 |
+
def preprocess_image(self, image):
|
31 |
+
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
|
32 |
+
return image
|
33 |
+
|
34 |
+
|
35 |
+
def preprocess_images(self, images):
|
36 |
+
return [self.preprocess_image(image) for image in images]
|
37 |
+
|
38 |
+
|
39 |
+
def vae_output_to_image(self, vae_output):
|
40 |
+
image = vae_output[0].cpu().float().permute(1, 2, 0).numpy()
|
41 |
+
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
|
42 |
+
return image
|
43 |
+
|
44 |
+
|
45 |
+
def vae_output_to_video(self, vae_output):
|
46 |
+
video = vae_output.cpu().permute(1, 2, 0).numpy()
|
47 |
+
video = [Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) for image in video]
|
48 |
+
return video
|
49 |
+
|
50 |
+
|
51 |
+
def merge_latents(self, value, latents, masks, scales, blur_kernel_size=33, blur_sigma=10.0):
|
52 |
+
if len(latents) > 0:
|
53 |
+
blur = GaussianBlur(kernel_size=blur_kernel_size, sigma=blur_sigma)
|
54 |
+
height, width = value.shape[-2:]
|
55 |
+
weight = torch.ones_like(value)
|
56 |
+
for latent, mask, scale in zip(latents, masks, scales):
|
57 |
+
mask = self.preprocess_image(mask.resize((width, height))).mean(dim=1, keepdim=True) > 0
|
58 |
+
mask = mask.repeat(1, latent.shape[1], 1, 1).to(dtype=latent.dtype, device=latent.device)
|
59 |
+
mask = blur(mask)
|
60 |
+
value += latent * mask * scale
|
61 |
+
weight += mask * scale
|
62 |
+
value /= weight
|
63 |
+
return value
|
64 |
+
|
65 |
+
|
66 |
+
def control_noise_via_local_prompts(self, prompt_emb_global, prompt_emb_locals, masks, mask_scales, inference_callback, special_kwargs=None, special_local_kwargs_list=None):
|
67 |
+
if special_kwargs is None:
|
68 |
+
noise_pred_global = inference_callback(prompt_emb_global)
|
69 |
+
else:
|
70 |
+
noise_pred_global = inference_callback(prompt_emb_global, special_kwargs)
|
71 |
+
if special_local_kwargs_list is None:
|
72 |
+
noise_pred_locals = [inference_callback(prompt_emb_local) for prompt_emb_local in prompt_emb_locals]
|
73 |
+
else:
|
74 |
+
noise_pred_locals = [inference_callback(prompt_emb_local, special_kwargs) for prompt_emb_local, special_kwargs in zip(prompt_emb_locals, special_local_kwargs_list)]
|
75 |
+
noise_pred = self.merge_latents(noise_pred_global, noise_pred_locals, masks, mask_scales)
|
76 |
+
return noise_pred
|
77 |
+
|
78 |
+
|
79 |
+
def extend_prompt(self, prompt, local_prompts, masks, mask_scales):
|
80 |
+
local_prompts = local_prompts or []
|
81 |
+
masks = masks or []
|
82 |
+
mask_scales = mask_scales or []
|
83 |
+
extended_prompt_dict = self.prompter.extend_prompt(prompt)
|
84 |
+
prompt = extended_prompt_dict.get("prompt", prompt)
|
85 |
+
local_prompts += extended_prompt_dict.get("prompts", [])
|
86 |
+
masks += extended_prompt_dict.get("masks", [])
|
87 |
+
mask_scales += [100.0] * len(extended_prompt_dict.get("masks", []))
|
88 |
+
return prompt, local_prompts, masks, mask_scales
|
89 |
+
|
90 |
+
|
91 |
+
def enable_cpu_offload(self):
|
92 |
+
self.cpu_offload = True
|
93 |
+
|
94 |
+
|
95 |
+
def load_models_to_device(self, loadmodel_names=[]):
|
96 |
+
# only load models to device if cpu_offload is enabled
|
97 |
+
if not self.cpu_offload:
|
98 |
+
return
|
99 |
+
# offload the unneeded models to cpu
|
100 |
+
for model_name in self.model_names:
|
101 |
+
if model_name not in loadmodel_names:
|
102 |
+
model = getattr(self, model_name)
|
103 |
+
if model is not None:
|
104 |
+
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
|
105 |
+
for module in model.modules():
|
106 |
+
if hasattr(module, "offload"):
|
107 |
+
module.offload()
|
108 |
+
else:
|
109 |
+
model.cpu()
|
110 |
+
# load the needed models to device
|
111 |
+
for model_name in loadmodel_names:
|
112 |
+
model = getattr(self, model_name)
|
113 |
+
if model is not None:
|
114 |
+
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
|
115 |
+
for module in model.modules():
|
116 |
+
if hasattr(module, "onload"):
|
117 |
+
module.onload()
|
118 |
+
else:
|
119 |
+
model.to(self.device)
|
120 |
+
# fresh the cuda cache
|
121 |
+
torch.cuda.empty_cache()
|
122 |
+
|
123 |
+
|
124 |
+
def generate_noise(self, shape, seed=None, device="cpu", dtype=torch.float16):
|
125 |
+
generator = None if seed is None else torch.Generator(device).manual_seed(seed)
|
126 |
+
noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
|
127 |
+
return noise
|
diffsynth/pipelines/wan_video.py
ADDED
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..models import ModelManager
|
2 |
+
from ..models.wan_video_dit import WanModel
|
3 |
+
from ..models.wan_video_text_encoder import WanTextEncoder
|
4 |
+
from ..models.wan_video_vae import WanVideoVAE
|
5 |
+
from ..models.wan_video_image_encoder import WanImageEncoder
|
6 |
+
from ..schedulers.flow_match import FlowMatchScheduler
|
7 |
+
from .base import BasePipeline
|
8 |
+
from ..prompters import WanPrompter
|
9 |
+
import torch, os
|
10 |
+
from einops import rearrange
|
11 |
+
import numpy as np
|
12 |
+
from PIL import Image
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
|
16 |
+
from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm
|
17 |
+
from ..models.wan_video_dit import WanLayerNorm, WanRMSNorm
|
18 |
+
from ..models.wan_video_vae import RMS_norm, CausalConv3d, Upsample
|
19 |
+
|
20 |
+
|
21 |
+
class WanVideoPipeline(BasePipeline):
|
22 |
+
|
23 |
+
def __init__(self, device="cuda", torch_dtype=torch.float16, tokenizer_path=None):
|
24 |
+
super().__init__(device=device, torch_dtype=torch_dtype)
|
25 |
+
self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True)
|
26 |
+
self.prompter = WanPrompter(tokenizer_path=tokenizer_path)
|
27 |
+
self.text_encoder: WanTextEncoder = None
|
28 |
+
self.image_encoder: WanImageEncoder = None
|
29 |
+
self.dit: WanModel = None
|
30 |
+
self.vae: WanVideoVAE = None
|
31 |
+
self.model_names = ['text_encoder', 'dit', 'vae']
|
32 |
+
self.height_division_factor = 16
|
33 |
+
self.width_division_factor = 16
|
34 |
+
|
35 |
+
|
36 |
+
def enable_vram_management(self, num_persistent_param_in_dit=None):
|
37 |
+
dtype = next(iter(self.text_encoder.parameters())).dtype
|
38 |
+
enable_vram_management(
|
39 |
+
self.text_encoder,
|
40 |
+
module_map = {
|
41 |
+
torch.nn.Linear: AutoWrappedLinear,
|
42 |
+
torch.nn.Embedding: AutoWrappedModule,
|
43 |
+
T5RelativeEmbedding: AutoWrappedModule,
|
44 |
+
T5LayerNorm: AutoWrappedModule,
|
45 |
+
},
|
46 |
+
module_config = dict(
|
47 |
+
offload_dtype=dtype,
|
48 |
+
offload_device="cpu",
|
49 |
+
onload_dtype=dtype,
|
50 |
+
onload_device="cpu",
|
51 |
+
computation_dtype=self.torch_dtype,
|
52 |
+
computation_device=self.device,
|
53 |
+
),
|
54 |
+
)
|
55 |
+
dtype = next(iter(self.dit.parameters())).dtype
|
56 |
+
enable_vram_management(
|
57 |
+
self.dit,
|
58 |
+
module_map = {
|
59 |
+
torch.nn.Linear: AutoWrappedLinear,
|
60 |
+
torch.nn.Conv3d: AutoWrappedModule,
|
61 |
+
torch.nn.LayerNorm: AutoWrappedModule,
|
62 |
+
WanLayerNorm: AutoWrappedModule,
|
63 |
+
WanRMSNorm: AutoWrappedModule,
|
64 |
+
},
|
65 |
+
module_config = dict(
|
66 |
+
offload_dtype=dtype,
|
67 |
+
offload_device="cpu",
|
68 |
+
onload_dtype=dtype,
|
69 |
+
onload_device=self.device,
|
70 |
+
computation_dtype=self.torch_dtype,
|
71 |
+
computation_device=self.device,
|
72 |
+
),
|
73 |
+
max_num_param=num_persistent_param_in_dit,
|
74 |
+
overflow_module_config = dict(
|
75 |
+
offload_dtype=dtype,
|
76 |
+
offload_device="cpu",
|
77 |
+
onload_dtype=dtype,
|
78 |
+
onload_device="cpu",
|
79 |
+
computation_dtype=self.torch_dtype,
|
80 |
+
computation_device=self.device,
|
81 |
+
),
|
82 |
+
)
|
83 |
+
dtype = next(iter(self.vae.parameters())).dtype
|
84 |
+
enable_vram_management(
|
85 |
+
self.vae,
|
86 |
+
module_map = {
|
87 |
+
torch.nn.Linear: AutoWrappedLinear,
|
88 |
+
torch.nn.Conv2d: AutoWrappedModule,
|
89 |
+
RMS_norm: AutoWrappedModule,
|
90 |
+
CausalConv3d: AutoWrappedModule,
|
91 |
+
Upsample: AutoWrappedModule,
|
92 |
+
torch.nn.SiLU: AutoWrappedModule,
|
93 |
+
torch.nn.Dropout: AutoWrappedModule,
|
94 |
+
},
|
95 |
+
module_config = dict(
|
96 |
+
offload_dtype=dtype,
|
97 |
+
offload_device="cpu",
|
98 |
+
onload_dtype=dtype,
|
99 |
+
onload_device=self.device,
|
100 |
+
computation_dtype=self.torch_dtype,
|
101 |
+
computation_device=self.device,
|
102 |
+
),
|
103 |
+
)
|
104 |
+
if self.image_encoder is not None:
|
105 |
+
dtype = next(iter(self.image_encoder.parameters())).dtype
|
106 |
+
enable_vram_management(
|
107 |
+
self.image_encoder,
|
108 |
+
module_map = {
|
109 |
+
torch.nn.Linear: AutoWrappedLinear,
|
110 |
+
torch.nn.Conv2d: AutoWrappedModule,
|
111 |
+
torch.nn.LayerNorm: AutoWrappedModule,
|
112 |
+
},
|
113 |
+
module_config = dict(
|
114 |
+
offload_dtype=dtype,
|
115 |
+
offload_device="cpu",
|
116 |
+
onload_dtype=dtype,
|
117 |
+
onload_device="cpu",
|
118 |
+
computation_dtype=self.torch_dtype,
|
119 |
+
computation_device=self.device,
|
120 |
+
),
|
121 |
+
)
|
122 |
+
self.enable_cpu_offload()
|
123 |
+
|
124 |
+
|
125 |
+
def fetch_models(self, model_manager: ModelManager):
|
126 |
+
text_encoder_model_and_path = model_manager.fetch_model("wan_video_text_encoder", require_model_path=True)
|
127 |
+
if text_encoder_model_and_path is not None:
|
128 |
+
self.text_encoder, tokenizer_path = text_encoder_model_and_path
|
129 |
+
self.prompter.fetch_models(self.text_encoder)
|
130 |
+
self.prompter.fetch_tokenizer(os.path.join(os.path.dirname(tokenizer_path), "google/umt5-xxl"))
|
131 |
+
self.dit = model_manager.fetch_model("wan_video_dit")
|
132 |
+
self.vae = model_manager.fetch_model("wan_video_vae")
|
133 |
+
self.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
|
134 |
+
|
135 |
+
|
136 |
+
@staticmethod
|
137 |
+
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None):
|
138 |
+
if device is None: device = model_manager.device
|
139 |
+
if torch_dtype is None: torch_dtype = model_manager.torch_dtype
|
140 |
+
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
|
141 |
+
pipe.fetch_models(model_manager)
|
142 |
+
return pipe
|
143 |
+
|
144 |
+
|
145 |
+
def denoising_model(self):
|
146 |
+
return self.dit
|
147 |
+
|
148 |
+
|
149 |
+
def encode_prompt(self, prompt, positive=True):
|
150 |
+
prompt_emb = self.prompter.encode_prompt(prompt, positive=positive)
|
151 |
+
return {"context": prompt_emb}
|
152 |
+
|
153 |
+
|
154 |
+
def encode_image(self, image, num_frames, height, width):
|
155 |
+
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
|
156 |
+
image = self.preprocess_image(image.resize((width, height))).to(self.device)
|
157 |
+
clip_context = self.image_encoder.encode_image([image])
|
158 |
+
msk = torch.ones(1, num_frames, height//8, width//8, device=self.device)
|
159 |
+
msk[:, 1:] = 0
|
160 |
+
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
|
161 |
+
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
|
162 |
+
msk = msk.transpose(1, 2)[0]
|
163 |
+
y = self.vae.encode([torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)], device=self.device)[0]
|
164 |
+
y = torch.concat([msk, y])
|
165 |
+
return {"clip_fea": clip_context, "y": [y]}
|
166 |
+
|
167 |
+
|
168 |
+
def tensor2video(self, frames):
|
169 |
+
frames = rearrange(frames, "C T H W -> T H W C")
|
170 |
+
frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
|
171 |
+
frames = [Image.fromarray(frame) for frame in frames]
|
172 |
+
return frames
|
173 |
+
|
174 |
+
|
175 |
+
def prepare_extra_input(self, latents=None):
|
176 |
+
return {"seq_len": latents.shape[2] * latents.shape[3] * latents.shape[4] // 4}
|
177 |
+
|
178 |
+
|
179 |
+
def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
180 |
+
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
|
181 |
+
latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
182 |
+
return latents
|
183 |
+
|
184 |
+
|
185 |
+
def decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
186 |
+
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
|
187 |
+
frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
188 |
+
return frames
|
189 |
+
|
190 |
+
def set_ip(self, local_path):
|
191 |
+
pass
|
192 |
+
@torch.no_grad()
|
193 |
+
def __call__(
|
194 |
+
self,
|
195 |
+
prompt,
|
196 |
+
negative_prompt="",
|
197 |
+
input_image=None,
|
198 |
+
input_video=None,
|
199 |
+
denoising_strength=1.0,
|
200 |
+
seed=None,
|
201 |
+
rand_device="cpu",
|
202 |
+
height=480,
|
203 |
+
width=832,
|
204 |
+
num_frames=81,
|
205 |
+
cfg_scale=5.0,
|
206 |
+
audio_cfg_scale=None,
|
207 |
+
num_inference_steps=50,
|
208 |
+
sigma_shift=5.0,
|
209 |
+
tiled=True,
|
210 |
+
tile_size=(30, 52),
|
211 |
+
tile_stride=(15, 26),
|
212 |
+
progress_bar_cmd=tqdm,
|
213 |
+
progress_bar_st=None,
|
214 |
+
**kwargs,
|
215 |
+
):
|
216 |
+
# Parameter check
|
217 |
+
height, width = self.check_resize_height_width(height, width)
|
218 |
+
if num_frames % 4 != 1:
|
219 |
+
num_frames = (num_frames + 2) // 4 * 4 + 1
|
220 |
+
print(f"Only `num_frames % 4 != 1` is acceptable. We round it up to {num_frames}.")
|
221 |
+
|
222 |
+
# Tiler parameters
|
223 |
+
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
224 |
+
|
225 |
+
# Scheduler
|
226 |
+
self.scheduler.set_timesteps(num_inference_steps, denoising_strength, shift=sigma_shift)
|
227 |
+
|
228 |
+
# Initialize noise
|
229 |
+
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=torch.float32).to(self.device)
|
230 |
+
if input_video is not None:
|
231 |
+
self.load_models_to_device(['vae'])
|
232 |
+
input_video = self.preprocess_images(input_video)
|
233 |
+
input_video = torch.stack(input_video, dim=2)
|
234 |
+
latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=noise.dtype, device=noise.device)
|
235 |
+
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
236 |
+
else:
|
237 |
+
latents = noise
|
238 |
+
|
239 |
+
# Encode prompts
|
240 |
+
self.load_models_to_device(["text_encoder"])
|
241 |
+
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
|
242 |
+
if cfg_scale != 1.0:
|
243 |
+
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
|
244 |
+
|
245 |
+
# Encode image
|
246 |
+
if input_image is not None and self.image_encoder is not None:
|
247 |
+
self.load_models_to_device(["image_encoder", "vae"])
|
248 |
+
image_emb = self.encode_image(input_image, num_frames, height, width)
|
249 |
+
else:
|
250 |
+
image_emb = {}
|
251 |
+
|
252 |
+
# Extra input
|
253 |
+
extra_input = self.prepare_extra_input(latents)
|
254 |
+
|
255 |
+
# Denoise
|
256 |
+
self.load_models_to_device(["dit"])
|
257 |
+
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
|
258 |
+
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
259 |
+
timestep = timestep.unsqueeze(0).to(dtype=torch.float32, device=self.device)
|
260 |
+
|
261 |
+
# Inference
|
262 |
+
noise_pred_posi = self.dit(latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input, **kwargs) # (zt,audio,prompt)
|
263 |
+
if audio_cfg_scale is not None:
|
264 |
+
audio_scale = kwargs['audio_scale']
|
265 |
+
kwargs['audio_scale'] = 0.0
|
266 |
+
noise_pred_noaudio = self.dit(latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input, **kwargs) #(zt,0,prompt)
|
267 |
+
# kwargs['ip_scale'] = ip_scale
|
268 |
+
if cfg_scale != 1.0: #prompt cfg
|
269 |
+
noise_pred_no_cond = self.dit(latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input, **kwargs) # (zt,0,0)
|
270 |
+
noise_pred = noise_pred_no_cond + cfg_scale * (noise_pred_noaudio - noise_pred_no_cond) + audio_cfg_scale * (noise_pred_posi - noise_pred_noaudio)
|
271 |
+
else:
|
272 |
+
noise_pred = noise_pred_noaudio + audio_cfg_scale * (noise_pred_posi - noise_pred_noaudio)
|
273 |
+
kwargs['audio_scale'] = audio_scale
|
274 |
+
else:
|
275 |
+
if cfg_scale != 1.0:
|
276 |
+
noise_pred_nega = self.dit(latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input, **kwargs) #(zt,audio,0)
|
277 |
+
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
278 |
+
else:
|
279 |
+
noise_pred = noise_pred_posi
|
280 |
+
|
281 |
+
# Scheduler
|
282 |
+
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
283 |
+
|
284 |
+
# Decode
|
285 |
+
self.load_models_to_device(['vae'])
|
286 |
+
frames = self.decode_video(latents, **tiler_kwargs)
|
287 |
+
self.load_models_to_device([])
|
288 |
+
frames = self.tensor2video(frames[0])
|
289 |
+
|
290 |
+
return frames
|
diffsynth/prompters/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .wan_prompter import WanPrompter
|
diffsynth/prompters/base_prompter.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..models.model_manager import ModelManager
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
|
6 |
+
def tokenize_long_prompt(tokenizer, prompt, max_length=None):
|
7 |
+
# Get model_max_length from self.tokenizer
|
8 |
+
length = tokenizer.model_max_length if max_length is None else max_length
|
9 |
+
|
10 |
+
# To avoid the warning. set self.tokenizer.model_max_length to +oo.
|
11 |
+
tokenizer.model_max_length = 99999999
|
12 |
+
|
13 |
+
# Tokenize it!
|
14 |
+
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
15 |
+
|
16 |
+
# Determine the real length.
|
17 |
+
max_length = (input_ids.shape[1] + length - 1) // length * length
|
18 |
+
|
19 |
+
# Restore tokenizer.model_max_length
|
20 |
+
tokenizer.model_max_length = length
|
21 |
+
|
22 |
+
# Tokenize it again with fixed length.
|
23 |
+
input_ids = tokenizer(
|
24 |
+
prompt,
|
25 |
+
return_tensors="pt",
|
26 |
+
padding="max_length",
|
27 |
+
max_length=max_length,
|
28 |
+
truncation=True
|
29 |
+
).input_ids
|
30 |
+
|
31 |
+
# Reshape input_ids to fit the text encoder.
|
32 |
+
num_sentence = input_ids.shape[1] // length
|
33 |
+
input_ids = input_ids.reshape((num_sentence, length))
|
34 |
+
|
35 |
+
return input_ids
|
36 |
+
|
37 |
+
|
38 |
+
|
39 |
+
class BasePrompter:
|
40 |
+
def __init__(self):
|
41 |
+
self.refiners = []
|
42 |
+
self.extenders = []
|
43 |
+
|
44 |
+
|
45 |
+
def load_prompt_refiners(self, model_manager: ModelManager, refiner_classes=[]):
|
46 |
+
for refiner_class in refiner_classes:
|
47 |
+
refiner = refiner_class.from_model_manager(model_manager)
|
48 |
+
self.refiners.append(refiner)
|
49 |
+
|
50 |
+
def load_prompt_extenders(self,model_manager:ModelManager,extender_classes=[]):
|
51 |
+
for extender_class in extender_classes:
|
52 |
+
extender = extender_class.from_model_manager(model_manager)
|
53 |
+
self.extenders.append(extender)
|
54 |
+
|
55 |
+
|
56 |
+
@torch.no_grad()
|
57 |
+
def process_prompt(self, prompt, positive=True):
|
58 |
+
if isinstance(prompt, list):
|
59 |
+
prompt = [self.process_prompt(prompt_, positive=positive) for prompt_ in prompt]
|
60 |
+
else:
|
61 |
+
for refiner in self.refiners:
|
62 |
+
prompt = refiner(prompt, positive=positive)
|
63 |
+
return prompt
|
64 |
+
|
65 |
+
@torch.no_grad()
|
66 |
+
def extend_prompt(self, prompt:str, positive=True):
|
67 |
+
extended_prompt = dict(prompt=prompt)
|
68 |
+
for extender in self.extenders:
|
69 |
+
extended_prompt = extender(extended_prompt)
|
70 |
+
return extended_prompt
|
diffsynth/prompters/wan_prompter.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base_prompter import BasePrompter
|
2 |
+
from ..models.wan_video_text_encoder import WanTextEncoder
|
3 |
+
from transformers import AutoTokenizer
|
4 |
+
import os, torch
|
5 |
+
import ftfy
|
6 |
+
import html
|
7 |
+
import string
|
8 |
+
import regex as re
|
9 |
+
|
10 |
+
|
11 |
+
def basic_clean(text):
|
12 |
+
text = ftfy.fix_text(text)
|
13 |
+
text = html.unescape(html.unescape(text))
|
14 |
+
return text.strip()
|
15 |
+
|
16 |
+
|
17 |
+
def whitespace_clean(text):
|
18 |
+
text = re.sub(r'\s+', ' ', text)
|
19 |
+
text = text.strip()
|
20 |
+
return text
|
21 |
+
|
22 |
+
|
23 |
+
def canonicalize(text, keep_punctuation_exact_string=None):
|
24 |
+
text = text.replace('_', ' ')
|
25 |
+
if keep_punctuation_exact_string:
|
26 |
+
text = keep_punctuation_exact_string.join(
|
27 |
+
part.translate(str.maketrans('', '', string.punctuation))
|
28 |
+
for part in text.split(keep_punctuation_exact_string))
|
29 |
+
else:
|
30 |
+
text = text.translate(str.maketrans('', '', string.punctuation))
|
31 |
+
text = text.lower()
|
32 |
+
text = re.sub(r'\s+', ' ', text)
|
33 |
+
return text.strip()
|
34 |
+
|
35 |
+
|
36 |
+
class HuggingfaceTokenizer:
|
37 |
+
|
38 |
+
def __init__(self, name, seq_len=None, clean=None, **kwargs):
|
39 |
+
assert clean in (None, 'whitespace', 'lower', 'canonicalize')
|
40 |
+
self.name = name
|
41 |
+
self.seq_len = seq_len
|
42 |
+
self.clean = clean
|
43 |
+
|
44 |
+
# init tokenizer
|
45 |
+
self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
|
46 |
+
self.vocab_size = self.tokenizer.vocab_size
|
47 |
+
|
48 |
+
def __call__(self, sequence, **kwargs):
|
49 |
+
return_mask = kwargs.pop('return_mask', False)
|
50 |
+
|
51 |
+
# arguments
|
52 |
+
_kwargs = {'return_tensors': 'pt'}
|
53 |
+
if self.seq_len is not None:
|
54 |
+
_kwargs.update({
|
55 |
+
'padding': 'max_length',
|
56 |
+
'truncation': True,
|
57 |
+
'max_length': self.seq_len
|
58 |
+
})
|
59 |
+
_kwargs.update(**kwargs)
|
60 |
+
|
61 |
+
# tokenization
|
62 |
+
if isinstance(sequence, str):
|
63 |
+
sequence = [sequence]
|
64 |
+
if self.clean:
|
65 |
+
sequence = [self._clean(u) for u in sequence]
|
66 |
+
ids = self.tokenizer(sequence, **_kwargs)
|
67 |
+
|
68 |
+
# output
|
69 |
+
if return_mask:
|
70 |
+
return ids.input_ids, ids.attention_mask
|
71 |
+
else:
|
72 |
+
return ids.input_ids
|
73 |
+
|
74 |
+
def _clean(self, text):
|
75 |
+
if self.clean == 'whitespace':
|
76 |
+
text = whitespace_clean(basic_clean(text))
|
77 |
+
elif self.clean == 'lower':
|
78 |
+
text = whitespace_clean(basic_clean(text)).lower()
|
79 |
+
elif self.clean == 'canonicalize':
|
80 |
+
text = canonicalize(basic_clean(text))
|
81 |
+
return text
|
82 |
+
|
83 |
+
|
84 |
+
class WanPrompter(BasePrompter):
|
85 |
+
|
86 |
+
def __init__(self, tokenizer_path=None, text_len=512):
|
87 |
+
super().__init__()
|
88 |
+
self.text_len = text_len
|
89 |
+
self.text_encoder = None
|
90 |
+
self.fetch_tokenizer(tokenizer_path)
|
91 |
+
|
92 |
+
def fetch_tokenizer(self, tokenizer_path=None):
|
93 |
+
if tokenizer_path is not None:
|
94 |
+
self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=self.text_len, clean='whitespace')
|
95 |
+
|
96 |
+
def fetch_models(self, text_encoder: WanTextEncoder = None):
|
97 |
+
self.text_encoder = text_encoder
|
98 |
+
|
99 |
+
def encode_prompt(self, prompt, positive=True, device="cuda"):
|
100 |
+
prompt = self.process_prompt(prompt, positive=positive)
|
101 |
+
|
102 |
+
ids, mask = self.tokenizer(prompt, return_mask=True, add_special_tokens=True)
|
103 |
+
ids = ids.to(device)
|
104 |
+
mask = mask.to(device)
|
105 |
+
seq_lens = mask.gt(0).sum(dim=1).long()
|
106 |
+
prompt_emb = self.text_encoder(ids, mask)
|
107 |
+
prompt_emb = [u[:v] for u, v in zip(prompt_emb, seq_lens)]
|
108 |
+
return prompt_emb
|
diffsynth/schedulers/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .ddim import EnhancedDDIMScheduler
|
2 |
+
from .continuous_ode import ContinuousODEScheduler
|
3 |
+
from .flow_match import FlowMatchScheduler
|
diffsynth/schedulers/continuous_ode.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
class ContinuousODEScheduler():
|
5 |
+
|
6 |
+
def __init__(self, num_inference_steps=100, sigma_max=700.0, sigma_min=0.002, rho=7.0):
|
7 |
+
self.sigma_max = sigma_max
|
8 |
+
self.sigma_min = sigma_min
|
9 |
+
self.rho = rho
|
10 |
+
self.set_timesteps(num_inference_steps)
|
11 |
+
|
12 |
+
|
13 |
+
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, **kwargs):
|
14 |
+
ramp = torch.linspace(1-denoising_strength, 1, num_inference_steps)
|
15 |
+
min_inv_rho = torch.pow(torch.tensor((self.sigma_min,)), (1 / self.rho))
|
16 |
+
max_inv_rho = torch.pow(torch.tensor((self.sigma_max,)), (1 / self.rho))
|
17 |
+
self.sigmas = torch.pow(max_inv_rho + ramp * (min_inv_rho - max_inv_rho), self.rho)
|
18 |
+
self.timesteps = torch.log(self.sigmas) * 0.25
|
19 |
+
|
20 |
+
|
21 |
+
def step(self, model_output, timestep, sample, to_final=False):
|
22 |
+
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
23 |
+
sigma = self.sigmas[timestep_id]
|
24 |
+
sample *= (sigma*sigma + 1).sqrt()
|
25 |
+
estimated_sample = -sigma / (sigma*sigma + 1).sqrt() * model_output + 1 / (sigma*sigma + 1) * sample
|
26 |
+
if to_final or timestep_id + 1 >= len(self.timesteps):
|
27 |
+
prev_sample = estimated_sample
|
28 |
+
else:
|
29 |
+
sigma_ = self.sigmas[timestep_id + 1]
|
30 |
+
derivative = 1 / sigma * (sample - estimated_sample)
|
31 |
+
prev_sample = sample + derivative * (sigma_ - sigma)
|
32 |
+
prev_sample /= (sigma_*sigma_ + 1).sqrt()
|
33 |
+
return prev_sample
|
34 |
+
|
35 |
+
|
36 |
+
def return_to_timestep(self, timestep, sample, sample_stablized):
|
37 |
+
# This scheduler doesn't support this function.
|
38 |
+
pass
|
39 |
+
|
40 |
+
|
41 |
+
def add_noise(self, original_samples, noise, timestep):
|
42 |
+
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
43 |
+
sigma = self.sigmas[timestep_id]
|
44 |
+
sample = (original_samples + noise * sigma) / (sigma*sigma + 1).sqrt()
|
45 |
+
return sample
|
46 |
+
|
47 |
+
|
48 |
+
def training_target(self, sample, noise, timestep):
|
49 |
+
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
50 |
+
sigma = self.sigmas[timestep_id]
|
51 |
+
target = (-(sigma*sigma + 1).sqrt() / sigma + 1 / (sigma*sigma + 1).sqrt() / sigma) * sample + 1 / (sigma*sigma + 1).sqrt() * noise
|
52 |
+
return target
|
53 |
+
|
54 |
+
|
55 |
+
def training_weight(self, timestep):
|
56 |
+
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
57 |
+
sigma = self.sigmas[timestep_id]
|
58 |
+
weight = (1 + sigma*sigma).sqrt() / sigma
|
59 |
+
return weight
|
diffsynth/schedulers/ddim.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch, math
|
2 |
+
|
3 |
+
|
4 |
+
class EnhancedDDIMScheduler():
|
5 |
+
|
6 |
+
def __init__(self, num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", prediction_type="epsilon", rescale_zero_terminal_snr=False):
|
7 |
+
self.num_train_timesteps = num_train_timesteps
|
8 |
+
if beta_schedule == "scaled_linear":
|
9 |
+
betas = torch.square(torch.linspace(math.sqrt(beta_start), math.sqrt(beta_end), num_train_timesteps, dtype=torch.float32))
|
10 |
+
elif beta_schedule == "linear":
|
11 |
+
betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
12 |
+
else:
|
13 |
+
raise NotImplementedError(f"{beta_schedule} is not implemented")
|
14 |
+
self.alphas_cumprod = torch.cumprod(1.0 - betas, dim=0)
|
15 |
+
if rescale_zero_terminal_snr:
|
16 |
+
self.alphas_cumprod = self.rescale_zero_terminal_snr(self.alphas_cumprod)
|
17 |
+
self.alphas_cumprod = self.alphas_cumprod.tolist()
|
18 |
+
self.set_timesteps(10)
|
19 |
+
self.prediction_type = prediction_type
|
20 |
+
|
21 |
+
|
22 |
+
def rescale_zero_terminal_snr(self, alphas_cumprod):
|
23 |
+
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
24 |
+
|
25 |
+
# Store old values.
|
26 |
+
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
27 |
+
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
28 |
+
|
29 |
+
# Shift so the last timestep is zero.
|
30 |
+
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
31 |
+
|
32 |
+
# Scale so the first timestep is back to the old value.
|
33 |
+
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
34 |
+
|
35 |
+
# Convert alphas_bar_sqrt to betas
|
36 |
+
alphas_bar = alphas_bar_sqrt.square() # Revert sqrt
|
37 |
+
|
38 |
+
return alphas_bar
|
39 |
+
|
40 |
+
|
41 |
+
def set_timesteps(self, num_inference_steps, denoising_strength=1.0, **kwargs):
|
42 |
+
# The timesteps are aligned to 999...0, which is different from other implementations,
|
43 |
+
# but I think this implementation is more reasonable in theory.
|
44 |
+
max_timestep = max(round(self.num_train_timesteps * denoising_strength) - 1, 0)
|
45 |
+
num_inference_steps = min(num_inference_steps, max_timestep + 1)
|
46 |
+
if num_inference_steps == 1:
|
47 |
+
self.timesteps = torch.Tensor([max_timestep])
|
48 |
+
else:
|
49 |
+
step_length = max_timestep / (num_inference_steps - 1)
|
50 |
+
self.timesteps = torch.Tensor([round(max_timestep - i*step_length) for i in range(num_inference_steps)])
|
51 |
+
|
52 |
+
|
53 |
+
def denoise(self, model_output, sample, alpha_prod_t, alpha_prod_t_prev):
|
54 |
+
if self.prediction_type == "epsilon":
|
55 |
+
weight_e = math.sqrt(1 - alpha_prod_t_prev) - math.sqrt(alpha_prod_t_prev * (1 - alpha_prod_t) / alpha_prod_t)
|
56 |
+
weight_x = math.sqrt(alpha_prod_t_prev / alpha_prod_t)
|
57 |
+
prev_sample = sample * weight_x + model_output * weight_e
|
58 |
+
elif self.prediction_type == "v_prediction":
|
59 |
+
weight_e = -math.sqrt(alpha_prod_t_prev * (1 - alpha_prod_t)) + math.sqrt(alpha_prod_t * (1 - alpha_prod_t_prev))
|
60 |
+
weight_x = math.sqrt(alpha_prod_t * alpha_prod_t_prev) + math.sqrt((1 - alpha_prod_t) * (1 - alpha_prod_t_prev))
|
61 |
+
prev_sample = sample * weight_x + model_output * weight_e
|
62 |
+
else:
|
63 |
+
raise NotImplementedError(f"{self.prediction_type} is not implemented")
|
64 |
+
return prev_sample
|
65 |
+
|
66 |
+
|
67 |
+
def step(self, model_output, timestep, sample, to_final=False):
|
68 |
+
alpha_prod_t = self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
|
69 |
+
if isinstance(timestep, torch.Tensor):
|
70 |
+
timestep = timestep.cpu()
|
71 |
+
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
72 |
+
if to_final or timestep_id + 1 >= len(self.timesteps):
|
73 |
+
alpha_prod_t_prev = 1.0
|
74 |
+
else:
|
75 |
+
timestep_prev = int(self.timesteps[timestep_id + 1])
|
76 |
+
alpha_prod_t_prev = self.alphas_cumprod[timestep_prev]
|
77 |
+
|
78 |
+
return self.denoise(model_output, sample, alpha_prod_t, alpha_prod_t_prev)
|
79 |
+
|
80 |
+
|
81 |
+
def return_to_timestep(self, timestep, sample, sample_stablized):
|
82 |
+
alpha_prod_t = self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
|
83 |
+
noise_pred = (sample - math.sqrt(alpha_prod_t) * sample_stablized) / math.sqrt(1 - alpha_prod_t)
|
84 |
+
return noise_pred
|
85 |
+
|
86 |
+
|
87 |
+
def add_noise(self, original_samples, noise, timestep):
|
88 |
+
sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
|
89 |
+
sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
|
90 |
+
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
91 |
+
return noisy_samples
|
92 |
+
|
93 |
+
|
94 |
+
def training_target(self, sample, noise, timestep):
|
95 |
+
if self.prediction_type == "epsilon":
|
96 |
+
return noise
|
97 |
+
else:
|
98 |
+
sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
|
99 |
+
sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
|
100 |
+
target = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
|
101 |
+
return target
|
102 |
+
|
103 |
+
|
104 |
+
def training_weight(self, timestep):
|
105 |
+
return 1.0
|
diffsynth/schedulers/flow_match.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
|
5 |
+
class FlowMatchScheduler():
|
6 |
+
|
7 |
+
def __init__(self, num_inference_steps=100, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003/1.002, inverse_timesteps=False, extra_one_step=False, reverse_sigmas=False):
|
8 |
+
self.num_train_timesteps = num_train_timesteps
|
9 |
+
self.shift = shift
|
10 |
+
self.sigma_max = sigma_max
|
11 |
+
self.sigma_min = sigma_min
|
12 |
+
self.inverse_timesteps = inverse_timesteps
|
13 |
+
self.extra_one_step = extra_one_step
|
14 |
+
self.reverse_sigmas = reverse_sigmas
|
15 |
+
self.set_timesteps(num_inference_steps)
|
16 |
+
|
17 |
+
|
18 |
+
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, shift=None):
|
19 |
+
if shift is not None:
|
20 |
+
self.shift = shift
|
21 |
+
sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
|
22 |
+
if self.extra_one_step:
|
23 |
+
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1]
|
24 |
+
else:
|
25 |
+
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps)
|
26 |
+
if self.inverse_timesteps:
|
27 |
+
self.sigmas = torch.flip(self.sigmas, dims=[0])
|
28 |
+
self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)
|
29 |
+
if self.reverse_sigmas:
|
30 |
+
self.sigmas = 1 - self.sigmas
|
31 |
+
self.timesteps = self.sigmas * self.num_train_timesteps
|
32 |
+
if training:
|
33 |
+
x = self.timesteps
|
34 |
+
y = torch.exp(-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2)
|
35 |
+
y_shifted = y - y.min()
|
36 |
+
bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum())
|
37 |
+
self.linear_timesteps_weights = bsmntw_weighing
|
38 |
+
|
39 |
+
|
40 |
+
def step(self, model_output, timestep, sample, to_final=False):
|
41 |
+
if isinstance(timestep, torch.Tensor):
|
42 |
+
timestep = timestep.cpu()
|
43 |
+
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
44 |
+
sigma = self.sigmas[timestep_id]
|
45 |
+
if to_final or timestep_id + 1 >= len(self.timesteps):
|
46 |
+
sigma_ = 1 if (self.inverse_timesteps or self.reverse_sigmas) else 0
|
47 |
+
else:
|
48 |
+
sigma_ = self.sigmas[timestep_id + 1]
|
49 |
+
prev_sample = sample + model_output * (sigma_ - sigma)
|
50 |
+
return prev_sample
|
51 |
+
|
52 |
+
|
53 |
+
def return_to_timestep(self, timestep, sample, sample_stablized):
|
54 |
+
if isinstance(timestep, torch.Tensor):
|
55 |
+
timestep = timestep.cpu()
|
56 |
+
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
57 |
+
sigma = self.sigmas[timestep_id]
|
58 |
+
model_output = (sample - sample_stablized) / sigma
|
59 |
+
return model_output
|
60 |
+
|
61 |
+
|
62 |
+
def add_noise(self, original_samples, noise, timestep):
|
63 |
+
if isinstance(timestep, torch.Tensor):
|
64 |
+
timestep = timestep.cpu()
|
65 |
+
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
66 |
+
sigma = self.sigmas[timestep_id]
|
67 |
+
sample = (1 - sigma) * original_samples + sigma * noise
|
68 |
+
return sample
|
69 |
+
|
70 |
+
|
71 |
+
def training_target(self, sample, noise, timestep):
|
72 |
+
target = noise - sample
|
73 |
+
return target
|
74 |
+
|
75 |
+
|
76 |
+
def training_weight(self, timestep):
|
77 |
+
timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs())
|
78 |
+
weights = self.linear_timesteps_weights[timestep_id]
|
79 |
+
return weights
|
diffsynth/vram_management/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .layers import *
|
diffsynth/vram_management/layers.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch, copy
|
2 |
+
from ..models.utils import init_weights_on_device
|
3 |
+
|
4 |
+
|
5 |
+
def cast_to(weight, dtype, device):
|
6 |
+
r = torch.empty_like(weight, dtype=dtype, device=device)
|
7 |
+
r.copy_(weight)
|
8 |
+
return r
|
9 |
+
|
10 |
+
|
11 |
+
class AutoWrappedModule(torch.nn.Module):
|
12 |
+
def __init__(self, module: torch.nn.Module, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device):
|
13 |
+
super().__init__()
|
14 |
+
self.module = module.to(dtype=offload_dtype, device=offload_device)
|
15 |
+
self.offload_dtype = offload_dtype
|
16 |
+
self.offload_device = offload_device
|
17 |
+
self.onload_dtype = onload_dtype
|
18 |
+
self.onload_device = onload_device
|
19 |
+
self.computation_dtype = computation_dtype
|
20 |
+
self.computation_device = computation_device
|
21 |
+
self.state = 0
|
22 |
+
|
23 |
+
def offload(self):
|
24 |
+
if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
|
25 |
+
self.module.to(dtype=self.offload_dtype, device=self.offload_device)
|
26 |
+
self.state = 0
|
27 |
+
|
28 |
+
def onload(self):
|
29 |
+
if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
|
30 |
+
self.module.to(dtype=self.onload_dtype, device=self.onload_device)
|
31 |
+
self.state = 1
|
32 |
+
|
33 |
+
def forward(self, *args, **kwargs):
|
34 |
+
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
|
35 |
+
module = self.module
|
36 |
+
else:
|
37 |
+
module = copy.deepcopy(self.module).to(dtype=self.computation_dtype, device=self.computation_device)
|
38 |
+
return module(*args, **kwargs)
|
39 |
+
|
40 |
+
|
41 |
+
class AutoWrappedLinear(torch.nn.Linear):
|
42 |
+
def __init__(self, module: torch.nn.Linear, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device):
|
43 |
+
with init_weights_on_device(device=torch.device("meta")):
|
44 |
+
super().__init__(in_features=module.in_features, out_features=module.out_features, bias=module.bias is not None, dtype=offload_dtype, device=offload_device)
|
45 |
+
self.weight = module.weight
|
46 |
+
self.bias = module.bias
|
47 |
+
self.offload_dtype = offload_dtype
|
48 |
+
self.offload_device = offload_device
|
49 |
+
self.onload_dtype = onload_dtype
|
50 |
+
self.onload_device = onload_device
|
51 |
+
self.computation_dtype = computation_dtype
|
52 |
+
self.computation_device = computation_device
|
53 |
+
self.state = 0
|
54 |
+
|
55 |
+
def offload(self):
|
56 |
+
if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
|
57 |
+
self.to(dtype=self.offload_dtype, device=self.offload_device)
|
58 |
+
self.state = 0
|
59 |
+
|
60 |
+
def onload(self):
|
61 |
+
if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
|
62 |
+
self.to(dtype=self.onload_dtype, device=self.onload_device)
|
63 |
+
self.state = 1
|
64 |
+
|
65 |
+
def forward(self, x, *args, **kwargs):
|
66 |
+
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
|
67 |
+
weight, bias = self.weight, self.bias
|
68 |
+
else:
|
69 |
+
weight = cast_to(self.weight, self.computation_dtype, self.computation_device)
|
70 |
+
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
|
71 |
+
return torch.nn.functional.linear(x, weight, bias)
|
72 |
+
|
73 |
+
|
74 |
+
def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0):
|
75 |
+
for name, module in model.named_children():
|
76 |
+
for source_module, target_module in module_map.items():
|
77 |
+
if isinstance(module, source_module):
|
78 |
+
num_param = sum(p.numel() for p in module.parameters())
|
79 |
+
if max_num_param is not None and total_num_param + num_param > max_num_param:
|
80 |
+
module_config_ = overflow_module_config
|
81 |
+
else:
|
82 |
+
module_config_ = module_config
|
83 |
+
module_ = target_module(module, **module_config_)
|
84 |
+
setattr(model, name, module_)
|
85 |
+
total_num_param += num_param
|
86 |
+
break
|
87 |
+
else:
|
88 |
+
total_num_param = enable_vram_management_recursively(module, module_map, module_config, max_num_param, overflow_module_config, total_num_param)
|
89 |
+
return total_num_param
|
90 |
+
|
91 |
+
|
92 |
+
def enable_vram_management(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None):
|
93 |
+
enable_vram_management_recursively(model, module_map, module_config, max_num_param, overflow_module_config, total_num_param=0)
|
94 |
+
model.vram_management_enabled = True
|
95 |
+
|
infer.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from diffsynth import ModelManager, WanVideoPipeline
|
3 |
+
from PIL import Image
|
4 |
+
import argparse
|
5 |
+
from transformers import Wav2Vec2Processor, Wav2Vec2Model
|
6 |
+
import librosa
|
7 |
+
import os
|
8 |
+
import subprocess
|
9 |
+
import cv2
|
10 |
+
from model import FantasyTalkingAudioConditionModel
|
11 |
+
from utils import save_video, get_audio_features, resize_image_by_longest_edge
|
12 |
+
from pathlib import Path
|
13 |
+
from datetime import datetime
|
14 |
+
|
15 |
+
def parse_args():
|
16 |
+
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
17 |
+
|
18 |
+
parser.add_argument(
|
19 |
+
"--wan_model_dir",
|
20 |
+
type=str,
|
21 |
+
default="./models/Wan2.1-I2V-14B-720P",
|
22 |
+
required=False,
|
23 |
+
help="The dir of the Wan I2V 14B model.",
|
24 |
+
)
|
25 |
+
parser.add_argument(
|
26 |
+
"--fantasytalking_model_path",
|
27 |
+
type=str,
|
28 |
+
default="./models/fantasytalking_model.ckpt",
|
29 |
+
required=False,
|
30 |
+
help="The .ckpt path of fantasytalking model.",
|
31 |
+
)
|
32 |
+
parser.add_argument(
|
33 |
+
"--wav2vec_model_dir",
|
34 |
+
type=str,
|
35 |
+
default="./models/wav2vec2-base-960h",
|
36 |
+
required=False,
|
37 |
+
help="The dir of wav2vec model.",
|
38 |
+
)
|
39 |
+
|
40 |
+
parser.add_argument(
|
41 |
+
"--image_path",
|
42 |
+
type=str,
|
43 |
+
default="./assets/images/woman.png",
|
44 |
+
required=False,
|
45 |
+
help="The path of the image.",
|
46 |
+
)
|
47 |
+
|
48 |
+
parser.add_argument(
|
49 |
+
"--audio_path",
|
50 |
+
type=str,
|
51 |
+
default="./assets/audios/woman.wav",
|
52 |
+
required=False,
|
53 |
+
help="The path of the audio.",
|
54 |
+
)
|
55 |
+
parser.add_argument(
|
56 |
+
"--prompt",
|
57 |
+
type=str,
|
58 |
+
default="A woman is talking.",
|
59 |
+
required=False,
|
60 |
+
help="prompt.",
|
61 |
+
)
|
62 |
+
parser.add_argument(
|
63 |
+
"--output_dir",
|
64 |
+
type=str,
|
65 |
+
default="./output",
|
66 |
+
help="Dir to save the model.",
|
67 |
+
)
|
68 |
+
parser.add_argument(
|
69 |
+
"--image_size",
|
70 |
+
type=int,
|
71 |
+
default=512,
|
72 |
+
help="The image will be resized proportionally to this size.",
|
73 |
+
)
|
74 |
+
parser.add_argument(
|
75 |
+
"--audio_scale",
|
76 |
+
type=float,
|
77 |
+
default=1.0,
|
78 |
+
help="Audio condition injection weight",
|
79 |
+
)
|
80 |
+
parser.add_argument(
|
81 |
+
"--prompt_cfg_scale",
|
82 |
+
type=float,
|
83 |
+
default=5.0,
|
84 |
+
required=False,
|
85 |
+
help="Prompt cfg scale",
|
86 |
+
)
|
87 |
+
parser.add_argument(
|
88 |
+
"--audio_cfg_scale",
|
89 |
+
type=float,
|
90 |
+
default=5.0,
|
91 |
+
required=False,
|
92 |
+
help="Audio cfg scale",
|
93 |
+
)
|
94 |
+
parser.add_argument(
|
95 |
+
"--max_num_frames",
|
96 |
+
type=int,
|
97 |
+
default=81,
|
98 |
+
required=False,
|
99 |
+
help="The maximum frames for generating videos, the audio part exceeding max_num_frames/fps will be truncated."
|
100 |
+
)
|
101 |
+
parser.add_argument(
|
102 |
+
"--fps",
|
103 |
+
type=int,
|
104 |
+
default=23,
|
105 |
+
required=False,
|
106 |
+
)
|
107 |
+
parser.add_argument(
|
108 |
+
"--num_persistent_param_in_dit",
|
109 |
+
type=int,
|
110 |
+
default=None,
|
111 |
+
required=False,
|
112 |
+
help="Maximum parameter quantity retained in video memory, small number to reduce VRAM required"
|
113 |
+
)
|
114 |
+
parser.add_argument(
|
115 |
+
"--seed",
|
116 |
+
type=int,
|
117 |
+
default=1111,
|
118 |
+
required=False,
|
119 |
+
)
|
120 |
+
args = parser.parse_args()
|
121 |
+
return args
|
122 |
+
|
123 |
+
def load_models(args):
|
124 |
+
# Load Wan I2V models
|
125 |
+
model_manager = ModelManager(device="cpu")
|
126 |
+
model_manager.load_models(
|
127 |
+
[
|
128 |
+
[
|
129 |
+
f"{args.wan_model_dir}/diffusion_pytorch_model-00001-of-00007.safetensors",
|
130 |
+
f"{args.wan_model_dir}/diffusion_pytorch_model-00002-of-00007.safetensors",
|
131 |
+
f"{args.wan_model_dir}/diffusion_pytorch_model-00003-of-00007.safetensors",
|
132 |
+
f"{args.wan_model_dir}/diffusion_pytorch_model-00004-of-00007.safetensors",
|
133 |
+
f"{args.wan_model_dir}/diffusion_pytorch_model-00005-of-00007.safetensors",
|
134 |
+
f"{args.wan_model_dir}/diffusion_pytorch_model-00006-of-00007.safetensors",
|
135 |
+
f"{args.wan_model_dir}/diffusion_pytorch_model-00007-of-00007.safetensors",
|
136 |
+
],
|
137 |
+
f"{args.wan_model_dir}/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
|
138 |
+
f"{args.wan_model_dir}/models_t5_umt5-xxl-enc-bf16.pth",
|
139 |
+
f"{args.wan_model_dir}/Wan2.1_VAE.pth",
|
140 |
+
],
|
141 |
+
# torch_dtype=torch.float8_e4m3fn, # You can set `torch_dtype=torch.bfloat16` to disable FP8 quantization.
|
142 |
+
torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.bfloat16` to disable FP8 quantization.
|
143 |
+
)
|
144 |
+
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
|
145 |
+
|
146 |
+
# Load FantasyTalking weights
|
147 |
+
fantasytalking = FantasyTalkingAudioConditionModel(pipe.dit, 768, 2048).to("cuda")
|
148 |
+
fantasytalking.load_audio_processor(args.fantasytalking_model_path, pipe.dit)
|
149 |
+
|
150 |
+
# You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required.
|
151 |
+
pipe.enable_vram_management(num_persistent_param_in_dit=args.num_persistent_param_in_dit)
|
152 |
+
|
153 |
+
# Load wav2vec models
|
154 |
+
wav2vec_processor = Wav2Vec2Processor.from_pretrained(args.wav2vec_model_dir)
|
155 |
+
wav2vec = Wav2Vec2Model.from_pretrained(args.wav2vec_model_dir).to("cuda")
|
156 |
+
|
157 |
+
return pipe,fantasytalking,wav2vec_processor,wav2vec
|
158 |
+
|
159 |
+
|
160 |
+
|
161 |
+
def main(args,pipe,fantasytalking,wav2vec_processor,wav2vec):
|
162 |
+
os.makedirs(args.output_dir,exist_ok=True)
|
163 |
+
|
164 |
+
duration = librosa.get_duration(filename=args.audio_path)
|
165 |
+
num_frames = min(int(args.fps*duration//4)*4+5,args.max_num_frames)
|
166 |
+
|
167 |
+
audio_wav2vec_fea = get_audio_features(wav2vec,wav2vec_processor,args.audio_path,args.fps,num_frames)
|
168 |
+
image = resize_image_by_longest_edge(args.image_path,args.image_size)
|
169 |
+
width, height = image.size
|
170 |
+
|
171 |
+
audio_proj_fea = fantasytalking.get_proj_fea(audio_wav2vec_fea)
|
172 |
+
pos_idx_ranges = fantasytalking.split_audio_sequence(audio_proj_fea.size(1),num_frames=num_frames)
|
173 |
+
audio_proj_split,audio_context_lens = fantasytalking.split_tensor_with_padding(audio_proj_fea,pos_idx_ranges,expand_length=4) # [b,21,9+8,768]
|
174 |
+
|
175 |
+
# Image-to-video
|
176 |
+
video_audio = pipe(
|
177 |
+
prompt=args.prompt,
|
178 |
+
negative_prompt="人物静止不动,静止,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
179 |
+
input_image=image,
|
180 |
+
width=width,
|
181 |
+
height=height,
|
182 |
+
num_frames=num_frames,
|
183 |
+
num_inference_steps=30,
|
184 |
+
seed=args.seed, tiled=True,
|
185 |
+
audio_scale=args.audio_scale,
|
186 |
+
cfg_scale = args.prompt_cfg_scale,
|
187 |
+
audio_cfg_scale=args.audio_cfg_scale,
|
188 |
+
audio_proj=audio_proj_split,
|
189 |
+
audio_context_lens=audio_context_lens,
|
190 |
+
latents_num_frames=(num_frames-1)//4+1
|
191 |
+
)
|
192 |
+
current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
|
193 |
+
save_path_tmp = f"{args.output_dir}/tmp_{Path(args.image_path).stem}_{Path(args.audio_path).stem}_{current_time}.mp4"
|
194 |
+
save_video(video_audio, save_path_tmp, fps=args.fps, quality=5)
|
195 |
+
|
196 |
+
save_path = f"{args.output_dir}/{Path(args.image_path).stem}_{Path(args.audio_path).stem}_{current_time}.mp4"
|
197 |
+
final_command = [
|
198 |
+
"ffmpeg", "-y",
|
199 |
+
"-i", save_path_tmp,
|
200 |
+
"-i", args.audio_path,
|
201 |
+
"-c:v", "libx264",
|
202 |
+
"-c:a", "aac",
|
203 |
+
"-shortest",
|
204 |
+
save_path
|
205 |
+
]
|
206 |
+
subprocess.run(final_command, check=True)
|
207 |
+
os.remove(save_path_tmp)
|
208 |
+
return save_path
|
209 |
+
|
210 |
+
if __name__ == "__main__":
|
211 |
+
args = parse_args()
|
212 |
+
pipe,fantasytalking,wav2vec_processor,wav2vec = load_models(args)
|
213 |
+
|
214 |
+
main(args,pipe,fantasytalking,wav2vec_processor,wav2vec)
|
model.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from diffsynth.models.wan_video_dit import flash_attention, WanModel
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch
|
5 |
+
import os
|
6 |
+
from safetensors import safe_open
|
7 |
+
|
8 |
+
|
9 |
+
class AudioProjModel(nn.Module):
|
10 |
+
def __init__(self, audio_in_dim=1024, cross_attention_dim=1024):
|
11 |
+
super().__init__()
|
12 |
+
self.cross_attention_dim = cross_attention_dim
|
13 |
+
self.proj = torch.nn.Linear(audio_in_dim, cross_attention_dim, bias=False)
|
14 |
+
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
15 |
+
|
16 |
+
def forward(self, audio_embeds):
|
17 |
+
context_tokens = self.proj(audio_embeds)
|
18 |
+
context_tokens = self.norm(context_tokens)
|
19 |
+
return context_tokens # [B,L,C]
|
20 |
+
|
21 |
+
|
22 |
+
class WanCrossAttentionProcessor(nn.Module):
|
23 |
+
def __init__(self, context_dim, hidden_dim):
|
24 |
+
super().__init__()
|
25 |
+
|
26 |
+
self.context_dim = context_dim
|
27 |
+
self.hidden_dim = hidden_dim
|
28 |
+
|
29 |
+
self.k_proj = nn.Linear(context_dim, hidden_dim, bias=False)
|
30 |
+
self.v_proj = nn.Linear(context_dim, hidden_dim, bias=False)
|
31 |
+
|
32 |
+
nn.init.zeros_(self.k_proj.weight)
|
33 |
+
nn.init.zeros_(self.v_proj.weight)
|
34 |
+
|
35 |
+
def __call__(
|
36 |
+
self,
|
37 |
+
attn: nn.Module,
|
38 |
+
x: torch.Tensor,
|
39 |
+
context: torch.Tensor,
|
40 |
+
context_lens: torch.Tensor,
|
41 |
+
audio_proj: torch.Tensor,
|
42 |
+
audio_context_lens: torch.Tensor,
|
43 |
+
latents_num_frames: int = 21,
|
44 |
+
audio_scale: float = 1.0,
|
45 |
+
) -> torch.Tensor:
|
46 |
+
"""
|
47 |
+
x: [B, L1, C].
|
48 |
+
context: [B, L2, C].
|
49 |
+
context_lens: [B].
|
50 |
+
audio_proj: [B, 21, L3, C]
|
51 |
+
audio_context_lens: [B*21].
|
52 |
+
"""
|
53 |
+
context_img = context[:, :257]
|
54 |
+
context = context[:, 257:]
|
55 |
+
b, n, d = x.size(0), attn.num_heads, attn.head_dim
|
56 |
+
|
57 |
+
# compute query, key, value
|
58 |
+
q = attn.norm_q(attn.q(x)).view(b, -1, n, d)
|
59 |
+
k = attn.norm_k(attn.k(context)).view(b, -1, n, d)
|
60 |
+
v = attn.v(context).view(b, -1, n, d)
|
61 |
+
k_img = attn.norm_k_img(attn.k_img(context_img)).view(b, -1, n, d)
|
62 |
+
v_img = attn.v_img(context_img).view(b, -1, n, d)
|
63 |
+
img_x = flash_attention(q, k_img, v_img, k_lens=None)
|
64 |
+
# compute attention
|
65 |
+
x = flash_attention(q, k, v, k_lens=context_lens)
|
66 |
+
x = x.flatten(2)
|
67 |
+
img_x = img_x.flatten(2)
|
68 |
+
|
69 |
+
if len(audio_proj.shape) == 4:
|
70 |
+
audio_q = q.view(b * latents_num_frames, -1, n, d) # [b, 21, l1, n, d]
|
71 |
+
ip_key = self.k_proj(audio_proj).view(b * latents_num_frames, -1, n, d)
|
72 |
+
ip_value = self.v_proj(audio_proj).view(b * latents_num_frames, -1, n, d)
|
73 |
+
audio_x = flash_attention(
|
74 |
+
audio_q, ip_key, ip_value, k_lens=audio_context_lens
|
75 |
+
)
|
76 |
+
audio_x = audio_x.view(b, q.size(1), n, d)
|
77 |
+
audio_x = audio_x.flatten(2)
|
78 |
+
elif len(audio_proj.shape) == 3:
|
79 |
+
ip_key = self.k_proj(audio_proj).view(b, -1, n, d)
|
80 |
+
ip_value = self.v_proj(audio_proj).view(b, -1, n, d)
|
81 |
+
audio_x = flash_attention(q, ip_key, ip_value, k_lens=audio_context_lens)
|
82 |
+
audio_x = audio_x.flatten(2)
|
83 |
+
# output
|
84 |
+
x = x + img_x + audio_x * audio_scale
|
85 |
+
x = attn.o(x)
|
86 |
+
return x
|
87 |
+
|
88 |
+
|
89 |
+
class FantasyTalkingAudioConditionModel(nn.Module):
|
90 |
+
def __init__(self, wan_dit: WanModel, audio_in_dim: int, audio_proj_dim: int):
|
91 |
+
super().__init__()
|
92 |
+
|
93 |
+
self.audio_in_dim = audio_in_dim
|
94 |
+
self.audio_proj_dim = audio_proj_dim
|
95 |
+
|
96 |
+
# audio proj model
|
97 |
+
self.proj_model = self.init_proj(self.audio_proj_dim)
|
98 |
+
self.set_audio_processor(wan_dit)
|
99 |
+
|
100 |
+
def init_proj(self, cross_attention_dim=5120):
|
101 |
+
proj_model = AudioProjModel(
|
102 |
+
audio_in_dim=self.audio_in_dim, cross_attention_dim=cross_attention_dim
|
103 |
+
)
|
104 |
+
return proj_model
|
105 |
+
|
106 |
+
def set_audio_processor(self, wan_dit):
|
107 |
+
attn_procs = {}
|
108 |
+
for name in wan_dit.attn_processors.keys():
|
109 |
+
attn_procs[name] = WanCrossAttentionProcessor(
|
110 |
+
context_dim=self.audio_proj_dim, hidden_dim=wan_dit.dim
|
111 |
+
)
|
112 |
+
wan_dit.set_attn_processor(attn_procs)
|
113 |
+
|
114 |
+
def load_audio_processor(self, ip_ckpt: str, wan_dit):
|
115 |
+
if os.path.splitext(ip_ckpt)[-1] == ".safetensors":
|
116 |
+
state_dict = {"proj_model": {}, "audio_processor": {}}
|
117 |
+
with safe_open(ip_ckpt, framework="pt", device="cpu") as f:
|
118 |
+
for key in f.keys():
|
119 |
+
if key.startswith("proj_model."):
|
120 |
+
state_dict["proj_model"][key.replace("proj_model.", "")] = (
|
121 |
+
f.get_tensor(key)
|
122 |
+
)
|
123 |
+
elif key.startswith("audio_processor."):
|
124 |
+
state_dict["audio_processor"][
|
125 |
+
key.replace("audio_processor.", "")
|
126 |
+
] = f.get_tensor(key)
|
127 |
+
else:
|
128 |
+
state_dict = torch.load(ip_ckpt, map_location="cpu")
|
129 |
+
self.proj_model.load_state_dict(state_dict["proj_model"])
|
130 |
+
wan_dit.load_state_dict(state_dict["audio_processor"], strict=False)
|
131 |
+
|
132 |
+
def get_proj_fea(self, audio_fea=None):
|
133 |
+
|
134 |
+
return self.proj_model(audio_fea) if audio_fea is not None else None
|
135 |
+
|
136 |
+
def split_audio_sequence(self, audio_proj_length, num_frames=81):
|
137 |
+
"""
|
138 |
+
Map the audio feature sequence to corresponding latent frame slices.
|
139 |
+
|
140 |
+
Args:
|
141 |
+
audio_proj_length (int): The total length of the audio feature sequence
|
142 |
+
(e.g., 173 in audio_proj[1, 173, 768]).
|
143 |
+
num_frames (int): The number of video frames in the training data (default: 81).
|
144 |
+
|
145 |
+
Returns:
|
146 |
+
list: A list of [start_idx, end_idx] pairs. Each pair represents the index range
|
147 |
+
(within the audio feature sequence) corresponding to a latent frame.
|
148 |
+
"""
|
149 |
+
# Average number of tokens per original video frame
|
150 |
+
tokens_per_frame = audio_proj_length / num_frames
|
151 |
+
|
152 |
+
# Each latent frame covers 4 video frames, and we want the center
|
153 |
+
tokens_per_latent_frame = tokens_per_frame * 4
|
154 |
+
half_tokens = int(tokens_per_latent_frame / 2)
|
155 |
+
|
156 |
+
pos_indices = []
|
157 |
+
for i in range(int((num_frames - 1) / 4) + 1):
|
158 |
+
if i == 0:
|
159 |
+
pos_indices.append(0)
|
160 |
+
else:
|
161 |
+
start_token = tokens_per_frame * ((i - 1) * 4 + 1)
|
162 |
+
end_token = tokens_per_frame * (i * 4 + 1)
|
163 |
+
center_token = int((start_token + end_token) / 2) - 1
|
164 |
+
pos_indices.append(center_token)
|
165 |
+
|
166 |
+
# Build index ranges centered around each position
|
167 |
+
pos_idx_ranges = [[idx - half_tokens, idx + half_tokens] for idx in pos_indices]
|
168 |
+
|
169 |
+
# Adjust the first range to avoid negative start index
|
170 |
+
pos_idx_ranges[0] = [
|
171 |
+
-(half_tokens * 2 - pos_idx_ranges[1][0]),
|
172 |
+
pos_idx_ranges[1][0],
|
173 |
+
]
|
174 |
+
|
175 |
+
return pos_idx_ranges
|
176 |
+
|
177 |
+
def split_tensor_with_padding(self, input_tensor, pos_idx_ranges, expand_length=0):
|
178 |
+
"""
|
179 |
+
Split the input tensor into subsequences based on index ranges, and apply right-side zero-padding
|
180 |
+
if the range exceeds the input boundaries.
|
181 |
+
|
182 |
+
Args:
|
183 |
+
input_tensor (Tensor): Input audio tensor of shape [1, L, 768].
|
184 |
+
pos_idx_ranges (list): A list of index ranges, e.g. [[-7, 1], [1, 9], ..., [165, 173]].
|
185 |
+
expand_length (int): Number of tokens to expand on both sides of each subsequence.
|
186 |
+
|
187 |
+
Returns:
|
188 |
+
sub_sequences (Tensor): A tensor of shape [1, F, L, 768], where L is the length after padding.
|
189 |
+
Each element is a padded subsequence.
|
190 |
+
k_lens (Tensor): A tensor of shape [F], representing the actual (unpadded) length of each subsequence.
|
191 |
+
Useful for ignoring padding tokens in attention masks.
|
192 |
+
"""
|
193 |
+
pos_idx_ranges = [
|
194 |
+
[idx[0] - expand_length, idx[1] + expand_length] for idx in pos_idx_ranges
|
195 |
+
]
|
196 |
+
sub_sequences = []
|
197 |
+
seq_len = input_tensor.size(1) # 173
|
198 |
+
max_valid_idx = seq_len - 1 # 172
|
199 |
+
k_lens_list = []
|
200 |
+
for start, end in pos_idx_ranges:
|
201 |
+
# Calculate the fill amount
|
202 |
+
pad_front = max(-start, 0)
|
203 |
+
pad_back = max(end - max_valid_idx, 0)
|
204 |
+
|
205 |
+
# Calculate the start and end indices of the valid part
|
206 |
+
valid_start = max(start, 0)
|
207 |
+
valid_end = min(end, max_valid_idx)
|
208 |
+
|
209 |
+
# Extract the valid part
|
210 |
+
if valid_start <= valid_end:
|
211 |
+
valid_part = input_tensor[:, valid_start : valid_end + 1, :]
|
212 |
+
else:
|
213 |
+
valid_part = input_tensor.new_zeros(
|
214 |
+
(1, 0, input_tensor.size(2))
|
215 |
+
)
|
216 |
+
|
217 |
+
# In the sequence dimension (the 1st dimension) perform padding
|
218 |
+
padded_subseq = F.pad(
|
219 |
+
valid_part,
|
220 |
+
(0, 0, 0, pad_back + pad_front, 0, 0),
|
221 |
+
mode="constant",
|
222 |
+
value=0,
|
223 |
+
)
|
224 |
+
k_lens_list.append(padded_subseq.size(-2) - pad_back - pad_front)
|
225 |
+
|
226 |
+
sub_sequences.append(padded_subseq)
|
227 |
+
return torch.stack(sub_sequences, dim=1), torch.tensor(
|
228 |
+
k_lens_list, dtype=torch.long
|
229 |
+
)
|
requirements.txt
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch>=2.0.0
|
2 |
+
torchvision
|
3 |
+
cupy-cuda12x
|
4 |
+
transformers==4.46.2
|
5 |
+
controlnet-aux==0.0.7
|
6 |
+
imageio
|
7 |
+
imageio[ffmpeg]
|
8 |
+
safetensors
|
9 |
+
einops
|
10 |
+
sentencepiece
|
11 |
+
protobuf
|
12 |
+
modelscope
|
13 |
+
ftfy
|
14 |
+
librosa
|
utils.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import imageio, librosa
|
2 |
+
import torch
|
3 |
+
from PIL import Image
|
4 |
+
from tqdm import tqdm
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
def resize_image_by_longest_edge(image_path, target_size):
|
9 |
+
image = Image.open(image_path).convert("RGB")
|
10 |
+
width, height = image.size
|
11 |
+
scale = target_size / max(width, height)
|
12 |
+
new_size = (int(width * scale), int(height * scale))
|
13 |
+
return image.resize(new_size, Image.LANCZOS)
|
14 |
+
|
15 |
+
|
16 |
+
def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None):
|
17 |
+
writer = imageio.get_writer(
|
18 |
+
save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params
|
19 |
+
)
|
20 |
+
for frame in tqdm(frames, desc="Saving video"):
|
21 |
+
frame = np.array(frame)
|
22 |
+
writer.append_data(frame)
|
23 |
+
writer.close()
|
24 |
+
|
25 |
+
|
26 |
+
def get_audio_features(wav2vec, audio_processor, audio_path, fps, num_frames):
|
27 |
+
sr = 16000
|
28 |
+
audio_input, sample_rate = librosa.load(audio_path, sr=sr) # 采样率为 16kHz
|
29 |
+
|
30 |
+
start_time = 0
|
31 |
+
# end_time = (0 + (num_frames - 1) * 1) / fps
|
32 |
+
end_time = num_frames / fps
|
33 |
+
|
34 |
+
start_sample = int(start_time * sr)
|
35 |
+
end_sample = int(end_time * sr)
|
36 |
+
|
37 |
+
try:
|
38 |
+
audio_segment = audio_input[start_sample:end_sample]
|
39 |
+
except:
|
40 |
+
audio_segment = audio_input
|
41 |
+
|
42 |
+
input_values = audio_processor(
|
43 |
+
audio_segment, sampling_rate=sample_rate, return_tensors="pt"
|
44 |
+
).input_values.to("cuda")
|
45 |
+
|
46 |
+
with torch.no_grad():
|
47 |
+
fea = wav2vec(input_values).last_hidden_state
|
48 |
+
|
49 |
+
return fea
|