Update inference.py
Browse files- inference.py +8 -6
inference.py
CHANGED
@@ -69,7 +69,7 @@ class InferenceRecipe:
|
|
69 |
"""Load and preprocess audio."""
|
70 |
try:
|
71 |
# Convert to tensor
|
72 |
-
wav = torch.from_numpy(audio_array).float().unsqueeze(0)
|
73 |
|
74 |
# Resample if needed
|
75 |
if sample_rate != self.sample_rate:
|
@@ -93,15 +93,15 @@ class InferenceRecipe:
|
|
93 |
raise
|
94 |
|
95 |
def _pad_codes(self, all_codes, time_seconds=30):
|
96 |
-
"""Pad codes to minimum length if needed."""
|
97 |
try:
|
98 |
min_frames = int(time_seconds * self.frame_rate)
|
99 |
frame_size = int(self.sample_rate / self.frame_rate)
|
100 |
-
|
101 |
if len(all_codes) < min_frames:
|
102 |
frames_to_add = min_frames - len(all_codes)
|
103 |
logger.info(f"Padding {frames_to_add} frames to reach minimum length")
|
104 |
with torch.no_grad(), self.mimi.streaming(batch_size=1):
|
|
|
105 |
chunk = torch.zeros(1, 1, frame_size, dtype=torch.float32, device=self.device)
|
106 |
for _ in range(frames_to_add):
|
107 |
additional_code = self.mimi.encode(chunk)
|
@@ -137,15 +137,17 @@ class InferenceRecipe:
|
|
137 |
"""Run a warmup pass."""
|
138 |
try:
|
139 |
frame_size = int(self.sample_rate / self.frame_rate)
|
140 |
-
|
141 |
-
|
142 |
|
143 |
with torch.no_grad(), self.lm_gen.streaming(1), self.mimi.streaming(1):
|
|
|
144 |
tokens = self.lm_gen.step(codes[:, :, 0:1])
|
145 |
if tokens is not None:
|
146 |
_ = self.mimi.decode(tokens[:, 1:])
|
147 |
|
148 |
-
|
|
|
149 |
logger.info("Warmup pass completed")
|
150 |
|
151 |
except Exception as e:
|
|
|
69 |
"""Load and preprocess audio."""
|
70 |
try:
|
71 |
# Convert to tensor
|
72 |
+
wav = torch.from_numpy(audio_array).float().unsqueeze(0).to(self.device)
|
73 |
|
74 |
# Resample if needed
|
75 |
if sample_rate != self.sample_rate:
|
|
|
93 |
raise
|
94 |
|
95 |
def _pad_codes(self, all_codes, time_seconds=30):
|
|
|
96 |
try:
|
97 |
min_frames = int(time_seconds * self.frame_rate)
|
98 |
frame_size = int(self.sample_rate / self.frame_rate)
|
99 |
+
|
100 |
if len(all_codes) < min_frames:
|
101 |
frames_to_add = min_frames - len(all_codes)
|
102 |
logger.info(f"Padding {frames_to_add} frames to reach minimum length")
|
103 |
with torch.no_grad(), self.mimi.streaming(batch_size=1):
|
104 |
+
# Create tensor on the correct device
|
105 |
chunk = torch.zeros(1, 1, frame_size, dtype=torch.float32, device=self.device)
|
106 |
for _ in range(frames_to_add):
|
107 |
additional_code = self.mimi.encode(chunk)
|
|
|
137 |
"""Run a warmup pass."""
|
138 |
try:
|
139 |
frame_size = int(self.sample_rate / self.frame_rate)
|
140 |
+
# Create tensor on the correct device from the start
|
141 |
+
chunk = torch.zeros(1, 1, frame_size, dtype=torch.float32, device=self.device)
|
142 |
|
143 |
with torch.no_grad(), self.lm_gen.streaming(1), self.mimi.streaming(1):
|
144 |
+
codes = self.mimi.encode(chunk) # chunk already on correct device
|
145 |
tokens = self.lm_gen.step(codes[:, :, 0:1])
|
146 |
if tokens is not None:
|
147 |
_ = self.mimi.decode(tokens[:, 1:])
|
148 |
|
149 |
+
if self.device.type == 'cuda':
|
150 |
+
torch.cuda.synchronize()
|
151 |
logger.info("Warmup pass completed")
|
152 |
|
153 |
except Exception as e:
|