1inkusFace commited on
Commit
ecea5f9
·
verified ·
1 Parent(s): 073fba8

revert to SkyReels-V1

Browse files
Files changed (1) hide show
  1. app.py +86 -296
app.py CHANGED
@@ -1,20 +1,19 @@
1
  import spaces
 
2
  import gradio as gr
3
  import argparse
4
  import sys
 
5
  import os
6
  import random
7
- import subprocess
8
- from PIL import Image
9
- import numpy as np
10
-
11
- # Removed environment-specific lines
12
  from diffusers.utils import export_to_video
13
  from diffusers.utils import load_image
14
 
15
  import torch
16
- import logging
17
- from collections import OrderedDict
18
 
19
  torch.backends.cuda.matmul.allow_tf32 = False
20
  torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
@@ -25,309 +24,100 @@ torch.backends.cudnn.benchmark = False
25
  torch.set_float32_matmul_precision("highest")
26
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
27
 
28
- logger = logging.getLogger(__name__)
29
-
30
-
31
- # --- Dummy Classes (Keep for standalone execution) ---
32
- class OffloadConfig:
33
- def __init__(
34
- self,
35
- high_cpu_memory: bool = False,
36
- parameters_level: bool = False,
37
- compiler_transformer: bool = False,
38
- compiler_cache: str = "",
39
- ):
40
- self.high_cpu_memory = high_cpu_memory
41
- self.parameters_level = parameters_level
42
- self.compiler_transformer = compiler_transformer
43
- self.compiler_cache = compiler_cache
44
-
45
-
46
- class TaskType: # Keep here for infer
47
- T2V = 0
48
- I2V = 1
49
-
50
-
51
- class LlamaModel:
52
- @staticmethod
53
- def from_pretrained(*args, **kwargs):
54
- return LlamaModel()
55
-
56
- def to(self, device):
57
- return self
58
-
59
-
60
- class HunyuanVideoTransformer3DModel:
61
- @staticmethod
62
- def from_pretrained(*args, **kwargs):
63
- return HunyuanVideoTransformer3DModel()
64
-
65
- def to(self, device):
66
- return self
67
-
68
-
69
- class SkyreelsVideoPipeline:
70
- @staticmethod
71
- def from_pretrained(*args, **kwargs):
72
- return SkyreelsVideoPipeline()
73
-
74
- def to(self, device):
75
- return self
76
-
77
- def __call__(self, *args, **kwargs):
78
- num_frames = kwargs.get("num_frames", 16) # Default to 16 frames
79
- height = kwargs.get("height", 512)
80
- width = kwargs.get("width", 512)
81
-
82
- if "image" in kwargs: # I2V
83
- image = kwargs["image"]
84
- # Convert PIL Image to PyTorch tensor (and normalize to [0, 1])
85
- image_tensor = torch.from_numpy(np.array(image)).float() / 255.0
86
- image_tensor = image_tensor.permute(2, 0, 1).unsqueeze(0) # (H, W, C) -> (1, C, H, W)
87
-
88
- # Create video by repeating the image
89
- frames = image_tensor.repeat(1, 1, num_frames, 1, 1) # (1, C, T, H, W)
90
- frames = frames + torch.randn_like(frames) * 0.05 # Add a little noise
91
- # Correct shape: (1, C, T, H, W) - NO PERMUTE HERE
92
-
93
- else: # T2V
94
- frames = torch.randn(1, 3, num_frames, height, width) # (1, C, T, H, W) - Correct!
95
-
96
- return type("obj", (object,), {"frames": frames})() # No longer a list!
97
-
98
- def __init__(self):
99
- super().__init__()
100
- self._modules = OrderedDict()
101
- self.vae = self.VAE()
102
- self._modules["vae"] = self.vae
103
-
104
- def named_children(self):
105
- return self._modules.items()
106
-
107
- class VAE:
108
- def enable_tiling(self):
109
- pass
110
-
111
-
112
- def quantize_(*args, **kwargs):
113
- return
114
-
115
-
116
- def float8_weight_only():
117
- return
118
-
119
-
120
- # --- End Dummy Classes ---
121
-
122
-
123
- class SkyReelsVideoSingleGpuInfer:
124
- def _load_model(
125
- self, model_id: str, base_model_id: str = "hunyuanvideo-community/HunyuanVideo", quant_model: bool = True
126
- ):
127
- logger.info(f"load model model_id:{model_id} quan_model:{quant_model}")
128
- text_encoder = LlamaModel.from_pretrained(
129
- base_model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16
130
- ).to("cpu")
131
- transformer = HunyuanVideoTransformer3DModel.from_pretrained(
132
- model_id, torch_dtype=torch.bfloat16, device="cpu"
133
- ).to("cpu")
134
-
135
- if quant_model:
136
- quantize_(text_encoder, float8_weight_only())
137
- text_encoder.to("cpu")
138
- torch.cuda.empty_cache()
139
- quantize_(transformer, float8_weight_only())
140
- transformer.to("cpu")
141
- torch.cuda.empty_cache()
142
-
143
- pipe = SkyreelsVideoPipeline.from_pretrained(
144
- base_model_id, transformer=transformer, text_encoder=text_encoder, torch_dtype=torch.bfloat16
145
- ).to("cpu")
146
- pipe.vae.enable_tiling()
147
- torch.cuda.empty_cache()
148
- return pipe
149
-
150
- def __init__(
151
- self,
152
- task_type: TaskType,
153
- model_id: str,
154
- quant_model: bool = True,
155
- is_offload: bool = True,
156
- offload_config: OffloadConfig = OffloadConfig(),
157
- enable_cfg_parallel: bool = True,
158
- ):
159
- self.task_type = task_type
160
- self.model_id = model_id
161
- self.quant_model = quant_model
162
- self.is_offload = is_offload
163
- self.offload_config = offload_config
164
- self.enable_cfg_parallel = enable_cfg_parallel
165
- self.pipe = None
166
- self.is_initialized = False
167
- self.gpu_device = None
168
-
169
- def initialize(self):
170
- """Initializes the model and moves it to the GPU."""
171
- if self.is_initialized:
172
- return
173
-
174
- if not torch.cuda.is_available():
175
- raise RuntimeError("CUDA is not available. Cannot initialize model.")
176
-
177
- self.gpu_device = "cuda:0"
178
- self.pipe = self._load_model(model_id=self.model_id, quant_model=self.quant_model)
179
-
180
- if self.is_offload:
181
- pass
182
- else:
183
- self.pipe.to(self.gpu_device)
184
-
185
- if self.offload_config.compiler_transformer:
186
- torch._dynamo.config.suppress_errors = True
187
- os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
188
- os.environ["TORCHINDUCTOR_CACHE_DIR"] = f"{self.offload_config.compiler_cache}"
189
- self.pipe.transformer = torch.compile(
190
- self.pipe.transformer, mode="max-autotune-no-cudagraphs", dynamic=True
191
- )
192
- if self.offload_config.compiler_transformer:
193
- self.warm_up()
194
- self.is_initialized = True
195
-
196
- def warm_up(self):
197
- if not self.is_initialized:
198
- raise RuntimeError("Model must be initialized before warm-up.")
199
-
200
- init_kwargs = {
201
- "prompt": "A woman is dancing in a room",
202
- "height": 544,
203
- "width": 960,
204
- "guidance_scale": 6,
205
- "num_inference_steps": 1,
206
- "negative_prompt": "bad quality",
207
- "num_frames": 16,
208
- "generator": torch.Generator(self.gpu_device).manual_seed(42),
209
- "embedded_guidance_scale": 1.0,
210
- }
211
- if self.task_type == TaskType.I2V:
212
- init_kwargs["image"] = Image.new("RGB", (544, 960), color="black")
213
- self.pipe(**init_kwargs)
214
- logger.info("Warm-up complete.")
215
-
216
- def infer(self, **kwargs):
217
- """Handles inference requests."""
218
- if not self.is_initialized:
219
- self.initialize()
220
- if "seed" in kwargs:
221
- kwargs["generator"] = torch.Generator(self.gpu_device).manual_seed(kwargs["seed"])
222
- del kwargs["seed"]
223
- assert (self.task_type == TaskType.I2V and "image" in kwargs) or self.task_type == TaskType.T2V
224
- result = self.pipe(**kwargs).frames # Return the tensor directly
225
- return result
226
-
227
-
228
- _predictor = None
229
-
230
-
231
- @spaces.GPU(duration=90)
232
- def generate_video(prompt: str, seed: int, image: str = None) -> tuple[str, dict]:
233
- """Generates a video based on the given prompt and seed.
234
-
235
- Args:
236
- prompt: The text prompt to guide video generation.
237
- seed: The random seed for reproducibility.
238
- image: Optional path to an image for Image-to-Video.
239
 
240
- Returns:
241
- A tuple containing the path to the generated video and the parameters used.
242
- """
243
- global _predictor
244
 
