Michael Hu
fix build error
7eff88c
"""
Speech Recognition Module using Whisper Large-v3
Handles audio preprocessing and transcription
"""
import logging
import numpy as np
logger = logging.getLogger(__name__)
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
from pydub import AudioSegment
import soundfile as sf # Add this import
def transcribe_audio(audio_path):
"""
Convert audio file to text using Whisper ASR model
Args:
audio_path: Path to input audio file
Returns:
Transcribed English text
"""
logger.info(f"Starting transcription for: {audio_path}")
try:
# Audio conversion
logger.info("Converting audio format")
audio = AudioSegment.from_file(audio_path)
processed_audio = audio.set_frame_rate(16000).set_channels(1)
wav_path = audio_path.replace(".mp3", ".wav")
processed_audio.export(wav_path, format="wav")
logger.info(f"Audio converted to: {wav_path}")
# Model initialization
logger.info("Loading Whisper model")
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {device}")
model = AutoModelForSpeechSeq2Seq.from_pretrained(
"openai/whisper-large-v3",
torch_dtype=torch.float32,
low_cpu_mem_usage=True,
use_safetensors=True
).to(device)
processor = AutoProcessor.from_pretrained("openai/whisper-large-v3")
logger.info("Model loaded successfully")
# Processing
logger.info("Processing audio input")
logger.debug("Loading audio data")
audio_data, sample_rate = sf.read(wav_path)
audio_data = audio_data.astype(np.float32)
# Increase chunk length and stride for longer transcriptions
inputs = processor(
audio_data,
sampling_rate=16000,
return_tensors="pt",
# Increase chunk length to handle longer segments
chunk_length_s=60, # Increased from 30
stride_length_s=10 # Increased from 5
).to(device)
# Transcription
logger.info("Generating transcription")
with torch.no_grad():
# Add max_length parameter to allow for longer outputs
outputs = model.generate(
**inputs,
language="en",
task="transcribe",
max_length=448, # Explicitly set max output length
no_repeat_ngram_size=3 # Prevent repetition in output
)
result = processor.batch_decode(outputs, skip_special_tokens=True)[0]
logger.info(f"transcription: %s" % result)
logger.info(f"Transcription completed successfully")
return result
except Exception as e:
logger.error(f"Transcription failed: {str(e)}", exc_info=True)
raise