RSHVR commited on
Commit
836c8e6
·
verified ·
1 Parent(s): 28d024b

Create stt.py

Browse files
Files changed (1) hide show
  1. stt.py +76 -0
stt.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torchaudio
4
+ import spaces # Import spaces module for Zero-GPU
5
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration
6
+
7
+ # Create directories
8
+ os.makedirs("transcriptions", exist_ok=True)
9
+
10
+ # Initialize global models
11
+ whisper_model = None
12
+ whisper_processor = None
13
+
14
+ # Model configurations
15
+ WHISPER_MODEL_SIZES = {
16
+ 'tiny': 'openai/whisper-tiny',
17
+ 'base': 'openai/whisper-base',
18
+ 'small': 'openai/whisper-small',
19
+ 'medium': 'openai/whisper-medium',
20
+ 'large': 'openai/whisper-large-v3',
21
+ }
22
+
23
+ @spaces.GPU # Add spaces.GPU decorator for Zero-GPU support
24
+ async def transcribe_audio(audio_file_path, model_size="base", language="en"):
25
+ global whisper_model, whisper_processor
26
+
27
+ try:
28
+ # Get model identifier
29
+ model_id = WHISPER_MODEL_SIZES.get(model_size.lower(), WHISPER_MODEL_SIZES['base'])
30
+
31
+ # Load model and processor on first use or if model size changes
32
+ if whisper_model is None or whisper_processor is None or (whisper_model and whisper_model.config._name_or_path != model_id):
33
+ print(f"Loading Whisper {model_size} model...")
34
+ whisper_processor = WhisperProcessor.from_pretrained(model_id)
35
+ whisper_model = WhisperForConditionalGeneration.from_pretrained(model_id)
36
+ print(f"Model loaded on device: {whisper_model.device}")
37
+
38
+ # Process audio
39
+ speech_array, sample_rate = torchaudio.load(audio_file_path)
40
+
41
+ # Convert to mono if needed
42
+ if speech_array.shape[0] > 1:
43
+ speech_array = torch.mean(speech_array, dim=0, keepdim=True)
44
+
45
+ # Resample to 16kHz if needed
46
+ if sample_rate != 16000:
47
+ resampler = torchaudio.transforms.Resample(sample_rate, 16000)
48
+ speech_array = resampler(speech_array)
49
+
50
+ # Prepare inputs for the model
51
+ input_features = whisper_processor(
52
+ speech_array.squeeze().numpy(),
53
+ sampling_rate=16000,
54
+ return_tensors="pt"
55
+ ).input_features
56
+
57
+ # Generate transcription
58
+ generation_kwargs = {}
59
+
60
+ if language:
61
+ forced_decoder_ids = whisper_processor.get_decoder_prompt_ids(language=language, task="transcribe")
62
+ generation_kwargs["forced_decoder_ids"] = forced_decoder_ids
63
+
64
+ # Run the model
65
+ with torch.no_grad():
66
+ predicted_ids = whisper_model.generate(input_features, **generation_kwargs)
67
+
68
+ # Decode the output
69
+ transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)
70
+
71
+ # Return the transcribed text
72
+ return transcription[0]
73
+
74
+ except Exception as e:
75
+ print(f"Error during transcription: {str(e)}")
76
+ return ""