245
  if seed == -1:
246
- random.seed()
247
  seed = int(random.randrange(4294967294))
248
-
249
- if image is None:
250
- task_type = TaskType.T2V
251
- model_id = "Skywork/SkyReels-V1-Hunyuan-T2V"
252
- kwargs = {
253
- "prompt": prompt,
254
- "height": 512,
255
- "width": 512,
256
- "num_frames": 16,
257
- "num_inference_steps": 30,
258
- "seed": seed,
259
- "guidance_scale": 7.5,
260
- "negative_prompt": "bad quality, worst quality",
261
- }
262
- else:
263
- task_type = TaskType.I2V
264
- model_id = "Skywork/SkyReels-V1-Hunyuan-I2V"
265
- kwargs = {
266
- "prompt": prompt,
267
- "image": load_image(image),
268
- "height": 512,
269
- "width": 512,
270
- "num_frames": 97,
271
- "num_inference_steps": 30,
272
- "seed": seed,
273
- "guidance_scale": 6.0,
274
- "embedded_guidance_scale": 1.0,
275
- "negative_prompt": "Aerial view, low quality, bad hands",
276
- "cfg_for": False,
277
- }
278
-
279
- if _predictor is None:
280
- _predictor = SkyReelsVideoSingleGpuInfer(
281
- task_type=task_type,
282
- model_id=model_id,
283
- quant_model=True,
284
- is_offload=True,
285
- offload_config=OffloadConfig(
286
- high_cpu_memory=True,
287
- parameters_level=True,
288
- compiler_transformer=False,
289
- ),
290
- )
291
- _predictor.initialize()
292
- logger.info("Predictor initialized")
293
-
294
- with torch.no_grad():
295
- output = _predictor.infer(**kwargs)
296
- '''
297
- output = (output.numpy() * 255).astype(np.uint8)
298
- # Correct Transpose: (1, C, T, H, W) -> (1, T, H, W, C)
299
- output = output.transpose(0, 2, 3, 4, 1)
300
- output = output[0] # Remove batch dimension: (T, H, W, C)
301
- '''
302
-
303
- save_dir = f"./result"
304
  os.makedirs(save_dir, exist_ok=True)
