lucas-ventura commited on
Commit
e218cb8
·
verified ·
1 Parent(s): 042e2e0

Upload asr.py

Browse files
Files changed (1) hide show
  1. asr.py +61 -0
asr.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import torch
4
+ import whisperx
5
+ from whisperx.audio import SAMPLE_RATE
6
+
7
+ from src.data.chapters import sec_to_hms
8
+
9
+ # Set device and disable TF32 for consistent results
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ torch.backends.cuda.matmul.allow_tf32 = False
12
+ torch.backends.cudnn.allow_tf32 = False
13
+
14
+
15
+ class ASRProcessor:
16
+ """
17
+ Automatic Speech Recognition processor using WhisperX.
18
+
19
+ Transcribes audio files and returns time-aligned transcription segments.
20
+ """
21
+
22
+ def __init__(self, model_name="large-v2", compute_type="float16"):
23
+ self.model_name = model_name
24
+ self.model = whisperx.load_model(model_name, device, compute_type=compute_type)
25
+
26
+ def get_asr(self, audio_file, return_duration=True):
27
+ assert Path(audio_file).exists(), f"File {audio_file} does not exist"
28
+ audio = whisperx.load_audio(audio_file)
29
+ result = self.model.transcribe(audio, batch_size=1)
30
+ language = result["language"]
31
+ duration = audio.shape[0] / SAMPLE_RATE
32
+
33
+ # Align the transcription
34
+ model_a, metadata = whisperx.load_align_model(
35
+ language_code=language, device=device
36
+ )
37
+ aligned_result = whisperx.align(
38
+ result["segments"],
39
+ model_a,
40
+ metadata,
41
+ audio,
42
+ device,
43
+ return_char_alignments=False,
44
+ )
45
+
46
+ # Format the output
47
+ segments = [
48
+ {field: segment[field] for field in ["start", "end", "text"]}
49
+ for segment in aligned_result["segments"]
50
+ ]
51
+
52
+ asr_clean = []
53
+ for segment in segments:
54
+ t = segment["text"].strip()
55
+ s = sec_to_hms(segment["start"])
56
+ asr_clean.append(f"{s}: {t}")
57
+
58
+ if return_duration:
59
+ return "\n".join(asr_clean) + "\n", duration
60
+ else:
61
+ return "\n".join(asr_clean) + "\n"