Varshitha2317 commited on
Commit
0a82f14
·
verified ·
1 Parent(s): 7547a78

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +186 -0
app.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ import torch
4
+ import torchvision as tv
5
+ import random, os
6
+ from diffusers import StableVideoDiffusionPipeline
7
+ from PIL import Image
8
+ from glob import glob
9
+ from typing import Optional
10
+
11
+ from tdd_svd_scheduler import TDDSVDStochasticIterativeScheduler
12
+ from utils import load_lora_weights, save_video
13
+
14
+ # LOCAL = True
15
+ LOCAL = False
16
+
17
+ if LOCAL:
18
+ svd_path = '/share2/duanyuxuan/diff_playground/diffusers_models/stable-video-diffusion-img2vid-xt-1-1'
19
+ lora_file_path = '/share2/duanyuxuan/diff_playground/SVD-TDD/svd-xt-1-1_tdd_lora_weights.safetensors'
20
+ else:
21
+ svd_path = 'stabilityai/stable-video-diffusion-img2vid-xt-1-1'
22
+ lora_repo_path = 'RED-AIGC/TDD'
23
+ lora_weight_name = 'svd-xt-1-1_tdd_lora_weights.safetensors'
24
+
25
+ if torch.cuda.is_available():
26
+ noise_scheduler = TDDSVDStochasticIterativeScheduler(num_train_timesteps = 250, sigma_min = 0.002, sigma_max = 700.0, sigma_data = 1.0,
27
+ s_noise = 1.0, rho = 7, clip_denoised = False)
28
+
29
+ pipeline = StableVideoDiffusionPipeline.from_pretrained(svd_path, scheduler = noise_scheduler, torch_dtype = torch.float16, variant = "fp16").to('cuda')
30
+ if LOCAL:
31
+ load_lora_weights(pipeline.unet, lora_file_path)
32
+ else:
33
+ load_lora_weights(pipeline.unet, lora_repo_path, weight_name = lora_weight_name)
34
+
35
+ max_64_bit_int = 2**63 - 1
36
+
37
+ @spaces.GPU
38
+ def sample(
39
+ image: Image,
40
+ seed: Optional[int] = 1,
41
+ randomize_seed: bool = False,
42
+ num_inference_steps: int = 4,
43
+ eta: float = 0.3,
44
+ min_guidance_scale: float = 1.0,
45
+ max_guidance_scale: float = 1.0,
46
+ fps: int = 7,
47
+ width: int = 512,
48
+ height: int = 512,
49
+ num_frames: int = 25,
50
+ motion_bucket_id: int = 127,
51
+ output_folder: str = "outputs_gradio",
52
+ ):
53
+ pipeline.scheduler.set_eta(eta)
54
+
55
+ if randomize_seed:
56
+ seed = random.randint(0, max_64_bit_int)
57
+ generator = torch.manual_seed(seed)
58
+
59
+ os.makedirs(output_folder, exist_ok=True)
60
+ base_count = len(glob(os.path.join(output_folder, "*.mp4")))
61
+ video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
62
+
63
+ with torch.autocast("cuda"):
64
+ frames = pipeline(
65
+ image, height = height, width = width,
66
+ num_inference_steps = num_inference_steps,
67
+ min_guidance_scale = min_guidance_scale,
68
+ max_guidance_scale = max_guidance_scale,
69
+ num_frames = num_frames, fps = fps, motion_bucket_id = motion_bucket_id,
70
+ decode_chunk_size = 8,
71
+ noise_aug_strength = 0.02,
72
+ generator = generator,
73
+ ).frames[0]
74
+ save_video(frames, video_path, fps = fps, quality = 5.0)
75
+ torch.manual_seed(seed)
76
+
77
+ return video_path, seed
78
+
79
+
80
+ def preprocess_image(image, height = 512, width = 512):
81
+ image = image.convert('RGB')
82
+ if image.size[0] != image.size[1]:
83
+ image = tv.transforms.functional.pil_to_tensor(image)
84
+ image = tv.transforms.functional.center_crop(image, min(image.shape[-2:]))
85
+ image = tv.transforms.functional.to_pil_image(image)
86
+ image = image.resize((width, height))
87
+ return image
88
+
89
+ css = """
90
+ h1 {
91
+ text-align: center;
92
+ display:block;
93
+ }
94
+ .gradio-container {
95
+ max-width: 70.5rem !important;
96
+ }
97
+ """
98
+
99
+ with gr.Blocks(css = css) as demo:
100
+ gr.Markdown(
101
+ """
102
+ # Stable Video Diffusion distilled by ✨Target-Driven Distillation✨
103
+ Target-Driven Distillation (TDD) is a state-of-the-art consistency distillation model that largely accelerates the inference processes of diffusion models. Using its delicate strategies of *target timestep selection* and *decoupled guidance*, models distilled by TDD can generated highly detailed images with only a few steps.
104
+ Besides, TDD is also available for distilling video generation models. This space presents TDD-distilled [SVD-xt 1.1](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt-1-1).
105
+ [**Project Page**](https://redaigc.github.io/TDD/) **|** [**Paper**](https://arxiv.org/abs/2409.01347) **|** [**Code**](https://github.com/RedAIGC/Target-Driven-Distillation) **|** [**Model**](https://huggingface.co/RED-AIGC/TDD) **|** [🤗 **TDD-SDXL Demo**](https://huggingface.co/spaces/RED-AIGC/TDD) **|** [🤗 **TDD-SVD Demo**](https://huggingface.co/spaces/RED-AIGC/SVD-TDD)
106
+ The codes of this space are built on [AnimateLCM-SVD](https://huggingface.co/spaces/wangfuyun/AnimateLCM-SVD) and we acknowledge their contribution.
107
+ """
108
+ )
109
+ with gr.Row():
110
+ with gr.Column():
111
+ image = gr.Image(label="Upload your image", type="pil")
112
+ generate_btn = gr.Button("Generate")
113
+ video = gr.Video()
114
+ with gr.Accordion("Options", open = True):
115
+ seed = gr.Slider(
116
+ label="Seed",
117
+ value=1,
118
+ randomize=False,
119
+ minimum=0,
120
+ maximum=max_64_bit_int,
121
+ step=1,
122
+ )
123
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
124
+ min_guidance_scale = gr.Slider(
125
+ label="Min guidance scale",
126
+ info="min strength of classifier-free guidance",
127
+ value=1.0,
128
+ minimum=1.0,
129
+ maximum=1.5,
130
+ )
131
+ max_guidance_scale = gr.Slider(
132
+ label="Max guidance scale",
133
+ info="max strength of classifier-free guidance, it should not be less than Min guidance scale",
134
+ value=1.0,
135
+ minimum=1.0,
136
+ maximum=3.0,
137
+ )
138
+ num_inference_steps = gr.Slider(
139
+ label="Num inference steps",
140
+ info="steps for inference",
141
+ value=4,
142
+ minimum=4,
143
+ maximum=8,
144
+ step=1,
145
+ )
146
+ eta = gr.Slider(
147
+ label = "Eta",
148
+ info = "the value of gamma in gamma-sampling",
149
+ value = 0.3,
150
+ minimum = 0.0,
151
+ maximum = 1.0,
152
+ step = 0.1,
153
+ )
154
+
155
+ image.upload(fn = preprocess_image, inputs = image, outputs = image, queue = False)
156
+ generate_btn.click(
157
+ fn = sample,
158
+ inputs = [
159
+ image,
160
+ seed,
161
+ randomize_seed,
162
+ num_inference_steps,
163
+ eta,
164
+ min_guidance_scale,
165
+ max_guidance_scale,
166
+ ],
167
+ outputs = [video, seed],
168
+ api_name = "video",
169
+ )
170
+ # safetensors_dropdown.change(fn=model_select, inputs=safetensors_dropdown)
171
+
172
+ # gr.Examples(
173
+ # examples=[
174
+ # ["examples/ipadapter_cat.jpg"],
175
+ # ],
176
+ # inputs=[image],
177
+ # outputs=[video, seed],
178
+ # fn=sample,
179
+ # cache_examples=True,
180
+ # )
181
+
182
+ if __name__ == "__main__":
183
+ if LOCAL:
184
+ demo.queue().launch(share=True, server_name='0.0.0.0')
185
+ else:
186
+ demo.queue(api_open=False).launch(show_api=False)