Spaces:
Running
Running
File size: 5,287 Bytes
ec8f341 118eb66 fdfb611 77f9377 fdfb611 3704245 fdfb611 6b67faa 118eb66 3704245 6b67faa 3704245 fdfb611 3704245 fdfb611 3704245 77f9377 3704245 77f9377 5a00166 77f9377 118eb66 6b67faa d76cb95 5a00166 fdfb611 77f9377 6b67faa 77f9377 3704245 6b67faa 3704245 77f9377 fdfb611 77f9377 3704245 fdfb611 3704245 6b67faa 118eb66 fdfb611 3704245 118eb66 fdfb611 3704245 fdfb611 e8ecd28 ec8f341 118eb66 e8ecd28 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import spaces
import torch
import gradio as gr
from diffusers import CogVideoXPipeline
from diffusers.utils import export_to_video
from PIL import Image
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 1. Load & optimize the CogVideoX pipeline with CPU offload
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
pipe = CogVideoXPipeline.from_pretrained(
"THUDM/CogVideoX1.5-5B",
torch_dtype=torch.bfloat16
)
pipe.enable_model_cpu_offload() # auto move submodules between CPU/GPU
pipe.vae.enable_slicing() # slice VAE for extra VRAM savings
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 2. Resolution parsing & sanitization
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def make_divisible_by_8(x: int) -> int:
return (x // 8) * 8
def parse_resolution(res_str: str):
"""
Convert strings like "480p" into (height, width) both divisible by 8
while preserving ~16:9 aspect ratio.
"""
h = int(res_str.rstrip("p"))
w = int(h * 16 / 9)
return make_divisible_by_8(h), make_divisible_by_8(w)
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 3. GPUβdecorated video generation function
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@spaces.GPU(duration=180) # allow up to 180s of GPU time
def generate_video(
prompt: str,
steps: int,
frames: int,
fps: int,
resolution: str
) -> str:
# 3.1 Determine target resolution and native resolution
target_h, target_w = parse_resolution(resolution)
# 3.2 Run the diffusion pipeline at native resolution
output = pipe(
prompt=prompt,
num_inference_steps=steps,
num_frames=frames,
)
video_frames = output.frames[0] # list of PIL Images at native size
# 3.3 Resize frames to user-specified resolution
resized_frames = [
frame.resize((target_w, target_h), Image.LANCZOS)
for frame in video_frames
]
# 3.4 Export to MP4 (H.264) with chosen FPS
video_path = export_to_video(resized_frames, "generated.mp4", fps=fps)
return video_path
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 4. Build the Gradio interface with interactive controls
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
with gr.Blocks(title="Textual Imagination: A text to video synthesis") as demo:
gr.Markdown(
"""
# ποΈ Textual Imagination: A text to video synthesis
Generate videos from text prompts.
Adjust inference steps, frame count, fps, and resolution below.
"""
)
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(
label="Prompt",
lines=2
)
steps_slider = gr.Slider(
minimum=1, maximum=100, step=1, value=50,
label="Inference Steps"
)
frames_slider = gr.Slider(
minimum=16, maximum=320, step=1, value=161,
label="Total Frames"
)
fps_slider = gr.Slider(
minimum=1, maximum=60, step=1, value=16,
label="Frames per Second (FPS)"
)
res_dropdown = gr.Dropdown(
choices=["360p", "480p", "720p", "1080p"],
value="480p",
label="Resolution"
)
gen_button = gr.Button("Generate Video")
with gr.Column():
video_output = gr.Video(
label="Generated Video",
format="mp4"
)
gen_button.click(
fn=generate_video,
inputs=[prompt_input, steps_slider, frames_slider, fps_slider, res_dropdown],
outputs=video_output
)
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 5. Launch: disable SSR so Gradio blocks and stays alive
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
ssr_mode=False
)
|