CallyticsDemo / src /audio /analysis.py
bunyaminergen's picture
Initial
1b97239
# Standard library imports
import os
import wave
from typing import List, Dict, Annotated, Union, Tuple
# Related third-party imports
import nltk
import numpy as np
import soundfile as sf
from librosa.feature import mfcc
from scipy.fft import fft, fftfreq
class WordSpeakerMapper:
"""
Maps words to speakers based on timestamps and aligns speaker tags after punctuation restoration.
This class processes word timing information and assigns each word to a speaker
based on the provided speaker timestamps. Missing timestamps are handled, and each
word can be aligned to a speaker based on different reference points ('start', 'mid', or 'end').
After punctuation restoration, word-speaker mapping can be realigned to ensure consistency
within a sentence.
Attributes
----------
word_timestamps : List[Dict]
List of word timing information with 'start', 'end', and 'text' keys.
speaker_timestamps : List[List[int]]
List of speaker segments, where each segment contains [start_time, end_time, speaker_id].
word_speaker_mapping : List[Dict] or None
Processed word-to-speaker mappings.
Methods
-------
filter_missing_timestamps(word_timestamps, initial_timestamp=0, final_timestamp=None)
Fills in missing start and end timestamps in word timing data.
get_words_speaker_mapping(word_anchor_option='start')
Maps words to speakers based on word and speaker timestamps.
"""
def __init__(
self,
word_timestamps: Annotated[List[Dict], "List of word timing information"],
speaker_timestamps: Annotated[List[List[Union[int, float]]], "List of speaker segments"],
):
"""
Initializes the WordSpeakerMapper with word and speaker timestamps.
Parameters
----------
word_timestamps : List[Dict]
List of word timing information.
speaker_timestamps : List[List[int]]
List of speaker segments.
"""
self.word_timestamps = self.filter_missing_timestamps(word_timestamps)
self.speaker_timestamps = speaker_timestamps
self.word_speaker_mapping = None
def filter_missing_timestamps(
self,
word_timestamps: Annotated[List[Dict], "List of word timing information"],
initial_timestamp: Annotated[int, "Start time of the first word"] = 0,
final_timestamp: Annotated[int, "End time of the last word"] = None
) -> Annotated[List[Dict], "List of word timestamps with missing values filled"]:
"""
Fills in missing start and end timestamps.
Parameters
----------
word_timestamps : List[Dict]
List of word timing information that may contain missing timestamps.
initial_timestamp : int, optional
Start time of the first word, default is 0.
final_timestamp : int, optional
End time of the last word, if available.
Returns
-------
List[Dict]
List of word timestamps with missing values filled.
Examples
--------
>>> word_timestamp = [{'text': 'Hello', 'end': 1.2}]
>>> mapper = WordSpeakerMapper([], [])
>>> mapper.filter_missing_timestamps(word_timestamps)
[{'text': 'Hello', 'start': 0, 'end': 1.2}]
"""
if word_timestamps[0].get("start") is None:
word_timestamps[0]["start"] = initial_timestamp
word_timestamps[0]["end"] = self._get_next_start_timestamp(word_timestamps, 0, final_timestamp)
result = [word_timestamps[0]]
for i, ws in enumerate(word_timestamps[1:], start=1):
if "text" not in ws:
continue
if ws.get("start") is None:
ws["start"] = word_timestamps[i - 1]["end"]
ws["end"] = self._get_next_start_timestamp(word_timestamps, i, final_timestamp)
if ws["text"] is not None:
result.append(ws)
return result
@staticmethod
def _get_next_start_timestamp(
word_timestamps: Annotated[List[Dict], "List of word timing information"],
current_word_index: Annotated[int, "Index of the current word"],
final_timestamp: Annotated[int, "Final timestamp if needed"]
) -> Annotated[int, "Next start timestamp for filling missing values"]:
"""
Finds the next start timestamp to fill in missing values.
Parameters
----------
word_timestamps : List[Dict]
List of word timing information.
current_word_index : int
Index of the current word.
final_timestamp : int, optional
Final timestamp to use if no next timestamp is found.
Returns
-------
int
Next start timestamp for filling missing values.
Examples
--------
>>> word_timestamp = [{'start': 0.5, 'text': 'Hello'}, {'end': 2.0}]
>>> mapper = WordSpeakerMapper([], [])
>>> mapper._get_next_start_timestamp(word_timestamps, 0, 2)
"""
if current_word_index == len(word_timestamps) - 1:
return word_timestamps[current_word_index]["start"]
next_word_index = current_word_index + 1
while next_word_index < len(word_timestamps):
if word_timestamps[next_word_index].get("start") is None:
word_timestamps[current_word_index]["text"] += (
" " + word_timestamps[next_word_index]["text"]
)
word_timestamps[next_word_index]["text"] = None
next_word_index += 1
if next_word_index == len(word_timestamps):
return final_timestamp
else:
return word_timestamps[next_word_index]["start"]
return final_timestamp
def get_words_speaker_mapping(self, word_anchor_option='start') -> List[Dict]:
"""
Maps words to speakers based on their timestamps.
Parameters
----------
word_anchor_option : str, optional
Anchor point for word mapping ('start', 'mid', or 'end'), default is 'start'.
Returns
-------
List[Dict]
List of word-to-speaker mappings with timestamps and speaker IDs.
Examples
--------
>>> word_timestamps = [{'start': 0.5, 'end': 1.2, 'text': 'Hello'}]
>>> speaker_timestamps = [[0, 1000, 1]]
>>> mapper = WordSpeakerMapper(word_timestamps, speaker_timestamps)
>>> mapper.get_words_speaker_mapping()
[{'text': 'Hello', 'start_time': 500, 'end_time': 1200, 'speaker': 1}]
"""
def get_word_ts_anchor(start: int, end: int, option: str) -> int:
"""
Determines the anchor timestamp for a word.
Parameters
----------
start : int
Start time of the word in milliseconds.
end : int
End time of the word in milliseconds.
option : str
Anchor point for timestamp calculation ('start', 'mid', or 'end').
Returns
-------
int
Anchor timestamp for the word.
Examples
--------
>>> get_word_ts_anchor(500, 1200, 'mid')
850
"""
if option == "end":
return end
elif option == "mid":
return (start + end) // 2
return start
wrd_spk_mapping = []
turn_idx = 0
num_speaker_ts = len(self.speaker_timestamps)
for wrd_dict in self.word_timestamps:
ws, we, wrd = (
int(wrd_dict["start"] * 1000),
int(wrd_dict["end"] * 1000),
wrd_dict["text"],
)
wrd_pos = get_word_ts_anchor(ws, we, word_anchor_option)
sp = -1
while turn_idx < num_speaker_ts and wrd_pos > self.speaker_timestamps[turn_idx][1]:
turn_idx += 1
if turn_idx < num_speaker_ts and self.speaker_timestamps[turn_idx][0] <= wrd_pos <= \
self.speaker_timestamps[turn_idx][1]:
sp = self.speaker_timestamps[turn_idx][2]
elif turn_idx > 0:
sp = self.speaker_timestamps[turn_idx - 1][2]
wrd_spk_mapping.append(
{"text": wrd, "start_time": ws, "end_time": we, "speaker": sp}
)
self.word_speaker_mapping = wrd_spk_mapping
return self.word_speaker_mapping
def realign_with_punctuation(self, max_words_in_sentence: int = 50) -> None:
"""
Realigns word-speaker mapping after punctuation restoration.
This method ensures consistent speaker mapping within sentences by analyzing
punctuation and adjusting speaker labels for words that are part of the same sentence.
Parameters
----------
max_words_in_sentence : int, optional
Maximum number of words to consider for realignment in a sentence,
default is 50.
Examples
--------
>>> word_speaker_mapping = [
... {"text": "Hello", "speaker": "Speaker 1"},
... {"text": "world", "speaker": "Speaker 2"},
... {"text": ".", "speaker": "Speaker 2"},
... {"text": "How", "speaker": "Speaker 1"},
... {"text": "are", "speaker": "Speaker 1"},
... {"text": "you", "speaker": "Speaker 2"},
... {"text": "?", "speaker": "Speaker 2"}
... ]
>>> mapper = WordSpeakerMapper([], [])
>>> mapper.word_speaker_mapping = word_speaker_mapping
>>> mapper.realign_with_punctuation()
>>> print(mapper.word_speaker_mapping)
[{'text': 'Hello', 'speaker': 'Speaker 1'},
{'text': 'world', 'speaker': 'Speaker 1'},
{'text': '.', 'speaker': 'Speaker 1'},
{'text': 'How', 'speaker': 'Speaker 1'},
{'text': 'are', 'speaker': 'Speaker 1'},
{'text': 'you', 'speaker': 'Speaker 1'},
{'text': '?', 'speaker': 'Speaker 1'}]
"""
sentence_ending_punctuations = ".?!"
def is_word_sentence_end(word_index: Annotated[int, "Index of the word to check"]) -> Annotated[
bool, "True if the word is a sentence end, False otherwise"]:
"""
Checks if a word is the end of a sentence based on punctuation.
This method determines whether a word at the given index marks
the end of a sentence by checking if the last character of the
word is a sentence-ending punctuation (e.g., '.', '!', or '?').
Parameters
----------
word_index : int
Index of the word to check in the `word_speaker_mapping`.
Returns
-------
bool
True if the word at the given index is the end of a sentence,
False otherwise.
"""
return (
word_index >= 0
and self.word_speaker_mapping[word_index]["text"][-1] in sentence_ending_punctuations
)
wsp_len = len(self.word_speaker_mapping)
words_list = [wd['text'] for wd in self.word_speaker_mapping]
speaker_list = [wd['speaker'] for wd in self.word_speaker_mapping]
k = 0
while k < len(self.word_speaker_mapping):
if (
k < wsp_len - 1
and speaker_list[k] != speaker_list[k + 1]
and not is_word_sentence_end(k)
):
left_idx = self._get_first_word_idx_of_sentence(
k, words_list, speaker_list, max_words_in_sentence
)
right_idx = (
self._get_last_word_idx_of_sentence(
k, words_list, max_words_in_sentence - (k - left_idx) - 1
)
if left_idx > -1
else -1
)
if min(left_idx, right_idx) == -1:
k += 1
continue
spk_labels = speaker_list[left_idx:right_idx + 1]
mod_speaker = max(set(spk_labels), key=spk_labels.count)
if spk_labels.count(mod_speaker) < len(spk_labels) // 2:
k += 1
continue
speaker_list[left_idx:right_idx + 1] = [mod_speaker] * (
right_idx - left_idx + 1
)
k = right_idx
k += 1
for idx in range(len(self.word_speaker_mapping)):
self.word_speaker_mapping[idx]["speaker"] = speaker_list[idx]
@staticmethod
def _get_first_word_idx_of_sentence(
word_idx: int, word_list: List[str], speaker_list: List[str], max_words: int
) -> int:
"""
Finds the first word index of a sentence for realignment.
Parameters
----------
word_idx : int
Current word index.
word_list : List[str]
List of words in the sentence.
speaker_list : List[str]
List of speakers for the words.
max_words : int
Maximum words to consider in the sentence.
Returns
-------
int
The index of the first word of the sentence.
Examples
--------
>>> words_list = ["Hello", "world", ".", "How", "are", "you", "?"]
>>> speakers_list = ["Speaker 1", "Speaker 1", "Speaker 1", "Speaker 2", "Speaker 2", "Speaker 2", "Speaker 2"]
>>> WordSpeakerMapper._get_first_word_idx_of_sentence(4, word_list, speaker_list, 50)
3
"""
sentence_ending_punctuations = ".?!"
is_word_sentence_end = (
lambda x: x >= 0 and word_list[x][-1] in sentence_ending_punctuations
)
left_idx = word_idx
while (
left_idx > 0
and word_idx - left_idx < max_words
and speaker_list[left_idx - 1] == speaker_list[left_idx]
and not is_word_sentence_end(left_idx - 1)
):
left_idx -= 1
return left_idx if left_idx == 0 or is_word_sentence_end(left_idx - 1) else -1
@staticmethod
def _get_last_word_idx_of_sentence(
word_idx: int, word_list: List[str], max_words: int
) -> int:
"""
Finds the last word index of a sentence for realignment.
Parameters
----------
word_idx : int
Current word index.
word_list : List[str]
List of words in the sentence.
max_words : int
Maximum words to consider in the sentence.
Returns
-------
int
The index of the last word of the sentence.
Examples
--------
>>> words_list = ["Hello", "world", ".", "How", "are", "you", "?"]
>>> WordSpeakerMapper._get_last_word_idx_of_sentence(3, word_list, 50)
6
"""
sentence_ending_punctuations = ".?!"
is_word_sentence_end = (
lambda x: x >= 0 and word_list[x][-1] in sentence_ending_punctuations
)
right_idx = word_idx
while (
right_idx < len(word_list) - 1
and right_idx - word_idx < max_words
and not is_word_sentence_end(right_idx)
):
right_idx += 1
return (
right_idx
if right_idx == len(word_list) - 1 or is_word_sentence_end(right_idx)
else -1
)
class SentenceSpeakerMapper:
"""
Groups words into sentences and assigns each sentence to a speaker.
This class uses word-speaker mapping to group words into sentences based on punctuation
and speaker changes. It uses the NLTK library to detect sentence boundaries.
Attributes
----------
sentence_checker : Callable
Function to check for sentence breaks.
sentence_ending_punctuations : str
String of punctuation characters that indicate sentence endings.
Methods
-------
get_sentences_speaker_mapping(word_speaker_mapping)
Groups words into sentences and assigns each sentence to a speaker.
"""
def __init__(self):
"""
Initializes the SentenceSpeakerMapper and downloads required NLTK resources.
"""
nltk.download('punkt', quiet=True)
self.sentence_checker = nltk.tokenize.PunktSentenceTokenizer().text_contains_sentbreak
self.sentence_ending_punctuations = ".?!"
def get_sentences_speaker_mapping(
self,
word_speaker_mapping: Annotated[List[Dict], "List of words with speaker labels"]
) -> Annotated[List[Dict], "List of sentences with speaker labels and timing information"]:
"""
Groups words into sentences and assigns each sentence to a speaker.
Parameters
----------
word_speaker_mapping : List[Dict]
List of words with speaker labels.
Returns
-------
List[Dict]
List of sentences with speaker labels and timing information.
Examples
--------
>>> sentence_mapper = SentenceSpeakerMapper()
>>> word_speaker_map = [
... {'text': 'Hello', 'start_time': 0, 'end_time': 500, 'speaker': 1},
... {'text': 'world.', 'start_time': 600, 'end_time': 1000, 'speaker': 1},
... {'text': 'How', 'start_time': 1100, 'end_time': 1300, 'speaker': 2},
... {'text': 'are', 'start_time': 1400, 'end_time': 1500, 'speaker': 2},
... {'text': 'you?', 'start_time': 1600, 'end_time': 2000, 'speaker': 2},
... ]
>>> sentence_mapper.get_sentences_speaker_mapping(word_speaker_mapping)
[{'speaker': 'Speaker 1', 'start_time': 0, 'end_time': 1000, 'text': 'Hello world. '},
{'speaker': 'Speaker 2', 'start_time': 1100, 'end_time': 2000, 'text': 'How are you?'}]
"""
snts = []
prev_spk = word_speaker_mapping[0]['speaker']
snt = {
"speaker": f"Speaker {prev_spk}",
"start_time": word_speaker_mapping[0]['start_time'],
"end_time": word_speaker_mapping[0]['end_time'],
"text": word_speaker_mapping[0]['text'] + " ",
}
for word_dict in word_speaker_mapping[1:]:
word, spk = word_dict["text"], word_dict["speaker"]
s, e = word_dict["start_time"], word_dict["end_time"]
if spk != prev_spk or self.sentence_checker(snt["text"] + word):
snts.append(snt)
snt = {
"speaker": f"Speaker {spk}",
"start_time": s,
"end_time": e,
"text": word + " ",
}
else:
snt["end_time"] = e
snt["text"] += word + " "
prev_spk = spk
snts.append(snt)
return snts
class Audio:
"""
A class to handle audio file analysis and property extraction.
This class provides methods to load an audio file, process it, and
extract various audio properties including spectral, temporal, and
perceptual features.
Parameters
----------
audio_path : str
Path to the audio file to be analyzed.
Attributes
----------
audio_path : str
Path to the audio file.
extension : str
File extension of the audio file.
samples : int
Total number of audio samples.
duration : float
Duration of the audio in seconds.
data : np.ndarray
Audio data loaded from the file.
rate : int
Sampling rate of the audio file.
"""
def __init__(self, audio_path: str):
"""
Initialize the Audio class with a given audio file path.
Parameters
----------
audio_path : str
Path to the audio file.
Raises
------
TypeError
If `audio_path` is not a non-empty string.
FileNotFoundError
If the file specified by `audio_path` does not exist.
ValueError
If the file has an unsupported extension or is empty.
RuntimeError
If there is an error reading the audio file.
"""
if not isinstance(audio_path, str) or not audio_path:
raise TypeError("audio_path must be a non-empty string")
if not os.path.isfile(audio_path):
raise FileNotFoundError(f"The specified audio file does not exist: {audio_path}")
valid_extensions = [".wav", ".flac", ".mp3", ".ogg", ".m4a", ".aac"]
extension = os.path.splitext(audio_path)[1].lower()
if extension not in valid_extensions:
raise ValueError(f"File extension {extension} is not recognized as a supported audio format.")
try:
self.data, self.rate = sf.read(audio_path, dtype='float32')
except RuntimeError as e:
raise RuntimeError(f"Error reading audio file: {audio_path}") from e
if len(self.data) == 0:
raise ValueError(f"Audio file is empty: {audio_path}")
# Convert stereo or multichannel audio to mono
if len(self.data.shape) > 1 and self.data.shape[1] > 1:
self.data = np.mean(self.data, axis=1)
self.audio_path = audio_path
self.extension = extension
self.samples = len(self.data)
self.duration = self.samples / self.rate
def properties(self) -> Tuple[
str, str, str, int, float, float, Union[int, None], int, float, float, Dict[str, float]]:
"""
Extract various properties and features from the audio file.
Returns
-------
Tuple[str, str, str, int, float, float, Union[int, None], int, float, float, Dict[str, float]]
A tuple containing:
- File name (str)
- File extension (str)
- File path (str)
- Sample rate (int)
- Minimum frequency (float)
- Maximum frequency (float)
- Bit depth (Union[int, None])
- Number of channels (int)
- Duration (float)
- Root mean square loudness (float)
- A dictionary of extracted properties (Dict[str, float])
Notes
-----
Properties extracted include:
- Spectral bands energy
- Zero Crossing Rate (ZCR)
- Spectral Centroid
- MFCCs (Mel Frequency Cepstral Coefficients)
Examples
--------
>>> audio = Audio("sample.wav")
>>> audio.properties()
('sample.wav', '.wav', '/path/to/sample.wav', 44100, 20.0, 20000.0, 16, 2, 5.2, 0.25, {...})
"""
bands = [(20, 250), (250, 2000), (2000, 6000), (6000, 20000)]
x = fft(self.data)
xf = fftfreq(self.samples, 1 / self.rate)
nonzero_indices = np.where(xf != 0)[0]
min_freq = np.min(np.abs(xf[nonzero_indices]))
max_freq = np.max(np.abs(xf))
bit_depth = None
if self.extension == ".wav":
with wave.open(self.audio_path, 'r') as wav_file:
bit_depth = wav_file.getsampwidth() * 8
channels = wav_file.getnchannels()
else:
info = sf.info(self.audio_path)
channels = info.channels
duration = float(self.duration)
loudness = np.sqrt(np.mean(self.data ** 2))
s = np.abs(x)
freqs = xf
eq_properties = {}
for band in bands:
band_mask = (freqs >= band[0]) & (freqs <= band[1])
band_data = s[band_mask]
band_energy = np.mean(band_data ** 2, axis=0) if band_data.size > 0 else 0
eq_properties[f"EQ_{band[0]}_{band[1]}_Hz"] = band_energy
zcr = np.sum(np.abs(np.diff(np.sign(self.data)))) / len(self.data)
magnitude_spectrum = np.abs(np.fft.rfft(self.data))
freqs_centroid = np.fft.rfftfreq(len(self.data), 1.0 / self.rate)
spectral_centroid = (np.sum(freqs_centroid * magnitude_spectrum) /
np.sum(magnitude_spectrum)) if np.sum(magnitude_spectrum) != 0 else 0.0
mfccs = mfcc(y=self.data, sr=self.rate, n_mfcc=13)
mfcc_mean = np.mean(mfccs, axis=1)
eq_properties["RMSLoudness"] = float(loudness)
eq_properties["ZeroCrossingRate"] = float(zcr)
eq_properties["SpectralCentroid"] = float(spectral_centroid)
for i, val in enumerate(mfcc_mean):
eq_properties[f"MFCC_{i + 1}"] = float(val)
eq_properties_converted = {key: float(value) for key, value in eq_properties.items()}
file_name = os.path.basename(self.audio_path)
path = os.path.abspath(self.audio_path)
bit_depth = int(bit_depth) if bit_depth is not None else None
channels = int(channels) if channels is not None else 1
return (
file_name,
self.extension,
path,
int(self.rate),
float(min_freq),
float(max_freq),
bit_depth,
channels,
float(duration),
float(loudness),
eq_properties_converted
)
if __name__ == "__main__":
words_timestamp = [
{'text': 'Hello', 'start': 0.0, 'end': 1.2},
{'text': 'world', 'start': 1.3, 'end': 2.0}
]
speaker_timestamp = [
[0.0, 1.5, 1],
[1.6, 3.0, 2]
]
word_sentence_mapper = WordSpeakerMapper(words_timestamp, speaker_timestamp)
word_speaker_maps = word_sentence_mapper.get_words_speaker_mapping()
print("Word-Speaker Mapping:")
print(word_speaker_maps)