1inkusFace commited on
Commit
227bc73
·
verified ·
1 Parent(s): c7cf3a4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -0
app.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ import argparse
4
+ import sys
5
+ import time
6
+ import os
7
+ import random
8
+ from skyreelsinfer import TaskType
9
+ from skyreelsinfer.offload import OffloadConfig
10
+ from skyreelsinfer.skyreels_video_infer import SkyReelsVideoInfer
11
+ from diffusers.utils import export_to_video
12
+ from diffusers.utils import load_image
13
+
14
+ predictor = None
15
+ task_type = None
16
+
17
+ def init_predictor():
18
+ global predictor
19
+ predictor = SkyReelsVideoInfer(
20
+ task_type= TaskType.I2V
21
+ model_id="Skywork/SkyReels-V1-Hunyuan-I2V",
22
+ quant_model=True,
23
+ world_size=gpu_num,
24
+ is_offload=True,
25
+ offload_config=OffloadConfig(
26
+ high_cpu_memory=True,
27
+ parameters_level=True,
28
+ compiler_transformer=False,
29
+ )
30
+ )
31
+
32
+ @spaces.GPU(duration=90)
33
+ def generate_video(prompt, seed, image=None):
34
+ global task_type
35
+ print(f"image:{type(image)}")
36
+ if seed == -1:
37
+ random.seed(time.time())
38
+ seed = int(random.randrange(4294967294))
39
+ kwargs = {
40
+ "prompt": prompt,
41
+ "height": 512,
42
+ "width": 512,
43
+ "num_frames": 97,
44
+ "num_inference_steps": 30,
45
+ "seed": seed,
46
+ "guidance_scale": 6.0,
47
+ "embedded_guidance_scale": 1.0,
48
+ "negative_prompt": "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion",
49
+ "cfg_for": False,
50
+ }
51
+ assert image is not None, "please input image"
52
+ kwargs["image"] = load_image(image=image)
53
+ global predictor
54
+ output = predictor.inference(kwargs)
55
+ save_dir = f"./result/{task_type}"
56
+ os.makedirs(save_dir, exist_ok=True)
57
+ video_out_file = f"{save_dir}/{prompt[:100].replace('/','')}_{seed}.mp4"
58
+ print(f"generate video, local path: {video_out_file}")
59
+ export_to_video(output, video_out_file, fps=24)
60
+ return video_out_file, kwargs
61
+
62
+ def create_gradio_interface():
63
+ with gr.Blocks() as demo:
64
+ with gr.Row():
65
+ image = gr.Image(label="Upload Image", type="filepath")
66
+ prompt = gr.Textbox(label="Input Prompt")
67
+ seed = gr.Number(label="Random Seed", value=-1)
68
+ submit_button = gr.Button("Generate Video")
69
+ output_video = gr.Video(label="Generated Video")
70
+ output_params = gr.Textbox(label="Output Parameters")
71
+ submit_button.click(
72
+ fn=generate_video,
73
+ inputs=[prompt, seed, image],
74
+ outputs=[output_video, output_params],
75
+ )
76
+ return demo
77
+
78
+ if __name__ == "__main__":
79
+ init_predictor()
80
+ demo = create_gradio_interface()
81
+ demo.launch()