RSHVR commited on
Commit
8d98b9d
·
verified ·
1 Parent(s): eb62218

Update stt.py

Browse files
Files changed (1) hide show
  1. stt.py +71 -50
stt.py CHANGED
@@ -1,8 +1,13 @@
 
1
  import os
2
  import torch
3
  import torchaudio
4
  import spaces
 
 
 
5
  from transformers import WhisperProcessor, WhisperForConditionalGeneration
 
6
 
7
  # Create directories
8
  os.makedirs("transcriptions", exist_ok=True)
@@ -20,63 +25,79 @@ WHISPER_MODEL_SIZES = {
20
  'large': 'openai/whisper-large-v3',
21
  }
22
 
23
- # Synchronous function with GPU decorator
24
- @spaces.GPU
25
- def _transcribe_audio_gpu(audio_file_path, model_size="base", language="en"):
26
- global whisper_model, whisper_processor
27
-
28
- try:
 
 
 
 
29
  # Get model identifier
30
- model_id = WHISPER_MODEL_SIZES.get(model_size.lower(), WHISPER_MODEL_SIZES['base'])
31
 
32
- # Load model and processor on first use or if model size changes
33
  if whisper_model is None or whisper_processor is None or (whisper_model and whisper_model.config._name_or_path != model_id):
34
- print(f"Loading Whisper {model_size} model...")
35
  whisper_processor = WhisperProcessor.from_pretrained(model_id)
36
  whisper_model = WhisperForConditionalGeneration.from_pretrained(model_id)
37
  print(f"Model loaded on device: {whisper_model.device}")
 
 
 
 
 
38
 
39
- # Process audio
40
- speech_array, sample_rate = torchaudio.load(audio_file_path)
41
-
42
- # Convert to mono if needed
43
- if speech_array.shape[0] > 1:
44
- speech_array = torch.mean(speech_array, dim=0, keepdim=True)
45
-
46
- # Resample to 16kHz if needed
47
- if sample_rate != 16000:
48
- resampler = torchaudio.transforms.Resample(sample_rate, 16000)
49
- speech_array = resampler(speech_array)
50
-
51
- # Prepare inputs for the model
52
- input_features = whisper_processor(
53
- speech_array.squeeze().numpy(),
54
- sampling_rate=16000,
55
- return_tensors="pt"
56
- ).input_features
57
-
58
- # Generate transcription
59
- generation_kwargs = {}
60
-
61
- if language:
62
- forced_decoder_ids = whisper_processor.get_decoder_prompt_ids(language=language, task="transcribe")
63
- generation_kwargs["forced_decoder_ids"] = forced_decoder_ids
64
-
65
- # Run the model
66
- with torch.no_grad():
67
- predicted_ids = whisper_model.generate(input_features, **generation_kwargs)
68
 
69
- # Decode the output
70
- transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)
71
-
72
- # Return the transcribed text
73
- return transcription[0]
74
-
75
- except Exception as e:
76
- print(f"Error during transcription: {str(e)}")
77
- return ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
- # Async wrapper that calls the GPU function
80
  async def transcribe_audio(audio_file_path, model_size="base", language="en"):
81
- # Call the GPU-decorated function
82
- return _transcribe_audio_gpu(audio_file_path, model_size, language)
 
 
 
 
1
+ # stt.py
2
  import os
3
  import torch
4
  import torchaudio
5
  import spaces
6
+ import numpy as np
7
+ from typing import Tuple
8
+ from numpy.typing import NDArray
9
  from transformers import WhisperProcessor, WhisperForConditionalGeneration
10
+ import tempfile
11
 
12
  # Create directories
13
  os.makedirs("transcriptions", exist_ok=True)
 
25
  'large': 'openai/whisper-large-v3',
26
  }
27
 
28
+ class WhisperSTTModel:
29
+ def __init__(self, model_size="base", language="en"):
30
+ self.model_size = model_size
31
+ self.language = language
32
+ self._initialize_model()
33
+
34
+ @spaces.GPU
35
+ def _initialize_model(self):
36
+ global whisper_model, whisper_processor
37
+
38
  # Get model identifier
39
+ model_id = WHISPER_MODEL_SIZES.get(self.model_size.lower(), WHISPER_MODEL_SIZES['base'])
40
 
41
+ # Load model and processor if not already loaded
42
  if whisper_model is None or whisper_processor is None or (whisper_model and whisper_model.config._name_or_path != model_id):
43
+ print(f"Loading Whisper {self.model_size} model...")
44
  whisper_processor = WhisperProcessor.from_pretrained(model_id)
45
  whisper_model = WhisperForConditionalGeneration.from_pretrained(model_id)
46
  print(f"Model loaded on device: {whisper_model.device}")
47
+
48
+ @spaces.GPU
49
+ def stt(self, audio: Tuple[int, NDArray[np.float32]]) -> str:
50
+ """Transcribe audio to text following the STTModel protocol"""
51
+ sample_rate, audio_array = audio
52
 
53
+ try:
54
+ # Convert to mono if needed
55
+ if len(audio_array.shape) > 1 and audio_array.shape[0] > 1:
56
+ audio_array = np.mean(audio_array, axis=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
+ # Convert numpy array to torch tensor
59
+ speech_array = torch.tensor(audio_array).unsqueeze(0)
60
+
61
+ # Resample to 16kHz if needed
62
+ if sample_rate != 16000:
63
+ resampler = torchaudio.transforms.Resample(sample_rate, 16000)
64
+ speech_array = resampler(speech_array)
65
+
66
+ # Prepare inputs for the model
67
+ input_features = whisper_processor(
68
+ speech_array.squeeze().numpy(),
69
+ sampling_rate=16000,
70
+ return_tensors="pt"
71
+ ).input_features
72
+
73
+ # Generate transcription
74
+ generation_kwargs = {}
75
+
76
+ if self.language:
77
+ forced_decoder_ids = whisper_processor.get_decoder_prompt_ids(language=self.language, task="transcribe")
78
+ generation_kwargs["forced_decoder_ids"] = forced_decoder_ids
79
+
80
+ # Run the model
81
+ with torch.no_grad():
82
+ predicted_ids = whisper_model.generate(input_features, **generation_kwargs)
83
+
84
+ # Decode the output
85
+ transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)
86
+
87
+ # Return the transcribed text
88
+ return transcription[0]
89
+
90
+ except Exception as e:
91
+ print(f"Error during transcription: {str(e)}")
92
+ return ""
93
+
94
+ # Create a singleton instance for easy import
95
+ whisper_stt = WhisperSTTModel(model_size="base", language="en")
96
 
97
+ # Legacy function for backward compatibility
98
  async def transcribe_audio(audio_file_path, model_size="base", language="en"):
99
+ """For compatibility with older code"""
100
+ # Load audio from file
101
+ speech_array, sample_rate = torchaudio.load(audio_file_path)
102
+ # Use the new model to transcribe
103
+ return whisper_stt.stt((sample_rate, speech_array.squeeze().numpy()))