Spaces:
Running
on
Zero
Running
on
Zero
Antoni Bigata
commited on
Commit
·
cf0da47
1
Parent(s):
ed769ff
requirements
Browse files
app.py
CHANGED
@@ -116,15 +116,77 @@ DEFAULT_AUDIO_PATH = os.path.join(
|
|
116 |
)
|
117 |
|
118 |
|
119 |
-
@spaces.GPU(duration=60)
|
120 |
-
def load_all_models():
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
vae_model = VaeWrapper("video")
|
129 |
|
130 |
vae_model = vae_model.half() # Convert to half precision
|
@@ -167,24 +229,6 @@ def load_all_models():
|
|
167 |
)
|
168 |
keyframe_model.en_and_decode_n_samples_a_time = 2
|
169 |
interpolation_model.en_and_decode_n_samples_a_time = 2
|
170 |
-
return (
|
171 |
-
keyframe_model,
|
172 |
-
interpolation_model,
|
173 |
-
vae_model,
|
174 |
-
hubert_model,
|
175 |
-
wavlm_model,
|
176 |
-
landmarks_extractor,
|
177 |
-
)
|
178 |
-
|
179 |
-
|
180 |
-
(
|
181 |
-
keyframe_model,
|
182 |
-
interpolation_model,
|
183 |
-
vae_model,
|
184 |
-
hubert_model,
|
185 |
-
wavlm_model,
|
186 |
-
landmarks_extractor,
|
187 |
-
) = load_all_models()
|
188 |
|
189 |
|
190 |
@spaces.GPU(duration=60)
|
|
|
116 |
)
|
117 |
|
118 |
|
119 |
+
# @spaces.GPU(duration=60)
|
120 |
+
# def load_all_models():
|
121 |
+
# global \
|
122 |
+
# keyframe_model, \
|
123 |
+
# interpolation_model, \
|
124 |
+
# vae_model, \
|
125 |
+
# hubert_model, \
|
126 |
+
# wavlm_model, \
|
127 |
+
# landmarks_extractor
|
128 |
+
# vae_model = VaeWrapper("video")
|
129 |
+
|
130 |
+
# vae_model = vae_model.half() # Convert to half precision
|
131 |
+
# try:
|
132 |
+
# vae_model = torch.compile(vae_model)
|
133 |
+
# print("Successfully compiled vae_model in FP16")
|
134 |
+
# except Exception as e:
|
135 |
+
# print(f"Warning: Failed to compile vae_model: {e}")
|
136 |
+
|
137 |
+
# hubert_model = HubertModel.from_pretrained("facebook/hubert-base-ls960").cuda()
|
138 |
+
# hubert_model = hubert_model.half() # Convert to half precision
|
139 |
+
# try:
|
140 |
+
# hubert_model = torch.compile(hubert_model)
|
141 |
+
# print("Successfully compiled hubert_model in FP16")
|
142 |
+
# except Exception as e:
|
143 |
+
# print(f"Warning: Failed to compile hubert_model: {e}")
|
144 |
+
|
145 |
+
# wavlm_model = WavLM_wrapper(
|
146 |
+
# model_size="Base+",
|
147 |
+
# feed_as_frames=False,
|
148 |
+
# merge_type="None",
|
149 |
+
# model_path=os.path.join(repo_path, "checkpoints/WavLM-Base+.pt"),
|
150 |
+
# ).cuda()
|
151 |
+
|
152 |
+
# wavlm_model = wavlm_model.half() # Convert to half precision
|
153 |
+
# try:
|
154 |
+
# wavlm_model = torch.compile(wavlm_model)
|
155 |
+
# print("Successfully compiled wavlm_model in FP16")
|
156 |
+
# except Exception as e:
|
157 |
+
# print(f"Warning: Failed to compile wavlm_model: {e}")
|
158 |
+
|
159 |
+
# landmarks_extractor = LandmarksExtractor()
|
160 |
+
# keyframe_model = load_model(
|
161 |
+
# config="keyframe.yaml",
|
162 |
+
# ckpt=os.path.join(repo_path, "checkpoints/keyframe_dub.pt"),
|
163 |
+
# )
|
164 |
+
# interpolation_model = load_model(
|
165 |
+
# config="interpolation.yaml",
|
166 |
+
# ckpt=os.path.join(repo_path, "checkpoints/interpolation_dub.pt"),
|
167 |
+
# )
|
168 |
+
# keyframe_model.en_and_decode_n_samples_a_time = 2
|
169 |
+
# interpolation_model.en_and_decode_n_samples_a_time = 2
|
170 |
+
# return (
|
171 |
+
# keyframe_model,
|
172 |
+
# interpolation_model,
|
173 |
+
# vae_model,
|
174 |
+
# hubert_model,
|
175 |
+
# wavlm_model,
|
176 |
+
# landmarks_extractor,
|
177 |
+
# )
|
178 |
+
|
179 |
+
|
180 |
+
# (
|
181 |
+
# keyframe_model,
|
182 |
+
# interpolation_model,
|
183 |
+
# vae_model,
|
184 |
+
# hubert_model,
|
185 |
+
# wavlm_model,
|
186 |
+
# landmarks_extractor,
|
187 |
+
# ) = load_all_models()
|
188 |
+
|
189 |
+
with spaces.GPU(duration=60) as gpu:
|
190 |
vae_model = VaeWrapper("video")
|
191 |
|
192 |
vae_model = vae_model.half() # Convert to half precision
|
|
|
229 |
)
|
230 |
keyframe_model.en_and_decode_n_samples_a_time = 2
|
231 |
interpolation_model.en_and_decode_n_samples_a_time = 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
232 |
|
233 |
|
234 |
@spaces.GPU(duration=60)
|