Spaces:
Running
on
Zero
Running
on
Zero
Antoni Bigata
commited on
Commit
·
7746897
1
Parent(s):
2fb3e22
addapt for zerogpu
Browse files
app.py
CHANGED
@@ -23,6 +23,7 @@ from inference_functions import (
|
|
23 |
)
|
24 |
from wordle_game import WordleGame
|
25 |
import torch.cuda.amp as amp # Import amp for mixed precision
|
|
|
26 |
|
27 |
|
28 |
# Set default tensor type to float16 for faster computation
|
@@ -96,10 +97,26 @@ def load_model(
|
|
96 |
return model
|
97 |
|
98 |
|
99 |
-
#
|
100 |
-
|
101 |
-
|
102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
vae_model = vae_model.half() # Convert to half precision
|
104 |
try:
|
105 |
vae_model = torch.compile(vae_model)
|
@@ -107,8 +124,7 @@ if torch.cuda.is_available():
|
|
107 |
except Exception as e:
|
108 |
print(f"Warning: Failed to compile vae_model: {e}")
|
109 |
|
110 |
-
hubert_model = HubertModel.from_pretrained("facebook/hubert-base-ls960").cuda()
|
111 |
-
if torch.cuda.is_available():
|
112 |
hubert_model = hubert_model.half() # Convert to half precision
|
113 |
try:
|
114 |
hubert_model = torch.compile(hubert_model)
|
@@ -116,13 +132,13 @@ if torch.cuda.is_available():
|
|
116 |
except Exception as e:
|
117 |
print(f"Warning: Failed to compile hubert_model: {e}")
|
118 |
|
119 |
-
wavlm_model = WavLM_wrapper(
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
).cuda()
|
125 |
-
|
126 |
wavlm_model = wavlm_model.half() # Convert to half precision
|
127 |
try:
|
128 |
wavlm_model = torch.compile(wavlm_model)
|
@@ -130,27 +146,23 @@ if torch.cuda.is_available():
|
|
130 |
except Exception as e:
|
131 |
print(f"Warning: Failed to compile wavlm_model: {e}")
|
132 |
|
133 |
-
landmarks_extractor = LandmarksExtractor()
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
|
145 |
-
|
146 |
-
|
147 |
-
os.path.dirname(__file__), "assets", "sample_video.mp4"
|
148 |
-
)
|
149 |
-
DEFAULT_AUDIO_PATH = os.path.join(
|
150 |
-
os.path.dirname(__file__), "assets", "sample_audio.wav"
|
151 |
-
)
|
152 |
|
153 |
|
|
|
154 |
@torch.no_grad()
|
155 |
def compute_video_embedding(video_reader, min_len):
|
156 |
"""Compute embeddings from video"""
|
@@ -200,6 +212,7 @@ def compute_video_embedding(video_reader, min_len):
|
|
200 |
return encoded, video_frames
|
201 |
|
202 |
|
|
|
203 |
@torch.no_grad()
|
204 |
def compute_hubert_embedding(raw_audio):
|
205 |
"""Compute embeddings from audio"""
|
@@ -246,6 +259,7 @@ def compute_hubert_embedding(raw_audio):
|
|
246 |
return audio_embeddings
|
247 |
|
248 |
|
|
|
249 |
@torch.no_grad()
|
250 |
def compute_wavlm_embedding(raw_audio):
|
251 |
"""Compute embeddings from audio"""
|
@@ -352,6 +366,7 @@ def extract_video_landmarks(video_frames):
|
|
352 |
return np.array(processed_landmarks)
|
353 |
|
354 |
|
|
|
355 |
@torch.no_grad()
|
356 |
def sample(
|
357 |
audio_list,
|
|
|
23 |
)
|
24 |
from wordle_game import WordleGame
|
25 |
import torch.cuda.amp as amp # Import amp for mixed precision
|
26 |
+
import spaces
|
27 |
|
28 |
|
29 |
# Set default tensor type to float16 for faster computation
|
|
|
97 |
return model
|
98 |
|
99 |
|
100 |
+
# Default media paths
|
101 |
+
DEFAULT_VIDEO_PATH = os.path.join(
|
102 |
+
os.path.dirname(__file__), "assets", "sample_video.mp4"
|
103 |
+
)
|
104 |
+
DEFAULT_AUDIO_PATH = os.path.join(
|
105 |
+
os.path.dirname(__file__), "assets", "sample_audio.wav"
|
106 |
+
)
|
107 |
+
|
108 |
+
|
109 |
+
@spaces.GPU(duration=60)
|
110 |
+
def load_all_models():
|
111 |
+
global \
|
112 |
+
keyframe_model, \
|
113 |
+
interpolation_model, \
|
114 |
+
vae_model, \
|
115 |
+
hubert_model, \
|
116 |
+
wavlm_model, \
|
117 |
+
landmarks_extractor
|
118 |
+
vae_model = VaeWrapper("video")
|
119 |
+
|
120 |
vae_model = vae_model.half() # Convert to half precision
|
121 |
try:
|
122 |
vae_model = torch.compile(vae_model)
|
|
|
124 |
except Exception as e:
|
125 |
print(f"Warning: Failed to compile vae_model: {e}")
|
126 |
|
127 |
+
hubert_model = HubertModel.from_pretrained("facebook/hubert-base-ls960").cuda()
|
|
|
128 |
hubert_model = hubert_model.half() # Convert to half precision
|
129 |
try:
|
130 |
hubert_model = torch.compile(hubert_model)
|
|
|
132 |
except Exception as e:
|
133 |
print(f"Warning: Failed to compile hubert_model: {e}")
|
134 |
|
135 |
+
wavlm_model = WavLM_wrapper(
|
136 |
+
model_size="Base+",
|
137 |
+
feed_as_frames=False,
|
138 |
+
merge_type="None",
|
139 |
+
model_path="/vol/paramonos2/projects/antoni/code/Personal/code_prep/keysync/pretrained_models/checkpoints/WavLM-Base+.pt",
|
140 |
+
).cuda()
|
141 |
+
|
142 |
wavlm_model = wavlm_model.half() # Convert to half precision
|
143 |
try:
|
144 |
wavlm_model = torch.compile(wavlm_model)
|
|
|
146 |
except Exception as e:
|
147 |
print(f"Warning: Failed to compile wavlm_model: {e}")
|
148 |
|
149 |
+
landmarks_extractor = LandmarksExtractor()
|
150 |
+
keyframe_model = load_model(
|
151 |
+
config="/vol/paramonos2/projects/antoni/code/Personal/code_prep/keysync/scripts/sampling/configs/keyframe.yaml",
|
152 |
+
ckpt="/vol/paramonos2/projects/antoni/code/Personal/code_prep/keysync/pretrained_models/checkpoints/keyframe_dub.pt",
|
153 |
+
)
|
154 |
+
interpolation_model = load_model(
|
155 |
+
config="/vol/paramonos2/projects/antoni/code/Personal/code_prep/keysync/scripts/sampling/configs/interpolation.yaml",
|
156 |
+
ckpt="/vol/paramonos2/projects/antoni/code/Personal/code_prep/keysync/pretrained_models/checkpoints/interpolation_dub.pt",
|
157 |
+
)
|
158 |
+
keyframe_model.en_and_decode_n_samples_a_time = 2
|
159 |
+
interpolation_model.en_and_decode_n_samples_a_time = 2
|
160 |
|
161 |
+
|
162 |
+
load_all_models()
|
|
|
|
|
|
|
|
|
|
|
163 |
|
164 |
|
165 |
+
@spaces.GPU(duration=60)
|
166 |
@torch.no_grad()
|
167 |
def compute_video_embedding(video_reader, min_len):
|
168 |
"""Compute embeddings from video"""
|
|
|
212 |
return encoded, video_frames
|
213 |
|
214 |
|
215 |
+
@spaces.GPU(duration=120)
|
216 |
@torch.no_grad()
|
217 |
def compute_hubert_embedding(raw_audio):
|
218 |
"""Compute embeddings from audio"""
|
|
|
259 |
return audio_embeddings
|
260 |
|
261 |
|
262 |
+
@spaces.GPU(duration=120)
|
263 |
@torch.no_grad()
|
264 |
def compute_wavlm_embedding(raw_audio):
|
265 |
"""Compute embeddings from audio"""
|
|
|
366 |
return np.array(processed_landmarks)
|
367 |
|
368 |
|
369 |
+
@spaces.GPU(duration=600)
|
370 |
@torch.no_grad()
|
371 |
def sample(
|
372 |
audio_list,
|