305
- video_out_file = f"{save_dir}/{seed}.mp4"
306
  print(f"generate video, local path: {video_out_file}")
307
  export_to_video(output, video_out_file, fps=24)
308
  return video_out_file, kwargs
309
 
310
 
311
- def create_gradio_interface():
312
- with gr.Blocks() as demo:
313
- with gr.Row():
314
- with gr.Column():
 
315
  image = gr.Image(label="Upload Image", type="filepath")
316
  prompt = gr.Textbox(label="Input Prompt")
317
  seed = gr.Number(label="Random Seed", value=-1)
318
- with gr.Column():
319
- submit_button = gr.Button("Generate Video")
320
- output_video = gr.Video(label="Generated Video")
321
- output_params = gr.Textbox(label="Output Parameters")
 
 
 
 
 
 
322
 
323
- submit_button.click(
324
- fn=generate_video,
325
- inputs=[prompt, seed, image],
326
- outputs=[output_video, output_params],
327
- )
328
- return demo
 
 
 
 
 
 
 
 
 
329
 
 
330
 
331
  if __name__ == "__main__":
332
- demo = create_gradio_interface()
333
- demo.queue().launch()
 
 
 
1
  import spaces
2
+
3
  import gradio as gr
4
  import argparse
5
  import sys
6
+ import time
7
  import os
8
  import random
9
+ #sys.path.append("..")
10
+ from skyreelsinfer import TaskType
11
+ from skyreelsinfer.offload import OffloadConfig
12
+ from skyreelsinfer.skyreels_video_infer import SkyReelsVideoInfer
 
