Jerich commited on
Commit
f8dca01
·
verified ·
1 Parent(s): faf4aa8

Fix torch import error in translate-audio endpoint

Browse files
Files changed (1) hide show
  1. app.py +8 -4
app.py CHANGED
@@ -6,8 +6,10 @@ import logging
6
  import threading
7
  import tempfile
8
  import uuid
 
9
  import numpy as np
10
  import soundfile as sf
 
11
  from fastapi import FastAPI, HTTPException, UploadFile, File, Form
12
  from fastapi.responses import JSONResponse
13
  from typing import Dict, Any, Optional
@@ -50,12 +52,10 @@ def load_models_task():
50
  try:
51
  loading_in_progress = True
52
 
53
- # Import heavy libraries only when needed
54
  logger.info("Starting to load STT model...")
55
- import torch
56
  from transformers import WhisperProcessor, WhisperForConditionalGeneration
57
 
58
- # Load STT model
59
  try:
60
  logger.info("Loading Whisper model...")
61
  model_status["stt"] = "loading"
@@ -177,15 +177,18 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
177
 
178
  try:
179
  # Read and preprocess the audio
 
180
  waveform, sample_rate = sf.read(temp_path)
 
181
  if sample_rate != 16000:
182
  logger.info(f"Resampling audio from {sample_rate} Hz to 16000 Hz")
183
- import librosa
184
  waveform = librosa.resample(waveform, orig_sr=sample_rate, target_sr=16000)
185
 
186
  # Process the audio with Whisper
187
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
188
  inputs = stt_processor(waveform, sampling_rate=16000, return_tensors="pt").to(device)
 
189
  with torch.no_grad():
190
  generated_ids = stt_model.generate(**inputs)
191
  transcription = stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
@@ -210,6 +213,7 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
210
  "output_audio": None
211
  }
212
  finally:
 
213
  os.unlink(temp_path)
214
 
215
  if __name__ == "__main__":
 
6
  import threading
7
  import tempfile
8
  import uuid
9
+ import torch
10
  import numpy as np
11
  import soundfile as sf
12
+ import librosa
13
  from fastapi import FastAPI, HTTPException, UploadFile, File, Form
14
  from fastapi.responses import JSONResponse
15
  from typing import Dict, Any, Optional
 
52
  try:
53
  loading_in_progress = True
54
 
55
+ # Load STT model
56
  logger.info("Starting to load STT model...")
 
57
  from transformers import WhisperProcessor, WhisperForConditionalGeneration
58
 
 
59
  try:
60
  logger.info("Loading Whisper model...")
61
  model_status["stt"] = "loading"
 
177
 
178
  try:
179
  # Read and preprocess the audio
180
+ logger.info(f"Reading audio file: {temp_path}")
181
  waveform, sample_rate = sf.read(temp_path)
182
+ logger.info(f"Audio loaded: sample_rate={sample_rate}, waveform_shape={waveform.shape}")
183
  if sample_rate != 16000:
184
  logger.info(f"Resampling audio from {sample_rate} Hz to 16000 Hz")
 
185
  waveform = librosa.resample(waveform, orig_sr=sample_rate, target_sr=16000)
186
 
187
  # Process the audio with Whisper
188
  device = "cuda" if torch.cuda.is_available() else "cpu"
189
+ logger.info(f"Using device: {device}")
190
  inputs = stt_processor(waveform, sampling_rate=16000, return_tensors="pt").to(device)
191
+ logger.info("Audio processed, generating transcription...")
192
  with torch.no_grad():
193
  generated_ids = stt_model.generate(**inputs)
194
  transcription = stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
 
213
  "output_audio": None
214
  }
215
  finally:
216
+ logger.info(f"Cleaning up temporary file: {temp_path}")
217
  os.unlink(temp_path)
218
 
219
  if __name__ == "__main__":