Upload 2 files
Browse files- app.py +100 -4
- requirements.txt +6 -0
app.py
CHANGED
@@ -1,7 +1,103 @@
|
|
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
|
|
|
|
|
5 |
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import spaces
|
2 |
import gradio as gr
|
3 |
+
import torch
|
4 |
+
from diffusers import LTXPipeline
|
5 |
+
import uuid
|
6 |
+
import time
|
7 |
+
import types
|
8 |
+
from typing import Optional
|
9 |
|
10 |
+
pipe = LTXPipeline.from_pretrained("a-r-r-o-w/LTX-Video-0.9.1-diffusers", torch_dtype=torch.bfloat16)
|
11 |
+
pipe.to("cuda")
|
12 |
+
pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead", fullgraph=True)
|
13 |
+
pipe.vae = torch.compile(pipe.vae, mode="reduce-overhead", fullgraph=True)
|
14 |
|
15 |
+
|
16 |
+
# pipe.vae.decode = vae_decode
|
17 |
+
|
18 |
+
HEIGHT = 480
|
19 |
+
WIDTH = 640
|
20 |
+
N_FRAME = 161
|
21 |
+
N_AVG_FRAME = 2
|
22 |
+
|
23 |
+
negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
|
24 |
+
|
25 |
+
prepare_latents_original = pipe.prepare_latents
|
26 |
+
|
27 |
+
# unpack will have shape B, C, F, H, W with F, H, W are in latent dim
|
28 |
+
|
29 |
+
# def prepare_latents_loop(*args, **kwargs):
|
30 |
+
# packed_latents = prepare_latents_original(*args, **kwargs)
|
31 |
+
# unpacked_latents = pipe._unpack_latents(packed_latents, (N_FRAME-1)//8+1, HEIGHT//32, WIDTH//32, 1, 1)
|
32 |
+
# # now average the first n and last n frames
|
33 |
+
# last_n = unpacked_latents[:, :, -N_AVG_FRAME:, :, :]
|
34 |
+
# # 0,1,2,3,4, roll -1 => 1,2,3,4,0
|
35 |
+
# # last n: [3, 4]
|
36 |
+
# # last_next_n: [4, 0]
|
37 |
+
# # then 3 will be 0.75*3 + 0.25*4, and 4 will be 0.75*4+0.25*0
|
38 |
+
# last_next_n = torch.roll(unpacked_latents, shifts=-1, dims=2)[:, :, -N_AVG_FRAME:, :, :]
|
39 |
+
# avg_n = last_n * 0.75 + last_next_n * 0.25
|
40 |
+
# unpacked_latents[:, :, -N_AVG_FRAME:, :, :] = avg_n
|
41 |
+
# # pack the latents back
|
42 |
+
# packed_latents = pipe._pack_latents(unpacked_latents)
|
43 |
+
# return packed_latents
|
44 |
+
|
45 |
+
# pipe.prepare_latents = prepare_latents_loop
|
46 |
+
|
47 |
+
# with the shift it will become step=0 0,1,2,3 -> step=1 1,2,3,0 -> step=2 2,3,0,1 -> step=3 3,0,1,2 -> step=4 0,1,2,3
|
48 |
+
# so we only shift (N_FRAME-1)//8+1 times
|
49 |
+
|
50 |
+
def modify_latents_callback(pipeline, step, timestep, callback_kwargs):
|
51 |
+
print("Rolling latents on step", step)
|
52 |
+
latents = callback_kwargs.get("latents")
|
53 |
+
unpacked_latents = pipeline._unpack_latents(latents, (N_FRAME-1)//8+1, HEIGHT//32, WIDTH//32, 1, 1)
|
54 |
+
modified_latents = torch.roll(unpacked_latents, shifts=1, dims=2)
|
55 |
+
modified_latents = pipeline._pack_latents(modified_latents)
|
56 |
+
return {"latents": modified_latents}
|
57 |
+
|
58 |
+
@spaces.GPU(duration=120)
|
59 |
+
def generate_gif(prompt, use_fixed_seed):
|
60 |
+
seed = 0 if use_fixed_seed else torch.seed()
|
61 |
+
generator = torch.Generator(device="cuda").manual_seed(seed)
|
62 |
+
|
63 |
+
output = pipe(
|
64 |
+
prompt=prompt,
|
65 |
+
negative_prompt=negative_prompt,
|
66 |
+
width=WIDTH,
|
67 |
+
height=HEIGHT,
|
68 |
+
num_frames=N_FRAME,
|
69 |
+
num_inference_steps=50,
|
70 |
+
decode_timestep=0.03,
|
71 |
+
decode_noise_scale=0.025,
|
72 |
+
generator=generator,
|
73 |
+
callback_on_step_end=modify_latents_callback,
|
74 |
+
).frames[0]
|
75 |
+
|
76 |
+
gif_path = f"/tmp/{uuid.uuid4().hex}.gif"
|
77 |
+
|
78 |
+
bef = time.time()
|
79 |
+
# imageio.mimsave(gif_path, output, format="GIF", fps=24, loop=0)
|
80 |
+
gif_path = f"/tmp/{uuid.uuid4().hex}.webp"
|
81 |
+
output[0].save(gif_path, format="WebP", save_all=True, append_images=output[1:], duration=1000/24, loop=0)
|
82 |
+
print("GIF creation time:", time.time() - bef)
|
83 |
+
return gif_path
|
84 |
+
|
85 |
+
with gr.Blocks() as demo:
|
86 |
+
gr.Markdown("## LTX Video → Looping GIF Generator")
|
87 |
+
with gr.Row():
|
88 |
+
with gr.Column():
|
89 |
+
prompt_input = gr.Textbox(label="Prompt", lines=4)
|
90 |
+
use_fixed_seed = gr.Checkbox(label="Use Fixed Seed", value=True)
|
91 |
+
generate_btn = gr.Button("Generate")
|
92 |
+
with gr.Column():
|
93 |
+
gif_output = gr.Image(label="Looping GIF Result", type="filepath")
|
94 |
+
|
95 |
+
generate_btn.click(
|
96 |
+
fn=generate_gif,
|
97 |
+
inputs=[prompt_input, use_fixed_seed],
|
98 |
+
outputs=gif_output,
|
99 |
+
concurrency_limit=1
|
100 |
+
)
|
101 |
+
|
102 |
+
demo.queue(max_size=5)
|
103 |
+
demo.launch(share=True)
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
diffusers==0.32.2
|
3 |
+
transformers
|
4 |
+
sentencepiece
|
5 |
+
torch
|
6 |
+
accelerate
|