imthanhlv commited on
Commit
15c34d2
·
verified ·
1 Parent(s): 88f09e9

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +100 -4
  2. requirements.txt +6 -0
app.py CHANGED
@@ -1,7 +1,103 @@
 
1
  import gradio as gr
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
  import gradio as gr
3
+ import torch
4
+ from diffusers import LTXPipeline
5
+ import uuid
6
+ import time
7
+ import types
8
+ from typing import Optional
9
 
10
+ pipe = LTXPipeline.from_pretrained("a-r-r-o-w/LTX-Video-0.9.1-diffusers", torch_dtype=torch.bfloat16)
11
+ pipe.to("cuda")
12
+ pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead", fullgraph=True)
13
+ pipe.vae = torch.compile(pipe.vae, mode="reduce-overhead", fullgraph=True)
14
 
15
+
16
+ # pipe.vae.decode = vae_decode
17
+
18
+ HEIGHT = 480
19
+ WIDTH = 640
20
+ N_FRAME = 161
21
+ N_AVG_FRAME = 2
22
+
23
+ negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
24
+
25
+ prepare_latents_original = pipe.prepare_latents
26
+
27
+ # unpack will have shape B, C, F, H, W with F, H, W are in latent dim
28
+
29
+ # def prepare_latents_loop(*args, **kwargs):
30
+ # packed_latents = prepare_latents_original(*args, **kwargs)
31
+ # unpacked_latents = pipe._unpack_latents(packed_latents, (N_FRAME-1)//8+1, HEIGHT//32, WIDTH//32, 1, 1)
32
+ # # now average the first n and last n frames
33
+ # last_n = unpacked_latents[:, :, -N_AVG_FRAME:, :, :]
34
+ # # 0,1,2,3,4, roll -1 => 1,2,3,4,0
35
+ # # last n: [3, 4]
36
+ # # last_next_n: [4, 0]
37
+ # # then 3 will be 0.75*3 + 0.25*4, and 4 will be 0.75*4+0.25*0
38
+ # last_next_n = torch.roll(unpacked_latents, shifts=-1, dims=2)[:, :, -N_AVG_FRAME:, :, :]
39
+ # avg_n = last_n * 0.75 + last_next_n * 0.25
40
+ # unpacked_latents[:, :, -N_AVG_FRAME:, :, :] = avg_n
41
+ # # pack the latents back
42
+ # packed_latents = pipe._pack_latents(unpacked_latents)
43
+ # return packed_latents
44
+
45
+ # pipe.prepare_latents = prepare_latents_loop
46
+
47
+ # with the shift it will become step=0 0,1,2,3 -> step=1 1,2,3,0 -> step=2 2,3,0,1 -> step=3 3,0,1,2 -> step=4 0,1,2,3
48
+ # so we only shift (N_FRAME-1)//8+1 times
49
+
50
+ def modify_latents_callback(pipeline, step, timestep, callback_kwargs):
51
+ print("Rolling latents on step", step)
52
+ latents = callback_kwargs.get("latents")
53
+ unpacked_latents = pipeline._unpack_latents(latents, (N_FRAME-1)//8+1, HEIGHT//32, WIDTH//32, 1, 1)
54
+ modified_latents = torch.roll(unpacked_latents, shifts=1, dims=2)
55
+ modified_latents = pipeline._pack_latents(modified_latents)
56
+ return {"latents": modified_latents}
57
+
58
+ @spaces.GPU(duration=120)
59
+ def generate_gif(prompt, use_fixed_seed):
60
+ seed = 0 if use_fixed_seed else torch.seed()
61
+ generator = torch.Generator(device="cuda").manual_seed(seed)
62
+
63
+ output = pipe(
64
+ prompt=prompt,
65
+ negative_prompt=negative_prompt,
66
+ width=WIDTH,
67
+ height=HEIGHT,
68
+ num_frames=N_FRAME,
69
+ num_inference_steps=50,
70
+ decode_timestep=0.03,
71
+ decode_noise_scale=0.025,
72
+ generator=generator,
73
+ callback_on_step_end=modify_latents_callback,
74
+ ).frames[0]
75
+
76
+ gif_path = f"/tmp/{uuid.uuid4().hex}.gif"
77
+
78
+ bef = time.time()
79
+ # imageio.mimsave(gif_path, output, format="GIF", fps=24, loop=0)
80
+ gif_path = f"/tmp/{uuid.uuid4().hex}.webp"
81
+ output[0].save(gif_path, format="WebP", save_all=True, append_images=output[1:], duration=1000/24, loop=0)
82
+ print("GIF creation time:", time.time() - bef)
83
+ return gif_path
84
+
85
+ with gr.Blocks() as demo:
86
+ gr.Markdown("## LTX Video → Looping GIF Generator")
87
+ with gr.Row():
88
+ with gr.Column():
89
+ prompt_input = gr.Textbox(label="Prompt", lines=4)
90
+ use_fixed_seed = gr.Checkbox(label="Use Fixed Seed", value=True)
91
+ generate_btn = gr.Button("Generate")
92
+ with gr.Column():
93
+ gif_output = gr.Image(label="Looping GIF Result", type="filepath")
94
+
95
+ generate_btn.click(
96
+ fn=generate_gif,
97
+ inputs=[prompt_input, use_fixed_seed],
98
+ outputs=gif_output,
99
+ concurrency_limit=1
100
+ )
101
+
102
+ demo.queue(max_size=5)
103
+ demo.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio
2
+ diffusers==0.32.2
3
+ transformers
4
+ sentencepiece
5
+ torch
6
+ accelerate