AnyaSchen commited on
Commit
7db1cf9
·
1 Parent(s): b0b8407

fix lang detection

Browse files
Files changed (3) hide show
  1. audio_processor.py +23 -1
  2. main.py +2 -10
  3. whisper_streaming_custom/backends.py +10 -11
audio_processor.py CHANGED
@@ -10,6 +10,7 @@ from typing import List, Dict, Any
10
  from timed_objects import ASRToken
11
  from whisper_streaming_custom.whisper_online import online_factory
12
  from core import WhisperLiveKit
 
13
 
14
  # Set up logging once
15
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
@@ -406,4 +407,25 @@ class AudioProcessor:
406
  logger.warning(f"Error writing to FFmpeg: {e}. Restarting...")
407
  await self.restart_ffmpeg()
408
  self.ffmpeg_process.stdin.write(message)
409
- self.ffmpeg_process.stdin.flush()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  from timed_objects import ASRToken
11
  from whisper_streaming_custom.whisper_online import online_factory
12
  from core import WhisperLiveKit
13
+ import librosa
14
 
15
  # Set up logging once
16
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
 
407
  logger.warning(f"Error writing to FFmpeg: {e}. Restarting...")
408
  await self.restart_ffmpeg()
409
  self.ffmpeg_process.stdin.write(message)
410
+ self.ffmpeg_process.stdin.flush()
411
+
412
+ async def detect_language(self, file_path):
413
+ """Detect the language of the audio file.
414
+
415
+ Args:
416
+ file_path: Path to the audio file
417
+
418
+ Returns:
419
+ tuple: (detected_language, confidence, probabilities)
420
+ """
421
+ try:
422
+ # Use the ASR backend to detect language
423
+ if self.asr:
424
+ return self.asr.detect_language(file_path)
425
+ else:
426
+ raise RuntimeError("ASR backend not initialized")
427
+
428
+ except Exception as e:
429
+ logger.error(f"Error in language detection: {e}")
430
+ logger.error(f"Traceback: {traceback.format_exc()}")
431
+ raise
main.py CHANGED
@@ -11,8 +11,6 @@ import traceback
11
  import argparse
12
  import uvicorn
13
  import numpy as np
14
- import librosa
15
- import io
16
  import tempfile
17
 
18
  from core import WhisperLiveKit
@@ -64,14 +62,8 @@ async def detect_language(file: UploadFile = File(...)):
64
 
65
  # Use the audio processor for language detection
66
  if audio_processor:
67
- # Load audio using librosa
68
- audio, sr = librosa.load(file_path, sr=16000)
69
-
70
- # Convert to format expected by Whisper
71
- audio = (audio * 32768).astype(np.int16)
72
-
73
- # Detect language
74
- detected_lang, confidence, probs = audio_processor.detect_language(audio)
75
 
76
  # Clean up - remove the temporary file
77
  os.remove(file_path)
 
11
  import argparse
12
  import uvicorn
13
  import numpy as np
 
 
14
  import tempfile
15
 
16
  from core import WhisperLiveKit
 
62
 
63
  # Use the audio processor for language detection
64
  if audio_processor:
65
+ # Detect language using the audio processor
66
+ detected_lang, confidence, probs = await audio_processor.detect_language(file_path)
 
 
 
 
 
 
67
 
68
  # Clean up - remove the temporary file
69
  os.remove(file_path)
whisper_streaming_custom/backends.py CHANGED
@@ -89,7 +89,7 @@ class WhisperTimestampedASR(ASRBase):
89
  def set_translate_task(self):
90
  self.transcribe_kargs["task"] = "translate"
91
 
92
- def detect_language(self, audio):
93
  import whisper
94
  """
95
  Detect the language of the audio using Whisper's language detection.
@@ -103,12 +103,9 @@ class WhisperTimestampedASR(ASRBase):
103
  - confidence (float): Confidence score for the detected language
104
  - probabilities (dict): Dictionary of language probabilities
105
  """
106
- try:
107
- # Ensure audio is in the correct format
108
- if not isinstance(audio, np.ndarray):
109
- audio = np.array(audio)
110
-
111
  # Pad or trim audio to the correct length
 
112
  audio = whisper.pad_or_trim(audio)
113
 
114
  # Create mel spectrogram with correct dimensions
@@ -183,12 +180,12 @@ class FasterWhisperASR(ASRBase):
183
  def set_translate_task(self):
184
  self.transcribe_kargs["task"] = "translate"
185
 
186
- def detect_language(self, audio):
187
  """
188
  Detect the language of the audio using faster-whisper's language detection.
189
 
190
  Args:
191
- audio (np.ndarray): Audio data as numpy array
192
 
193
  Returns:
194
  tuple: (detected_language, confidence, probabilities)
@@ -197,9 +194,11 @@ class FasterWhisperASR(ASRBase):
197
  - probabilities (dict): Dictionary of language probabilities
198
  """
199
  try:
200
- # Ensure audio is in the correct format
201
- if not isinstance(audio, np.ndarray):
202
- audio = np.array(audio)
 
 
203
 
204
  # Use faster-whisper's detect_language method
205
  language, language_probability, all_language_probs = self.model.detect_language(
 
89
  def set_translate_task(self):
90
  self.transcribe_kargs["task"] = "translate"
91
 
92
+ def detect_language(self, audio_file_path):
93
  import whisper
94
  """
95
  Detect the language of the audio using Whisper's language detection.
 
103
  - confidence (float): Confidence score for the detected language
104
  - probabilities (dict): Dictionary of language probabilities
105
  """
106
+ try:
 
 
 
 
107
  # Pad or trim audio to the correct length
108
+ audio = whisper.load_audio(audio_file_path)
109
  audio = whisper.pad_or_trim(audio)
110
 
111
  # Create mel spectrogram with correct dimensions
 
180
  def set_translate_task(self):
181
  self.transcribe_kargs["task"] = "translate"
182
 
183
+ def detect_language(self, audio_file_path):
184
  """
185
  Detect the language of the audio using faster-whisper's language detection.
186
 
187
  Args:
188
+ audio_file_path: Path to the audio file
189
 
190
  Returns:
191
  tuple: (detected_language, confidence, probabilities)
 
194
  - probabilities (dict): Dictionary of language probabilities
195
  """
196
  try:
197
+ # Load audio using soundfile
198
+ audio, sr = sf.read(audio_file_path)
199
+
200
+ # Convert to format expected by Whisper (16-bit PCM)
201
+ audio = (audio * 32768).astype(np.int16)
202
 
203
  # Use faster-whisper's detect_language method
204
  language, language_probability, all_language_probs = self.model.detect_language(