Antoni Bigata commited on
Commit
7746897
·
1 Parent(s): 2fb3e22

addapt for zerogpu

Browse files
Files changed (1) hide show
  1. app.py +46 -31
app.py CHANGED
@@ -23,6 +23,7 @@ from inference_functions import (
23
  )
24
  from wordle_game import WordleGame
25
  import torch.cuda.amp as amp # Import amp for mixed precision
 
26
 
27
 
28
  # Set default tensor type to float16 for faster computation
@@ -96,10 +97,26 @@ def load_model(
96
  return model
97
 
98
 
99
- # keyframe_model = KeyframeModel(device=device)
100
- # interpolation_model = InterpolationModel(device=device)
101
- vae_model = VaeWrapper("video")
102
- if torch.cuda.is_available():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  vae_model = vae_model.half() # Convert to half precision
104
  try:
105
  vae_model = torch.compile(vae_model)
@@ -107,8 +124,7 @@ if torch.cuda.is_available():
107
  except Exception as e:
108
  print(f"Warning: Failed to compile vae_model: {e}")
109
 
110
- hubert_model = HubertModel.from_pretrained("facebook/hubert-base-ls960").cuda()
111
- if torch.cuda.is_available():
112
  hubert_model = hubert_model.half() # Convert to half precision
113
  try:
114
  hubert_model = torch.compile(hubert_model)
@@ -116,13 +132,13 @@ if torch.cuda.is_available():
116
  except Exception as e:
117
  print(f"Warning: Failed to compile hubert_model: {e}")
118
 
119
- wavlm_model = WavLM_wrapper(
120
- model_size="Base+",
121
- feed_as_frames=False,
122
- merge_type="None",
123
- model_path="/vol/paramonos2/projects/antoni/code/Personal/code_prep/keysync/pretrained_models/checkpoints/WavLM-Base+.pt",
124
- ).cuda()
125
- if torch.cuda.is_available():
126
  wavlm_model = wavlm_model.half() # Convert to half precision
127
  try:
128
  wavlm_model = torch.compile(wavlm_model)
@@ -130,27 +146,23 @@ if torch.cuda.is_available():
130
  except Exception as e:
131
  print(f"Warning: Failed to compile wavlm_model: {e}")
132
 
133
- landmarks_extractor = LandmarksExtractor()
134
- # keyframe_model = load_model(
135
- # config="/vol/paramonos2/projects/antoni/code/Personal/code_prep/keysync/scripts/sampling/configs/keyframe.yaml",
136
- # ckpt="/vol/paramonos2/projects/antoni/code/Personal/code_prep/keysync/pretrained_models/checkpoints/keyframe_dub.pt",
137
- # )
138
- # interpolation_model = load_model(
139
- # config="/vol/paramonos2/projects/antoni/code/Personal/code_prep/keysync/scripts/sampling/configs/interpolation.yaml",
140
- # ckpt="/vol/paramonos2/projects/antoni/code/Personal/code_prep/keysync/pretrained_models/checkpoints/interpolation_dub.pt",
141
- # )
142
- # keyframe_model.en_and_decode_n_samples_a_time = 2
143
- # interpolation_model.en_and_decode_n_samples_a_time = 2
144
 
145
- # Default media paths
146
- DEFAULT_VIDEO_PATH = os.path.join(
147
- os.path.dirname(__file__), "assets", "sample_video.mp4"
148
- )
149
- DEFAULT_AUDIO_PATH = os.path.join(
150
- os.path.dirname(__file__), "assets", "sample_audio.wav"
151
- )
152
 
153
 
 
154
  @torch.no_grad()
155
  def compute_video_embedding(video_reader, min_len):
156
  """Compute embeddings from video"""
@@ -200,6 +212,7 @@ def compute_video_embedding(video_reader, min_len):
200
  return encoded, video_frames
201
 
202
 
 
203
  @torch.no_grad()
204
  def compute_hubert_embedding(raw_audio):
205
  """Compute embeddings from audio"""
@@ -246,6 +259,7 @@ def compute_hubert_embedding(raw_audio):
246
  return audio_embeddings
247
 
248
 
 
249
  @torch.no_grad()
250
  def compute_wavlm_embedding(raw_audio):
251
  """Compute embeddings from audio"""
@@ -352,6 +366,7 @@ def extract_video_landmarks(video_frames):
352
  return np.array(processed_landmarks)
353
 
354
 
 
355
  @torch.no_grad()
356
  def sample(
357
  audio_list,
 
23
  )
24
  from wordle_game import WordleGame
25
  import torch.cuda.amp as amp # Import amp for mixed precision
26
+ import spaces
27
 
28
 
29
  # Set default tensor type to float16 for faster computation
 
97
  return model
98
 
99
 
100
+ # Default media paths
101
+ DEFAULT_VIDEO_PATH = os.path.join(
102
+ os.path.dirname(__file__), "assets", "sample_video.mp4"
103
+ )
104
+ DEFAULT_AUDIO_PATH = os.path.join(
105
+ os.path.dirname(__file__), "assets", "sample_audio.wav"
106
+ )
107
+
108
+
109
+ @spaces.GPU(duration=60)
110
+ def load_all_models():
111
+ global \
112
+ keyframe_model, \
113
+ interpolation_model, \
114
+ vae_model, \
115
+ hubert_model, \
116
+ wavlm_model, \
117
+ landmarks_extractor
118
+ vae_model = VaeWrapper("video")
119
+
120
  vae_model = vae_model.half() # Convert to half precision
121
  try:
122
  vae_model = torch.compile(vae_model)
 
124
  except Exception as e:
125
  print(f"Warning: Failed to compile vae_model: {e}")
126
 
127
+ hubert_model = HubertModel.from_pretrained("facebook/hubert-base-ls960").cuda()
 
128
  hubert_model = hubert_model.half() # Convert to half precision
129
  try:
130
  hubert_model = torch.compile(hubert_model)
 
132
  except Exception as e:
133
  print(f"Warning: Failed to compile hubert_model: {e}")
134
 
135
+ wavlm_model = WavLM_wrapper(
136
+ model_size="Base+",
137
+ feed_as_frames=False,
138
+ merge_type="None",
139
+ model_path="/vol/paramonos2/projects/antoni/code/Personal/code_prep/keysync/pretrained_models/checkpoints/WavLM-Base+.pt",
140
+ ).cuda()
141
+
142
  wavlm_model = wavlm_model.half() # Convert to half precision
143
  try:
144
  wavlm_model = torch.compile(wavlm_model)
 
146
  except Exception as e:
147
  print(f"Warning: Failed to compile wavlm_model: {e}")
148
 
149
+ landmarks_extractor = LandmarksExtractor()
150
+ keyframe_model = load_model(
151
+ config="/vol/paramonos2/projects/antoni/code/Personal/code_prep/keysync/scripts/sampling/configs/keyframe.yaml",
152
+ ckpt="/vol/paramonos2/projects/antoni/code/Personal/code_prep/keysync/pretrained_models/checkpoints/keyframe_dub.pt",
153
+ )
154
+ interpolation_model = load_model(
155
+ config="/vol/paramonos2/projects/antoni/code/Personal/code_prep/keysync/scripts/sampling/configs/interpolation.yaml",
156
+ ckpt="/vol/paramonos2/projects/antoni/code/Personal/code_prep/keysync/pretrained_models/checkpoints/interpolation_dub.pt",
157
+ )
158
+ keyframe_model.en_and_decode_n_samples_a_time = 2
159
+ interpolation_model.en_and_decode_n_samples_a_time = 2
160
 
161
+
162
+ load_all_models()
 
 
 
 
 
163
 
164
 
165
+ @spaces.GPU(duration=60)
166
  @torch.no_grad()
167
  def compute_video_embedding(video_reader, min_len):
168
  """Compute embeddings from video"""
 
212
  return encoded, video_frames
213
 
214
 
215
+ @spaces.GPU(duration=120)
216
  @torch.no_grad()
217
  def compute_hubert_embedding(raw_audio):
218
  """Compute embeddings from audio"""
 
259
  return audio_embeddings
260
 
261
 
262
+ @spaces.GPU(duration=120)
263
  @torch.no_grad()
264
  def compute_wavlm_embedding(raw_audio):
265
  """Compute embeddings from audio"""
 
366
  return np.array(processed_landmarks)
367
 
368
 
369
+ @spaces.GPU(duration=600)
370
  @torch.no_grad()
371
  def sample(
372
  audio_list,