Spaces:
Paused
Paused
Fix torch import error in translate-audio endpoint
Browse files
app.py
CHANGED
@@ -6,8 +6,10 @@ import logging
|
|
6 |
import threading
|
7 |
import tempfile
|
8 |
import uuid
|
|
|
9 |
import numpy as np
|
10 |
import soundfile as sf
|
|
|
11 |
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
|
12 |
from fastapi.responses import JSONResponse
|
13 |
from typing import Dict, Any, Optional
|
@@ -50,12 +52,10 @@ def load_models_task():
|
|
50 |
try:
|
51 |
loading_in_progress = True
|
52 |
|
53 |
-
#
|
54 |
logger.info("Starting to load STT model...")
|
55 |
-
import torch
|
56 |
from transformers import WhisperProcessor, WhisperForConditionalGeneration
|
57 |
|
58 |
-
# Load STT model
|
59 |
try:
|
60 |
logger.info("Loading Whisper model...")
|
61 |
model_status["stt"] = "loading"
|
@@ -177,15 +177,18 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
|
|
177 |
|
178 |
try:
|
179 |
# Read and preprocess the audio
|
|
|
180 |
waveform, sample_rate = sf.read(temp_path)
|
|
|
181 |
if sample_rate != 16000:
|
182 |
logger.info(f"Resampling audio from {sample_rate} Hz to 16000 Hz")
|
183 |
-
import librosa
|
184 |
waveform = librosa.resample(waveform, orig_sr=sample_rate, target_sr=16000)
|
185 |
|
186 |
# Process the audio with Whisper
|
187 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
188 |
inputs = stt_processor(waveform, sampling_rate=16000, return_tensors="pt").to(device)
|
|
|
189 |
with torch.no_grad():
|
190 |
generated_ids = stt_model.generate(**inputs)
|
191 |
transcription = stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
@@ -210,6 +213,7 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
|
|
210 |
"output_audio": None
|
211 |
}
|
212 |
finally:
|
|
|
213 |
os.unlink(temp_path)
|
214 |
|
215 |
if __name__ == "__main__":
|
|
|
6 |
import threading
|
7 |
import tempfile
|
8 |
import uuid
|
9 |
+
import torch
|
10 |
import numpy as np
|
11 |
import soundfile as sf
|
12 |
+
import librosa
|
13 |
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
|
14 |
from fastapi.responses import JSONResponse
|
15 |
from typing import Dict, Any, Optional
|
|
|
52 |
try:
|
53 |
loading_in_progress = True
|
54 |
|
55 |
+
# Load STT model
|
56 |
logger.info("Starting to load STT model...")
|
|
|
57 |
from transformers import WhisperProcessor, WhisperForConditionalGeneration
|
58 |
|
|
|
59 |
try:
|
60 |
logger.info("Loading Whisper model...")
|
61 |
model_status["stt"] = "loading"
|
|
|
177 |
|
178 |
try:
|
179 |
# Read and preprocess the audio
|
180 |
+
logger.info(f"Reading audio file: {temp_path}")
|
181 |
waveform, sample_rate = sf.read(temp_path)
|
182 |
+
logger.info(f"Audio loaded: sample_rate={sample_rate}, waveform_shape={waveform.shape}")
|
183 |
if sample_rate != 16000:
|
184 |
logger.info(f"Resampling audio from {sample_rate} Hz to 16000 Hz")
|
|
|
185 |
waveform = librosa.resample(waveform, orig_sr=sample_rate, target_sr=16000)
|
186 |
|
187 |
# Process the audio with Whisper
|
188 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
189 |
+
logger.info(f"Using device: {device}")
|
190 |
inputs = stt_processor(waveform, sampling_rate=16000, return_tensors="pt").to(device)
|
191 |
+
logger.info("Audio processed, generating transcription...")
|
192 |
with torch.no_grad():
|
193 |
generated_ids = stt_model.generate(**inputs)
|
194 |
transcription = stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
|
|
213 |
"output_audio": None
|
214 |
}
|
215 |
finally:
|
216 |
+
logger.info(f"Cleaning up temporary file: {temp_path}")
|
217 |
os.unlink(temp_path)
|
218 |
|
219 |
if __name__ == "__main__":
|