lorocksUMD commited on
Commit
0ce559f
·
verified ·
1 Parent(s): 709a33a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +184 -140
app.py CHANGED
@@ -1,154 +1,198 @@
1
- import gradio as gr
2
- import numpy as np
3
- import random
4
 
5
- # import spaces #[uncomment to use ZeroGPU]
6
- from diffusers import DiffusionPipeline
7
  import torch
 
 
 
 
 
 
 
 
 
 
8
 
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
11
 
12
- if torch.cuda.is_available():
13
- torch_dtype = torch.float16
 
 
14
  else:
15
- torch_dtype = torch.float32
16
-
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
- pipe = pipe.to(device)
19
-
20
- MAX_SEED = np.iinfo(np.int32).max
21
- MAX_IMAGE_SIZE = 1024
22
-
23
-
24
- # @spaces.GPU #[uncomment to use ZeroGPU]
25
- def infer(
26
- prompt,
27
- negative_prompt,
28
- seed,
29
- randomize_seed,
30
- width,
31
- height,
32
- guidance_scale,
33
- num_inference_steps,
34
- progress=gr.Progress(track_tqdm=True),
35
- ):
36
- if randomize_seed:
37
- seed = random.randint(0, MAX_SEED)
38
-
39
- generator = torch.Generator().manual_seed(seed)
40
-
41
- image = pipe(
42
- prompt=prompt,
43
- negative_prompt=negative_prompt,
44
- guidance_scale=guidance_scale,
45
- num_inference_steps=num_inference_steps,
46
- width=width,
47
- height=height,
48
- generator=generator,
49
- ).images[0]
50
-
51
- return image, seed
52
-
53
-
54
- examples = [
55
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
56
- "An astronaut riding a green horse",
57
- "A delicious ceviche cheesecake slice",
58
- ]
59
-
60
- css = """
61
- #col-container {
62
- margin: 0 auto;
63
- max-width: 640px;
64
- }
65
- """
66
 
67
- with gr.Blocks(css=css) as demo:
68
- with gr.Column(elem_id="col-container"):
69
- gr.Markdown(" # Text-to-Image Gradio Template")
70
 
71
- with gr.Row():
72
- prompt = gr.Text(
73
- label="Prompt",
74
- show_label=False,
75
- max_lines=1,
76
- placeholder="Enter your prompt",
77
- container=False,
78
- )
79
 
80
- run_button = gr.Button("Run", scale=0, variant="primary")
81
 
