Spaces:
Paused
Paused
fix lang detection
Browse files- audio_processor.py +23 -1
- main.py +2 -10
- whisper_streaming_custom/backends.py +10 -11
audio_processor.py
CHANGED
@@ -10,6 +10,7 @@ from typing import List, Dict, Any
|
|
10 |
from timed_objects import ASRToken
|
11 |
from whisper_streaming_custom.whisper_online import online_factory
|
12 |
from core import WhisperLiveKit
|
|
|
13 |
|
14 |
# Set up logging once
|
15 |
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
@@ -406,4 +407,25 @@ class AudioProcessor:
|
|
406 |
logger.warning(f"Error writing to FFmpeg: {e}. Restarting...")
|
407 |
await self.restart_ffmpeg()
|
408 |
self.ffmpeg_process.stdin.write(message)
|
409 |
-
self.ffmpeg_process.stdin.flush()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
from timed_objects import ASRToken
|
11 |
from whisper_streaming_custom.whisper_online import online_factory
|
12 |
from core import WhisperLiveKit
|
13 |
+
import librosa
|
14 |
|
15 |
# Set up logging once
|
16 |
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
|
|
407 |
logger.warning(f"Error writing to FFmpeg: {e}. Restarting...")
|
408 |
await self.restart_ffmpeg()
|
409 |
self.ffmpeg_process.stdin.write(message)
|
410 |
+
self.ffmpeg_process.stdin.flush()
|
411 |
+
|
412 |
+
async def detect_language(self, file_path):
|
413 |
+
"""Detect the language of the audio file.
|
414 |
+
|
415 |
+
Args:
|
416 |
+
file_path: Path to the audio file
|
417 |
+
|
418 |
+
Returns:
|
419 |
+
tuple: (detected_language, confidence, probabilities)
|
420 |
+
"""
|
421 |
+
try:
|
422 |
+
# Use the ASR backend to detect language
|
423 |
+
if self.asr:
|
424 |
+
return self.asr.detect_language(file_path)
|
425 |
+
else:
|
426 |
+
raise RuntimeError("ASR backend not initialized")
|
427 |
+
|
428 |
+
except Exception as e:
|
429 |
+
logger.error(f"Error in language detection: {e}")
|
430 |
+
logger.error(f"Traceback: {traceback.format_exc()}")
|
431 |
+
raise
|
main.py
CHANGED
@@ -11,8 +11,6 @@ import traceback
|
|
11 |
import argparse
|
12 |
import uvicorn
|
13 |
import numpy as np
|
14 |
-
import librosa
|
15 |
-
import io
|
16 |
import tempfile
|
17 |
|
18 |
from core import WhisperLiveKit
|
@@ -64,14 +62,8 @@ async def detect_language(file: UploadFile = File(...)):
|
|
64 |
|
65 |
# Use the audio processor for language detection
|
66 |
if audio_processor:
|
67 |
-
#
|
68 |
-
|
69 |
-
|
70 |
-
# Convert to format expected by Whisper
|
71 |
-
audio = (audio * 32768).astype(np.int16)
|
72 |
-
|
73 |
-
# Detect language
|
74 |
-
detected_lang, confidence, probs = audio_processor.detect_language(audio)
|
75 |
|
76 |
# Clean up - remove the temporary file
|
77 |
os.remove(file_path)
|
|
|
11 |
import argparse
|
12 |
import uvicorn
|
13 |
import numpy as np
|
|
|
|
|
14 |
import tempfile
|
15 |
|
16 |
from core import WhisperLiveKit
|
|
|
62 |
|
63 |
# Use the audio processor for language detection
|
64 |
if audio_processor:
|
65 |
+
# Detect language using the audio processor
|
66 |
+
detected_lang, confidence, probs = await audio_processor.detect_language(file_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
# Clean up - remove the temporary file
|
69 |
os.remove(file_path)
|
whisper_streaming_custom/backends.py
CHANGED
@@ -89,7 +89,7 @@ class WhisperTimestampedASR(ASRBase):
|
|
89 |
def set_translate_task(self):
|
90 |
self.transcribe_kargs["task"] = "translate"
|
91 |
|
92 |
-
def detect_language(self,
|
93 |
import whisper
|
94 |
"""
|
95 |
Detect the language of the audio using Whisper's language detection.
|
@@ -103,12 +103,9 @@ class WhisperTimestampedASR(ASRBase):
|
|
103 |
- confidence (float): Confidence score for the detected language
|
104 |
- probabilities (dict): Dictionary of language probabilities
|
105 |
"""
|
106 |
-
try:
|
107 |
-
# Ensure audio is in the correct format
|
108 |
-
if not isinstance(audio, np.ndarray):
|
109 |
-
audio = np.array(audio)
|
110 |
-
|
111 |
# Pad or trim audio to the correct length
|
|
|
112 |
audio = whisper.pad_or_trim(audio)
|
113 |
|
114 |
# Create mel spectrogram with correct dimensions
|
@@ -183,12 +180,12 @@ class FasterWhisperASR(ASRBase):
|
|
183 |
def set_translate_task(self):
|
184 |
self.transcribe_kargs["task"] = "translate"
|
185 |
|
186 |
-
def detect_language(self,
|
187 |
"""
|
188 |
Detect the language of the audio using faster-whisper's language detection.
|
189 |
|
190 |
Args:
|
191 |
-
|
192 |
|
193 |
Returns:
|
194 |
tuple: (detected_language, confidence, probabilities)
|
@@ -197,9 +194,11 @@ class FasterWhisperASR(ASRBase):
|
|
197 |
- probabilities (dict): Dictionary of language probabilities
|
198 |
"""
|
199 |
try:
|
200 |
-
#
|
201 |
-
|
202 |
-
|
|
|
|
|
203 |
|
204 |
# Use faster-whisper's detect_language method
|
205 |
language, language_probability, all_language_probs = self.model.detect_language(
|
|
|
89 |
def set_translate_task(self):
|
90 |
self.transcribe_kargs["task"] = "translate"
|
91 |
|
92 |
+
def detect_language(self, audio_file_path):
|
93 |
import whisper
|
94 |
"""
|
95 |
Detect the language of the audio using Whisper's language detection.
|
|
|
103 |
- confidence (float): Confidence score for the detected language
|
104 |
- probabilities (dict): Dictionary of language probabilities
|
105 |
"""
|
106 |
+
try:
|
|
|
|
|
|
|
|
|
107 |
# Pad or trim audio to the correct length
|
108 |
+
audio = whisper.load_audio(audio_file_path)
|
109 |
audio = whisper.pad_or_trim(audio)
|
110 |
|
111 |
# Create mel spectrogram with correct dimensions
|
|
|
180 |
def set_translate_task(self):
|
181 |
self.transcribe_kargs["task"] = "translate"
|
182 |
|
183 |
+
def detect_language(self, audio_file_path):
|
184 |
"""
|
185 |
Detect the language of the audio using faster-whisper's language detection.
|
186 |
|
187 |
Args:
|
188 |
+
audio_file_path: Path to the audio file
|
189 |
|
190 |
Returns:
|
191 |
tuple: (detected_language, confidence, probabilities)
|
|
|
194 |
- probabilities (dict): Dictionary of language probabilities
|
195 |
"""
|
196 |
try:
|
197 |
+
# Load audio using soundfile
|
198 |
+
audio, sr = sf.read(audio_file_path)
|
199 |
+
|
200 |
+
# Convert to format expected by Whisper (16-bit PCM)
|
201 |
+
audio = (audio * 32768).astype(np.int16)
|
202 |
|
203 |
# Use faster-whisper's detect_language method
|
204 |
language, language_probability, all_language_probs = self.model.detect_language(
|