tezuesh commited on
Commit
76b1e32
·
verified ·
1 Parent(s): 452fa43

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +8 -6
inference.py CHANGED
@@ -69,7 +69,7 @@ class InferenceRecipe:
69
  """Load and preprocess audio."""
70
  try:
71
  # Convert to tensor
72
- wav = torch.from_numpy(audio_array).float().unsqueeze(0)
73
 
74
  # Resample if needed
75
  if sample_rate != self.sample_rate:
@@ -93,15 +93,15 @@ class InferenceRecipe:
93
  raise
94
 
95
  def _pad_codes(self, all_codes, time_seconds=30):
96
- """Pad codes to minimum length if needed."""
97
  try:
98
  min_frames = int(time_seconds * self.frame_rate)
99
  frame_size = int(self.sample_rate / self.frame_rate)
100
-
101
  if len(all_codes) < min_frames:
102
  frames_to_add = min_frames - len(all_codes)
103
  logger.info(f"Padding {frames_to_add} frames to reach minimum length")
104
  with torch.no_grad(), self.mimi.streaming(batch_size=1):
 
105
  chunk = torch.zeros(1, 1, frame_size, dtype=torch.float32, device=self.device)
106
  for _ in range(frames_to_add):
107
  additional_code = self.mimi.encode(chunk)
@@ -137,15 +137,17 @@ class InferenceRecipe:
137
  """Run a warmup pass."""
138
  try:
139
  frame_size = int(self.sample_rate / self.frame_rate)
140
- chunk = torch.zeros(1, 1, frame_size, dtype=torch.float32)
141
- codes = self.mimi.encode(chunk)
142
 
143
  with torch.no_grad(), self.lm_gen.streaming(1), self.mimi.streaming(1):
 
144
  tokens = self.lm_gen.step(codes[:, :, 0:1])
145
  if tokens is not None:
146
  _ = self.mimi.decode(tokens[:, 1:])
147
 
148
- torch.cuda.synchronize()
 
149
  logger.info("Warmup pass completed")
150
 
151
  except Exception as e:
 
69
  """Load and preprocess audio."""
70
  try:
71
  # Convert to tensor
72
+ wav = torch.from_numpy(audio_array).float().unsqueeze(0).to(self.device)
73
 
74
  # Resample if needed
75
  if sample_rate != self.sample_rate:
 
93
  raise
94
 
95
  def _pad_codes(self, all_codes, time_seconds=30):
 
96
  try:
97
  min_frames = int(time_seconds * self.frame_rate)
98
  frame_size = int(self.sample_rate / self.frame_rate)
99
+
100
  if len(all_codes) < min_frames:
101
  frames_to_add = min_frames - len(all_codes)
102
  logger.info(f"Padding {frames_to_add} frames to reach minimum length")
103
  with torch.no_grad(), self.mimi.streaming(batch_size=1):
104
+ # Create tensor on the correct device
105
  chunk = torch.zeros(1, 1, frame_size, dtype=torch.float32, device=self.device)
106
  for _ in range(frames_to_add):
107
  additional_code = self.mimi.encode(chunk)
 
137
  """Run a warmup pass."""
138
  try:
139
  frame_size = int(self.sample_rate / self.frame_rate)
140
+ # Create tensor on the correct device from the start
141
+ chunk = torch.zeros(1, 1, frame_size, dtype=torch.float32, device=self.device)
142
 
143
  with torch.no_grad(), self.lm_gen.streaming(1), self.mimi.streaming(1):
144
+ codes = self.mimi.encode(chunk) # chunk already on correct device
145
  tokens = self.lm_gen.step(codes[:, :, 0:1])
146
  if tokens is not None:
147
  _ = self.mimi.decode(tokens[:, 1:])
148
 
149
+ if self.device.type == 'cuda':
150
+ torch.cuda.synchronize()
151
  logger.info("Warmup pass completed")
152
 
153
  except Exception as e: