File size: 4,668 Bytes
15c34d2
88f09e9
15c34d2
 
 
 
 
 
88f09e9
2605810
 
f4d0dd9
88f09e9
15c34d2
 
 
5ad8ab6
 
15c34d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a6a259b
15c34d2
 
 
 
a6a259b
15c34d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d9a404
 
 
 
 
 
 
 
 
 
15c34d2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import spaces
import gradio as gr
import torch
from diffusers import LTXPipeline
import uuid
import time
import types
from typing import Optional

pipe = LTXPipeline.from_pretrained("a-r-r-o-w/LTX-Video-0.9.1-diffusers", torch_dtype=torch.bfloat16)
# pipe = LTXPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16)
pipe.enable_sequential_cpu_offload()


# pipe.vae.decode = vae_decode

HEIGHT = 512
WIDTH = 768
N_FRAME = 161
N_AVG_FRAME = 2

negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"

prepare_latents_original = pipe.prepare_latents

# unpack will have shape B, C, F, H, W with F, H, W are in latent dim

# def prepare_latents_loop(*args, **kwargs):
#     packed_latents = prepare_latents_original(*args, **kwargs)
#     unpacked_latents = pipe._unpack_latents(packed_latents, (N_FRAME-1)//8+1, HEIGHT//32, WIDTH//32, 1, 1)
#     # now average the first n and last n frames
#     last_n = unpacked_latents[:, :, -N_AVG_FRAME:, :, :]
#     # 0,1,2,3,4, roll -1 => 1,2,3,4,0
#     # last n: [3, 4]
#     # last_next_n: [4, 0]
#     # then 3 will be 0.75*3 + 0.25*4, and 4 will be 0.75*4+0.25*0
#     last_next_n = torch.roll(unpacked_latents, shifts=-1, dims=2)[:, :, -N_AVG_FRAME:, :, :]
#     avg_n = last_n * 0.75 + last_next_n * 0.25
#     unpacked_latents[:, :, -N_AVG_FRAME:, :, :] = avg_n
#     # pack the latents back
#     packed_latents = pipe._pack_latents(unpacked_latents)
#     return packed_latents

# pipe.prepare_latents = prepare_latents_loop

def modify_latents_callback(pipeline, step, timestep, callback_kwargs):
    print("Rolling latents on step", step)
    latents = callback_kwargs.get("latents")
    unpacked_latents = pipeline._unpack_latents(latents, (N_FRAME-1)//8+1, HEIGHT//32, WIDTH//32, 1, 1)
    # the frame order after each denoising step will be 0,1,2 -> 2,0,1 -> 1,2,0 -> 0,1,2 ...
    modified_latents = torch.roll(unpacked_latents, shifts=1, dims=2)
    modified_latents = pipeline._pack_latents(modified_latents)
    return {"latents": modified_latents}

@spaces.GPU(duration=140)
def generate_gif(prompt, use_fixed_seed):
    seed = 0 if use_fixed_seed else torch.seed()
    generator = torch.Generator(device="cuda").manual_seed(seed)

    output = pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,
        width=WIDTH,
        height=HEIGHT,
        num_frames=N_FRAME,
        num_inference_steps=50,
        decode_timestep=0.03,
        decode_noise_scale=0.025,
        generator=generator,
        callback_on_step_end=modify_latents_callback,
    ).frames[0]

    gif_path = f"/tmp/{uuid.uuid4().hex}.gif"

    bef = time.time()
    # imageio.mimsave(gif_path, output, format="GIF", fps=24, loop=0)
    gif_path = f"/tmp/{uuid.uuid4().hex}.webp"
    output[0].save(gif_path, format="WebP", save_all=True, append_images=output[1:], duration=1000/24, loop=0)
    print("GIF creation time:", time.time() - bef)
    return gif_path

with gr.Blocks() as demo:
    gr.Markdown("## LTX Video → Looping GIF Generator")
    with gr.Row():
        with gr.Column():
            prompt_input = gr.Textbox(label="Prompt", lines=4)
            use_fixed_seed = gr.Checkbox(label="Use Fixed Seed", value=True)
            generate_btn = gr.Button("Generate")
        with gr.Column():
            gif_output = gr.Image(label="Looping GIF Result", type="filepath")

    generate_btn.click(
        fn=generate_gif,
        inputs=[prompt_input, use_fixed_seed],
        outputs=gif_output,
        concurrency_limit=1
    )

    gr.Dataset(
        components=[prompt_input, use_fixed_seed, gif_output],
        samples=[
            ["A woman with light skin, wearing a blue jacket and a black hat with a veil, looks down and to her right, then back up as she speaks; she has brown hair styled in an updo, light brown eyebrows, and is wearing a white collared shirt under her jacket; the camera remains stationary on her face as she speaks; the background is out of focus, but shows trees and people in period clothing; the scene is captured in real-life footage.", False, "examples/woman.webp"],
            ["A sleek white car skids into a narrow alley on wet pavement, its rear tires releasing a thin cloud of smoky exhaust as it accelerates past flickering neon signs. Rain cascades from the eaves. The camera swoops low and follows behind, capturing towering skyscraper reflections in puddles and the car’s headlights. Lightning flashes overhead, intensifying the gritty atmosphere", True, "examples/car.webp"]
        ],
        label="Example Inputs and Outputs",
        type="index",
    )

demo.queue(max_size=5)
demo.launch(share=True)