CallyticsDemo / src /audio /alignment.py
bunyaminergen's picture
Initial
1b97239
# Standard library imports
import os
from typing import Annotated, List, Dict
# Related third-party imports
import torch
from faster_whisper import decode_audio
from ctc_forced_aligner import (
generate_emissions,
get_alignments,
get_spans,
load_alignment_model,
postprocess_results,
preprocess_text,
)
class ForcedAligner:
"""
ForcedAligner is a class for aligning audio to a provided transcript using a pre-trained alignment model.
Attributes
----------
device : str
Device to run the model on ('cuda' for GPU or 'cpu').
alignment_model : torch.nn.Module
The pre-trained alignment model.
alignment_tokenizer : Any
Tokenizer for processing text in alignment.
Methods
-------
align(audio_path, transcript, language, batch_size)
Aligns audio with a transcript and returns word-level timing information.
"""
def __init__(self, device: Annotated[str, "Device for model ('cuda' or 'cpu')"] = None):
"""
Initialize the ForcedAligner with the specified device.
Parameters
----------
device : str, optional
Device for running the model, by default 'cuda' if available, otherwise 'cpu'.
"""
self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
self.alignment_model, self.alignment_tokenizer = load_alignment_model(
self.device,
dtype=torch.float16 if self.device == 'cuda' else torch.float32,
)
def align(
self,
audio_path: Annotated[str, "Path to the audio file"],
transcript: Annotated[str, "Transcript of the audio content"],
language: Annotated[str, "Language of the transcript"] = 'en',
batch_size: Annotated[int, "Batch size for emission generation"] = 8,
) -> Annotated[List[Dict[str, float]], "List of word alignment data with timestamps"]:
"""
Aligns audio with a transcript and returns word-level timing information.
Parameters
----------
audio_path : str
Path to the audio file.
transcript : str
Transcript text corresponding to the audio.
language : str, optional
Language code for the transcript, default is 'en' (English).
batch_size : int, optional
Batch size for generating emissions, by default 8.
Returns
-------
List[Dict[str, float]]
A list of dictionaries containing word timing information.
Raises
------
FileNotFoundError
If the specified audio file does not exist.
Examples
--------
>>> aligner = ForcedAligner()
>>> aligner.align("path/to/audio.wav", "hello world")
[{'word': 'hello', 'start': 0.0, 'end': 0.5}, {'word': 'world', 'start': 0.6, 'end': 1.0}]
"""
if not os.path.exists(audio_path):
raise FileNotFoundError(
f"The audio file at path '{audio_path}' was not found."
)
speech_array = torch.from_numpy(decode_audio(audio_path))
emissions, stride = generate_emissions(
self.alignment_model,
speech_array.to(self.alignment_model.dtype).to(self.alignment_model.device),
batch_size=batch_size,
)
tokens_starred, text_starred = preprocess_text(
transcript,
romanize=True,
language=language,
)
segments, scores, blank_token = get_alignments(
emissions,
tokens_starred,
self.alignment_tokenizer,
)
spans = get_spans(tokens_starred, segments, blank_token)
word_timestamps = postprocess_results(text_starred, spans, stride, scores)
if self.device == 'cuda':
del self.alignment_model
torch.cuda.empty_cache()
print(f"Word_Timestamps: {word_timestamps}")
return word_timestamps
if __name__ == "__main__":
forced_aligner = ForcedAligner()
try:
path = "example_audio.wav"
audio_transcript = "This is a test transcript."
word_timestamp = forced_aligner.align(path, audio_transcript)
print(word_timestamp)
except FileNotFoundError as e:
print(e)