82
- result = gr.Image(label="Result", show_label=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
- with gr.Accordion("Advanced Settings", open=False):
85
- negative_prompt = gr.Text(
86
- label="Negative prompt",
87
- max_lines=1,
88
- placeholder="Enter a negative prompt",
89
- visible=False,
90
- )
91
 
92
- seed = gr.Slider(
93
- label="Seed",
94
- minimum=0,
95
- maximum=MAX_SEED,
96
- step=1,
97
- value=0,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  )
 
 
 
 
99
 
100
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
101
-
102
- with gr.Row():
103
- width = gr.Slider(
104
- label="Width",
105
- minimum=256,
106
- maximum=MAX_IMAGE_SIZE,
107
- step=32,
108
- value=1024, # Replace with defaults that work for your model
109
- )
110
-
111
- height = gr.Slider(
112
- label="Height",
113
- minimum=256,
114
- maximum=MAX_IMAGE_SIZE,
115
- step=32,
116
- value=1024, # Replace with defaults that work for your model
117
- )
118
-
119
- with gr.Row():
120
- guidance_scale = gr.Slider(
121
- label="Guidance scale",
122
- minimum=0.0,
123
- maximum=10.0,
124
- step=0.1,
125
- value=0.0, # Replace with defaults that work for your model
126
- )
127
-
128
- num_inference_steps = gr.Slider(
129
- label="Number of inference steps",
130
- minimum=1,
131
- maximum=50,
132
- step=1,
133
- value=2, # Replace with defaults that work for your model
134
- )
135
-
136
- gr.Examples(examples=examples, inputs=[prompt])
137
- gr.on(
138
- triggers=[run_button.click, prompt.submit],
139
- fn=infer,
140
- inputs=[
141
- prompt,
142
- negative_prompt,
143
- seed,
144
- randomize_seed,
145
- width,
146
- height,
147
- guidance_scale,
148
- num_inference_steps,
149
- ],
150
- outputs=[result, seed],
151
- )
152
 
153
- if __name__ == "__main__":
154
- demo.launch()
 
 
 
 
1
+ import csv
2
+ import os
3
+ import tempfile
4
 
5
+ import gradio as gr
6
+ import requests
7
  import torch
8
+ import torchvision
9
+ import torchvision.transforms as T
10
+ from PIL import Image
11
+ from featup.util import norm
12
+ from torchaudio.functional import resample
13
+
14
+ from denseav.train import LitAVAligner
15
+ from denseav.plotting import plot_attention_video, plot_2head_attention_video, plot_feature_video
16
+ from denseav.shared import norm, crop_to_divisor, blur_dim
17
+ from os.path import join
18
 
 
 
19
 
20
+ mode = "hf"
21
+
22
+ if mode == "local":
23
+ sample_videos_dir = "samples"
24
  else:
25
+ os.environ['TORCH_HOME'] = '/tmp/.cache'
26
+ os.environ['HF_HOME'] = '/tmp/.cache'
27
+ os.environ['HF_DATASETS_CACHE'] = '/tmp/.cache'
28
+ os.environ['TRANSFORMERS_CACHE'] = '/tmp/.cache'
29
+ os.environ['GRADIO_EXAMPLES_CACHE'] = '/tmp/gradio_cache'
30
+ sample_videos_dir = "/tmp/samples"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
 
 
 
32
 
33
+ def download_video(url, save_path):
34
+ response = requests.get(url)
35
+ with open(save_path, 'wb') as file:
36
+ file.write(response.content)
 
 
 
 
37
 
 
38
 
39
+ base_url = "https://marhamilresearch4.blob.core.windows.net/denseav-public/samples/"
40
+ sample_videos_urls = {
41
+ "puppies.mp4": base_url + "puppies.mp4",
42
+ "peppers.mp4": base_url + "peppers.mp4",
43
+ "boat.mp4": base_url + "boat.mp4",
44
+ "elephant2.mp4": base_url + "elephant2.mp4",
45
+
46
+ }
47
+
48
+ # Ensure the directory for sample videos exists
49
+ os.makedirs(sample_videos_dir, exist_ok=True)
50
+
51
+ # Download each sample video
52
+ for filename, url in sample_videos_urls.items():
53
+ save_path = os.path.join(sample_videos_dir, filename)
54
+ # Download the video if it doesn't already exist
55
+ if not os.path.exists(save_path):
56
+ print(f"Downloading {filename}...")
57
+ download_video(url, save_path)
58
+ else:
59
+ print(f"{filename} already exists. Skipping download.")
60
+
61
+ csv.field_size_limit(100000000)
62
+ options = ['language', "sound-language", "sound"]
63
+ load_size = 224
64
+ plot_size = 224
65
+
66
+ video_input = gr.Video(label="Choose a video to featurize", height=480)
67
+ model_option = gr.Radio(options, value="language", label='Choose a model')
68
+
69
+ video_output1 = gr.Video(label="Audio Video Attention", height=480)
70
+ video_output2 = gr.Video(label="Multi-Head Audio Video Attention (Only Availible for sound_and_language)",
71
+ height=480)
72
+ video_output3 = gr.Video(label="Visual Features", height=480)
73
+
74
+ models = {o: LitAVAligner.from_pretrained(f"mhamilton723/DenseAV-{o}") for o in options}
75
+
76
+
77
+ def process_video(video, model_option):
78
+ # model = models[model_option].cuda()
79
+ model = models[model_option]
80
+
81
+ original_frames, audio, info = torchvision.io.read_video(video, end_pts=10, pts_unit='sec')
82
+ sample_rate = 16000
83
+
84
+ if info["audio_fps"] != sample_rate:
85
+ audio = resample(audio, info["audio_fps"], sample_rate)
86
+ audio = audio[0].unsqueeze(0)
87
+
88
+ img_transform = T.Compose([
89
+ T.Resize(load_size, Image.BILINEAR),
90
+ lambda x: crop_to_divisor(x, 8),
91
+ lambda x: x.to(torch.float32) / 255,
92
+ norm])
93
+
94
+ frames = torch.cat([img_transform(f.permute(2, 0, 1)).unsqueeze(0) for f in original_frames], axis=0)
95
+
96
+ plotting_img_transform = T.Compose([
97
+ T.Resize(plot_size, Image.BILINEAR),
98
+ lambda x: crop_to_divisor(x, 8),
99
+ lambda x: x.to(torch.float32) / 255])
100
+
101
+ frames_to_plot = plotting_img_transform(original_frames.permute(0, 3, 1, 2))
102
+
103
+ with torch.no_grad():
104
+ # audio_feats = model.forward_audio({"audio": audio.cuda()})
105
+ audio_feats = model.forward_audio({"audio": audio})
106
+ audio_feats = {k: v.cpu() for k, v in audio_feats.items()}
107
+ # image_feats = model.forward_image({"frames": frames.unsqueeze(0).cuda()}, max_batch_size=2)
108
+ image_feats = model.forward_image({"frames": frames.unsqueeze(0)}, max_batch_size=2)
109
+ image_feats = {k: v.cpu() for k, v in image_feats.items()}
110
+
111
+ sim_by_head = model.sim_agg.get_pairwise_sims(
112
+ {**image_feats, **audio_feats},
113
+ raw=False,
114
+ agg_sim=False,
115
+ agg_heads=False
116
+ ).mean(dim=-2).cpu()
117
+
118
+ sim_by_head = blur_dim(sim_by_head, window=3, dim=-1)
119
+ print(sim_by_head.shape)
120
+
121
+ temp_video_path_1 = tempfile.mktemp(suffix='.mp4')
122
+
123
+ plot_attention_video(
124
+ sim_by_head,
125
+ frames_to_plot,
126
+ audio,
127
+ info["video_fps"],
128
+ sample_rate,
129
+ temp_video_path_1)
130
+
131
+ if model_option == "sound_and_language":
132
+ temp_video_path_2 = tempfile.mktemp(suffix='.mp4')
133
+
134
+ plot_2head_attention_video(
135
+ sim_by_head,
136
+ frames_to_plot,
137
+ audio,
138
+ info["video_fps"],
139
+ sample_rate,
140
+ temp_video_path_2)
141
+
142
+ else:
143
+ temp_video_path_2 = None
144
+
145
+ temp_video_path_3 = tempfile.mktemp(suffix='.mp4')
146
+ temp_video_path_4 = tempfile.mktemp(suffix='.mp4')
147
+
148
+ plot_feature_video(
149
+ image_feats["image_feats"].cpu(),
150
+ audio_feats['audio_feats'].cpu(),
151
+ frames_to_plot,
152
+ audio,
153
+ info["video_fps"],
154
+ sample_rate,
155
+ temp_video_path_3,
156
+ temp_video_path_4,
157
+ )
158
+ # return temp_video_path_1, temp_video_path_2, temp_video_path_3, temp_video_path_4
159
+
160
+ return temp_video_path_1, temp_video_path_2, temp_video_path_3
161
 
 
 
 
 
 
 
 
162
 
163
+ with gr.Blocks() as demo:
164
+ with gr.Column():
165
+ gr.Markdown("## Visualizing Sound and Language with DenseAV")
166
+ gr.Markdown(
167
+ "This demo allows you to explore the inner attention maps of DenseAV's dense multi-head contrastive operator.")
168
+ with gr.Row():
169
+ with gr.Column(scale=1):
170
+ model_option.render()
171
+ with gr.Column(scale=3):
172
+ video_input.render()
173
+ with gr.Row():
174
+ submit_button = gr.Button("Submit")
175
+ with gr.Row():
176
+ gr.Examples(
177
+ examples=[
178
+ [join(sample_videos_dir, "puppies.mp4"), "sound_and_language"],
179
+ [join(sample_videos_dir, "peppers.mp4"), "language"],
180
+ [join(sample_videos_dir, "elephant2.mp4"), "language"],
181
+ [join(sample_videos_dir, "boat.mp4"), "language"]
182
+
183
+ ],
184
+ inputs=[video_input, model_option]
185
  )
186
+ with gr.Row():
187
+ video_output1.render()
188
+ video_output2.render()
189
+ video_output3.render()
190
 
191
+ submit_button.click(fn=process_video, inputs=[video_input, model_option],
192
+ outputs=[video_output1, video_output2, video_output3])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
+
195
+ if mode == "local":
196
+ demo.launch(server_name="0.0.0.0", server_port=6006, debug=True)
197
+ else:
198
+ demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)