13
  from diffusers.utils import export_to_video
14
  from diffusers.utils import load_image
15
 
16
  import torch
 
 
17
 
18
  torch.backends.cuda.matmul.allow_tf32 = False
19
  torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
 
24
  torch.set_float32_matmul_precision("highest")
25
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
26
 
27
+ predictor = None
28
+ task_type = None
29
+
30
+ def get_transformer_model_id(task_type:str) -> str:
31
+ return "Skywork/SkyReels-V1-Hunyuan-I2V" if task_type == "i2v" else "Skywork/SkyReels-V1-Hunyuan-T2V"
32
+
33
+ def init_predictor(task_type:str, gpu_num:int=1):
34
+ global predictor
35
+ predictor = SkyReelsVideoInfer(
36
+ task_type= TaskType.I2V if task_type == "i2v" else TaskType.T2V,
37
+ model_id=get_transformer_model_id(task_type),
38
+ quant_model=True,
39
+ world_size=gpu_num,
40
+ is_offload=True,
41
+ offload_config=OffloadConfig(
42
+ high_cpu_memory=True,
43
+ parameters_level=True,
44
+ compiler_transformer=False,
45
+ )
46
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
+ def generate_video(prompt, seed, image=None):
49
+ global task_type
50
+ print(f"image:{type(image)}")
 
51
 
52
  if seed == -1:
53
+ random.seed(time.time())
54
  seed = int(random.randrange(4294967294))
55
+
56
+ kwargs = {
57
+ "prompt": prompt,
58
+ "height": 512,
59
+ "width": 512,
60
+ "num_frames": 97,
61
+ "num_inference_steps": 30,
62
+ "seed": seed,
63
+ "guidance_scale": 6.0,
64
+ "embedded_guidance_scale": 1.0,
65
+ "negative_prompt": "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion",
66
+ "cfg_for": False,
67
+ }
68
+
69
+ if task_type == "i2v":
70
+ assert image is not None, "please input image"
71
+ kwargs["image"] = load_image(image=image)
72
+ global predictor
73
+ output = predictor.inference(kwargs)
74
+ save_dir = f"./result/{task_type}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  os.makedirs(save_dir, exist_ok=True)
76
+ video_out_file = f"{save_dir}/{prompt[:100].replace('/','')}_{seed}.mp4"
77
  print(f"generate video, local path: {video_out_file}")
78
  export_to_video(output, video_out_file, fps=24)
79
  return video_out_file, kwargs
80
 
81
 
82
+ def create_gradio_interface(task_type):
83
+ """Create a Gradio interface based on the task type."""
84
+ if task_type == "i2v":
85
+ with gr.Blocks() as demo:
86
+ with gr.Row():
87
  image = gr.Image(label="Upload Image", type="filepath")
88
  prompt = gr.Textbox(label="Input Prompt")
89
  seed = gr.Number(label="Random Seed", value=-1)
90
+ submit_button = gr.Button("Generate Video")
91
+ output_video = gr.Video(label="Generated Video")
92
+ output_params = gr.Textbox(label="Output Parameters")
93
+
94
+ # Submit button logic
95
+ submit_button.click(
96
+ fn=generate_video,
97
+ inputs=[prompt, seed, image],
98
+ outputs=[output_video, output_params],
99
+ )
100
 
101
+ elif task_type == "t2v":
102
+ with gr.Blocks() as demo:
103
+ with gr.Row():
104
+ prompt = gr.Textbox(label="Input Prompt")
105
+ seed = gr.Number(label="Random Seed", value=-1)
106
+ submit_button = gr.Button("Generate Video")
107
+ output_video = gr.Video(label="Generated Video")
108
+ output_params = gr.Textbox(label="Output Parameters")
109
+
110
+ # Submit button logic
111
+ submit_button.click(
112
+ fn=generate_video,
113
+ inputs=[prompt, seed],
114
+ outputs=[output_video, output_params], # Pass task_type as additional input
115
+ )
116
 
117
+ return demo
118
 
119
  if __name__ == "__main__":
120
+ # Parse command-line arguments
121
+ init_predictor(task_type="i2v", gpu_num=1)
122
+ demo = create_gradio_interface("i2v")
123
+ demo.launch()