# Standard library imports import os import wave from typing import List, Dict, Annotated, Union, Tuple # Related third-party imports import nltk import numpy as np import soundfile as sf from librosa.feature import mfcc from scipy.fft import fft, fftfreq class WordSpeakerMapper: """ Maps words to speakers based on timestamps and aligns speaker tags after punctuation restoration. This class processes word timing information and assigns each word to a speaker based on the provided speaker timestamps. Missing timestamps are handled, and each word can be aligned to a speaker based on different reference points ('start', 'mid', or 'end'). After punctuation restoration, word-speaker mapping can be realigned to ensure consistency within a sentence. Attributes ---------- word_timestamps : List[Dict] List of word timing information with 'start', 'end', and 'text' keys. speaker_timestamps : List[List[int]] List of speaker segments, where each segment contains [start_time, end_time, speaker_id]. word_speaker_mapping : List[Dict] or None Processed word-to-speaker mappings. Methods ------- filter_missing_timestamps(word_timestamps, initial_timestamp=0, final_timestamp=None) Fills in missing start and end timestamps in word timing data. get_words_speaker_mapping(word_anchor_option='start') Maps words to speakers based on word and speaker timestamps. """ def __init__( self, word_timestamps: Annotated[List[Dict], "List of word timing information"], speaker_timestamps: Annotated[List[List[Union[int, float]]], "List of speaker segments"], ): """ Initializes the WordSpeakerMapper with word and speaker timestamps. Parameters ---------- word_timestamps : List[Dict] List of word timing information. speaker_timestamps : List[List[int]] List of speaker segments. """ self.word_timestamps = self.filter_missing_timestamps(word_timestamps) self.speaker_timestamps = speaker_timestamps self.word_speaker_mapping = None def filter_missing_timestamps( self, word_timestamps: Annotated[List[Dict], "List of word timing information"], initial_timestamp: Annotated[int, "Start time of the first word"] = 0, final_timestamp: Annotated[int, "End time of the last word"] = None ) -> Annotated[List[Dict], "List of word timestamps with missing values filled"]: """ Fills in missing start and end timestamps. Parameters ---------- word_timestamps : List[Dict] List of word timing information that may contain missing timestamps. initial_timestamp : int, optional Start time of the first word, default is 0. final_timestamp : int, optional End time of the last word, if available. Returns ------- List[Dict] List of word timestamps with missing values filled. Examples -------- >>> word_timestamp = [{'text': 'Hello', 'end': 1.2}] >>> mapper = WordSpeakerMapper([], []) >>> mapper.filter_missing_timestamps(word_timestamps) [{'text': 'Hello', 'start': 0, 'end': 1.2}] """ if word_timestamps[0].get("start") is None: word_timestamps[0]["start"] = initial_timestamp word_timestamps[0]["end"] = self._get_next_start_timestamp(word_timestamps, 0, final_timestamp) result = [word_timestamps[0]] for i, ws in enumerate(word_timestamps[1:], start=1): if "text" not in ws: continue if ws.get("start") is None: ws["start"] = word_timestamps[i - 1]["end"] ws["end"] = self._get_next_start_timestamp(word_timestamps, i, final_timestamp) if ws["text"] is not None: result.append(ws) return result @staticmethod def _get_next_start_timestamp( word_timestamps: Annotated[List[Dict], "List of word timing information"], current_word_index: Annotated[int, "Index of the current word"], final_timestamp: Annotated[int, "Final timestamp if needed"] ) -> Annotated[int, "Next start timestamp for filling missing values"]: """ Finds the next start timestamp to fill in missing values. Parameters ---------- word_timestamps : List[Dict] List of word timing information. current_word_index : int Index of the current word. final_timestamp : int, optional Final timestamp to use if no next timestamp is found. Returns ------- int Next start timestamp for filling missing values. Examples -------- >>> word_timestamp = [{'start': 0.5, 'text': 'Hello'}, {'end': 2.0}] >>> mapper = WordSpeakerMapper([], []) >>> mapper._get_next_start_timestamp(word_timestamps, 0, 2) """ if current_word_index == len(word_timestamps) - 1: return word_timestamps[current_word_index]["start"] next_word_index = current_word_index + 1 while next_word_index < len(word_timestamps): if word_timestamps[next_word_index].get("start") is None: word_timestamps[current_word_index]["text"] += ( " " + word_timestamps[next_word_index]["text"] ) word_timestamps[next_word_index]["text"] = None next_word_index += 1 if next_word_index == len(word_timestamps): return final_timestamp else: return word_timestamps[next_word_index]["start"] return final_timestamp def get_words_speaker_mapping(self, word_anchor_option='start') -> List[Dict]: """ Maps words to speakers based on their timestamps. Parameters ---------- word_anchor_option : str, optional Anchor point for word mapping ('start', 'mid', or 'end'), default is 'start'. Returns ------- List[Dict] List of word-to-speaker mappings with timestamps and speaker IDs. Examples -------- >>> word_timestamps = [{'start': 0.5, 'end': 1.2, 'text': 'Hello'}] >>> speaker_timestamps = [[0, 1000, 1]] >>> mapper = WordSpeakerMapper(word_timestamps, speaker_timestamps) >>> mapper.get_words_speaker_mapping() [{'text': 'Hello', 'start_time': 500, 'end_time': 1200, 'speaker': 1}] """ def get_word_ts_anchor(start: int, end: int, option: str) -> int: """ Determines the anchor timestamp for a word. Parameters ---------- start : int Start time of the word in milliseconds. end : int End time of the word in milliseconds. option : str Anchor point for timestamp calculation ('start', 'mid', or 'end'). Returns ------- int Anchor timestamp for the word. Examples -------- >>> get_word_ts_anchor(500, 1200, 'mid') 850 """ if option == "end": return end elif option == "mid": return (start + end) // 2 return start wrd_spk_mapping = [] turn_idx = 0 num_speaker_ts = len(self.speaker_timestamps) for wrd_dict in self.word_timestamps: ws, we, wrd = ( int(wrd_dict["start"] * 1000), int(wrd_dict["end"] * 1000), wrd_dict["text"], ) wrd_pos = get_word_ts_anchor(ws, we, word_anchor_option) sp = -1 while turn_idx < num_speaker_ts and wrd_pos > self.speaker_timestamps[turn_idx][1]: turn_idx += 1 if turn_idx < num_speaker_ts and self.speaker_timestamps[turn_idx][0] <= wrd_pos <= \ self.speaker_timestamps[turn_idx][1]: sp = self.speaker_timestamps[turn_idx][2] elif turn_idx > 0: sp = self.speaker_timestamps[turn_idx - 1][2] wrd_spk_mapping.append( {"text": wrd, "start_time": ws, "end_time": we, "speaker": sp} ) self.word_speaker_mapping = wrd_spk_mapping return self.word_speaker_mapping def realign_with_punctuation(self, max_words_in_sentence: int = 50) -> None: """ Realigns word-speaker mapping after punctuation restoration. This method ensures consistent speaker mapping within sentences by analyzing punctuation and adjusting speaker labels for words that are part of the same sentence. Parameters ---------- max_words_in_sentence : int, optional Maximum number of words to consider for realignment in a sentence, default is 50. Examples -------- >>> word_speaker_mapping = [ ... {"text": "Hello", "speaker": "Speaker 1"}, ... {"text": "world", "speaker": "Speaker 2"}, ... {"text": ".", "speaker": "Speaker 2"}, ... {"text": "How", "speaker": "Speaker 1"}, ... {"text": "are", "speaker": "Speaker 1"}, ... {"text": "you", "speaker": "Speaker 2"}, ... {"text": "?", "speaker": "Speaker 2"} ... ] >>> mapper = WordSpeakerMapper([], []) >>> mapper.word_speaker_mapping = word_speaker_mapping >>> mapper.realign_with_punctuation() >>> print(mapper.word_speaker_mapping) [{'text': 'Hello', 'speaker': 'Speaker 1'}, {'text': 'world', 'speaker': 'Speaker 1'}, {'text': '.', 'speaker': 'Speaker 1'}, {'text': 'How', 'speaker': 'Speaker 1'}, {'text': 'are', 'speaker': 'Speaker 1'}, {'text': 'you', 'speaker': 'Speaker 1'}, {'text': '?', 'speaker': 'Speaker 1'}] """ sentence_ending_punctuations = ".?!" def is_word_sentence_end(word_index: Annotated[int, "Index of the word to check"]) -> Annotated[ bool, "True if the word is a sentence end, False otherwise"]: """ Checks if a word is the end of a sentence based on punctuation. This method determines whether a word at the given index marks the end of a sentence by checking if the last character of the word is a sentence-ending punctuation (e.g., '.', '!', or '?'). Parameters ---------- word_index : int Index of the word to check in the `word_speaker_mapping`. Returns ------- bool True if the word at the given index is the end of a sentence, False otherwise. """ return ( word_index >= 0 and self.word_speaker_mapping[word_index]["text"][-1] in sentence_ending_punctuations ) wsp_len = len(self.word_speaker_mapping) words_list = [wd['text'] for wd in self.word_speaker_mapping] speaker_list = [wd['speaker'] for wd in self.word_speaker_mapping] k = 0 while k < len(self.word_speaker_mapping): if ( k < wsp_len - 1 and speaker_list[k] != speaker_list[k + 1] and not is_word_sentence_end(k) ): left_idx = self._get_first_word_idx_of_sentence( k, words_list, speaker_list, max_words_in_sentence ) right_idx = ( self._get_last_word_idx_of_sentence( k, words_list, max_words_in_sentence - (k - left_idx) - 1 ) if left_idx > -1 else -1 ) if min(left_idx, right_idx) == -1: k += 1 continue spk_labels = speaker_list[left_idx:right_idx + 1] mod_speaker = max(set(spk_labels), key=spk_labels.count) if spk_labels.count(mod_speaker) < len(spk_labels) // 2: k += 1 continue speaker_list[left_idx:right_idx + 1] = [mod_speaker] * ( right_idx - left_idx + 1 ) k = right_idx k += 1 for idx in range(len(self.word_speaker_mapping)): self.word_speaker_mapping[idx]["speaker"] = speaker_list[idx] @staticmethod def _get_first_word_idx_of_sentence( word_idx: int, word_list: List[str], speaker_list: List[str], max_words: int ) -> int: """ Finds the first word index of a sentence for realignment. Parameters ---------- word_idx : int Current word index. word_list : List[str] List of words in the sentence. speaker_list : List[str] List of speakers for the words. max_words : int Maximum words to consider in the sentence. Returns ------- int The index of the first word of the sentence. Examples -------- >>> words_list = ["Hello", "world", ".", "How", "are", "you", "?"] >>> speakers_list = ["Speaker 1", "Speaker 1", "Speaker 1", "Speaker 2", "Speaker 2", "Speaker 2", "Speaker 2"] >>> WordSpeakerMapper._get_first_word_idx_of_sentence(4, word_list, speaker_list, 50) 3 """ sentence_ending_punctuations = ".?!" is_word_sentence_end = ( lambda x: x >= 0 and word_list[x][-1] in sentence_ending_punctuations ) left_idx = word_idx while ( left_idx > 0 and word_idx - left_idx < max_words and speaker_list[left_idx - 1] == speaker_list[left_idx] and not is_word_sentence_end(left_idx - 1) ): left_idx -= 1 return left_idx if left_idx == 0 or is_word_sentence_end(left_idx - 1) else -1 @staticmethod def _get_last_word_idx_of_sentence( word_idx: int, word_list: List[str], max_words: int ) -> int: """ Finds the last word index of a sentence for realignment. Parameters ---------- word_idx : int Current word index. word_list : List[str] List of words in the sentence. max_words : int Maximum words to consider in the sentence. Returns ------- int The index of the last word of the sentence. Examples -------- >>> words_list = ["Hello", "world", ".", "How", "are", "you", "?"] >>> WordSpeakerMapper._get_last_word_idx_of_sentence(3, word_list, 50) 6 """ sentence_ending_punctuations = ".?!" is_word_sentence_end = ( lambda x: x >= 0 and word_list[x][-1] in sentence_ending_punctuations ) right_idx = word_idx while ( right_idx < len(word_list) - 1 and right_idx - word_idx < max_words and not is_word_sentence_end(right_idx) ): right_idx += 1 return ( right_idx if right_idx == len(word_list) - 1 or is_word_sentence_end(right_idx) else -1 ) class SentenceSpeakerMapper: """ Groups words into sentences and assigns each sentence to a speaker. This class uses word-speaker mapping to group words into sentences based on punctuation and speaker changes. It uses the NLTK library to detect sentence boundaries. Attributes ---------- sentence_checker : Callable Function to check for sentence breaks. sentence_ending_punctuations : str String of punctuation characters that indicate sentence endings. Methods ------- get_sentences_speaker_mapping(word_speaker_mapping) Groups words into sentences and assigns each sentence to a speaker. """ def __init__(self): """ Initializes the SentenceSpeakerMapper and downloads required NLTK resources. """ nltk.download('punkt', quiet=True) self.sentence_checker = nltk.tokenize.PunktSentenceTokenizer().text_contains_sentbreak self.sentence_ending_punctuations = ".?!" def get_sentences_speaker_mapping( self, word_speaker_mapping: Annotated[List[Dict], "List of words with speaker labels"] ) -> Annotated[List[Dict], "List of sentences with speaker labels and timing information"]: """ Groups words into sentences and assigns each sentence to a speaker. Parameters ---------- word_speaker_mapping : List[Dict] List of words with speaker labels. Returns ------- List[Dict] List of sentences with speaker labels and timing information. Examples -------- >>> sentence_mapper = SentenceSpeakerMapper() >>> word_speaker_map = [ ... {'text': 'Hello', 'start_time': 0, 'end_time': 500, 'speaker': 1}, ... {'text': 'world.', 'start_time': 600, 'end_time': 1000, 'speaker': 1}, ... {'text': 'How', 'start_time': 1100, 'end_time': 1300, 'speaker': 2}, ... {'text': 'are', 'start_time': 1400, 'end_time': 1500, 'speaker': 2}, ... {'text': 'you?', 'start_time': 1600, 'end_time': 2000, 'speaker': 2}, ... ] >>> sentence_mapper.get_sentences_speaker_mapping(word_speaker_mapping) [{'speaker': 'Speaker 1', 'start_time': 0, 'end_time': 1000, 'text': 'Hello world. '}, {'speaker': 'Speaker 2', 'start_time': 1100, 'end_time': 2000, 'text': 'How are you?'}] """ snts = [] prev_spk = word_speaker_mapping[0]['speaker'] snt = { "speaker": f"Speaker {prev_spk}", "start_time": word_speaker_mapping[0]['start_time'], "end_time": word_speaker_mapping[0]['end_time'], "text": word_speaker_mapping[0]['text'] + " ", } for word_dict in word_speaker_mapping[1:]: word, spk = word_dict["text"], word_dict["speaker"] s, e = word_dict["start_time"], word_dict["end_time"] if spk != prev_spk or self.sentence_checker(snt["text"] + word): snts.append(snt) snt = { "speaker": f"Speaker {spk}", "start_time": s, "end_time": e, "text": word + " ", } else: snt["end_time"] = e snt["text"] += word + " " prev_spk = spk snts.append(snt) return snts class Audio: """ A class to handle audio file analysis and property extraction. This class provides methods to load an audio file, process it, and extract various audio properties including spectral, temporal, and perceptual features. Parameters ---------- audio_path : str Path to the audio file to be analyzed. Attributes ---------- audio_path : str Path to the audio file. extension : str File extension of the audio file. samples : int Total number of audio samples. duration : float Duration of the audio in seconds. data : np.ndarray Audio data loaded from the file. rate : int Sampling rate of the audio file. """ def __init__(self, audio_path: str): """ Initialize the Audio class with a given audio file path. Parameters ---------- audio_path : str Path to the audio file. Raises ------ TypeError If `audio_path` is not a non-empty string. FileNotFoundError If the file specified by `audio_path` does not exist. ValueError If the file has an unsupported extension or is empty. RuntimeError If there is an error reading the audio file. """ if not isinstance(audio_path, str) or not audio_path: raise TypeError("audio_path must be a non-empty string") if not os.path.isfile(audio_path): raise FileNotFoundError(f"The specified audio file does not exist: {audio_path}") valid_extensions = [".wav", ".flac", ".mp3", ".ogg", ".m4a", ".aac"] extension = os.path.splitext(audio_path)[1].lower() if extension not in valid_extensions: raise ValueError(f"File extension {extension} is not recognized as a supported audio format.") try: self.data, self.rate = sf.read(audio_path, dtype='float32') except RuntimeError as e: raise RuntimeError(f"Error reading audio file: {audio_path}") from e if len(self.data) == 0: raise ValueError(f"Audio file is empty: {audio_path}") # Convert stereo or multichannel audio to mono if len(self.data.shape) > 1 and self.data.shape[1] > 1: self.data = np.mean(self.data, axis=1) self.audio_path = audio_path self.extension = extension self.samples = len(self.data) self.duration = self.samples / self.rate def properties(self) -> Tuple[ str, str, str, int, float, float, Union[int, None], int, float, float, Dict[str, float]]: """ Extract various properties and features from the audio file. Returns ------- Tuple[str, str, str, int, float, float, Union[int, None], int, float, float, Dict[str, float]] A tuple containing: - File name (str) - File extension (str) - File path (str) - Sample rate (int) - Minimum frequency (float) - Maximum frequency (float) - Bit depth (Union[int, None]) - Number of channels (int) - Duration (float) - Root mean square loudness (float) - A dictionary of extracted properties (Dict[str, float]) Notes ----- Properties extracted include: - Spectral bands energy - Zero Crossing Rate (ZCR) - Spectral Centroid - MFCCs (Mel Frequency Cepstral Coefficients) Examples -------- >>> audio = Audio("sample.wav") >>> audio.properties() ('sample.wav', '.wav', '/path/to/sample.wav', 44100, 20.0, 20000.0, 16, 2, 5.2, 0.25, {...}) """ bands = [(20, 250), (250, 2000), (2000, 6000), (6000, 20000)] x = fft(self.data) xf = fftfreq(self.samples, 1 / self.rate) nonzero_indices = np.where(xf != 0)[0] min_freq = np.min(np.abs(xf[nonzero_indices])) max_freq = np.max(np.abs(xf)) bit_depth = None if self.extension == ".wav": with wave.open(self.audio_path, 'r') as wav_file: bit_depth = wav_file.getsampwidth() * 8 channels = wav_file.getnchannels() else: info = sf.info(self.audio_path) channels = info.channels duration = float(self.duration) loudness = np.sqrt(np.mean(self.data ** 2)) s = np.abs(x) freqs = xf eq_properties = {} for band in bands: band_mask = (freqs >= band[0]) & (freqs <= band[1]) band_data = s[band_mask] band_energy = np.mean(band_data ** 2, axis=0) if band_data.size > 0 else 0 eq_properties[f"EQ_{band[0]}_{band[1]}_Hz"] = band_energy zcr = np.sum(np.abs(np.diff(np.sign(self.data)))) / len(self.data) magnitude_spectrum = np.abs(np.fft.rfft(self.data)) freqs_centroid = np.fft.rfftfreq(len(self.data), 1.0 / self.rate) spectral_centroid = (np.sum(freqs_centroid * magnitude_spectrum) / np.sum(magnitude_spectrum)) if np.sum(magnitude_spectrum) != 0 else 0.0 mfccs = mfcc(y=self.data, sr=self.rate, n_mfcc=13) mfcc_mean = np.mean(mfccs, axis=1) eq_properties["RMSLoudness"] = float(loudness) eq_properties["ZeroCrossingRate"] = float(zcr) eq_properties["SpectralCentroid"] = float(spectral_centroid) for i, val in enumerate(mfcc_mean): eq_properties[f"MFCC_{i + 1}"] = float(val) eq_properties_converted = {key: float(value) for key, value in eq_properties.items()} file_name = os.path.basename(self.audio_path) path = os.path.abspath(self.audio_path) bit_depth = int(bit_depth) if bit_depth is not None else None channels = int(channels) if channels is not None else 1 return ( file_name, self.extension, path, int(self.rate), float(min_freq), float(max_freq), bit_depth, channels, float(duration), float(loudness), eq_properties_converted ) if __name__ == "__main__": words_timestamp = [ {'text': 'Hello', 'start': 0.0, 'end': 1.2}, {'text': 'world', 'start': 1.3, 'end': 2.0} ] speaker_timestamp = [ [0.0, 1.5, 1], [1.6, 3.0, 2] ] word_sentence_mapper = WordSpeakerMapper(words_timestamp, speaker_timestamp) word_speaker_maps = word_sentence_mapper.get_words_speaker_mapping() print("Word-Speaker Mapping:") print(word_speaker_maps)