File size: 5,287 Bytes
ec8f341
118eb66
 
fdfb611
 
77f9377
fdfb611
3704245
 
 
fdfb611
6b67faa
118eb66
 
3704245
 
6b67faa
3704245
 
 
 
 
 
 
fdfb611
3704245
 
fdfb611
3704245
 
 
 
 
 
 
77f9377
3704245
 
 
 
 
 
 
77f9377
 
5a00166
77f9377
118eb66
6b67faa
d76cb95
5a00166
fdfb611
77f9377
6b67faa
77f9377
 
 
 
 
 
 
 
3704245
6b67faa
3704245
 
 
77f9377
fdfb611
 
77f9377
 
3704245
fdfb611
 
3704245
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b67faa
118eb66
fdfb611
3704245
118eb66
fdfb611
 
3704245
 
 
fdfb611
e8ecd28
ec8f341
 
118eb66
e8ecd28
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import spaces
import torch
import gradio as gr
from diffusers import CogVideoXPipeline
from diffusers.utils import export_to_video
from PIL import Image

# ────────────────────────────────────────────────────────────
# 1. Load & optimize the CogVideoX pipeline with CPU offload
# ────────────────────────────────────────────────────────────
pipe = CogVideoXPipeline.from_pretrained(
    "THUDM/CogVideoX1.5-5B",
    torch_dtype=torch.bfloat16
)
pipe.enable_model_cpu_offload()   # auto move submodules between CPU/GPU
pipe.vae.enable_slicing()         # slice VAE for extra VRAM savings

# ────────────────────────────────────────────────────────────
# 2. Resolution parsing & sanitization
# ────────────────────────────────────────────────────────────
def make_divisible_by_8(x: int) -> int:
    return (x // 8) * 8

def parse_resolution(res_str: str):
    """
    Convert strings like "480p" into (height, width) both divisible by 8
    while preserving ~16:9 aspect ratio.
    """
    h = int(res_str.rstrip("p"))
    w = int(h * 16 / 9)
    return make_divisible_by_8(h), make_divisible_by_8(w)

# ────────────────────────────────────────────────────────────
# 3. GPU‑decorated video generation function
# ────────────────────────────────────────────────────────────
@spaces.GPU(duration=180)  # allow up to 180s of GPU time
def generate_video(
    prompt: str,
    steps: int,
    frames: int,
    fps: int,
    resolution: str
) -> str:
    # 3.1 Determine target resolution and native resolution
    target_h, target_w = parse_resolution(resolution)

    # 3.2 Run the diffusion pipeline at native resolution
    output = pipe(
        prompt=prompt,
        num_inference_steps=steps,
        num_frames=frames,
    )
    video_frames = output.frames[0]  # list of PIL Images at native size

    # 3.3 Resize frames to user-specified resolution
    resized_frames = [
        frame.resize((target_w, target_h), Image.LANCZOS)
        for frame in video_frames
    ]

    # 3.4 Export to MP4 (H.264) with chosen FPS
    video_path = export_to_video(resized_frames, "generated.mp4", fps=fps)
    return video_path

# ────────────────────────────────────────────────────────────
# 4. Build the Gradio interface with interactive controls
# ────────────────────────────────────────────────────────────
with gr.Blocks(title="Textual Imagination: A text to video synthesis") as demo:
    gr.Markdown(
        """
        # 🎞️ Textual Imagination: A text to video synthesis
        Generate videos from text prompts.
        Adjust inference steps, frame count, fps, and resolution below.
        """
    )
    with gr.Row():
        with gr.Column():
            prompt_input = gr.Textbox(
                label="Prompt",
                lines=2
            )
            steps_slider = gr.Slider(
                minimum=1, maximum=100, step=1, value=50,
                label="Inference Steps"
            )
            frames_slider = gr.Slider(
                minimum=16, maximum=320, step=1, value=161,
                label="Total Frames"
            )
            fps_slider = gr.Slider(
                minimum=1, maximum=60, step=1, value=16,
                label="Frames per Second (FPS)"
            )
            res_dropdown = gr.Dropdown(
                choices=["360p", "480p", "720p", "1080p"],
                value="480p",
                label="Resolution"
            )
            gen_button = gr.Button("Generate Video")
        with gr.Column():
            video_output = gr.Video(
                label="Generated Video",
                format="mp4"
            )

    gen_button.click(
        fn=generate_video,
        inputs=[prompt_input, steps_slider, frames_slider, fps_slider, res_dropdown],
        outputs=video_output
    )

# ────────────────────────────────────────────────────────────
# 5. Launch: disable SSR so Gradio blocks and stays alive
# ────────────────────────────────────────────────────────────
if __name__ == "__main__":
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        ssr_mode=False
    )