+

+
+
+
+
+
+
+[](https://linkedin.com/in/bunyaminergen)
+
+# Callytics
+
+`Callytics` is an advanced call analytics solution that leverages speech recognition and large language models (LLMs)
+technologies to analyze phone conversations from customer service and call centers. By processing both the
+audio and text of each call, it provides insights such as sentiment analysis, topic detection, conflict detection,
+profanity word detection and summary. These cutting-edge techniques help businesses optimize customer interactions,
+identify areas for improvement, and enhance overall service quality.
+
+When an audio file is placed in the `.data/input` directory, the entire pipeline automatically starts running, and the
+resulting data is inserted into the database.
+
+**Note**: _This is only a `v1.1.0` version; many new features will be added, models
+will be fine-tuned or trained from scratch, and various optimization efforts will be applied. For more information,
+you can check out the [Upcoming](#upcoming) section._
+
+**Note**: _If you would like to contribute to this repository,
+please read the [CONTRIBUTING](.docs/documentation/CONTRIBUTING.md) first._
+
+
+
+---
+
+### Table of Contents
+
+- [Prerequisites](#prerequisites)
+- [Architecture](#architecture)
+- [Math And Algorithm](#math-and-algorithm)
+- [Features](#features)
+- [Demo](#demo)
+- [Installation](#installation)
+- [File Structure](#file-structure)
+- [Database Structure](#database-structure)
+- [Datasets](#datasets)
+- [Version Control System](#version-control-system)
+- [Upcoming](#upcoming)
+- [Documentations](#documentations)
+- [License](#licence)
+- [Links](#links)
+- [Team](#team)
+- [Contact](#contact)
+- [Citation](#citation)
+
+---
+
+### Prerequisites
+
+##### General
+
+- `Python 3.11` _(or above)_
+
+##### Llama
+
+- `GPU (min 24GB)` _(or above)_
+- `Hugging Face Credentials (Account, Token)`
+- `Llama-3.2-11B-Vision-Instruct` _(or above)_
+
+##### OpenAI
+
+- `GPU (min 12GB)` _(for other process such as `faster whisper` & `NeMo`)_
+- At least one of the following is required:
+ - `OpenAI Credentials (Account, API Key)`
+ - `Azure OpenAI Credentials (Account, API Key, API Base URL)`
+
+---
+
+### Architecture
+
+
+
+---
+
+### Math and Algorithm
+
+This section describes the mathematical models and algorithms used in the project.
+
+_**Note**: The mathematical concepts and algorithms specific to this repository, rather than the models used, will be
+provided in this section. Please refer to the `RESOURCES` under the [Documentations](#documentations) section for the
+repositories and models utilized or referenced._
+
+##### Silence Duration Calculation
+
+The silence durations are derived from the time intervals between speech segments:
+
+$$S = \{s_1, s_2, \ldots, s_n\}$$
+
+represent _the set of silence durations (in seconds)_ between consecutive speech segments.
+
+- **A user-defined factor**:
+
+$$\text{factor} \in \mathbb{R}^{+}$$
+
+To determine a threshold that distinguishes _significant_ silence from trivial gaps, two statistical methods can be
+applied:
+
+**1. Standard Deviation-Based Threshold**
+
+- _Mean_:
+
+$$\mu = \frac{1}{n}\sum_{i=1}^{n}s_i$$
+
+- _Standard Deviation_:
+
+$$
+\sigma = \sqrt{\frac{1}{n}\sum_{i=1}^{n}(s_i - \mu)^2}
+$$
+
+- _Threshold_:
+
+$$
+T_{\text{std}} = \sigma \cdot \text{factor}
+$$
+
+**2. Median + Interquartile Range (IQR) Threshold**
+
+- _Median_:
+
+_Let:_
+
+$$ S = \{s_{(1)} \leq s_{(2)} \leq \cdots \leq s_{(n)}\} $$
+
+be an ordered set.
+
+_Then:_
+
+$$
+M = \text{median}(S) =
+\begin{cases}
+s_{\frac{n+1}{2}}, & \text{if } n \text{ is odd}, \\\\[6pt]
+\frac{s_{\frac{n}{2}} + s_{\frac{n}{2}+1}}{2}, & \text{if } n \text{ is even}.
+\end{cases}
+$$
+
+- _Quartiles:_
+
+$$
+Q_1 = s_{(\lfloor 0.25n \rfloor)}, \quad Q_3 = s_{(\lfloor 0.75n \rfloor)}
+$$
+
+- _IQR_:
+
+$$
+\text{IQR} = Q_3 - Q_1
+$$
+
+- **Threshold:**
+
+$$
+T_{\text{median\\_iqr}} = M + (\text{IQR} \times \text{factor})
+$$
+
+**Total Silence Above Threshold**
+
+Once the threshold
+
+$$T$$
+
+either
+
+$$T_{\text{std}}$$
+
+or
+
+$$T_{\text{median\\_iqr}}$$
+
+is defined, we sum only those silence durations that meet or exceed this threshold:
+
+$$
+\text{TotalSilence} = \sum_{i=1}^{n} s_i \cdot \mathbf{1}(s_i \geq T)
+$$
+
+where $$\mathbf{1}(s_i \geq T)$$ is an indicator function defined as:
+
+$$
+\mathbf{1}(s_i \geq T) =
+\begin{cases}
+1 & \text{if } s_i \geq T \\
+0 & \text{otherwise}
+\end{cases}
+$$
+
+**Summary:**
+
+- **Identify the silence durations:**
+
+$$
+S = \{s_1, s_2, \ldots, s_n\}
+$$
+
+- **Determine the threshold using either:**
+
+_Standard deviation-based:_
+
+$$
+T = \sigma \cdot \text{factor}
+$$
+
+_Median+IQR-based:_
+
+$$
+T = M + (\text{IQR} \cdot \text{factor})
+$$
+
+- **Compute the total silence above this threshold:**
+
+$$
+\text{TotalSilence} = \sum_{i=1}^{n} s_i \cdot \mathbf{1}(s_i \geq T)
+$$
+
+---
+
+### Features
+
+- [x] Speech Enhancement
+- [x] Sentiment Analysis
+- [x] Profanity Word Detection
+- [x] Summary
+- [x] Conflict Detection
+- [x] Topic Detection
+
+---
+
+### Demo
+
+
+
+---
+
+### Installation
+
+##### Linux/Ubuntu
+
+```bash
+sudo apt update -y && sudo apt upgrade -y
+```
+
+```bash
+sudo apt install ffmpeg -y
+```
+
+```bash
+sudo apt install -y ffmpeg build-essential g++
+```
+
+```bash
+git clone https://github.com/bunyaminergen/Callytics
+```
+
+```bash
+cd Callytics
+```
+
+```bash
+conda env create -f environment.yaml
+```
+
+```bash
+conda activate Callytics
+```
+
+##### Environment
+
+`.env` file sample:
+
+```Text
+# CREDENTIALS
+# OPENAI
+OPENAI_API_KEY=
+
+# HUGGINGFACE
+HUGGINGFACE_TOKEN=
+
+# AZURE OPENAI
+AZURE_OPENAI_API_KEY=
+AZURE_OPENAI_API_BASE=
+AZURE_OPENAI_API_VERSION=
+
+# DATABASE
+DB_NAME=
+DB_USER=
+DB_PASSWORD=
+DB_HOST=
+DB_PORT=
+DB_URL=
+```
+
+---
+
+##### Database
+
+_In this section, an `example database` and `tables` are provided. It is a `well-structured` and `simple design`. If you
+create the tables
+and columns in the same structure in your remote database, you will not encounter errors in the code. However, if you
+want to change the database structure, you will also need to refactor the code._
+
+*Note*: __Refer to the [Database Structure](#database-structure) section for the database schema and tables.__
+
+```bash
+sqlite3 .db/Callytics.sqlite < src/db/sql/Schema.sql
+```
+
+##### Grafana
+
+_In this section, it is explained how to install `Grafana` on your `local` environment. Since Grafana is a third-party
+open-source monitoring application, you must handle its installation yourself and connect your database. Of course, you
+can also use it with `Granafa Cloud` instead of `local` environment._
+
+```bash
+sudo apt update -y && sudo apt upgrade -y
+```
+
+```bash
+sudo apt install -y apt-transport-https software-properties-common wget
+```
+
+```bash
+wget -q -O - https://packages.grafana.com/gpg.key | sudo apt-key add -
+```
+
+```bash
+echo "deb https://packages.grafana.com/oss/deb stable main" | sudo tee /etc/apt/sources.list.d/grafana.list
+```
+
+```bash
+sudo apt install -y grafana
+```
+
+```bash
+sudo systemctl start grafana-server
+sudo systemctl enable grafana-server
+sudo systemctl daemon-reload
+```
+
+```bash
+http://localhost:3000
+```
+
+**SQLite Plugin**
+
+```bash
+sudo grafana-cli plugins install frser-sqlite-datasource
+```
+
+```bash
+sudo systemctl restart grafana-server
+```
+
+```bash
+sudo systemctl daemon-reload
+```
+
+### File Structure
+
+```Text
+.
+├── automation
+│ └── service
+│ └── callytics.service
+├── config
+│ ├── config.yaml
+│ ├── nemo
+│ │ └── diar_infer_telephonic.yaml
+│ └── prompt.yaml
+├── .data
+│ ├── example
+│ │ └── LogisticsCallCenterConversation.mp3
+│ └── input
+├── .db
+│ └── Callytics.sqlite
+├── .docs
+│ ├── documentation
+│ │ ├── CONTRIBUTING.md
+│ │ └── RESOURCES.md
+│ └── img
+│ ├── Callytics.drawio
+│ ├── Callytics.gif
+│ ├── CallyticsIcon.png
+│ ├── Callytics.png
+│ ├── Callytics.svg
+│ └── database.png
+├── .env
+├── environment.yaml
+├── .gitattributes
+├── .github
+│ └── CODEOWNERS
+├── .gitignore
+├── LICENSE
+├── main.py
+├── README.md
+├── requirements.txt
+└── src
+ ├── audio
+ │ ├── alignment.py
+ │ ├── analysis.py
+ │ ├── effect.py
+ │ ├── error.py
+ │ ├── io.py
+ │ ├── metrics.py
+ │ ├── preprocessing.py
+ │ ├── processing.py
+ │ └── utils.py
+ ├── db
+ │ ├── manager.py
+ │ └── sql
+ │ ├── AudioPropertiesInsert.sql
+ │ ├── Schema.sql
+ │ ├── TopicFetch.sql
+ │ ├── TopicInsert.sql
+ │ └── UtteranceInsert.sql
+ ├── text
+ │ ├── llm.py
+ │ ├── model.py
+ │ ├── prompt.py
+ │ └── utils.py
+ └── utils
+ └── utils.py
+
+19 directories, 43 files
+```
+
+---
+
+### Database Structure
+
+
+
+
+---
+
+### Datasets
+
+- [Callytics Speaker Verification Dataset *(CSVD)*](.data/groundtruth/speakerverification/DatasetCard.md)
+
+---
+
+### Version Control System
+
+##### Releases
+
+- [v1.0.0](https://github.com/bunyaminergen/Callytics/archive/refs/tags/v1.0.0.zip) _.zip_
+- [v1.0.0](https://github.com/bunyaminergen/Callytics/archive/refs/tags/v1.0.0.tar.gz) _.tar.gz_
+
+
+- [v1.1.0](https://github.com/bunyaminergen/Callytics/archive/refs/tags/v1.1.0.zip) _.zip_
+- [v1.1.0](https://github.com/bunyaminergen/Callytics/archive/refs/tags/v1.1.0.tar.gz) _.tar.gz_
+
+##### Branches
+
+- [main](https://github.com/bunyaminergen/Callytics/tree/main)
+- [develop](https://github.com/bunyaminergen/Callytics/tree/develop)
+
---
-title: CallyticsDemo
-emoji: 🏢
-colorFrom: blue
-colorTo: yellow
-sdk: gradio
-sdk_version: 5.23.3
-app_file: app.py
-pinned: false
-license: gpl-3.0
-short_description: CallyticsDemo
+
+### Upcoming
+
+- [ ] **Speech Emotion Recognition:** Develop a model to automatically detect emotions from speech data.
+- [ ] **New Forced Alignment Model:** Train a forced alignment model from scratch.
+- [ ] **New Vocal Separation Model:** Train a vocal separation model from scratch.
+- [ ] **Unit Tests:** Add a comprehensive unit testing script to validate functionality.
+- [ ] **Logging Logic:** Implement a more comprehensive and structured logging mechanism.
+- [ ] **Warnings:** Add meaningful and detailed warning messages for better user guidance.
+- [ ] **Real-Time Analysis:** Enable real-time analysis capabilities within the system.
+- [ ] **Dockerization:** Containerize the repository to ensure seamless deployment and environment consistency.
+- [ ] **New Transcription Models:** Integrate and test new transcription models
+ suchas [AIOLA’s Multi-Head Speech Recognition Model](https://venturebeat.com/ai/aiola-drops-ultra-fast-multi-head-speech-recognition-model-beats-openai-whisper/).
+- [ ] **Noise Reduction Model:** Identify, test, and integrate a deep learning-based noise reduction model. Consider
+ existing models like **Facebook Research Denoiser**, **Noise2Noise**, **Audio Denoiser CNN**. Write test scripts for
+ evaluation, and if necessary, train a new model for optimal performance.
+
+##### Considerations
+
+- [ ] Detect CSR's identity via Voice Recognition/Identification instead of Diarization and LLM.
+- [ ] Transform the code structure into a pipeline for better modularity and scalability.
+- [ ] Publish the repository as a Python package on **PyPI** for wider distribution.
+- [ ] Convert the repository into a Linux package to support Linux-based systems.
+- [ ] Implement a two-step processing workflow: perform **diarization** (speaker segmentation) first, then apply *
+ *transcription** for each identified speaker separately. This approach can improve transcription accuracy by
+ leveraging speaker separation.
+- [ ] Enable **parallel processing** for tasks such as diarization, transcription, and model inference to improve
+ overall system performance and reduce processing time.
+- [ ] Explore using **Docker Compose** for multi-container orchestration if required.
+- [ ] Upload the models and relevant resources to **Hugging Face** for easier access, sharing, and community
+ collaboration.
+- [ ] Consider writing a **Command Line Interface (CLI)** to simplify user interaction and improve usability.
+- [ ] Test the ability to use **different language models (LLMs)** for specific tasks. For instance, using **BERT** for
+ profanity detection. Evaluate their performance and suitability for different use cases as a feature.
+
---
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+### Documentations
+
+- [RESOURCES](.docs/documentation/RESOURCES.md)
+- [CONTRIBUTING](.docs/documentation/CONTRIBUTING.md)
+- [PRESENTATION](.docs/presentation/CallyticsPresentationEN.pdf)
+
+---
+
+### Licence
+
+- [LICENSE](LICENSE)
+
+---
+
+### Links
+
+- [Github](https://github.com/bunyaminergen/Callytics)
+- [Website](https://bunyaminergen.com)
+- [Linkedin](https://www.linkedin.com/in/bunyaminergen)
+
+---
+
+### Team
+
+- [Bunyamin Ergen](https://www.linkedin.com/in/bunyaminergen)
+
+---
+
+### Contact
+
+- [Mail](mailto:info@bunyaminergen.com)
+
+---
+
+### Citation
+
+```bibtex
+@software{ Callytics,
+ author = {Bunyamin Ergen},
+ title = {{Callytics}},
+ year = {2024},
+ month = {12},
+ url = {https://github.com/bunyaminergen/Callytics},
+ version = {v1.1.0},
+}
+```
+
+---
diff --git a/automation/service/callytics.service b/automation/service/callytics.service
new file mode 100644
index 0000000000000000000000000000000000000000..48d4ca10ba0598d724656ee0b96b87f2fd03f21b
--- /dev/null
+++ b/automation/service/callytics.service
@@ -0,0 +1,19 @@
+[Unit]
+Description=Callytics
+After=network.target
+
+[Service]
+Type=simple
+User=bunyamin
+EnvironmentFile=/home/bunyamin/Callytics/.env
+WorkingDirectory=/home/bunyamin/Callytics
+ExecStart=/bin/bash -c "source /home/bunyamin/anaconda3/etc/profile.d/conda.sh \
+ && conda activate Callytics \
+ && python /home/bunyamin/Callytics/main.py"
+Restart=on-failure
+RestartSec=5
+StandardOutput=journal
+StandardError=journal
+
+[Install]
+WantedBy=multi-user.target
diff --git a/config/config.yaml b/config/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9cb8913f31bc97af49ace6583f495ecc3094728b
--- /dev/null
+++ b/config/config.yaml
@@ -0,0 +1,26 @@
+runtime:
+ device: "cpu" # Options: "cpu", "cuda"
+ compute_type: "int8" # Options: "int8", "float16"
+ cuda_alloc_conf: "expandable_segments:True" # PyTorch CUDA Memory Management
+
+language:
+ audio: "en" # Options: "en", "tr"
+ text: "en" # Options: "en", "tr"
+
+models:
+ llama:
+ model_name: "meta-llama/Llama-3.2-3B-Instruct" # Options: "meta-llama/Llama-3.2-3B-Instruct", etc.
+ huggingface_api_key: "${HUGGINGFACE_TOKEN}"
+
+ openai:
+ model_name: "gpt-4o" # Options: "gpt-4", "gpt-4o", etc.
+ openai_api_key: "${OPENAI_API_KEY}"
+
+ azure_openai:
+ model_name: "gpt-4o" # Options: "gpt-4", "gpt-4o", etc.
+ azure_openai_api_key: "${AZURE_OPENAI_API_KEY}"
+ azure_openai_api_base: "${AZURE_OPENAI_API_BASE}"
+ azure_openai_api_version: "${AZURE_OPENAI_API_VERSION}"
+
+ mpsenet:
+ model_name: "JacobLinCool/MP-SENet-DNS" # Options: "JacobLinCool/MP-SENet-DNS", "JacobLinCool/MP-SENet-VB"
diff --git a/config/nemo/diar_infer_telephonic.yaml b/config/nemo/diar_infer_telephonic.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..68df031e5f67b682259dd1d973010579c4ac1ae2
--- /dev/null
+++ b/config/nemo/diar_infer_telephonic.yaml
@@ -0,0 +1,86 @@
+name: "ClusterDiarizer"
+
+num_workers: 1
+sample_rate: 16000
+batch_size: 64
+device: cuda
+verbose: True
+
+diarizer:
+ manifest_filepath: .temp/manifest.json
+ out_dir: .temp
+ oracle_vad: False
+ collar: 0.25
+ ignore_overlap: True
+
+ vad:
+ model_path: vad_multilingual_marblenet
+ external_vad_manifest: null
+ parameters:
+ window_length_in_sec: 0.15
+ shift_length_in_sec: 0.01
+ smoothing: "median"
+ overlap: 0.5
+ onset: 0.1
+ offset: 0.1
+ pad_onset: 0.1
+ pad_offset: 0
+ min_duration_on: 0
+ min_duration_off: 0.2
+ filter_speech_first: True
+
+ speaker_embeddings:
+ model_path: titanet_large
+ parameters:
+ window_length_in_sec: [ 1.5,1.25,1.0,0.75,0.5 ]
+ shift_length_in_sec: [ 0.75,0.625,0.5,0.375,0.25 ]
+ multiscale_weights: [ 1,1,1,1,1 ]
+ save_embeddings: True
+
+ clustering:
+ parameters:
+ oracle_num_speakers: False
+ max_num_speakers: 8
+ enhanced_count_thres: 80
+ max_rp_threshold: 0.25
+ sparse_search_volume: 30
+ maj_vote_spk_count: False
+ chunk_cluster_count: 50
+ embeddings_per_chunk: 10000
+
+ msdd_model:
+ model_path: diar_msdd_telephonic
+ parameters:
+ use_speaker_model_from_ckpt: True
+ infer_batch_size: 25
+ sigmoid_threshold: [ 0.7 ]
+ seq_eval_mode: False
+ split_infer: True
+ diar_window_length: 50
+ overlap_infer_spk_limit: 5
+
+ asr:
+ model_path: stt_en_conformer_ctc_large
+ parameters:
+ asr_based_vad: False
+ asr_based_vad_threshold: 1.0
+ asr_batch_size: null
+ decoder_delay_in_sec: null
+ word_ts_anchor_offset: null
+ word_ts_anchor_pos: "start"
+ fix_word_ts_with_VAD: False
+ colored_text: False
+ print_time: True
+ break_lines: False
+
+ ctc_decoder_parameters:
+ pretrained_language_model: null
+ beam_width: 32
+ alpha: 0.5
+ beta: 2.5
+
+ realigning_lm_parameters:
+ arpa_language_model: null
+ min_number_of_words: 3
+ max_number_of_words: 10
+ logprob_diff_threshold: 1.2
diff --git a/config/prompt.yaml b/config/prompt.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..671b4f96ada50a102159e841272c69cc49ff1ce6
--- /dev/null
+++ b/config/prompt.yaml
@@ -0,0 +1,107 @@
+Classification:
+ system: >
+ Your task is to identify the role of each speaker as either 'Customer' or 'Customer Service Representative (CSR)'.
+ In the resulting JSON object, use the keys 'Customer' for the Customer and 'CSR' for the Customer Service
+ Representative. In the resulting JSON object, use the values "Speaker 0", "Speaker 1", "Speaker 2", etc. Please
+ respond with a valid JSON object. Ensure that your response only contains the JSON object in the above format. Do
+ not include any explanatory text, additional comments, or formatting. Now, analyze the following conversation:
+
+ user: >
+ {user_context}
+
+
+SentimentAnalysis:
+ system: >
+ You are a sentiment analysis tool. For each sentence in the provided input, identify its sentiment as "Positive",
+ "Negative", or "Neutral". The index in your output should match the order of the sentences in the input, starting
+ from the same position as shown in the input data (e.g., index always start from 0). Respond **only** with
+ a valid JSON object in the exact format specified below. Do not include any additional text. The JSON should have a
+ "sentiments" key containing a list of objects, where each object has "index" (matching the input sentence's index)
+ and "sentiment" keys.
+
+ Example:
+ {{
+ "sentiments": [
+ {{"index": 0, "sentiment": "Positive"}},
+ {{"index": 1, "sentiment": "Neutral"}},
+ {{"index": 2, "sentiment": "Negative"}}
+ ]
+ }}
+
+ Analyze the following conversation and ensure the indices match the input:
+ user: >
+ {user_context}
+
+
+ProfanityWordDetection:
+ system: >
+ You are a profanity word detection tool. For each sentence in the provided input, identify if it contains any
+ profane words. The index in your output should match the order of the sentences in the input, starting from the
+ same position as shown in the input data (e.g., index always start from 0). Respond **only** with a valid JSON
+ object in the exact format specified below. Do not include any additional text. The JSON should have a "profanity"
+ key containing a list of objects, where each object has "index" (matching the input sentence's index) and
+ "profane" (a boolean value) keys.
+
+ Example:
+ {{
+ "profanity": [
+ {{"index": 0, "profane": "true"}},
+ {{"index": 1, "profane": "false"}},
+ {{"index": 2, "profane": "true"}}
+ ]
+ }}
+
+ Analyze the following conversation and ensure the indices match the input:
+ user: >
+ {user_context}
+
+Summary:
+ system: >
+ Your task is to summarize the entire conversation in a single sentence. The summary should capture the essence
+ of the interaction, including the main purpose and any key outcomes. Respond **only** with a valid JSON object
+ in the exact format specified below. Do not include any additional text. The JSON should have a single key
+ "summary" with a string value.
+
+ Example:
+ {
+ "summary": "The customer requested a copy of their invoice and the CSR confirmed it would be sent by email."
+ }
+
+ Now, summarize the following conversation:
+ user: >
+ {user_context}
+
+ConflictDetection:
+ system: >
+ Your task is to determine if there is any conflict or disagreement between the speakers in the given conversation.
+ A conflict is defined as any instance where the speakers express opposing views, argue, or express frustration.
+ Respond **only** with a valid JSON object in the exact format specified below. Do not include any additional text.
+ The JSON should have a single key "conflict" with a boolean value.
+
+ Example:
+ {
+ "conflict": true
+ }
+
+ Now, analyze the following conversation:
+ user: >
+ {user_context}
+
+TopicDetection:
+ system: >
+ Your task is to identify the topic of a conversation. You will receive a conversation transcript and a
+ list of predefined topics. Your job is to determine which topic best matches the conversation. If none
+ of the provided topics match, suggest a new topic based on the conversation content.
+
+ Here is the list of predefined topics: {system_context}.
+
+ Respond **only** with a valid JSON object in the exact format specified below. Do not include any additional text.
+ The JSON should have one key: "topic" (a string that is the matched topic or new topic).
+
+ Example:
+ {{
+ "topic": "Billing"
+ }}
+
+ user: >
+ {user_context}
diff --git a/environment.yaml b/environment.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9c47ba4e6ee3a59b5c53fbce4bd6723b480f8576
--- /dev/null
+++ b/environment.yaml
@@ -0,0 +1,31 @@
+name: Callytics
+channels:
+ - defaults
+ - conda-forge
+dependencies:
+ - python=3.11
+ - pip:
+ - cython==3.0.11
+ - nemo_toolkit[asr]>=2.dev
+ - nltk==3.9.1
+ - faster-whisper==1.1.0
+ - demucs==4.0.1
+ - deepmultilingualpunctuation @ git+https://github.com/oliverguhr/deepmultilingualpunctuation.git@5a0dd7f4fd56687f59405aa8eba1144393d8b74b
+ - ctc-forced-aligner @ git+https://github.com/MahmoudAshraf97/ctc-forced-aligner.git@c7cc7ce609e5f8f1f553fbd1e53124447ffe46d8
+ - openai==1.57.0
+ - accelerate>=0.26.0
+ - torch==2.5.1
+ - pydub==0.25.1
+ - omegaconf==2.3.0
+ - python-dotenv==1.0.1
+ - transformers==4.47.0
+ - librosa==0.10.2.post1
+ - soundfile==0.12.1
+ - noisereduce==3.0.3
+ - numpy==1.26.4
+ - pyannote.audio==3.3.2
+ - watchdog==6.0.0
+ - scipy==1.14.1
+ - IPython==8.30.0
+ - pyyaml==6.0.2
+ - MPSENet==1.0.3
diff --git a/main.py b/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd17206696fea548d2177171e8da273950486986
--- /dev/null
+++ b/main.py
@@ -0,0 +1,292 @@
+# Standard library imports
+import os
+
+# Related third-party imports
+from omegaconf import OmegaConf
+from nemo.collections.asr.models.msdd_models import NeuralDiarizer
+
+# Local imports
+from src.audio.utils import Formatter
+from src.audio.metrics import SilenceStats
+from src.audio.error import DialogueDetecting
+from src.audio.alignment import ForcedAligner
+from src.audio.effect import DemucsVocalSeparator
+from src.audio.preprocessing import SpeechEnhancement
+from src.audio.io import SpeakerTimestampReader, TranscriptWriter
+from src.audio.analysis import WordSpeakerMapper, SentenceSpeakerMapper, Audio
+from src.audio.processing import AudioProcessor, Transcriber, PunctuationRestorer
+from src.text.utils import Annotator
+from src.text.llm import LLMOrchestrator, LLMResultHandler
+from src.utils.utils import Cleaner, Watcher
+from src.db.manager import Database
+
+
+async def main(audio_file_path: str):
+ """
+ Process an audio file to perform diarization, transcription, punctuation restoration,
+ and speaker role classification.
+
+ Parameters
+ ----------
+ audio_file_path : str
+ The path to the input audio file to be processed.
+
+ Returns
+ -------
+ None
+ """
+ # Paths
+ config_nemo = "config/nemo/diar_infer_telephonic.yaml"
+ manifest_path = ".temp/manifest.json"
+ temp_dir = ".temp"
+ rttm_file_path = os.path.join(temp_dir, "pred_rttms", "mono_file.rttm")
+ transcript_output_path = ".temp/output.txt"
+ srt_output_path = ".temp/output.srt"
+ config_path = "config/config.yaml"
+ prompt_path = "config/prompt.yaml"
+ db_path = ".db/Callytics.sqlite"
+ db_topic_fetch_path = "src/db/sql/TopicFetch.sql"
+ db_topic_insert_path = "src/db/sql/TopicInsert.sql"
+ db_audio_properties_insert_path = "src/db/sql/AudioPropertiesInsert.sql"
+ db_utterance_insert_path = "src/db/sql/UtteranceInsert.sql"
+
+ # Configuration
+ config = OmegaConf.load(config_path)
+ device = config.runtime.device
+ compute_type = config.runtime.compute_type
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = config.runtime.cuda_alloc_conf
+
+ # Initialize Classes
+ dialogue_detector = DialogueDetecting(delete_original=True)
+ enhancer = SpeechEnhancement(config_path=config_path, output_dir=temp_dir)
+ separator = DemucsVocalSeparator()
+ processor = AudioProcessor(audio_path=audio_file_path, temp_dir=temp_dir)
+ transcriber = Transcriber(device=device, compute_type=compute_type)
+ aligner = ForcedAligner(device=device)
+ llm_handler = LLMOrchestrator(config_path=config_path, prompt_config_path=prompt_path, model_id="openai")
+ llm_result_handler = LLMResultHandler()
+ cleaner = Cleaner()
+ formatter = Formatter()
+ db = Database(db_path)
+ audio_feature_extractor = Audio(audio_file_path)
+
+ # Step 1: Detect Dialogue
+ has_dialogue = dialogue_detector.process(audio_file_path)
+ if not has_dialogue:
+ return
+
+ # Step 2: Speech Enhancement
+ audio_path = enhancer.enhance_audio(
+ input_path=audio_file_path,
+ output_path=os.path.join(temp_dir, "enhanced.wav"),
+ noise_threshold=0.0001,
+ verbose=True
+ )
+
+ # Step 3: Vocal Separation
+ vocal_path = separator.separate_vocals(audio_file=audio_path, output_dir=temp_dir)
+
+ # Step 4: Transcription
+ transcript, info = transcriber.transcribe(audio_path=vocal_path)
+ detected_language = info["language"]
+
+ # Step 5: Forced Alignment
+ word_timestamps = aligner.align(
+ audio_path=vocal_path,
+ transcript=transcript,
+ language=detected_language
+ )
+
+ # Step 6: Diarization
+ processor.audio_path = vocal_path
+ mono_audio_path = processor.convert_to_mono()
+ processor.audio_path = mono_audio_path
+ processor.create_manifest(manifest_path)
+ cfg = OmegaConf.load(config_nemo)
+ cfg.diarizer.manifest_filepath = manifest_path
+ cfg.diarizer.out_dir = temp_dir
+ msdd_model = NeuralDiarizer(cfg=cfg)
+ msdd_model.diarize()
+
+ # Step 7: Processing Transcript
+ # Step 7.1: Speaker Timestamps
+ speaker_reader = SpeakerTimestampReader(rttm_path=rttm_file_path)
+ speaker_ts = speaker_reader.read_speaker_timestamps()
+
+ # Step 7.2: Mapping Words
+ word_speaker_mapper = WordSpeakerMapper(word_timestamps, speaker_ts)
+ wsm = word_speaker_mapper.get_words_speaker_mapping()
+
+ # Step 7.3: Punctuation Restoration
+ punct_restorer = PunctuationRestorer(language=detected_language)
+ wsm = punct_restorer.restore_punctuation(wsm)
+ word_speaker_mapper.word_speaker_mapping = wsm
+ word_speaker_mapper.realign_with_punctuation()
+ wsm = word_speaker_mapper.word_speaker_mapping
+
+ # Step 7.4: Mapping Sentences
+ sentence_mapper = SentenceSpeakerMapper()
+ ssm = sentence_mapper.get_sentences_speaker_mapping(wsm)
+
+ # Step 8 (Optional): Write Transcript and SRT Files
+ writer = TranscriptWriter()
+ writer.write_transcript(ssm, transcript_output_path)
+ writer.write_srt(ssm, srt_output_path)
+
+ # Step 9: Classify Speaker Roles
+ speaker_roles = await llm_handler.generate("Classification", ssm)
+
+ # Step 9.1: LLM results validate and fallback
+ ssm = llm_result_handler.validate_and_fallback(speaker_roles, ssm)
+ llm_result_handler.log_result(ssm, speaker_roles)
+
+ # Step 10: Sentiment Analysis
+ ssm_with_indices = formatter.add_indices_to_ssm(ssm)
+ annotator = Annotator(ssm_with_indices)
+ sentiment_results = await llm_handler.generate("SentimentAnalysis", user_input=ssm)
+ annotator.add_sentiment(sentiment_results)
+
+ # Step 11: Profanity Word Detection
+ profane_results = await llm_handler.generate("ProfanityWordDetection", user_input=ssm)
+ annotator.add_profanity(profane_results)
+
+ # Step 12: Summary
+ summary_result = await llm_handler.generate("Summary", user_input=ssm)
+ annotator.add_summary(summary_result)
+
+ # Step 13: Conflict Detection
+ conflict_result = await llm_handler.generate("ConflictDetection", user_input=ssm)
+ annotator.add_conflict(conflict_result)
+
+ # Step 14: Topic Detection
+ topics = db.fetch(db_topic_fetch_path)
+ topic_result = await llm_handler.generate(
+ "TopicDetection",
+ user_input=ssm,
+ system_input=topics
+ )
+ annotator.add_topic(topic_result)
+
+ # Step 15: File/Audio Feature Extraction
+ props = audio_feature_extractor.properties()
+
+ (
+ name,
+ file_extension,
+ absolute_file_path,
+ sample_rate,
+ min_frequency,
+ max_frequency,
+ audio_bit_depth,
+ num_channels,
+ audio_duration,
+ rms_loudness,
+ final_features
+ ) = props
+
+ rms_loudness_db = final_features["RMSLoudness"]
+ zero_crossing_rate_db = final_features["ZeroCrossingRate"]
+ spectral_centroid_db = final_features["SpectralCentroid"]
+ eq_20_250_db = final_features["EQ_20_250_Hz"]
+ eq_250_2000_db = final_features["EQ_250_2000_Hz"]
+ eq_2000_6000_db = final_features["EQ_2000_6000_Hz"]
+ eq_6000_20000_db = final_features["EQ_6000_20000_Hz"]
+ mfcc_values = [final_features[f"MFCC_{i}"] for i in range(1, 14)]
+
+ final_output = annotator.finalize()
+
+ # Step 16: Tocal Silence Calculation
+ stats = SilenceStats.from_segments(final_output['ssm'])
+ t_std = stats.threshold_std(factor=0.99)
+ final_output["silence"] = t_std
+
+ print("Final_Output:", final_output)
+
+ # Step 17: Database
+ # Step 17.1: Insert File Table
+ summary = final_output.get("summary", "")
+ conflict_flag = 1 if final_output.get("conflict", False) else 0
+ silence_value = final_output.get("silence", 0.0)
+ detected_topic = final_output.get("topic", "Unknown")
+
+ topic_id = db.get_or_insert_topic_id(detected_topic, topics, db_topic_insert_path)
+
+ params = (
+ name,
+ topic_id,
+ file_extension,
+ absolute_file_path,
+ sample_rate,
+ min_frequency,
+ max_frequency,
+ audio_bit_depth,
+ num_channels,
+ audio_duration,
+ rms_loudness_db,
+ zero_crossing_rate_db,
+ spectral_centroid_db,
+ eq_20_250_db,
+ eq_250_2000_db,
+ eq_2000_6000_db,
+ eq_6000_20000_db,
+ *mfcc_values,
+ summary,
+ conflict_flag,
+ silence_value
+ )
+
+ last_id = db.insert(db_audio_properties_insert_path, params)
+ print(f"Audio properties inserted successfully into the File table with ID: {last_id}")
+
+ # Step 17.2: Insert Utterance Table
+ utterances = final_output["ssm"]
+
+ for utterance in utterances:
+ file_id = last_id
+ speaker = utterance["speaker"]
+ sequence = utterance["index"]
+ start_time = utterance["start_time"] / 1000.0
+ end_time = utterance["end_time"] / 1000.0
+ content = utterance["text"]
+ sentiment = utterance["sentiment"]
+ profane = 1 if utterance["profane"] else 0
+
+ utterance_params = (
+ file_id,
+ speaker,
+ sequence,
+ start_time,
+ end_time,
+ content,
+ sentiment,
+ profane
+ )
+
+ db.insert(db_utterance_insert_path, utterance_params)
+
+ print("Utterances inserted successfully into the Utterance table.")
+
+ # Step 18: Clean Up
+ cleaner.cleanup(temp_dir, audio_file_path)
+
+
+async def process(path: str):
+ """
+ Asynchronous callback function that is triggered when a new audio file is detected.
+
+ Parameters
+ ----------
+ path : str
+ The path to the newly created audio file.
+
+ Returns
+ -------
+ None
+ """
+ print(f"Processing new audio file: {path}")
+ await main(path)
+
+
+if __name__ == "__main__":
+ directory_to_watch = ".data/input"
+ Watcher.start_watcher(directory_to_watch, process)
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..a4a23e06fbcf42a1bc0609de4682e154c0b67215
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,23 @@
+nemo_toolkit[asr]>=2.dev
+nltk==3.9.1
+faster-whisper==1.1.0
+demucs==4.0.1
+deepmultilingualpunctuation @ git+https://github.com/oliverguhr/deepmultilingualpunctuation.git@5a0dd7f4fd56687f59405aa8eba1144393d8b74b
+ctc-forced-aligner @ git+https://github.com/MahmoudAshraf97/ctc-forced-aligner.git@c7cc7ce609e5f8f1f553fbd1e53124447ffe46d8
+openai==1.57.0
+accelerate>=0.26.0
+torch==2.5.1
+pydub==0.25.1
+omegaconf==2.3.0
+python-dotenv==1.0.1
+transformers==4.47.0
+librosa==0.10.2.post1
+soundfile==0.12.1
+noisereduce==3.0.3
+numpy==1.26.4
+pyannote.audio==3.3.2
+watchdog==6.0.0
+scipy==1.14.1
+IPython==8.30.0
+pyyaml==6.0.2
+MPSENet==1.0.3
\ No newline at end of file
diff --git a/src/__init__.py b/src/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad443218e16663eb726341da74b0283a4e3dcb66
--- /dev/null
+++ b/src/__init__.py
@@ -0,0 +1,5 @@
+# Standard library imports
+import warnings
+
+warnings.resetwarnings()
+warnings.simplefilter("always")
diff --git a/src/audio/__init__.py b/src/audio/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/audio/alignment.py b/src/audio/alignment.py
new file mode 100644
index 0000000000000000000000000000000000000000..076fc517e9b70d2f6b13ea64e71995d86a40c7c9
--- /dev/null
+++ b/src/audio/alignment.py
@@ -0,0 +1,137 @@
+# Standard library imports
+import os
+from typing import Annotated, List, Dict
+
+# Related third-party imports
+import torch
+from faster_whisper import decode_audio
+from ctc_forced_aligner import (
+ generate_emissions,
+ get_alignments,
+ get_spans,
+ load_alignment_model,
+ postprocess_results,
+ preprocess_text,
+)
+
+
+class ForcedAligner:
+ """
+ ForcedAligner is a class for aligning audio to a provided transcript using a pre-trained alignment model.
+
+ Attributes
+ ----------
+ device : str
+ Device to run the model on ('cuda' for GPU or 'cpu').
+ alignment_model : torch.nn.Module
+ The pre-trained alignment model.
+ alignment_tokenizer : Any
+ Tokenizer for processing text in alignment.
+
+ Methods
+ -------
+ align(audio_path, transcript, language, batch_size)
+ Aligns audio with a transcript and returns word-level timing information.
+ """
+
+ def __init__(self, device: Annotated[str, "Device for model ('cuda' or 'cpu')"] = None):
+ """
+ Initialize the ForcedAligner with the specified device.
+
+ Parameters
+ ----------
+ device : str, optional
+ Device for running the model, by default 'cuda' if available, otherwise 'cpu'.
+ """
+ self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
+
+ self.alignment_model, self.alignment_tokenizer = load_alignment_model(
+ self.device,
+ dtype=torch.float16 if self.device == 'cuda' else torch.float32,
+ )
+
+ def align(
+ self,
+ audio_path: Annotated[str, "Path to the audio file"],
+ transcript: Annotated[str, "Transcript of the audio content"],
+ language: Annotated[str, "Language of the transcript"] = 'en',
+ batch_size: Annotated[int, "Batch size for emission generation"] = 8,
+ ) -> Annotated[List[Dict[str, float]], "List of word alignment data with timestamps"]:
+ """
+ Aligns audio with a transcript and returns word-level timing information.
+
+ Parameters
+ ----------
+ audio_path : str
+ Path to the audio file.
+ transcript : str
+ Transcript text corresponding to the audio.
+ language : str, optional
+ Language code for the transcript, default is 'en' (English).
+ batch_size : int, optional
+ Batch size for generating emissions, by default 8.
+
+ Returns
+ -------
+ List[Dict[str, float]]
+ A list of dictionaries containing word timing information.
+
+ Raises
+ ------
+ FileNotFoundError
+ If the specified audio file does not exist.
+
+ Examples
+ --------
+ >>> aligner = ForcedAligner()
+ >>> aligner.align("path/to/audio.wav", "hello world")
+ [{'word': 'hello', 'start': 0.0, 'end': 0.5}, {'word': 'world', 'start': 0.6, 'end': 1.0}]
+ """
+ if not os.path.exists(audio_path):
+ raise FileNotFoundError(
+ f"The audio file at path '{audio_path}' was not found."
+ )
+
+ speech_array = torch.from_numpy(decode_audio(audio_path))
+
+ emissions, stride = generate_emissions(
+ self.alignment_model,
+ speech_array.to(self.alignment_model.dtype).to(self.alignment_model.device),
+ batch_size=batch_size,
+ )
+
+ tokens_starred, text_starred = preprocess_text(
+ transcript,
+ romanize=True,
+ language=language,
+ )
+
+ segments, scores, blank_token = get_alignments(
+ emissions,
+ tokens_starred,
+ self.alignment_tokenizer,
+ )
+
+ spans = get_spans(tokens_starred, segments, blank_token)
+
+ word_timestamps = postprocess_results(text_starred, spans, stride, scores)
+
+ if self.device == 'cuda':
+ del self.alignment_model
+ torch.cuda.empty_cache()
+
+ print(f"Word_Timestamps: {word_timestamps}")
+
+ return word_timestamps
+
+
+if __name__ == "__main__":
+
+ forced_aligner = ForcedAligner()
+ try:
+ path = "example_audio.wav"
+ audio_transcript = "This is a test transcript."
+ word_timestamp = forced_aligner.align(path, audio_transcript)
+ print(word_timestamp)
+ except FileNotFoundError as e:
+ print(e)
\ No newline at end of file
diff --git a/src/audio/analysis.py b/src/audio/analysis.py
new file mode 100644
index 0000000000000000000000000000000000000000..e368140c7e43ec907f20536b7521f0d4c6630e49
--- /dev/null
+++ b/src/audio/analysis.py
@@ -0,0 +1,715 @@
+# Standard library imports
+import os
+import wave
+from typing import List, Dict, Annotated, Union, Tuple
+
+# Related third-party imports
+import nltk
+import numpy as np
+import soundfile as sf
+from librosa.feature import mfcc
+from scipy.fft import fft, fftfreq
+
+
+class WordSpeakerMapper:
+ """
+ Maps words to speakers based on timestamps and aligns speaker tags after punctuation restoration.
+
+ This class processes word timing information and assigns each word to a speaker
+ based on the provided speaker timestamps. Missing timestamps are handled, and each
+ word can be aligned to a speaker based on different reference points ('start', 'mid', or 'end').
+ After punctuation restoration, word-speaker mapping can be realigned to ensure consistency
+ within a sentence.
+
+ Attributes
+ ----------
+ word_timestamps : List[Dict]
+ List of word timing information with 'start', 'end', and 'text' keys.
+ speaker_timestamps : List[List[int]]
+ List of speaker segments, where each segment contains [start_time, end_time, speaker_id].
+ word_speaker_mapping : List[Dict] or None
+ Processed word-to-speaker mappings.
+
+ Methods
+ -------
+ filter_missing_timestamps(word_timestamps, initial_timestamp=0, final_timestamp=None)
+ Fills in missing start and end timestamps in word timing data.
+ get_words_speaker_mapping(word_anchor_option='start')
+ Maps words to speakers based on word and speaker timestamps.
+ """
+
+ def __init__(
+ self,
+ word_timestamps: Annotated[List[Dict], "List of word timing information"],
+ speaker_timestamps: Annotated[List[List[Union[int, float]]], "List of speaker segments"],
+ ):
+ """
+ Initializes the WordSpeakerMapper with word and speaker timestamps.
+
+ Parameters
+ ----------
+ word_timestamps : List[Dict]
+ List of word timing information.
+ speaker_timestamps : List[List[int]]
+ List of speaker segments.
+ """
+ self.word_timestamps = self.filter_missing_timestamps(word_timestamps)
+ self.speaker_timestamps = speaker_timestamps
+ self.word_speaker_mapping = None
+
+ def filter_missing_timestamps(
+ self,
+ word_timestamps: Annotated[List[Dict], "List of word timing information"],
+ initial_timestamp: Annotated[int, "Start time of the first word"] = 0,
+ final_timestamp: Annotated[int, "End time of the last word"] = None
+ ) -> Annotated[List[Dict], "List of word timestamps with missing values filled"]:
+ """
+ Fills in missing start and end timestamps.
+
+ Parameters
+ ----------
+ word_timestamps : List[Dict]
+ List of word timing information that may contain missing timestamps.
+ initial_timestamp : int, optional
+ Start time of the first word, default is 0.
+ final_timestamp : int, optional
+ End time of the last word, if available.
+
+ Returns
+ -------
+ List[Dict]
+ List of word timestamps with missing values filled.
+
+ Examples
+ --------
+ >>> word_timestamp = [{'text': 'Hello', 'end': 1.2}]
+ >>> mapper = WordSpeakerMapper([], [])
+ >>> mapper.filter_missing_timestamps(word_timestamps)
+ [{'text': 'Hello', 'start': 0, 'end': 1.2}]
+ """
+ if word_timestamps[0].get("start") is None:
+ word_timestamps[0]["start"] = initial_timestamp
+ word_timestamps[0]["end"] = self._get_next_start_timestamp(word_timestamps, 0, final_timestamp)
+
+ result = [word_timestamps[0]]
+
+ for i, ws in enumerate(word_timestamps[1:], start=1):
+ if "text" not in ws:
+ continue
+
+ if ws.get("start") is None:
+ ws["start"] = word_timestamps[i - 1]["end"]
+ ws["end"] = self._get_next_start_timestamp(word_timestamps, i, final_timestamp)
+
+ if ws["text"] is not None:
+ result.append(ws)
+ return result
+
+ @staticmethod
+ def _get_next_start_timestamp(
+ word_timestamps: Annotated[List[Dict], "List of word timing information"],
+ current_word_index: Annotated[int, "Index of the current word"],
+ final_timestamp: Annotated[int, "Final timestamp if needed"]
+ ) -> Annotated[int, "Next start timestamp for filling missing values"]:
+ """
+ Finds the next start timestamp to fill in missing values.
+
+ Parameters
+ ----------
+ word_timestamps : List[Dict]
+ List of word timing information.
+ current_word_index : int
+ Index of the current word.
+ final_timestamp : int, optional
+ Final timestamp to use if no next timestamp is found.
+
+ Returns
+ -------
+ int
+ Next start timestamp for filling missing values.
+
+ Examples
+ --------
+ >>> word_timestamp = [{'start': 0.5, 'text': 'Hello'}, {'end': 2.0}]
+ >>> mapper = WordSpeakerMapper([], [])
+ >>> mapper._get_next_start_timestamp(word_timestamps, 0, 2)
+ """
+ if current_word_index == len(word_timestamps) - 1:
+ return word_timestamps[current_word_index]["start"]
+
+ next_word_index = current_word_index + 1
+ while next_word_index < len(word_timestamps):
+ if word_timestamps[next_word_index].get("start") is None:
+ word_timestamps[current_word_index]["text"] += (
+ " " + word_timestamps[next_word_index]["text"]
+ )
+ word_timestamps[next_word_index]["text"] = None
+ next_word_index += 1
+ if next_word_index == len(word_timestamps):
+ return final_timestamp
+ else:
+ return word_timestamps[next_word_index]["start"]
+ return final_timestamp
+
+ def get_words_speaker_mapping(self, word_anchor_option='start') -> List[Dict]:
+ """
+ Maps words to speakers based on their timestamps.
+
+ Parameters
+ ----------
+ word_anchor_option : str, optional
+ Anchor point for word mapping ('start', 'mid', or 'end'), default is 'start'.
+
+ Returns
+ -------
+ List[Dict]
+ List of word-to-speaker mappings with timestamps and speaker IDs.
+
+ Examples
+ --------
+ >>> word_timestamps = [{'start': 0.5, 'end': 1.2, 'text': 'Hello'}]
+ >>> speaker_timestamps = [[0, 1000, 1]]
+ >>> mapper = WordSpeakerMapper(word_timestamps, speaker_timestamps)
+ >>> mapper.get_words_speaker_mapping()
+ [{'text': 'Hello', 'start_time': 500, 'end_time': 1200, 'speaker': 1}]
+ """
+
+ def get_word_ts_anchor(start: int, end: int, option: str) -> int:
+ """
+ Determines the anchor timestamp for a word.
+
+ Parameters
+ ----------
+ start : int
+ Start time of the word in milliseconds.
+ end : int
+ End time of the word in milliseconds.
+ option : str
+ Anchor point for timestamp calculation ('start', 'mid', or 'end').
+
+ Returns
+ -------
+ int
+ Anchor timestamp for the word.
+
+ Examples
+ --------
+ >>> get_word_ts_anchor(500, 1200, 'mid')
+ 850
+ """
+ if option == "end":
+ return end
+ elif option == "mid":
+ return (start + end) // 2
+ return start
+
+ wrd_spk_mapping = []
+ turn_idx = 0
+ num_speaker_ts = len(self.speaker_timestamps)
+
+ for wrd_dict in self.word_timestamps:
+ ws, we, wrd = (
+ int(wrd_dict["start"] * 1000),
+ int(wrd_dict["end"] * 1000),
+ wrd_dict["text"],
+ )
+ wrd_pos = get_word_ts_anchor(ws, we, word_anchor_option)
+
+ sp = -1
+
+ while turn_idx < num_speaker_ts and wrd_pos > self.speaker_timestamps[turn_idx][1]:
+ turn_idx += 1
+
+ if turn_idx < num_speaker_ts and self.speaker_timestamps[turn_idx][0] <= wrd_pos <= \
+ self.speaker_timestamps[turn_idx][1]:
+ sp = self.speaker_timestamps[turn_idx][2]
+ elif turn_idx > 0:
+ sp = self.speaker_timestamps[turn_idx - 1][2]
+
+ wrd_spk_mapping.append(
+ {"text": wrd, "start_time": ws, "end_time": we, "speaker": sp}
+ )
+
+ self.word_speaker_mapping = wrd_spk_mapping
+ return self.word_speaker_mapping
+
+ def realign_with_punctuation(self, max_words_in_sentence: int = 50) -> None:
+ """
+ Realigns word-speaker mapping after punctuation restoration.
+
+ This method ensures consistent speaker mapping within sentences by analyzing
+ punctuation and adjusting speaker labels for words that are part of the same sentence.
+
+ Parameters
+ ----------
+ max_words_in_sentence : int, optional
+ Maximum number of words to consider for realignment in a sentence,
+ default is 50.
+
+ Examples
+ --------
+ >>> word_speaker_mapping = [
+ ... {"text": "Hello", "speaker": "Speaker 1"},
+ ... {"text": "world", "speaker": "Speaker 2"},
+ ... {"text": ".", "speaker": "Speaker 2"},
+ ... {"text": "How", "speaker": "Speaker 1"},
+ ... {"text": "are", "speaker": "Speaker 1"},
+ ... {"text": "you", "speaker": "Speaker 2"},
+ ... {"text": "?", "speaker": "Speaker 2"}
+ ... ]
+ >>> mapper = WordSpeakerMapper([], [])
+ >>> mapper.word_speaker_mapping = word_speaker_mapping
+ >>> mapper.realign_with_punctuation()
+ >>> print(mapper.word_speaker_mapping)
+ [{'text': 'Hello', 'speaker': 'Speaker 1'},
+ {'text': 'world', 'speaker': 'Speaker 1'},
+ {'text': '.', 'speaker': 'Speaker 1'},
+ {'text': 'How', 'speaker': 'Speaker 1'},
+ {'text': 'are', 'speaker': 'Speaker 1'},
+ {'text': 'you', 'speaker': 'Speaker 1'},
+ {'text': '?', 'speaker': 'Speaker 1'}]
+ """
+ sentence_ending_punctuations = ".?!"
+
+ def is_word_sentence_end(word_index: Annotated[int, "Index of the word to check"]) -> Annotated[
+ bool, "True if the word is a sentence end, False otherwise"]:
+ """
+ Checks if a word is the end of a sentence based on punctuation.
+
+ This method determines whether a word at the given index marks
+ the end of a sentence by checking if the last character of the
+ word is a sentence-ending punctuation (e.g., '.', '!', or '?').
+
+ Parameters
+ ----------
+ word_index : int
+ Index of the word to check in the `word_speaker_mapping`.
+
+ Returns
+ -------
+ bool
+ True if the word at the given index is the end of a sentence,
+ False otherwise.
+
+ """
+ return (
+ word_index >= 0
+ and self.word_speaker_mapping[word_index]["text"][-1] in sentence_ending_punctuations
+ )
+
+ wsp_len = len(self.word_speaker_mapping)
+ words_list = [wd['text'] for wd in self.word_speaker_mapping]
+ speaker_list = [wd['speaker'] for wd in self.word_speaker_mapping]
+
+ k = 0
+ while k < len(self.word_speaker_mapping):
+ if (
+ k < wsp_len - 1
+ and speaker_list[k] != speaker_list[k + 1]
+ and not is_word_sentence_end(k)
+ ):
+ left_idx = self._get_first_word_idx_of_sentence(
+ k, words_list, speaker_list, max_words_in_sentence
+ )
+ right_idx = (
+ self._get_last_word_idx_of_sentence(
+ k, words_list, max_words_in_sentence - (k - left_idx) - 1
+ )
+ if left_idx > -1
+ else -1
+ )
+ if min(left_idx, right_idx) == -1:
+ k += 1
+ continue
+
+ spk_labels = speaker_list[left_idx:right_idx + 1]
+ mod_speaker = max(set(spk_labels), key=spk_labels.count)
+ if spk_labels.count(mod_speaker) < len(spk_labels) // 2:
+ k += 1
+ continue
+
+ speaker_list[left_idx:right_idx + 1] = [mod_speaker] * (
+ right_idx - left_idx + 1
+ )
+ k = right_idx
+
+ k += 1
+
+ for idx in range(len(self.word_speaker_mapping)):
+ self.word_speaker_mapping[idx]["speaker"] = speaker_list[idx]
+
+ @staticmethod
+ def _get_first_word_idx_of_sentence(
+ word_idx: int, word_list: List[str], speaker_list: List[str], max_words: int
+ ) -> int:
+ """
+ Finds the first word index of a sentence for realignment.
+
+ Parameters
+ ----------
+ word_idx : int
+ Current word index.
+ word_list : List[str]
+ List of words in the sentence.
+ speaker_list : List[str]
+ List of speakers for the words.
+ max_words : int
+ Maximum words to consider in the sentence.
+
+ Returns
+ -------
+ int
+ The index of the first word of the sentence.
+
+ Examples
+ --------
+ >>> words_list = ["Hello", "world", ".", "How", "are", "you", "?"]
+ >>> speakers_list = ["Speaker 1", "Speaker 1", "Speaker 1", "Speaker 2", "Speaker 2", "Speaker 2", "Speaker 2"]
+ >>> WordSpeakerMapper._get_first_word_idx_of_sentence(4, word_list, speaker_list, 50)
+ 3
+ """
+ sentence_ending_punctuations = ".?!"
+ is_word_sentence_end = (
+ lambda x: x >= 0 and word_list[x][-1] in sentence_ending_punctuations
+ )
+ left_idx = word_idx
+ while (
+ left_idx > 0
+ and word_idx - left_idx < max_words
+ and speaker_list[left_idx - 1] == speaker_list[left_idx]
+ and not is_word_sentence_end(left_idx - 1)
+ ):
+ left_idx -= 1
+
+ return left_idx if left_idx == 0 or is_word_sentence_end(left_idx - 1) else -1
+
+ @staticmethod
+ def _get_last_word_idx_of_sentence(
+ word_idx: int, word_list: List[str], max_words: int
+ ) -> int:
+ """
+ Finds the last word index of a sentence for realignment.
+
+ Parameters
+ ----------
+ word_idx : int
+ Current word index.
+ word_list : List[str]
+ List of words in the sentence.
+ max_words : int
+ Maximum words to consider in the sentence.
+
+ Returns
+ -------
+ int
+ The index of the last word of the sentence.
+
+ Examples
+ --------
+ >>> words_list = ["Hello", "world", ".", "How", "are", "you", "?"]
+ >>> WordSpeakerMapper._get_last_word_idx_of_sentence(3, word_list, 50)
+ 6
+ """
+ sentence_ending_punctuations = ".?!"
+ is_word_sentence_end = (
+ lambda x: x >= 0 and word_list[x][-1] in sentence_ending_punctuations
+ )
+ right_idx = word_idx
+ while (
+ right_idx < len(word_list) - 1
+ and right_idx - word_idx < max_words
+ and not is_word_sentence_end(right_idx)
+ ):
+ right_idx += 1
+
+ return (
+ right_idx
+ if right_idx == len(word_list) - 1 or is_word_sentence_end(right_idx)
+ else -1
+ )
+
+
+class SentenceSpeakerMapper:
+ """
+ Groups words into sentences and assigns each sentence to a speaker.
+
+ This class uses word-speaker mapping to group words into sentences based on punctuation
+ and speaker changes. It uses the NLTK library to detect sentence boundaries.
+
+ Attributes
+ ----------
+ sentence_checker : Callable
+ Function to check for sentence breaks.
+ sentence_ending_punctuations : str
+ String of punctuation characters that indicate sentence endings.
+
+ Methods
+ -------
+ get_sentences_speaker_mapping(word_speaker_mapping)
+ Groups words into sentences and assigns each sentence to a speaker.
+ """
+
+ def __init__(self):
+ """
+ Initializes the SentenceSpeakerMapper and downloads required NLTK resources.
+ """
+ nltk.download('punkt', quiet=True)
+ self.sentence_checker = nltk.tokenize.PunktSentenceTokenizer().text_contains_sentbreak
+ self.sentence_ending_punctuations = ".?!"
+
+ def get_sentences_speaker_mapping(
+ self,
+ word_speaker_mapping: Annotated[List[Dict], "List of words with speaker labels"]
+ ) -> Annotated[List[Dict], "List of sentences with speaker labels and timing information"]:
+ """
+ Groups words into sentences and assigns each sentence to a speaker.
+
+ Parameters
+ ----------
+ word_speaker_mapping : List[Dict]
+ List of words with speaker labels.
+
+ Returns
+ -------
+ List[Dict]
+ List of sentences with speaker labels and timing information.
+
+ Examples
+ --------
+ >>> sentence_mapper = SentenceSpeakerMapper()
+ >>> word_speaker_map = [
+ ... {'text': 'Hello', 'start_time': 0, 'end_time': 500, 'speaker': 1},
+ ... {'text': 'world.', 'start_time': 600, 'end_time': 1000, 'speaker': 1},
+ ... {'text': 'How', 'start_time': 1100, 'end_time': 1300, 'speaker': 2},
+ ... {'text': 'are', 'start_time': 1400, 'end_time': 1500, 'speaker': 2},
+ ... {'text': 'you?', 'start_time': 1600, 'end_time': 2000, 'speaker': 2},
+ ... ]
+ >>> sentence_mapper.get_sentences_speaker_mapping(word_speaker_mapping)
+ [{'speaker': 'Speaker 1', 'start_time': 0, 'end_time': 1000, 'text': 'Hello world. '},
+ {'speaker': 'Speaker 2', 'start_time': 1100, 'end_time': 2000, 'text': 'How are you?'}]
+ """
+ snts = []
+ prev_spk = word_speaker_mapping[0]['speaker']
+ snt = {
+ "speaker": f"Speaker {prev_spk}",
+ "start_time": word_speaker_mapping[0]['start_time'],
+ "end_time": word_speaker_mapping[0]['end_time'],
+ "text": word_speaker_mapping[0]['text'] + " ",
+ }
+
+ for word_dict in word_speaker_mapping[1:]:
+ word, spk = word_dict["text"], word_dict["speaker"]
+ s, e = word_dict["start_time"], word_dict["end_time"]
+ if spk != prev_spk or self.sentence_checker(snt["text"] + word):
+ snts.append(snt)
+ snt = {
+ "speaker": f"Speaker {spk}",
+ "start_time": s,
+ "end_time": e,
+ "text": word + " ",
+ }
+ else:
+ snt["end_time"] = e
+ snt["text"] += word + " "
+ prev_spk = spk
+
+ snts.append(snt)
+ return snts
+
+
+class Audio:
+ """
+ A class to handle audio file analysis and property extraction.
+
+ This class provides methods to load an audio file, process it, and
+ extract various audio properties including spectral, temporal, and
+ perceptual features.
+
+ Parameters
+ ----------
+ audio_path : str
+ Path to the audio file to be analyzed.
+
+ Attributes
+ ----------
+ audio_path : str
+ Path to the audio file.
+ extension : str
+ File extension of the audio file.
+ samples : int
+ Total number of audio samples.
+ duration : float
+ Duration of the audio in seconds.
+ data : np.ndarray
+ Audio data loaded from the file.
+ rate : int
+ Sampling rate of the audio file.
+ """
+
+ def __init__(self, audio_path: str):
+ """
+ Initialize the Audio class with a given audio file path.
+
+ Parameters
+ ----------
+ audio_path : str
+ Path to the audio file.
+
+ Raises
+ ------
+ TypeError
+ If `audio_path` is not a non-empty string.
+ FileNotFoundError
+ If the file specified by `audio_path` does not exist.
+ ValueError
+ If the file has an unsupported extension or is empty.
+ RuntimeError
+ If there is an error reading the audio file.
+ """
+ if not isinstance(audio_path, str) or not audio_path:
+ raise TypeError("audio_path must be a non-empty string")
+
+ if not os.path.isfile(audio_path):
+ raise FileNotFoundError(f"The specified audio file does not exist: {audio_path}")
+
+ valid_extensions = [".wav", ".flac", ".mp3", ".ogg", ".m4a", ".aac"]
+ extension = os.path.splitext(audio_path)[1].lower()
+ if extension not in valid_extensions:
+ raise ValueError(f"File extension {extension} is not recognized as a supported audio format.")
+
+ try:
+ self.data, self.rate = sf.read(audio_path, dtype='float32')
+ except RuntimeError as e:
+ raise RuntimeError(f"Error reading audio file: {audio_path}") from e
+
+ if len(self.data) == 0:
+ raise ValueError(f"Audio file is empty: {audio_path}")
+
+ # Convert stereo or multichannel audio to mono
+ if len(self.data.shape) > 1 and self.data.shape[1] > 1:
+ self.data = np.mean(self.data, axis=1)
+
+ self.audio_path = audio_path
+ self.extension = extension
+ self.samples = len(self.data)
+ self.duration = self.samples / self.rate
+
+ def properties(self) -> Tuple[
+ str, str, str, int, float, float, Union[int, None], int, float, float, Dict[str, float]]:
+ """
+ Extract various properties and features from the audio file.
+
+ Returns
+ -------
+ Tuple[str, str, str, int, float, float, Union[int, None], int, float, float, Dict[str, float]]
+ A tuple containing:
+ - File name (str)
+ - File extension (str)
+ - File path (str)
+ - Sample rate (int)
+ - Minimum frequency (float)
+ - Maximum frequency (float)
+ - Bit depth (Union[int, None])
+ - Number of channels (int)
+ - Duration (float)
+ - Root mean square loudness (float)
+ - A dictionary of extracted properties (Dict[str, float])
+
+ Notes
+ -----
+ Properties extracted include:
+ - Spectral bands energy
+ - Zero Crossing Rate (ZCR)
+ - Spectral Centroid
+ - MFCCs (Mel Frequency Cepstral Coefficients)
+
+ Examples
+ --------
+ >>> audio = Audio("sample.wav")
+ >>> audio.properties()
+ ('sample.wav', '.wav', '/path/to/sample.wav', 44100, 20.0, 20000.0, 16, 2, 5.2, 0.25, {...})
+ """
+ bands = [(20, 250), (250, 2000), (2000, 6000), (6000, 20000)]
+
+ x = fft(self.data)
+ xf = fftfreq(self.samples, 1 / self.rate)
+
+ nonzero_indices = np.where(xf != 0)[0]
+ min_freq = np.min(np.abs(xf[nonzero_indices]))
+ max_freq = np.max(np.abs(xf))
+
+ bit_depth = None
+ if self.extension == ".wav":
+ with wave.open(self.audio_path, 'r') as wav_file:
+ bit_depth = wav_file.getsampwidth() * 8
+ channels = wav_file.getnchannels()
+ else:
+ info = sf.info(self.audio_path)
+ channels = info.channels
+
+ duration = float(self.duration)
+ loudness = np.sqrt(np.mean(self.data ** 2))
+
+ s = np.abs(x)
+ freqs = xf
+ eq_properties = {}
+ for band in bands:
+ band_mask = (freqs >= band[0]) & (freqs <= band[1])
+ band_data = s[band_mask]
+ band_energy = np.mean(band_data ** 2, axis=0) if band_data.size > 0 else 0
+ eq_properties[f"EQ_{band[0]}_{band[1]}_Hz"] = band_energy
+
+ zcr = np.sum(np.abs(np.diff(np.sign(self.data)))) / len(self.data)
+
+ magnitude_spectrum = np.abs(np.fft.rfft(self.data))
+ freqs_centroid = np.fft.rfftfreq(len(self.data), 1.0 / self.rate)
+ spectral_centroid = (np.sum(freqs_centroid * magnitude_spectrum) /
+ np.sum(magnitude_spectrum)) if np.sum(magnitude_spectrum) != 0 else 0.0
+
+ mfccs = mfcc(y=self.data, sr=self.rate, n_mfcc=13)
+
+ mfcc_mean = np.mean(mfccs, axis=1)
+
+ eq_properties["RMSLoudness"] = float(loudness)
+ eq_properties["ZeroCrossingRate"] = float(zcr)
+ eq_properties["SpectralCentroid"] = float(spectral_centroid)
+ for i, val in enumerate(mfcc_mean):
+ eq_properties[f"MFCC_{i + 1}"] = float(val)
+
+ eq_properties_converted = {key: float(value) for key, value in eq_properties.items()}
+
+ file_name = os.path.basename(self.audio_path)
+ path = os.path.abspath(self.audio_path)
+
+ bit_depth = int(bit_depth) if bit_depth is not None else None
+ channels = int(channels) if channels is not None else 1
+
+ return (
+ file_name,
+ self.extension,
+ path,
+ int(self.rate),
+ float(min_freq),
+ float(max_freq),
+ bit_depth,
+ channels,
+ float(duration),
+ float(loudness),
+ eq_properties_converted
+ )
+
+
+if __name__ == "__main__":
+ words_timestamp = [
+ {'text': 'Hello', 'start': 0.0, 'end': 1.2},
+ {'text': 'world', 'start': 1.3, 'end': 2.0}
+ ]
+ speaker_timestamp = [
+ [0.0, 1.5, 1],
+ [1.6, 3.0, 2]
+ ]
+
+ word_sentence_mapper = WordSpeakerMapper(words_timestamp, speaker_timestamp)
+ word_speaker_maps = word_sentence_mapper.get_words_speaker_mapping()
+ print("Word-Speaker Mapping:")
+ print(word_speaker_maps)
diff --git a/src/audio/effect.py b/src/audio/effect.py
new file mode 100644
index 0000000000000000000000000000000000000000..1512fc1e554547345567cbd4008d99fb722afcdc
--- /dev/null
+++ b/src/audio/effect.py
@@ -0,0 +1,114 @@
+# Standard library imports
+import os
+import warnings
+from typing import Annotated, Optional
+
+# Related third-party imports
+import demucs.separate
+
+
+class DemucsVocalSeparator:
+ """
+ A class for separating vocals from an audio file using the Demucs model.
+
+ This class utilizes the Demucs model to separate specified audio stems (e.g., vocals) from an input audio file.
+ It supports saving the separated outputs to a specified directory.
+
+ Attributes
+ ----------
+ model_name : str
+ Name of the Demucs model to use for separation.
+ two_stems : str
+ The stem to isolate (e.g., "vocals").
+
+ Methods
+ -------
+ separate_vocals(audio_file: str, output_dir: str) -> Optional[str]
+ Separates vocals (or other specified stem) from the audio file and returns the path to the separated file.
+
+ """
+
+ def __init__(
+ self,
+ model_name: Annotated[str, "Demucs model name to use for separation"] = "htdemucs",
+ two_stems: Annotated[str, "Stem to isolate (e.g., vocals, drums)"] = "vocals"
+ ):
+ """
+ Initializes the DemucsVocalSeparator with the given parameters.
+
+ Parameters
+ ----------
+ model_name : str, optional
+ Name of the Demucs model to use for separation (default is "htdemucs").
+ two_stems : str, optional
+ The stem to isolate (default is "vocals").
+ """
+ self.model_name = model_name
+ self.two_stems = two_stems
+
+ def separate_vocals(self, audio_file: str, output_dir: str) -> Optional[str]:
+ """
+ Separates vocals (or other specified stem) from the audio file.
+
+ This method invokes the Demucs model to isolate a specified audio stem (e.g., vocals).
+ The output is saved in WAV format in the specified output directory.
+
+ Parameters
+ ----------
+ audio_file : str
+ Path to the input audio file.
+ output_dir : str
+ Directory where the separated files will be saved.
+
+ Returns
+ -------
+ Optional[str]
+ Path to the separated vocal file if successful, or the original audio file path if not.
+
+ Raises
+ ------
+ Warning
+ If vocal separation fails or the separated file is not found.
+
+ Examples
+ --------
+ >>> separator = DemucsVocalSeparator()
+ >>> vocal_path = separator.separate_vocals("path/to/audio/file.mp3", "output_dir")
+ Vocal separation successful! Outputs saved in WAV format at 'output_dir' directory.
+ """
+ demucs_args = [
+ "--two-stems", self.two_stems,
+ "-n", self.model_name,
+ "-o", output_dir,
+ audio_file
+ ]
+
+ try:
+ demucs.separate.main(demucs_args)
+ print(f"Vocal separation successful! Outputs saved in WAV format at '{output_dir}' directory.")
+
+ output_path = os.path.join(
+ output_dir, self.model_name,
+ os.path.splitext(os.path.basename(audio_file))[0]
+ )
+ vocal_file = os.path.join(output_path, f"{self.two_stems}.wav")
+
+ if os.path.exists(vocal_file):
+ return vocal_file
+ else:
+ print("Separated vocal file not found. Returning original audio file path.")
+ warnings.warn("Vocal separation was unsuccessful; using the original audio file.", stacklevel=2)
+ return audio_file
+
+ except Exception as e:
+ print(f"An error occurred during vocal separation: {e}")
+ warnings.warn("Vocal separation failed; proceeding with the original audio file.", stacklevel=2)
+ return audio_file
+
+
+if __name__ == "__main__":
+ file = "example_audio.mp3"
+ output_directory = "separated_outputs"
+ vocal_separator = DemucsVocalSeparator()
+ separated_file_path = vocal_separator.separate_vocals(file, output_directory)
+ print(f"Separated file path: {separated_file_path}")
diff --git a/src/audio/error.py b/src/audio/error.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9587b59c5015743a8b342533175206e49de45c1
--- /dev/null
+++ b/src/audio/error.py
@@ -0,0 +1,214 @@
+# Standard library imports
+import os
+import logging
+import subprocess
+from typing import Annotated
+
+# Related third party imports
+from pyannote.audio import Pipeline
+
+logging.basicConfig(level=logging.INFO)
+
+
+class DialogueDetecting:
+ """
+ Class for detecting dialogue in audio files using speaker diarization.
+
+ This class processes audio files by dividing them into chunks, applying a
+ pre-trained speaker diarization model, and detecting if there are multiple
+ speakers in the audio.
+
+ Parameters
+ ----------
+ pipeline_model : str, optional
+ Name of the pre-trained diarization model. Defaults to "pyannote/speaker-diarization".
+ chunk_duration : int, optional
+ Duration of each chunk in seconds. Defaults to 5.
+ sample_rate : int, optional
+ Sampling rate for the processed audio chunks. Defaults to 16000.
+ channels : int, optional
+ Number of audio channels. Defaults to 1.
+ delete_original : bool, optional
+ If True, deletes the original audio file when no dialogue is detected. Defaults to False.
+ skip_if_no_dialogue : bool, optional
+ If True, skips further processing if no dialogue is detected. Defaults to False.
+ temp_dir : str, optional
+ Directory for temporary chunk files. Defaults to ".temp".
+
+ Attributes
+ ----------
+ pipeline : Pipeline
+ Instance of the PyAnnote pipeline for speaker diarization.
+ """
+
+ def __init__(self,
+ pipeline_model: str = "pyannote/speaker-diarization",
+ chunk_duration: int = 5,
+ sample_rate: int = 16000,
+ channels: int = 1,
+ delete_original: bool = False,
+ skip_if_no_dialogue: bool = False,
+ temp_dir: str = ".temp"):
+ self.pipeline_model = pipeline_model
+ self.chunk_duration = chunk_duration
+ self.sample_rate = sample_rate
+ self.channels = channels
+ self.delete_original = delete_original
+ self.skip_if_no_dialogue = skip_if_no_dialogue
+ self.temp_dir = temp_dir
+ self.pipeline = Pipeline.from_pretrained(pipeline_model)
+
+ if not os.path.exists(self.temp_dir):
+ os.makedirs(self.temp_dir)
+
+ @staticmethod
+ def get_audio_duration(audio_file: Annotated[str, "Path to the audio file"]) -> Annotated[
+ float, "Duration of the audio in seconds"]:
+ """
+ Get the duration of an audio file in seconds.
+
+ Parameters
+ ----------
+ audio_file : str
+ Path to the audio file.
+
+ Returns
+ -------
+ float
+ Duration of the audio file in seconds.
+
+ Examples
+ --------
+ >>> DialogueDetecting.get_audio_duration("example.wav")
+ 120.5
+ """
+ result = subprocess.run(
+ ["ffprobe", "-v", "error", "-show_entries", "format=duration",
+ "-of", "default=noprint_wrappers=1:nokey=1", audio_file],
+ capture_output=True, text=True, check=True
+ )
+ return float(result.stdout.strip())
+
+ def create_chunk(self, audio_file: str, chunk_file: str, start_time: float, end_time: float):
+ """
+ Create a chunk of the audio file.
+
+ Parameters
+ ----------
+ audio_file : str
+ Path to the original audio file.
+ chunk_file : str
+ Path to save the generated chunk file.
+ start_time : float
+ Start time of the chunk in seconds.
+ end_time : float
+ End time of the chunk in seconds.
+ """
+ duration = end_time - start_time
+ subprocess.run([
+ "ffmpeg", "-y",
+ "-ss", str(start_time),
+ "-t", str(duration),
+ "-i", audio_file,
+ "-ar", str(self.sample_rate),
+ "-ac", str(self.channels),
+ "-f", "wav",
+ chunk_file
+ ], check=True)
+
+ def process_chunk(self, chunk_file: Annotated[str, "Path to the chunk file"]) -> Annotated[
+ set, "Set of detected speaker labels"]:
+ """
+ Process a single chunk of audio to detect speakers.
+
+ Parameters
+ ----------
+ chunk_file : str
+ Path to the chunk file.
+
+ Returns
+ -------
+ set
+ Set of detected speaker labels in the chunk.
+ """
+ diarization = self.pipeline(chunk_file)
+ speakers_in_chunk = set()
+ for segment, track, label in diarization.itertracks(yield_label=True):
+ speakers_in_chunk.add(label)
+ return speakers_in_chunk
+
+ def process(self, audio_file: Annotated[str, "Path to the input audio file"]) -> Annotated[
+ bool, "True if dialogue detected, False otherwise"]:
+ """
+ Process the audio file to detect dialogue.
+
+ Parameters
+ ----------
+ audio_file : str
+ Path to the audio file.
+
+ Returns
+ -------
+ bool
+ True if at least two speakers are detected, False otherwise.
+
+ Examples
+ --------
+ >>> dialogue_detector = DialogueDetecting()
+ >>> dialogue_detector.process("example.wav")
+ True
+ """
+ total_duration = self.get_audio_duration(audio_file)
+ num_chunks = int(total_duration // self.chunk_duration) + 1
+
+ speakers_detected = set()
+ chunk_files = []
+
+ try:
+ for i in range(num_chunks):
+ start_time = i * self.chunk_duration
+ end_time = min(float((i + 1) * self.chunk_duration), total_duration)
+
+ if end_time - start_time < 1.0:
+ logging.info("Last chunk is too short to process.")
+ break
+
+ chunk_file = os.path.join(self.temp_dir, f"chunk_{i}.wav")
+ chunk_files.append(chunk_file)
+ logging.info(f"Creating chunk: {chunk_file}")
+ self.create_chunk(audio_file, chunk_file, start_time, end_time)
+
+ logging.info(f"Processing chunk: {chunk_file}")
+ chunk_speakers = self.process_chunk(chunk_file)
+ speakers_detected.update(chunk_speakers)
+
+ if len(speakers_detected) >= 2:
+ logging.info("At least two speakers detected, stopping.")
+ return True
+
+ if len(speakers_detected) < 2:
+ logging.info("No dialogue detected or only one speaker found.")
+ if self.delete_original:
+ logging.info(f"No dialogue found. Deleting original file: {audio_file}")
+ os.remove(audio_file)
+ if self.skip_if_no_dialogue:
+ logging.info("Skipping further processing due to lack of dialogue.")
+ return False
+
+ finally:
+ logging.info("Cleaning up temporary chunk files.")
+ for chunk_file in chunk_files:
+ if os.path.exists(chunk_file):
+ os.remove(chunk_file)
+
+ if os.path.exists(self.temp_dir) and not os.listdir(self.temp_dir):
+ os.rmdir(self.temp_dir)
+
+ return len(speakers_detected) >= 2
+
+
+if __name__ == "__main__":
+ processor = DialogueDetecting(delete_original=True)
+ audio_path = ".data/example/kafkasya.mp3"
+ process_result = processor.process(audio_path)
+ print("Dialogue detected:", process_result)
diff --git a/src/audio/io.py b/src/audio/io.py
new file mode 100644
index 0000000000000000000000000000000000000000..9cfd0bbb77f567e4a3b4d91c61664b925d2965be
--- /dev/null
+++ b/src/audio/io.py
@@ -0,0 +1,248 @@
+# Standard library imports
+import os
+from typing import List, Dict, Annotated
+
+
+class SpeakerTimestampReader:
+ """
+ A class to read and parse speaker timestamps from an RTTM file.
+
+ Attributes
+ ----------
+ rttm_path : str
+ Path to the RTTM file containing speaker timestamps.
+
+ Methods
+ -------
+ read_speaker_timestamps()
+ Reads the RTTM file and extracts speaker timestamps.
+
+ Parameters
+ ----------
+ rttm_path : str
+ Path to the RTTM file containing speaker timestamps.
+
+ Raises
+ ------
+ FileNotFoundError
+ If the RTTM file does not exist at the specified path.
+
+ """
+
+ def __init__(self, rttm_path: str):
+ """
+ Initializes the SpeakerTimestampReader with the path to an RTTM file.
+
+ Parameters
+ ----------
+ rttm_path : str
+ Path to the RTTM file containing speaker timestamps.
+
+ Raises
+ ------
+ FileNotFoundError
+ If the RTTM file does not exist at the specified path.
+ """
+ if not os.path.isfile(rttm_path):
+ raise FileNotFoundError(f"RTTM file not found at: {rttm_path}")
+ self.rttm_path = rttm_path
+
+ def read_speaker_timestamps(self) -> List[List[float]]:
+ """
+ Reads the RTTM file and extracts speaker timestamps.
+
+ Returns
+ -------
+ List[List[float]]
+ A list where each sublist contains [start_time, end_time, speaker_label].
+
+ Notes
+ -----
+ - The times are converted to milliseconds.
+ - Lines with invalid data are skipped.
+
+ Examples
+ --------
+ >>> reader = SpeakerTimestampReader("path/to/rttm_file.rttm")
+ >>> timestamps = reader.read_speaker_timestamps()
+ Speaker_Timestamps: [[0.0, 2000.0, 1], [2100.0, 4000.0, 2]]
+ """
+ speaker_ts = []
+ with open(self.rttm_path) as f:
+ lines = f.readlines()
+ for line in lines:
+ line_list = line.strip().split()
+ try:
+ if len(line_list) < 8:
+ print(f"Skipping line due to unexpected format: {line.strip()}")
+ continue
+
+ start_time = float(line_list[3]) * 1000
+ duration = float(line_list[4]) * 1000
+ end_time = start_time + duration
+
+ speaker_label_str = line_list[7]
+ speaker_label = int(speaker_label_str.split("_")[-1])
+
+ speaker_ts.append([start_time, end_time, speaker_label])
+ except (ValueError, IndexError) as e:
+ print(f"Skipping line due to parsing error: {line.strip()} - {e}")
+ continue
+
+ print(f"Speaker_Timestamps: {speaker_ts}")
+ return speaker_ts
+
+
+class TranscriptWriter:
+ """
+ A class to write speaker-aware transcripts in plain text or SRT formats.
+
+ Methods
+ -------
+ write_transcript(sentences_speaker_mapping, file_path)
+ Writes the speaker-aware transcript to a text file.
+ write_srt(sentences_speaker_mapping, file_path)
+ Writes the speaker-aware transcript to an SRT file format.
+ """
+
+ def __init__(self):
+ """
+ Initializes the TranscriptWriter.
+ """
+ pass
+
+ @staticmethod
+ def write_transcript(sentences_speaker_mapping: List[Dict], file_path: str):
+ """
+ Writes the speaker-aware transcript to a text file.
+
+ Parameters
+ ----------
+ sentences_speaker_mapping : List[Dict]
+ List of sentences with speaker labels, where each dictionary contains:
+ - "speaker": Speaker label (e.g., Speaker 1, Speaker 2).
+ - "text": Text of the spoken sentence.
+ file_path : str
+ Path to the output text file.
+
+ Examples
+ --------
+ >>> sentences_speaker_map = [{"speaker": "Speaker 1", "text": "Hello."},
+ {"speaker": "Speaker 2", "text": "Hi there."}]
+ >>> TranscriptWriter.write_transcript(sentences_speaker_mapping, "output.txt")
+ """
+ with open(file_path, "w", encoding="utf-8") as f:
+ previous_speaker = sentences_speaker_mapping[0]["speaker"]
+ f.write(f"{previous_speaker}: ")
+
+ for sentence_dict in sentences_speaker_mapping:
+ speaker = sentence_dict["speaker"]
+ sentence = sentence_dict["text"].strip()
+
+ if speaker != previous_speaker:
+ f.write(f"\n\n{speaker}: ")
+ previous_speaker = speaker
+
+ f.write(sentence + " ")
+
+ @staticmethod
+ def write_srt(sentences_speaker_mapping: List[Dict], file_path: str):
+ """
+ Writes the speaker-aware transcript to an SRT file format.
+
+ Parameters
+ ----------
+ sentences_speaker_mapping : List[Dict]
+ List of sentences with speaker labels and timestamps, where each dictionary contains:
+ - "start_time": Start time of the sentence in milliseconds.
+ - "end_time": End time of the sentence in milliseconds.
+ - "speaker": Speaker label.
+ - "text": Text of the spoken sentence.
+ file_path : str
+ Path to the output SRT file.
+
+ Notes
+ -----
+ The function formats timestamps in the HH:MM:SS,mmm format for SRT.
+
+ Examples
+ --------
+ >>> sentences_speaker_map = [{"start_time": 0, "end_time": 2000,
+ "speaker": "Speaker 1", "text": "Hello."}]
+ >>> TranscriptWriter.write_srt(sentences_speaker_mapping, "output.srt")
+ """
+
+ def format_timestamp(milliseconds: Annotated[float, "Time in milliseconds"]) -> Annotated[
+ str, "Formatted timestamp in HH:MM:SS,mmm"]:
+ """
+ Converts a time value in milliseconds to an SRT timestamp format.
+
+ This function takes a time value in milliseconds and formats it into
+ the standard SRT (SubRip Subtitle) timestamp format: `HH:MM:SS,mmm`.
+
+ Parameters
+ ----------
+ milliseconds : float
+ Time value in milliseconds to be converted.
+
+ Returns
+ -------
+ str
+ A string representing the time in `HH:MM:SS,mmm` format.
+
+ Raises
+ ------
+ ValueError
+ If the input time is negative.
+
+ Examples
+ --------
+ >>> format_timestamp(3723001)
+ '01:02:03,001'
+ >>> format_timestamp(0)
+ '00:00:00,000'
+ >>> format_timestamp(59_999.9)
+ '00:00:59,999'
+
+ Notes
+ -----
+ The function ensures the correct zero-padding for hours, minutes,
+ seconds, and milliseconds to meet the SRT format requirements.
+ """
+ if milliseconds < 0:
+ raise ValueError("Time in milliseconds cannot be negative.")
+
+ hours = int(milliseconds // 3_600_000)
+ minutes = int((milliseconds % 3_600_000) // 60_000)
+ seconds = int((milliseconds % 60_000) // 1_000)
+ milliseconds = int(milliseconds % 1_000)
+
+ return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}"
+
+ with open(file_path, "w", encoding="utf-8") as f:
+ for i, segment in enumerate(sentences_speaker_mapping, start=1):
+ start_time = format_timestamp(segment['start_time'])
+ end_time = format_timestamp(segment['end_time'])
+ speaker = segment['speaker']
+ text = segment['text'].strip().replace('-->', '->')
+
+ f.write(f"{i}\n")
+ f.write(f"{start_time} --> {end_time}\n")
+ f.write(f"{speaker}: {text}\n\n")
+
+
+if __name__ == "__main__":
+ example_rttm_path = "example.rttm"
+ try:
+ timestamp_reader = SpeakerTimestampReader(example_rttm_path)
+ extracted_speaker_timestamps = timestamp_reader.read_speaker_timestamps()
+ except FileNotFoundError as file_error:
+ print(file_error)
+
+ example_sentences_mapping = [
+ {"speaker": "Speaker 1", "text": "Hello there.", "start_time": 0, "end_time": 2000},
+ {"speaker": "Speaker 2", "text": "How are you?", "start_time": 2100, "end_time": 4000},
+ ]
+ transcript_writer = TranscriptWriter()
+ transcript_writer.write_transcript(example_sentences_mapping, "output.txt")
+ transcript_writer.write_srt(example_sentences_mapping, "output.srt")
diff --git a/src/audio/metrics.py b/src/audio/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..660449aa3856e54872df2b3c112c78b2c969bca0
--- /dev/null
+++ b/src/audio/metrics.py
@@ -0,0 +1,242 @@
+# Standard library imports
+import math
+from typing import Annotated, List, Dict
+
+# Related third-party imports
+import numpy as np
+
+
+class SilenceStats:
+ """
+ A class to compute and analyze statistics for silence durations
+ between speech segments.
+
+ This class provides methods to compute common statistical metrics
+ (mean, median, standard deviation, interquartile range) and thresholds
+ based on silence durations.
+
+ Attributes
+ ----------
+ silence_durations : List[float]
+ A sorted list of silence durations.
+
+ Methods
+ -------
+ from_segments(segments)
+ Class method to create a SilenceStats instance from speech segments.
+ median()
+ Compute the median silence duration.
+ mean()
+ Compute the mean silence duration.
+ std()
+ Compute the standard deviation of silence durations.
+ iqr()
+ Compute the interquartile range (IQR) of silence durations.
+ threshold_std(factor=0.95)
+ Compute threshold based on standard deviation.
+ threshold_median_iqr(factor=1.5)
+ Compute threshold based on median + IQR.
+ total_silence_above_threshold(threshold)
+ Compute total silence above a given threshold.
+ """
+
+ def __init__(self, silence_durations: Annotated[List[float], "List of silence durations"]):
+ """
+ Initialize the SilenceStats class with a list of silence durations.
+
+ Parameters
+ ----------
+ silence_durations : List[float]
+ List of silence durations (non-negative values).
+ """
+ if not all(isinstance(x, (int, float)) and x >= 0 for x in silence_durations):
+ raise ValueError("silence_durations must be a list of non-negative numbers.")
+ self.silence_durations = sorted(silence_durations)
+
+ @classmethod
+ def from_segments(cls, segments: Annotated[List[Dict], "List of speech segments"]) -> "SilenceStats":
+ """
+ Create a SilenceStats instance from a list of speech segments.
+
+ Parameters
+ ----------
+ segments : List[Dict]
+ List of speech segments, where each segment contains 'start_time'
+ and 'end_time' keys.
+
+ Returns
+ -------
+ SilenceStats
+ A SilenceStats instance with computed silence durations.
+
+ Examples
+ --------
+ >>> segment = [{"start_time": 0, "end_time": 5}, {"start_time": 10, "end_time": 15}]
+ >>> stat = SilenceStats.from_segments(segments)
+ >>> stat.silence_durations
+ [5]
+ """
+ segments_sorted = sorted(segments, key=lambda x: x['start_time'])
+ durations = [
+ segments_sorted[i + 1]['start_time'] - segments_sorted[i]['end_time']
+ for i in range(len(segments_sorted) - 1)
+ if (segments_sorted[i + 1]['start_time'] - segments_sorted[i]['end_time']) > 0
+ ]
+ return cls(durations)
+
+ def median(self) -> Annotated[float, "Median of silence durations"]:
+ """
+ Compute the median silence duration.
+
+ Returns
+ -------
+ float
+ The median of the silence durations.
+ """
+ n = len(self.silence_durations)
+ if n == 0:
+ return 0.0
+ mid = n // 2
+ if n % 2 == 0:
+ return (self.silence_durations[mid - 1] + self.silence_durations[mid]) / 2
+ return self.silence_durations[mid]
+
+ def mean(self) -> Annotated[float, "Mean of silence durations"]:
+ """
+ Compute the mean silence duration.
+
+ Returns
+ -------
+ float
+ The mean of the silence durations.
+ """
+ return sum(self.silence_durations) / len(self.silence_durations) if self.silence_durations else 0.0
+
+ def std(self) -> Annotated[float, "Standard deviation of silence durations"]:
+ """
+ Compute the standard deviation of silence durations.
+
+ Returns
+ -------
+ float
+ The standard deviation of the silence durations.
+ """
+ n = len(self.silence_durations)
+ if n == 0:
+ return 0.0
+ mu = self.mean()
+ var = sum((x - mu) ** 2 for x in self.silence_durations) / n
+ return math.sqrt(var)
+
+ def iqr(self) -> Annotated[float, "Interquartile range (IQR) of silence durations"]:
+ """
+ Compute the Interquartile Range (IQR).
+
+ Returns
+ -------
+ float
+ The IQR of the silence durations.
+ """
+ if not self.silence_durations:
+ return 0.0
+ q1 = np.percentile(self.silence_durations, 25)
+ q3 = np.percentile(self.silence_durations, 75)
+ return q3 - q1
+
+ def threshold_std(self, factor: Annotated[float, "Scaling factor for std threshold"] = 0.95) -> float:
+ """
+ Compute the threshold based on standard deviation.
+
+ Parameters
+ ----------
+ factor : float, optional
+ A scaling factor for the standard deviation, by default 0.95.
+
+ Returns
+ -------
+ float
+ Threshold based on standard deviation.
+ """
+ return self.std() * factor
+
+ def threshold_median_iqr(self, factor: Annotated[float, "Scaling factor for IQR"] = 1.5) -> float:
+ """
+ Compute the threshold based on median and IQR.
+
+ Parameters
+ ----------
+ factor : float, optional
+ A scaling factor for the IQR, by default 1.5.
+
+ Returns
+ -------
+ float
+ Threshold based on median and IQR.
+ """
+ return self.median() + (self.iqr() * factor)
+
+ def total_silence_above_threshold(
+ self, threshold: Annotated[float, "Threshold value for silence"]
+ ) -> Annotated[float, "Total silence above the threshold"]:
+ """
+ Compute the total silence above the given threshold.
+
+ Parameters
+ ----------
+ threshold : float
+ The threshold value to compare silence durations.
+
+ Returns
+ -------
+ float
+ Total silence duration above the threshold.
+ """
+ return sum(s for s in self.silence_durations if s >= threshold)
+
+
+if __name__ == "__main__":
+ final_ssm = {
+ 'ssm': [
+ {'speaker': 'Customer', 'start_time': 8500, 'end_time': 9760, 'text': 'Hey, G-Chance, this is Jennifer. ',
+ 'index': 0, 'sentiment': 'Neutral', 'profane': False},
+ {'speaker': 'CSR', 'start_time': 10660, 'end_time': 11560, 'text': 'Yes, hi, Jennifer. ', 'index': 1,
+ 'sentiment': 'Neutral', 'profane': False},
+ {'speaker': 'CSR', 'start_time': 11620, 'end_time': 12380, 'text': "Good afternoon, ma'am. ", 'index': 2,
+ 'sentiment': 'Neutral', 'profane': False},
+ {'speaker': 'CSR', 'start_time': 83880, 'end_time': 85460, 'text': 'Okay. ', 'index': 24,
+ 'sentiment': 'Neutral', 'profane': False},
+ {'speaker': 'CSR', 'start_time': 85500, 'end_time': 85620, 'text': 'Yeah. ', 'index': 25,
+ 'sentiment': 'Neutral', 'profane': False},
+ {'speaker': 'CSR', 'start_time': 86400, 'end_time': 90320,
+ 'text': "So I'll be sending this shipping documents right after this call. ", 'index': 26,
+ 'sentiment': 'Neutral', 'profane': False},
+ {'speaker': 'CSR', 'start_time': 90400, 'end_time': 91160, 'text': 'Thank you so much. ', 'index': 27,
+ 'sentiment': 'Neutral', 'profane': False},
+ {'speaker': 'Customer', 'start_time': 92060, 'end_time': 92680, 'text': 'Okay, thank you. ', 'index': 28,
+ 'sentiment': 'Neutral', 'profane': False},
+ {'speaker': 'Customer', 'start_time': 93880, 'end_time': 98220, 'text': 'All right, bye-bye. ', 'index': 29,
+ 'sentiment': 'Neutral', 'profane': False}
+ ],
+ 'summary': 'Gabby from Transplace AP Team called Jennifer to request copies of a carrier invoice, bill of '
+ 'lading, and proof of delivery, and Jennifer provided her email for Gabby to send the shipping '
+ 'documents.',
+ 'conflict': False,
+ 'topic': 'Invoice and Shipping Documents Request'
+ }
+
+ stats = SilenceStats.from_segments(final_ssm['ssm'])
+
+ print("Mean:", stats.mean())
+ print("Median:", stats.median())
+ print("Std Dev:", stats.std())
+ print("IQR:", stats.iqr())
+
+ t_std = stats.threshold_std(factor=0.97)
+ t_median_iqr = stats.threshold_median_iqr(factor=1.49)
+ print("Threshold (std-based):", t_std)
+ print("Threshold (median+IQR):", t_median_iqr)
+
+ print("Total silence (std-based):", stats.total_silence_above_threshold(t_std))
+ print("Total silence (median+IQR-based):", stats.total_silence_above_threshold(t_median_iqr))
+ final_ssm["silence"] = t_std
+ print(final_ssm)
diff --git a/src/audio/preprocessing.py b/src/audio/preprocessing.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f2d181eb404bdf502d28c98b7f9482938f70253
--- /dev/null
+++ b/src/audio/preprocessing.py
@@ -0,0 +1,267 @@
+# Standard library imports
+import os
+from typing import Annotated
+
+# Related third-party imports
+import librosa
+import soundfile as sf
+from librosa.feature import rms
+from omegaconf import OmegaConf
+from noisereduce import reduce_noise
+from MPSENet import MPSENet
+
+# Local imports
+from src.utils.utils import Logger
+
+
+class Denoiser:
+ """
+ A class to handle audio denoising using librosa and noisereduce.
+
+ This class provides methods to load noisy audio, apply denoising, and
+ save the cleaned output to disk.
+
+ Parameters
+ ----------
+ config_path : str
+ Path to the configuration file that specifies runtime settings.
+ output_dir : str, optional
+ Directory to save cleaned audio files. Defaults to ".temp".
+
+ Attributes
+ ----------
+ config : omegaconf.DictConfig
+ Loaded configuration data.
+ output_dir : str
+ Directory to save cleaned audio files.
+ logger : Logger
+ Logger instance for recording messages.
+ """
+
+ def __init__(self, config_path: Annotated[str, "Path to the config file"],
+ output_dir: Annotated[str, "Default directory to save cleaned audio files"] = ".temp") -> None:
+ """
+ Initialize the Denoiser class.
+
+ Parameters
+ ----------
+ config_path : str
+ Path to the configuration file that specifies runtime settings.
+ output_dir : str, optional
+ Default directory to save cleaned audio files. Defaults to ".temp".
+ """
+ self.config = OmegaConf.load(config_path)
+ self.output_dir = output_dir
+ os.makedirs(self.output_dir, exist_ok=True)
+ self.logger = Logger(name="DenoiserLogger")
+
+ def denoise_audio(
+ self,
+ input_path: Annotated[str, "Path to the noisy audio file"],
+ output_dir: Annotated[str, "Directory to save the cleaned audio file"],
+ noise_threshold: Annotated[float, "Noise threshold value to decide if denoising is needed"],
+ print_output: Annotated[bool, "Whether to log the process to console"] = False,
+ ) -> str:
+ """
+ Denoise an audio file using noisereduce and librosa.
+
+ Parameters
+ ----------
+ input_path : str
+ Path to the noisy input audio file.
+ output_dir : str
+ Directory to save the cleaned audio file.
+ noise_threshold : float
+ Noise threshold value to decide if denoising is needed.
+ print_output : bool, optional
+ Whether to log the process to the console. Defaults to False.
+
+ Returns
+ -------
+ str
+ Path to the saved audio file if denoising is performed, otherwise the original audio file path.
+
+ Examples
+ --------
+ >>> denoise = Denoiser("config.yaml")
+ >>> input_file = "noisy_audio.wav"
+ >>> output_directory = "cleaned_audio"
+ >>> noise_thresh = 0.02
+ >>> result = denoiser.denoise_audio(input_file, output_directory, noise_thresh)
+ >>> print(result)
+ cleaned_audio/denoised.wav
+ """
+ self.logger.log(f"Loading: {input_path}", print_output=print_output)
+
+ noisy_waveform, sr = librosa.load(input_path, sr=None)
+
+ noise_level = rms(y=noisy_waveform).mean()
+ self.logger.log(f"Calculated noise level: {noise_level}", print_output=print_output)
+
+ if noise_level < noise_threshold:
+ self.logger.log("Noise level is below the threshold. Skipping denoising.", print_output=print_output)
+ return input_path
+
+ self.logger.log("Denoising process started...", print_output=print_output)
+
+ cleaned_waveform = reduce_noise(y=noisy_waveform, sr=sr)
+
+ output_path = os.path.join(output_dir, "denoised.wav")
+
+ os.makedirs(output_dir, exist_ok=True)
+
+ sf.write(output_path, cleaned_waveform, sr)
+
+ self.logger.log(f"Denoising completed! Cleaned file: {output_path}", print_output=print_output)
+
+ return output_path
+
+
+class SpeechEnhancement:
+ """
+ A class for speech enhancement using the MPSENet model.
+
+ This class provides methods to load audio, apply enhancement using a
+ pre-trained MPSENet model, and save the enhanced output.
+
+ Parameters
+ ----------
+ config_path : str
+ Path to the configuration file specifying runtime settings.
+ output_dir : str, optional
+ Directory to save enhanced audio files. Defaults to ".temp".
+
+ Attributes
+ ----------
+ config : omegaconf.DictConfig
+ Loaded configuration data.
+ output_dir : str
+ Directory to save enhanced audio files.
+ model_name : str
+ Name of the pre-trained model.
+ device : str
+ Device to run the model (e.g., "cpu" or "cuda").
+ model : MPSENet
+ Pre-trained MPSENet model instance.
+ """
+
+ def __init__(
+ self,
+ config_path: Annotated[str, "Path to the config file"],
+ output_dir: Annotated[str, "Default directory to save enhanced audio files"] = ".temp"
+ ) -> None:
+ """
+ Initialize the SpeechEnhancement class.
+
+ Parameters
+ ----------
+ config_path : str
+ Path to the configuration file specifying runtime settings.
+ output_dir : str, optional
+ Directory to save enhanced audio files. Defaults to ".temp".
+ """
+ self.config = OmegaConf.load(config_path)
+ self.output_dir = output_dir
+ os.makedirs(self.output_dir, exist_ok=True)
+
+ self.model_name = self.config.models.mpsenet.model_name
+ self.device = self.config.runtime.device
+
+ self.model = MPSENet.from_pretrained(self.model_name).to(self.device)
+
+ def enhance_audio(
+ self,
+ input_path: Annotated[str, "Path to the original audio file"],
+ output_path: Annotated[str, "Path to save the enhanced audio file"],
+ noise_threshold: Annotated[float, "Noise threshold value to decide if enhancement is needed"],
+ verbose: Annotated[bool, "Whether to log additional info to console"] = False,
+ ) -> str:
+ """
+ Enhance an audio file using the MPSENet model.
+
+ Parameters
+ ----------
+ input_path : str
+ Path to the original input audio file.
+ output_path : str
+ Path to save the enhanced audio file.
+ noise_threshold : float
+ Noise threshold value to decide if enhancement is needed.
+ verbose : bool, optional
+ Whether to log additional info to the console. Defaults to False.
+
+ Returns
+ -------
+ str
+ Path to the enhanced audio file if enhancement is performed, otherwise the original file path.
+
+ Examples
+ --------
+ >>> enhancer = SpeechEnhancement("config.yaml")
+ >>> input_file = "raw_audio.wav"
+ >>> output_file = "enhanced_audio.wav"
+ >>> noise_thresh = 0.03
+ >>> result = enhancer.enhance_audio(input_file, output_file, noise_thresh)
+ >>> print(result)
+ enhanced_audio.wav
+ """
+ raw_waveform, sr_raw = librosa.load(input_path, sr=None)
+ noise_level = rms(y=raw_waveform).mean()
+
+ if verbose:
+ print(f"[SpeechEnhancement] Detected noise level: {noise_level:.6f}")
+
+ if noise_level < noise_threshold:
+ if verbose:
+ print(f"[SpeechEnhancement] Noise level < {noise_threshold} → enhancement skipped.")
+ return input_path
+
+ sr_model = self.model.h.sampling_rate
+ waveform, sr = librosa.load(input_path, sr=sr_model)
+
+ if verbose:
+ print(f"[SpeechEnhancement] Enhancement with MPSENet started using model: {self.model_name}")
+
+ enhanced_waveform, sr_out, _ = self.model(waveform)
+
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
+ sf.write(output_path, enhanced_waveform, sr_out)
+
+ if verbose:
+ print(f"[SpeechEnhancement] Enhancement complete. Saved to: {output_path}")
+
+ return output_path
+
+
+if __name__ == "__main__":
+
+ test_config_path = "config/config.yaml"
+ noisy_audio_file = ".data/example/noisy/LookOncetoHearTargetSpeechHearingwithNoisyExamples.mp3"
+ temp_dir = ".temp"
+
+ denoiser = Denoiser(config_path=test_config_path, output_dir=temp_dir)
+ denoised_path = denoiser.denoise_audio(
+ input_path=noisy_audio_file,
+ output_dir=temp_dir,
+ noise_threshold=0.005,
+ print_output=True
+ )
+ if denoised_path == noisy_audio_file:
+ print("Denoising skipped due to low noise level.")
+ else:
+ print(f"Denoising completed! Cleaned file saved at: {denoised_path}")
+
+ speech_enhancer = SpeechEnhancement(config_path=test_config_path, output_dir=temp_dir)
+ enhanced_audio_path = os.path.join(temp_dir, "enhanced_audio.wav")
+
+ result_path = speech_enhancer.enhance_audio(
+ input_path=denoised_path,
+ output_path=enhanced_audio_path,
+ noise_threshold=0.005,
+ verbose=True
+ )
+
+ if result_path == denoised_path:
+ print("Enhancement skipped due to low noise level.")
+ else:
+ print(f"Speech enhancement completed! Enhanced file saved at: {result_path}")
diff --git a/src/audio/processing.py b/src/audio/processing.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5554b51c21148518e841ec59096d014aa880316
--- /dev/null
+++ b/src/audio/processing.py
@@ -0,0 +1,614 @@
+# Standard library imports
+import os
+import re
+import json
+from io import TextIOWrapper
+from typing import Annotated, Optional, Tuple, List, Dict
+
+# Related third party imports
+import torch
+import faster_whisper
+from pydub import AudioSegment
+from deepmultilingualpunctuation import PunctuationModel
+
+# Local imports
+from src.audio.utils import TokenizerUtils
+
+
+class AudioProcessor:
+ """
+ A class to handle various audio processing tasks, such as conversion,
+ trimming, merging, and audio transformations.
+
+ Parameters
+ ----------
+ audio_path : str
+ Path to the audio file to process.
+ temp_dir : str, optional
+ Directory for storing temporary files. Defaults to ".temp".
+
+ Attributes
+ ----------
+ audio_path : str
+ Path to the input audio file.
+ temp_dir : str
+ Path to the temporary directory for processed files.
+ mono_audio_path : Optional[str]
+ Path to the mono audio file after conversion.
+
+ Methods
+ -------
+ convert_to_mono()
+ Converts the audio file to mono.
+ get_duration()
+ Gets the duration of the audio file in seconds.
+ change_format(new_format)
+ Converts the audio file to a new format.
+ trim_audio(start_time, end_time)
+ Trims the audio file to the specified time range.
+ adjust_volume(change_in_db)
+ Adjusts the volume of the audio file.
+ get_channels()
+ Gets the number of audio channels.
+ fade_in_out(fade_in_duration, fade_out_duration)
+ Applies fade-in and fade-out effects to the audio.
+ merge_audio(other_audio_path)
+ Merges the current audio with another audio file.
+ split_audio(chunk_duration)
+ Splits the audio file into chunks of a specified duration.
+ create_manifest(manifest_path)
+ Creates a manifest file containing metadata about the audio.
+ """
+
+ def __init__(
+ self,
+ audio_path: Annotated[str, "Path to the audio file"],
+ temp_dir: Annotated[str, "Directory for temporary processed files"] = ".temp"
+ ) -> None:
+ if not isinstance(audio_path, str):
+ raise TypeError("Expected 'audio_path' to be a string.")
+ if not isinstance(temp_dir, str):
+ raise TypeError("Expected 'temp_dir' to be a string.")
+
+ self.audio_path = audio_path
+ self.temp_dir = temp_dir
+ self.mono_audio_path = None
+ os.makedirs(temp_dir, exist_ok=True)
+
+ def convert_to_mono(self) -> Annotated[str, "Path to the mono audio file"]:
+ """
+ Convert the audio file to mono.
+
+ Returns
+ -------
+ str
+ Path to the mono audio file.
+
+ Examples
+ --------
+ >>> processor = AudioProcessor("example.wav")
+ >>> mono_path = processor.convert_to_mono()
+ >>> isinstance(mono_path, str)
+ True
+ """
+ sound = AudioSegment.from_file(self.audio_path)
+ mono_sound = sound.set_channels(1)
+ self.mono_audio_path = os.path.join(self.temp_dir, "mono_file.wav")
+ mono_sound.export(self.mono_audio_path, format="wav")
+ return self.mono_audio_path
+
+ def get_duration(self) -> Annotated[float, "Audio duration in seconds"]:
+ """
+ Get the duration of the audio file.
+
+ Returns
+ -------
+ float
+ Duration of the audio in seconds.
+
+ Examples
+ --------
+ >>> processor = AudioProcessor("example.wav")
+ >>> duration = processor.get_duration()
+ >>> isinstance(duration, float)
+ True
+ """
+ sound = AudioSegment.from_file(self.audio_path)
+ return len(sound) / 1000.0
+
+ def change_format(
+ self, new_format: Annotated[str, "New audio format"]
+ ) -> Annotated[str, "Path to converted audio file"]:
+ """
+ Convert the audio file to a new format.
+
+ Parameters
+ ----------
+ new_format : str
+ Desired format for the output audio file.
+
+ Returns
+ -------
+ str
+ Path to the converted audio file.
+
+ Examples
+ --------
+ >>> processor = AudioProcessor("example.wav")
+ >>> converted_path = processor.change_format("mp3")
+ >>> isinstance(converted_path, str)
+ True
+ """
+ if not isinstance(new_format, str):
+ raise TypeError("Expected 'new_format' to be a string.")
+
+ sound = AudioSegment.from_file(self.audio_path)
+ output_path = os.path.join(self.temp_dir, f"converted_file.{new_format}")
+ sound.export(output_path, format=new_format)
+ return output_path
+
+ def trim_audio(
+ self, start_time: Annotated[float, "Start time in seconds"],
+ end_time: Annotated[float, "End time in seconds"]
+ ) -> Annotated[str, "Path to trimmed audio file"]:
+ """
+ Trim the audio file to the specified duration.
+
+ Parameters
+ ----------
+ start_time : float
+ Start time in seconds.
+ end_time : float
+ End time in seconds.
+
+ Returns
+ -------
+ str
+ Path to the trimmed audio file.
+
+ Examples
+ --------
+ >>> processor = AudioProcessor("example.wav")
+ >>> trimmed_path = processor.trim_audio(0.0, 10.0)
+ >>> isinstance(trimmed_path, str)
+ True
+ """
+ if not isinstance(start_time, (int, float)):
+ raise TypeError("Expected 'start_time' to be a float or int.")
+ if not isinstance(end_time, (int, float)):
+ raise TypeError("Expected 'end_time' to be a float or int.")
+
+ sound = AudioSegment.from_file(self.audio_path)
+ trimmed_audio = sound[start_time * 1000:end_time * 1000]
+ trimmed_audio_path = os.path.join(self.temp_dir, "trimmed_file.wav")
+ trimmed_audio.export(trimmed_audio_path, format="wav")
+ return trimmed_audio_path
+
+ def adjust_volume(
+ self, change_in_db: Annotated[float, "Volume change in dB"]
+ ) -> Annotated[str, "Path to volume-adjusted audio file"]:
+ """
+ Adjust the volume of the audio file.
+
+ Parameters
+ ----------
+ change_in_db : float
+ Volume change in decibels.
+
+ Returns
+ -------
+ str
+ Path to the volume-adjusted audio file.
+
+ Examples
+ --------
+ >>> processor = AudioProcessor("example.wav")
+ >>> adjusted_path = processor.adjust_volume(5.0)
+ >>> isinstance(adjusted_path, str)
+ True
+ """
+ if not isinstance(change_in_db, (int, float)):
+ raise TypeError("Expected 'change_in_db' to be a float or int.")
+
+ sound = AudioSegment.from_file(self.audio_path)
+ adjusted_audio = sound + change_in_db
+ adjusted_audio_path = os.path.join(self.temp_dir, "adjusted_volume.wav")
+ adjusted_audio.export(adjusted_audio_path, format="wav")
+ return adjusted_audio_path
+
+ def get_channels(self) -> Annotated[int, "Number of channels"]:
+ """
+ Get the number of audio channels.
+
+ Returns
+ -------
+ int
+ Number of audio channels.
+
+ Examples
+ --------
+ >>> processor = AudioProcessor("example.wav")
+ >>> channels = processor.get_channels()
+ >>> isinstance(channels, int)
+ True
+ """
+ sound = AudioSegment.from_file(self.audio_path)
+ return sound.channels
+
+ def fade_in_out(
+ self, fade_in_duration: Annotated[float, "Fade-in duration in seconds"],
+ fade_out_duration: Annotated[float, "Fade-out duration in seconds"]
+ ) -> Annotated[str, "Path to faded audio file"]:
+ """
+ Apply fade-in and fade-out effects to the audio file.
+
+ Parameters
+ ----------
+ fade_in_duration : float
+ Duration of the fade-in effect in seconds.
+ fade_out_duration : float
+ Duration of the fade-out effect in seconds.
+
+ Returns
+ -------
+ str
+ Path to the faded audio file.
+
+ Examples
+ --------
+ >>> processor = AudioProcessor("example.wav")
+ >>> faded_path = processor.fade_in_out(1.0, 2.0)
+ >>> isinstance(faded_path, str)
+ True
+ """
+ if not isinstance(fade_in_duration, (int, float)):
+ raise TypeError("Expected 'fade_in_duration' to be a float or int.")
+ if not isinstance(fade_out_duration, (int, float)):
+ raise TypeError("Expected 'fade_out_duration' to be a float or int.")
+
+ sound = AudioSegment.from_file(self.audio_path)
+ faded_audio = sound.fade_in(fade_in_duration * 1000).fade_out(fade_out_duration * 1000)
+ faded_audio_path = os.path.join(self.temp_dir, "faded_audio.wav")
+ faded_audio.export(faded_audio_path, format="wav")
+ return faded_audio_path
+
+ def merge_audio(
+ self, other_audio_path: Annotated[str, "Path to other audio file"]
+ ) -> Annotated[str, "Path to merged audio file"]:
+ """
+ Merge the current audio file with another audio file.
+
+ Parameters
+ ----------
+ other_audio_path : str
+ Path to the other audio file.
+
+ Returns
+ -------
+ str
+ Path to the merged audio file.
+
+ Examples
+ --------
+ >>> processor = AudioProcessor("example.wav")
+ >>> merged_path = processor.merge_audio("other_example.wav")
+ >>> isinstance(merged_path, str)
+ True
+ """
+ if not isinstance(other_audio_path, str):
+ raise TypeError("Expected 'other_audio_path' to be a string.")
+
+ sound1 = AudioSegment.from_file(self.audio_path)
+ sound2 = AudioSegment.from_file(other_audio_path)
+ merged_audio = sound1 + sound2
+ merged_audio_path = os.path.join(self.temp_dir, "merged_audio.wav")
+ merged_audio.export(merged_audio_path, format="wav")
+ return merged_audio_path
+
+ def split_audio(
+ self, chunk_duration: Annotated[float, "Chunk duration in seconds"]
+ ) -> Annotated[List[str], "Paths to audio chunks"]:
+ """
+ Split the audio file into chunks of the specified duration.
+
+ Parameters
+ ----------
+ chunk_duration : float
+ Duration of each chunk in seconds.
+
+ Returns
+ -------
+ List[str]
+ Paths to the generated audio chunks.
+
+ Examples
+ --------
+ >>> processor = AudioProcessor("example.wav")
+ >>> chunks = processor.split_audio(10.0)
+ >>> isinstance(chunks, list)
+ True
+ """
+ if not isinstance(chunk_duration, (int, float)):
+ raise TypeError("Expected 'chunk_duration' to be a float or int.")
+
+ sound = AudioSegment.from_file(self.audio_path)
+ chunk_paths = []
+
+ for i in range(0, len(sound), int(chunk_duration * 1000)):
+ chunk = sound[i:i + int(chunk_duration * 1000)]
+ chunk_path = os.path.join(self.temp_dir, f"chunk_{i // 1000}.wav")
+ chunk.export(chunk_path, format="wav")
+ chunk_paths.append(chunk_path)
+
+ return chunk_paths
+
+ def create_manifest(
+ self,
+ manifest_path: Annotated[str, "Manifest file path"]
+ ) -> None:
+ """
+ Create a manifest file containing metadata about the audio file.
+
+ Parameters
+ ----------
+ manifest_path : str
+ Path to the manifest file.
+
+ Examples
+ --------
+ >>> processor = AudioProcessor("example.wav")
+ >>> processor.create_manifest("manifest.json")
+ """
+ duration = self.get_duration()
+ manifest_entry = {
+ "audio_filepath": self.audio_path,
+ "offset": 0,
+ "duration": duration,
+ "label": "infer",
+ "text": "-",
+ "rttm_filepath": None,
+ "uem_filepath": None
+ }
+ with open(manifest_path, 'w', encoding='utf-8') as f: # type: TextIOWrapper
+ json.dump(manifest_entry, f)
+
+
+class Transcriber:
+ """
+ A class for transcribing audio files using a pre-trained Whisper model.
+
+ Parameters
+ ----------
+ model_name : str, optional
+ Name of the model to load. Defaults to 'large-v3'.
+ device : str, optional
+ Device to use for model inference ('cpu' or 'cuda'). Defaults to 'cpu'.
+ compute_type : str, optional
+ Data type for model computation ('int8', 'float16', etc.). Defaults to 'int8'.
+
+ Attributes
+ ----------
+ model : faster_whisper.WhisperModel
+ Loaded Whisper model for transcription.
+ device : str
+ Device used for inference.
+
+ Methods
+ -------
+ transcribe(audio_path, language=None, suppress_numerals=False)
+ Transcribes the audio file into text.
+ """
+
+ def __init__(
+ self,
+ model_name: Annotated[str, "Name of the model to load"] = 'large-v3',
+ device: Annotated[str, "Device to use for model inference"] = 'cpu',
+ compute_type: Annotated[str, "Data type for model computation, e.g., 'int8' or 'float16'"] = 'int8'
+ ) -> None:
+ if not isinstance(model_name, str):
+ raise TypeError("Expected 'model_name' to be of type str")
+ if not isinstance(device, str):
+ raise TypeError("Expected 'device' to be of type str")
+ if not isinstance(compute_type, str):
+ raise TypeError("Expected 'compute_type' to be of type str")
+
+ self.device = device
+ self.model = faster_whisper.WhisperModel(
+ model_name, device=device, compute_type=compute_type
+ )
+
+ def transcribe(
+ self,
+ audio_path: Annotated[str, "Path to the audio file to transcribe"],
+ language: Annotated[Optional[str], "Language code for transcription, e.g., 'en' for English"] = None,
+ suppress_numerals: Annotated[bool, "Whether to suppress numerals in the transcription"] = False
+ ) -> Annotated[Tuple[str, dict], "Transcription text and additional information"]:
+ """
+ Transcribe an audio file into text.
+
+ Parameters
+ ----------
+ audio_path : str
+ Path to the audio file.
+ language : str, optional
+ Language code for transcription (e.g., 'en' for English).
+ suppress_numerals : bool, optional
+ Whether to suppress numerals in the transcription. Defaults to False.
+
+ Returns
+ -------
+ Tuple[str, dict]
+ The transcribed text and additional transcription metadata.
+
+ Examples
+ --------
+ >>> transcriber = Transcriber()
+ >>> text, information = transcriber.transcribe("example.wav")
+ >>> isinstance(text, str)
+ True
+ >>> isinstance(info, dict)
+ True
+ """
+ if not isinstance(audio_path, str):
+ raise TypeError("Expected 'audio_path' to be of type str")
+ if language is not None and not isinstance(language, str):
+ raise TypeError("Expected 'language' to be of type str if provided")
+ if not isinstance(suppress_numerals, bool):
+ raise TypeError("Expected 'suppress_numerals' to be of type bool")
+
+ audio_waveform = faster_whisper.decode_audio(audio_path)
+ suppress_tokens = [-1]
+ if suppress_numerals:
+ suppress_tokens = TokenizerUtils.find_numeral_symbol_tokens(
+ self.model.hf_tokenizer
+ )
+
+ transcript_segments, info = self.model.transcribe(
+ audio_waveform,
+ language=language,
+ suppress_tokens=suppress_tokens,
+ without_timestamps=True,
+ vad_filter=True,
+ log_progress=True,
+ )
+
+ transcript = ''.join(segment.text for segment in transcript_segments)
+ info = vars(info)
+
+ if self.device == 'cuda':
+ del self.model
+ torch.cuda.empty_cache()
+
+ print(transcript, info)
+
+ return transcript, info
+
+
+class PunctuationRestorer:
+ """
+ A class for restoring punctuation in transcribed text.
+
+ Parameters
+ ----------
+ language : str, optional
+ Language for punctuation restoration. Defaults to 'en'.
+
+ Attributes
+ ----------
+ language : str
+ Language used for punctuation restoration.
+ punct_model : PunctuationModel
+ Model for predicting punctuation.
+ supported_languages : List[str]
+ List of languages supported by the model.
+
+ Methods
+ -------
+ restore_punctuation(word_speaker_mapping)
+ Restores punctuation in the provided text based on word mappings.
+ """
+
+ def __init__(self, language: Annotated[str, "Language for punctuation restoration"] = 'en') -> None:
+ self.language = language
+ self.punct_model = PunctuationModel(model="kredor/punctuate-all")
+ self.supported_languages = [
+ "en", "fr", "de", "es", "it", "nl", "pt", "bg", "pl", "cs", "sk", "sl",
+ ]
+
+ def restore_punctuation(
+ self, word_speaker_mapping: Annotated[List[Dict], "List of word-speaker mappings"]
+ ) -> Annotated[List[Dict], "Word mappings with restored punctuation"]:
+ """
+ Restore punctuation for transcribed text.
+
+ Parameters
+ ----------
+ word_speaker_mapping : List[Dict]
+ List of dictionaries containing word and speaker mappings.
+
+ Returns
+ -------
+ List[Dict]
+ Updated list with punctuation restored.
+
+ Examples
+ --------
+ >>> restorer = PunctuationRestorer()
+ >>> mapping = [{"text": "hello"}, {"text": "world"}]
+ >>> result = restorer.restore_punctuation(mapping)
+ >>> isinstance(result, list)
+ True
+ >>> "text" in result[0]
+ True
+ """
+ if self.language not in self.supported_languages:
+ print(f"Punctuation restoration is not available for {self.language} language.")
+ return word_speaker_mapping
+
+ words_list = [word_dict["text"] for word_dict in word_speaker_mapping]
+ labeled_words = self.punct_model.predict(words_list)
+
+ ending_puncts = ".?!"
+ model_puncts = ".,;:!?"
+ is_acronym = lambda x: re.fullmatch(r"\b(?:[a-zA-Z]\.){2,}", x)
+
+ for word_dict, labeled_tuple in zip(word_speaker_mapping, labeled_words):
+ word = word_dict["text"]
+ if (
+ word
+ and labeled_tuple[1] in ending_puncts
+ and (word[-1] not in model_puncts or is_acronym(word))
+ ):
+ word += labeled_tuple[1]
+ word = word.rstrip(".") if word.endswith("..") else word
+ word_dict["text"] = word
+
+ return word_speaker_mapping
+
+
+if __name__ == "__main__":
+ sample_audio_path = "sample_audio.wav"
+ audio_processor_instance = AudioProcessor(sample_audio_path)
+
+ mono_audio_path = audio_processor_instance.convert_to_mono()
+ print(f"Mono audio file saved at: {mono_audio_path}")
+
+ audio_duration = audio_processor_instance.get_duration()
+ print(f"Audio duration: {audio_duration} seconds")
+
+ converted_audio_path = audio_processor_instance.change_format("mp3")
+ print(f"Converted audio file saved at: {converted_audio_path}")
+
+ audio_path_trimmed = audio_processor_instance.trim_audio(0.0, 10.0)
+ print(f"Trimmed audio file saved at: {audio_path_trimmed}")
+
+ volume_adjusted_audio_path = audio_processor_instance.adjust_volume(5.0)
+ print(f"Volume adjusted audio file saved at: {volume_adjusted_audio_path}")
+
+ additional_audio_path = "additional_audio.wav"
+ merged_audio_output_path = audio_processor_instance.merge_audio(additional_audio_path)
+ print(f"Merged audio file saved at: {merged_audio_output_path}")
+
+ audio_chunk_paths = audio_processor_instance.split_audio(10.0)
+ print(f"Audio chunks saved at: {audio_chunk_paths}")
+
+ output_manifest_path = "output_manifest.json"
+ audio_processor_instance.create_manifest(output_manifest_path)
+ print(f"Manifest file saved at: {output_manifest_path}")
+
+ transcriber_instance = Transcriber()
+ transcribed_text_output, transcription_metadata = transcriber_instance.transcribe(sample_audio_path)
+ print(f"Transcribed Text: {transcribed_text_output}")
+ print(f"Transcription Info: {transcription_metadata}")
+
+ word_mapping_example = [
+ {"text": "hello"},
+ {"text": "world"},
+ {"text": "this"},
+ {"text": "is"},
+ {"text": "a"},
+ {"text": "test"}
+ ]
+ punctuation_restorer_instance = PunctuationRestorer()
+ punctuation_restored_mapping = punctuation_restorer_instance.restore_punctuation(word_mapping_example)
+ print(f"Restored Mapping: {punctuation_restored_mapping}")
diff --git a/src/audio/utils.py b/src/audio/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..522dcc196a5887c90c2dfff9dba6777671a151f4
--- /dev/null
+++ b/src/audio/utils.py
@@ -0,0 +1,189 @@
+# Standard library imports
+import warnings
+from typing import List, Dict, Union
+
+
+class TokenizerUtils:
+ """
+ Utility class for handling token-related operations, particularly for identifying tokens
+ that contain numerals or specific symbols.
+
+ This class includes an __init__ method for completeness, but it does not perform any
+ initialization since the class is intended to be used as a static utility class.
+
+ Methods
+ -------
+ find_numeral_symbol_tokens(tokenizer)
+ Returns a list of token IDs that include numerals or symbols like '%', '$', or '£'.
+ """
+
+ def __init__(self):
+ """Initialize the TokenizerUtils class. This method is present for completeness."""
+ pass
+
+ @staticmethod
+ def find_numeral_symbol_tokens(tokenizer) -> List[int]:
+ """
+ Identifies tokens that contain numerals or certain symbols in the tokenizer vocabulary.
+
+ Parameters
+ ----------
+ tokenizer : Any
+ Tokenizer object with a 'get_vocab' method, typically from Hugging Face's tokenizer library.
+
+ Returns
+ -------
+ List[int]
+ List of token IDs for tokens that contain numerals or symbols.
+
+ Examples
+ --------
+ >>> TokenizerUtils.find_numeral_symbol_tokens(tokenizer)
+ [-1, 123, 456, 789]
+ """
+ numeral_symbol_tokens = [-1]
+ for token, token_id in tokenizer.get_vocab().items():
+ if any(c in "0123456789%$£" for c in token):
+ numeral_symbol_tokens.append(token_id)
+ return numeral_symbol_tokens
+
+
+class Formatter:
+ """
+ A utility class for formatting audio-related data, such as sentence-speaker mappings.
+
+ Methods
+ -------
+ add_indices_to_ssm(ssm: List[Dict], reference_length: int = None) -> List[Dict]:
+ Adds an index key to each item in the SSM list and checks for length mismatches with a reference.
+ format_ssm_as_dialogue(
+ ssm: List[Dict],
+ print_output: bool = False,
+ return_dict: bool = False
+ ) -> Union[str, Dict[str, List[str]]]:
+ Formats sentence-speaker mappings into a readable dialogue format and optionally prints it or returns a
+ dictionary grouped by speakers.
+ """
+
+ @staticmethod
+ def add_indices_to_ssm(ssm: List[Dict], reference_length: int = None) -> List[Dict]:
+ """
+ Adds an index key to each item in the SSM list and optionally checks for length mismatches with a reference
+ length.
+
+ Parameters
+ ----------
+ ssm : List[Dict]
+ The final SSM data.
+ reference_length : int, optional
+ A reference length to compare the SSM length against, default is None.
+
+ Returns
+ -------
+ List[Dict]
+ The SSM data with added index keys and any necessary adjustments.
+ """
+ if reference_length is not None and len(ssm) != reference_length:
+ warnings.warn(
+ f"Mismatch: SSM Length = {len(ssm)}, Reference Length = {reference_length}. "
+ f"Adjusting to match lengths...",
+ UserWarning,
+ )
+
+ for idx, item in enumerate(ssm):
+ item["index"] = idx
+
+ if reference_length is not None:
+ if len(ssm) > reference_length:
+ ssm = ssm[:reference_length]
+ elif len(ssm) < reference_length:
+ for i in range(len(ssm), reference_length):
+ ssm.append({
+ "index": i,
+ "speaker": "Unknown",
+ "start_time": None,
+ "end_time": None,
+ "text": "[Placeholder]"
+ })
+
+ return ssm
+
+ @staticmethod
+ def format_ssm_as_dialogue(
+ ssm: List[Dict],
+ print_output: bool = False,
+ return_dict: bool = False
+ ) -> Union[str, Dict[str, List[str]]]:
+ """
+ Formats the sentence-speaker mapping (ssm) as a dialogue and optionally prints the result or returns it as a
+ dictionary grouped by speakers.
+
+ Parameters
+ ----------
+ ssm : List[Dict]
+ List of sentences with speaker labels.
+ print_output : bool, optional
+ Whether to print the formatted dialogue, default is False.
+ return_dict : bool, optional
+ Whether to return the response as a dictionary grouped by speakers, default is False.
+
+ Returns
+ -------
+ Union[str, Dict[str, List[str]]]
+ If `return_dict` is True, returns a dictionary with speakers as keys and lists of their sentences as values.
+ Otherwise, returns the formatted dialogue string.
+ """
+ dialogue_dict: Dict[str, List[str]] = {}
+
+ for sentence in ssm:
+ speaker = sentence['speaker']
+ text = sentence['text'].strip()
+
+ if speaker in dialogue_dict:
+ dialogue_dict[speaker].append(text)
+ else:
+ dialogue_dict[speaker] = [text]
+
+ if print_output:
+ print("Formatted Dialogue:")
+ for speaker, texts in dialogue_dict.items():
+ for text in texts:
+ print(f"{speaker}: {text}")
+ print()
+
+ if return_dict:
+ return dialogue_dict
+
+ formatted_dialogue = "\n\n".join(
+ [f"{speaker}: {text}" for speaker, texts in dialogue_dict.items() for text in texts]
+ )
+ return formatted_dialogue
+
+
+if __name__ == "__main__":
+ # noinspection PyMissingOrEmptyDocstring
+ class DummyTokenizer:
+ @staticmethod
+ def get_vocab():
+ return {
+ "hello": 1,
+ "world": 2,
+ "100%": 3,
+ "$value": 4,
+ "item_123": 5,
+ "£price": 6
+ }
+
+
+ dummy_tokenizer = DummyTokenizer()
+ numeral_tokens = TokenizerUtils.find_numeral_symbol_tokens(dummy_tokenizer)
+ print(f"Numeral and symbol tokens: {numeral_tokens}")
+
+ speaker_sentence_mapping = [
+ {"speaker": "Speaker 1", "text": "Hello, how are you?"},
+ {"speaker": "Speaker 2", "text": "I'm fine, thank you! And you?"},
+ {"speaker": "Speaker 1", "text": "I'm doing great, thanks for asking."}
+ ]
+
+ formatted_dialogue_str = Formatter.format_ssm_as_dialogue(speaker_sentence_mapping, print_output=True)
+ print(f"Formatted Dialogue:\n{formatted_dialogue_str}")
diff --git a/src/db/__init__.py b/src/db/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/db/manager.py b/src/db/manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a0573a5f1eb02628cf82dad6421dd98ef4a9c9b
--- /dev/null
+++ b/src/db/manager.py
@@ -0,0 +1,149 @@
+# Standard library imports
+import sqlite3
+from typing import Annotated, List, Tuple, Optional
+
+
+class Database:
+ """
+ A class to interact with an SQLite database.
+
+ This class provides methods to fetch data, insert data, and handle specific
+ tasks like fetching or inserting topic IDs in a database.
+
+ Parameters
+ ----------
+ db_path : str
+ The path to the SQLite database file.
+
+ Attributes
+ ----------
+ db_path : str
+ The path to the SQLite database file.
+ """
+
+ def __init__(self, db_path: Annotated[str, "Path to the SQLite database"]):
+ """
+ Initializes the Database class with the provided database path.
+
+ Parameters
+ ----------
+ db_path : str
+ The path to the SQLite database file.
+ """
+ self.db_path = db_path
+
+ def fetch(
+ self,
+ sql_file_path: Annotated[str, "Path to the SQL file"]
+ ) -> Annotated[List[Tuple], "Results fetched from the query"]:
+ """
+ Executes a SELECT query from an SQL file and fetches the results.
+
+ Parameters
+ ----------
+ sql_file_path : str
+ Path to the SQL file containing the SELECT query.
+
+ Returns
+ -------
+ List[Tuple]
+ A list of tuples representing rows returned by the query.
+
+ Examples
+ --------
+ >>> db = Database("example.db")
+ >>> result = db.fetch("select_query.sql")
+ >>> print(results)
+ [(1, 'data1'), (2, 'data2')]
+ """
+ with open(sql_file_path, encoding='utf-8') as f:
+ query = f.read()
+
+ conn = sqlite3.connect(self.db_path)
+ cursor = conn.cursor()
+ cursor.execute(query)
+ results = cursor.fetchall()
+ conn.close()
+
+ return results
+
+ def insert(
+ self,
+ sql_file_path: Annotated[str, "Path to the SQL file"],
+ params: Optional[Annotated[Tuple, "Query parameters"]] = None
+ ) -> Annotated[int, "ID of the last inserted row"]:
+ """
+ Executes an INSERT query from an SQL file and returns the last row ID.
+
+ Parameters
+ ----------
+ sql_file_path : str
+ Path to the SQL file containing the INSERT query.
+ params : tuple, optional
+ Parameters for the query. Defaults to None.
+
+ Returns
+ -------
+ int
+ The ID of the last inserted row.
+
+ Examples
+ --------
+ >>> db = Database("example.db")
+ >>> last_id_ = db.insert("insert_query.sql", ("value1", "value2"))
+ >>> print(last_id)
+ 3
+ """
+ with open(sql_file_path, encoding='utf-8') as f:
+ query = f.read()
+
+ conn = sqlite3.connect(self.db_path)
+ cursor = conn.cursor()
+ if params is not None:
+ cursor.execute(query, params)
+ else:
+ cursor.execute(query)
+ conn.commit()
+ last_id = cursor.lastrowid
+ conn.close()
+ return last_id
+
+ def get_or_insert_topic_id(
+ self,
+ detected_topic: Annotated[str, "Topic to detect or insert"],
+ topics: Annotated[List[Tuple], "Existing topics with IDs"],
+ db_topic_insert_path: Annotated[str, "Path to the SQL file for inserting topics"]
+ ) -> Annotated[int, "Topic ID"]:
+ """
+ Fetches an existing topic ID or inserts a new one and returns its ID.
+
+ Parameters
+ ----------
+ detected_topic : str
+ The topic to be detected or inserted.
+ topics : List[Tuple[int, str]]
+ A list of existing topics as (id, name) tuples.
+ db_topic_insert_path : str
+ Path to the SQL file for inserting a new topic.
+
+ Returns
+ -------
+ int
+ The ID of the detected or newly inserted topic.
+
+ Examples
+ --------
+ >>> db = Database("example.db")
+ >>> topics_ = [(1, 'Python'), (2, 'SQL')]
+ >>> topic_id_ = db.get_or_insert_topic_id("AI", topics, "insert_topic.sql")
+ >>> print(topic_id)
+ 3
+ """
+ detected_topic_lower = detected_topic.lower()
+ topic_map = {t[1].lower(): t[0] for t in topics}
+
+ if detected_topic_lower in topic_map:
+ return topic_map[detected_topic_lower]
+ else:
+ topic_id = self.insert(db_topic_insert_path, (detected_topic,))
+ return topic_id
diff --git a/src/db/sql/AudioPropertiesInsert.sql b/src/db/sql/AudioPropertiesInsert.sql
new file mode 100644
index 0000000000000000000000000000000000000000..dbc7032704d0c6931655c9bdfddffe164ff4c7a6
--- /dev/null
+++ b/src/db/sql/AudioPropertiesInsert.sql
@@ -0,0 +1,34 @@
+INSERT INTO File (Name,
+ TopicID,
+ Extension,
+ Path,
+ Rate,
+ MinFreq,
+ MaxFreq,
+ BitDepth,
+ Channels,
+ Duration,
+ RMSLoudness,
+ ZeroCrossingRate,
+ SpectralCentroid,
+ EQ_20_250_Hz,
+ EQ_250_2000_Hz,
+ EQ_2000_6000_Hz,
+ EQ_6000_20000_Hz,
+ MFCC_1,
+ MFCC_2,
+ MFCC_3,
+ MFCC_4,
+ MFCC_5,
+ MFCC_6,
+ MFCC_7,
+ MFCC_8,
+ MFCC_9,
+ MFCC_10,
+ MFCC_11,
+ MFCC_12,
+ MFCC_13,
+ Summary,
+ Conflict,
+ Silence)
+VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
diff --git a/src/db/sql/Schema.sql b/src/db/sql/Schema.sql
new file mode 100644
index 0000000000000000000000000000000000000000..c489d6dbef0cc24ceaa104ac510bb3f18af8c7d7
--- /dev/null
+++ b/src/db/sql/Schema.sql
@@ -0,0 +1,62 @@
+CREATE TABLE Topic
+(
+ ID INTEGER PRIMARY KEY AUTOINCREMENT,
+ Name TEXT NOT NULL UNIQUE CHECK (length(Name) <= 500)
+);
+
+INSERT INTO Topic (Name)
+VALUES ('Unknown');
+
+CREATE TABLE File
+(
+ ID INTEGER PRIMARY KEY AUTOINCREMENT,
+ Name TEXT NOT NULL,
+ TopicID INTEGER,
+ Extension TEXT,
+ Path TEXT,
+ Rate INTEGER,
+ MinFreq REAL,
+ MaxFreq REAL,
+ BitDepth INTEGER,
+ Channels INTEGER,
+ Duration REAL,
+ RMSLoudness REAL,
+ ZeroCrossingRate REAL,
+ SpectralCentroid REAL,
+ EQ_20_250_Hz REAL,
+ EQ_250_2000_Hz REAL,
+ EQ_2000_6000_Hz REAL,
+ EQ_6000_20000_Hz REAL,
+ MFCC_1 REAL,
+ MFCC_2 REAL,
+ MFCC_3 REAL,
+ MFCC_4 REAL,
+ MFCC_5 REAL,
+ MFCC_6 REAL,
+ MFCC_7 REAL,
+ MFCC_8 REAL,
+ MFCC_9 REAL,
+ MFCC_10 REAL,
+ MFCC_11 REAL,
+ MFCC_12 REAL,
+ MFCC_13 REAL,
+ Summary TEXT NOT NULL,
+ Conflict INTEGER NOT NULL CHECK (Conflict IN (0, 1)),
+ Silence REAL NOT NULL,
+
+ FOREIGN KEY (TopicID) REFERENCES Topic (ID)
+);
+
+CREATE TABLE Utterance
+(
+ ID INTEGER PRIMARY KEY AUTOINCREMENT,
+ FileID INTEGER NOT NULL,
+ Speaker TEXT CHECK (Speaker IN ('Customer', 'CSR')) NOT NULL,
+ Sequence INTEGER NOT NULL,
+ StartTime REAL NOT NULL,
+ EndTime REAL NOT NULL,
+ Content TEXT NOT NULL,
+ Sentiment TEXT CHECK (Sentiment IN ('Neutral', 'Positive', 'Negative')) NOT NULL,
+ Profane INTEGER NOT NULL CHECK (Profane IN (0, 1)),
+ FOREIGN KEY (FileID) REFERENCES File (ID)
+);
diff --git a/src/db/sql/TopicFetch.sql b/src/db/sql/TopicFetch.sql
new file mode 100644
index 0000000000000000000000000000000000000000..0a617ea35b3ceb4f4095ea8201c6ff47725ce4a6
--- /dev/null
+++ b/src/db/sql/TopicFetch.sql
@@ -0,0 +1,2 @@
+SELECT ID, Name
+FROM Topic;
\ No newline at end of file
diff --git a/src/db/sql/TopicInsert.sql b/src/db/sql/TopicInsert.sql
new file mode 100644
index 0000000000000000000000000000000000000000..6acbda1fba88357905ea86aa0fd3b4988a7d9417
--- /dev/null
+++ b/src/db/sql/TopicInsert.sql
@@ -0,0 +1,2 @@
+INSERT INTO Topic (Name)
+VALUES (?)
\ No newline at end of file
diff --git a/src/db/sql/UtteranceInsert.sql b/src/db/sql/UtteranceInsert.sql
new file mode 100644
index 0000000000000000000000000000000000000000..afc35b64cb184eb5ca34d0738ae5b9256f6503f2
--- /dev/null
+++ b/src/db/sql/UtteranceInsert.sql
@@ -0,0 +1,9 @@
+INSERT INTO Utterance (FileID,
+ Speaker,
+ Sequence,
+ StartTime,
+ EndTime,
+ Content,
+ Sentiment,
+ Profane)
+VALUES (?, ?, ?, ?, ?, ?, ?, ?);
diff --git a/src/text/__init__.py b/src/text/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e293aa12ba61bec585e694487b62d06009ebdbd
--- /dev/null
+++ b/src/text/__init__.py
@@ -0,0 +1,5 @@
+from .model import ModelRegistry, LLaMAModel, OpenAIModel, AzureOpenAIModel
+
+ModelRegistry.register("llama", LLaMAModel)
+ModelRegistry.register("openai", OpenAIModel)
+ModelRegistry.register("azure_openai", AzureOpenAIModel)
\ No newline at end of file
diff --git a/src/text/llm.py b/src/text/llm.py
new file mode 100644
index 0000000000000000000000000000000000000000..2deddd259c0df2db9037a941b4525236d18d6f18
--- /dev/null
+++ b/src/text/llm.py
@@ -0,0 +1,333 @@
+# Standard library imports
+import re
+import json
+import asyncio
+from typing import Annotated, Optional, Dict, Any, List
+
+# Related third-party imports
+import yaml
+
+# Local imports
+from src.text.model import LanguageModelManager
+from src.audio.utils import Formatter
+
+
+class LLMOrchestrator:
+ """
+ A handler to perform specific LLM tasks such as classification or sentiment analysis.
+
+ This class uses a language model to perform different tasks by dynamically changing the prompt.
+
+ Parameters
+ ----------
+ config_path : str
+ Path to the configuration file for the language model manager.
+ prompt_config_path : str
+ Path to the configuration file containing prompts for different tasks.
+ model_id : str, optional
+ Identifier of the model to use. Defaults to "llama".
+ cache_size : int, optional
+ Cache size for the language model manager. Defaults to 2.
+
+ Attributes
+ ----------
+ manager : LanguageModelManager
+ An instance of LanguageModelManager for interacting with the model.
+ model_id : str
+ The identifier of the language model in use.
+ prompts : Dict[str, Dict[str, str]]
+ A dictionary containing prompts for different tasks.
+ """
+
+ def __init__(
+ self,
+ config_path: Annotated[str, "Path to the configuration file"],
+ prompt_config_path: Annotated[str, "Path to the prompt configuration file"],
+ model_id: Annotated[str, "Language model identifier"] = "llama",
+ cache_size: Annotated[int, "Cache size for the language model manager"] = 2,
+ ):
+ """
+ Initializes the LLMOrchestrator with a language model manager and loads prompts.
+
+ Parameters
+ ----------
+ config_path : str
+ Path to the configuration file for the language model manager.
+ prompt_config_path : str
+ Path to the configuration file containing prompts for different tasks.
+ model_id : str, optional
+ Identifier of the model to use. Defaults to "llama".
+ cache_size : int, optional
+ Cache size for the language model manager. Defaults to 2.
+ """
+ self.manager = LanguageModelManager(config_path=config_path, cache_size=cache_size)
+ self.model_id = model_id
+ self.prompts = self._load_prompts(prompt_config_path)
+
+ @staticmethod
+ def _load_prompts(prompt_config_path: str) -> Dict[str, Dict[str, str]]:
+ """
+ Loads prompts from the prompt configuration file.
+
+ Parameters
+ ----------
+ prompt_config_path : str
+ Path to the prompt configuration file.
+
+ Returns
+ -------
+ Dict[str, Dict[str, str]]
+ A dictionary containing prompts for different tasks.
+ """
+ with open(prompt_config_path, encoding='utf-8') as f:
+ prompts = yaml.safe_load(f)
+ return prompts
+
+ @staticmethod
+ def extract_json(
+ response: Annotated[str, "The response string to extract JSON from"]
+ ) -> Annotated[Optional[Dict[str, Any]], "Extracted JSON as a dictionary or None if not found"]:
+ """
+ Extracts the last valid JSON object from a given response string.
+
+ Parameters
+ ----------
+ response : str
+ The response string to extract JSON from.
+
+ Returns
+ -------
+ Optional[Dict[str, Any]]
+ The last valid JSON dictionary if successfully extracted and parsed, otherwise None.
+ """
+ json_pattern = r'\{(?:[^{}]|(?:\{[^{}]*\}))*\}'
+ matches = re.findall(json_pattern, response)
+ for match in reversed(matches):
+ try:
+ return json.loads(match)
+ except json.JSONDecodeError:
+ continue
+ return None
+
+ async def generate(
+ self,
+ prompt_name: Annotated[str, "The name of the prompt to use (e.g., 'Classification', 'SentimentAnalysis')"],
+ user_input: Annotated[Any, "The user's context or input data"],
+ system_input: Annotated[Optional[Any], "The system's context or input data"] = None
+ ) -> Annotated[Dict[str, Any], "Task results or error dictionary"]:
+ """
+ Performs the specified LLM task using the selected prompt, supporting both user and optional system contexts.
+ """
+ if prompt_name not in self.prompts:
+ return {"error": f"Prompt '{prompt_name}' is not defined in prompt.yaml."}
+
+ system_prompt_template = self.prompts[prompt_name].get('system', '')
+ user_prompt_template = self.prompts[prompt_name].get('user', '')
+
+ if not system_prompt_template or not user_prompt_template:
+ return {"error": f"Prompts for '{prompt_name}' are incomplete."}
+
+ formatted_user_input = Formatter.format_ssm_as_dialogue(user_input)
+
+ if system_input:
+ system_prompt = system_prompt_template.format(system_context=system_input)
+ else:
+ system_prompt = system_prompt_template
+
+ user_prompt = user_prompt_template.format(user_context=formatted_user_input)
+
+ messages = [
+ {"role": "system", "content": system_prompt},
+ {"role": "user", "content": user_prompt}
+ ]
+
+ response = await self.manager.generate(
+ model_id=self.model_id,
+ messages=messages,
+ max_new_tokens=10000,
+ )
+ print(response)
+
+ dict_obj = self.extract_json(response)
+ if dict_obj:
+ return dict_obj
+ else:
+ return {"error": "No valid JSON object found in the response."}
+
+
+class LLMResultHandler:
+ """
+ A handler class to process and validate the output from a Language Learning Model (LLM)
+ and format structured data.
+
+ This class ensures that the input data conforms to expected formats and applies fallback
+ mechanisms to maintain data integrity.
+
+ Methods
+ -------
+ validate_and_fallback(llm_result, ssm)
+ Validates the LLM result against structured speaker metadata and applies fallback.
+ _fallback(ssm)
+ Applies fallback formatting to the speaker data.
+ log_result(ssm, llm_result)
+ Logs the final processed data and the original LLM result.
+ """
+
+ def __init__(self):
+ """
+ Initializes the LLMResultHandler class.
+ """
+ pass
+
+ def validate_and_fallback(
+ self,
+ llm_result: Annotated[Dict[str, str], "LLM result with customer and CSR speaker identifiers"],
+ ssm: Annotated[List[Dict[str, Any]], "List of sentences with speaker metadata"]
+ ) -> Annotated[List[Dict[str, Any]], "Processed speaker metadata"]:
+ """
+ Validates the LLM result and applies corrections to the speaker metadata.
+
+ Parameters
+ ----------
+ llm_result : dict
+ A dictionary containing speaker identifiers for 'Customer' and 'CSR'.
+ ssm : list of dict
+ A list of dictionaries where each dictionary represents a sentence with
+ metadata, including the 'speaker'.
+
+ Returns
+ -------
+ list of dict
+ The processed speaker metadata with standardized speaker labels.
+
+ Examples
+ --------
+ >>> result = {"Customer": "Speaker 1", "CSR": "Speaker 2"}
+ >>> ssm_ = [{"speaker": "Speaker 1", "text": "Hello!"}, {"speaker": "Speaker 2", "text": "Hi!"}]
+ >>> handler = LLMResultHandler()
+ >>> handler.validate_and_fallback(llm_result, ssm)
+ [{'speaker': 'Customer', 'text': 'Hello!'}, {'speaker': 'CSR', 'text': 'Hi!'}]
+ """
+ if not isinstance(llm_result, dict):
+ return self._fallback(ssm)
+
+ if "Customer" not in llm_result or "CSR" not in llm_result:
+ return self._fallback(ssm)
+
+ customer_speaker = llm_result["Customer"]
+ csr_speaker = llm_result["CSR"]
+
+ speaker_pattern = r"^Speaker\s+\d+$"
+
+ if (not re.match(speaker_pattern, customer_speaker)) or (not re.match(speaker_pattern, csr_speaker)):
+ return self._fallback(ssm)
+
+ ssm_speakers = {sentence["speaker"] for sentence in ssm}
+ if customer_speaker not in ssm_speakers or csr_speaker not in ssm_speakers:
+ return self._fallback(ssm)
+
+ for sentence in ssm:
+ if sentence["speaker"] == csr_speaker:
+ sentence["speaker"] = "CSR"
+ elif sentence["speaker"] == customer_speaker:
+ sentence["speaker"] = "Customer"
+ else:
+ sentence["speaker"] = "Customer"
+
+ return ssm
+
+ @staticmethod
+ def _fallback(
+ ssm: Annotated[List[Dict[str, Any]], "List of sentences with speaker metadata"]
+ ) -> Annotated[List[Dict[str, Any]], "Fallback speaker metadata"]:
+ """
+ Applies fallback formatting to speaker metadata when validation fails.
+
+ Parameters
+ ----------
+ ssm : list of dict
+ A list of dictionaries representing sentences with speaker metadata.
+
+ Returns
+ -------
+ list of dict
+ The speaker metadata with fallback formatting applied.
+
+ Examples
+ --------
+ >>> ssm_ = [{"speaker": "Speaker 1", "text": "Hello!"}, {"speaker": "Speaker 2", "text": "Hi!"}]
+ >>> handler = LLMResultHandler()
+ >>> handler._fallback(ssm)
+ [{'speaker': 'CSR', 'text': 'Hello!'}, {'speaker': 'Customer', 'text': 'Hi!'}]
+ """
+ if len(ssm) > 0:
+ first_speaker = ssm[0]["speaker"]
+ for sentence in ssm:
+ if sentence["speaker"] == first_speaker:
+ sentence["speaker"] = "CSR"
+ else:
+ sentence["speaker"] = "Customer"
+ return ssm
+
+ @staticmethod
+ def log_result(
+ ssm: Annotated[List[Dict[str, Any]], "Final processed speaker metadata"],
+ llm_result: Annotated[Dict[str, str], "Original LLM result"]
+ ) -> None:
+ """
+ Logs the final processed speaker metadata and the original LLM result.
+
+ Parameters
+ ----------
+ ssm : list of dict
+ The processed speaker metadata.
+ llm_result : dict
+ The original LLM result.
+
+ Returns
+ -------
+ None
+
+ Examples
+ --------
+ >>> ssm_ = [{"speaker": "CSR", "text": "Hello!"}, {"speaker": "Customer", "text": "Hi!"}]
+ >>> result = {"Customer": "Speaker 1", "CSR": "Speaker 2"}
+ >>> handler = LLMResultHandler()
+ >>> handler.log_result(ssm, llm_result)
+ Final SSM: [{'speaker': 'CSR', 'text': 'Hello!'}, {'speaker': 'Customer', 'text': 'Hi!'}]
+ LLM Result: {'Customer': 'Speaker 1', 'CSR': 'Speaker 2'}
+ """
+ print("Final SSM:", ssm)
+ print("LLM Result:", llm_result)
+
+
+if __name__ == "__main__":
+ # noinspection PyMissingOrEmptyDocstring
+ async def main():
+ handler = LLMOrchestrator(
+ config_path="config/config.yaml",
+ prompt_config_path="config/prompt.yaml",
+ model_id="openai",
+ )
+
+ conversation = [
+ {"speaker": "Speaker 1", "text": "Hello, I need help with my order."},
+ {"speaker": "Speaker 0", "text": "Sure, I'd be happy to assist you."},
+ {"speaker": "Speaker 1", "text": "I haven't received it yet."},
+ {"speaker": "Speaker 0", "text": "Let me check the status for you."}
+ ]
+
+ speaker_roles = await handler.generate("Classification", conversation)
+ print("Speaker Roles:", speaker_roles)
+ print("Type:", type(speaker_roles))
+
+ sentiment_analyzer = LLMOrchestrator(
+ config_path="config/config.yaml",
+ prompt_config_path="config/prompt.yaml"
+ )
+
+ sentiment = await sentiment_analyzer.generate("SentimentAnalysis", conversation)
+ print("\nSentiment Analysis:", sentiment)
+
+
+ asyncio.run(main())
diff --git a/src/text/model.py b/src/text/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..63c0c5e9fa46c0b8b478cd3bf6e12244c0dc5890
--- /dev/null
+++ b/src/text/model.py
@@ -0,0 +1,552 @@
+# Standard library imports
+import os
+import json
+import asyncio
+from abc import ABC, abstractmethod
+from collections import OrderedDict
+from typing import Optional, Any, Annotated
+
+# Related third-party imports
+import yaml
+import torch
+import openai
+from openai import OpenAI
+from dotenv import load_dotenv
+from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
+
+load_dotenv()
+
+
+class LanguageModel(ABC):
+ """
+ Abstract base class for language models.
+
+ This class provides a common interface for language models with methods
+ to generate text and unload resources.
+
+ Parameters
+ ----------
+ config : dict
+ Configuration for the language model.
+ """
+
+ def __init__(self, config: Annotated[dict, "Configuration for the language model"]):
+ self.config = config
+
+ @abstractmethod
+ def generate(
+ self,
+ messages: Annotated[list, "List of message dictionaries"],
+ **kwargs: Annotated[Any, "Additional keyword arguments"]
+ ) -> Annotated[str, "Generated text"]:
+ """
+ Generate text based on the given input messages.
+
+ Parameters
+ ----------
+ messages : list
+ List of message dictionaries with 'role' and 'content'.
+ **kwargs : Any
+ Additional keyword arguments.
+
+ Returns
+ -------
+ str
+ Generated text output.
+ """
+ pass
+
+ def unload(self) -> Annotated[None, "Unload resources used by the language model"]:
+ """
+ Unload resources used by the language model.
+ """
+ pass
+
+
+class LLaMAModel(LanguageModel):
+ """
+ LLaMA language model implementation using Hugging Face Transformers.
+
+ Parameters
+ ----------
+ config : dict
+ Configuration for the LLaMA model.
+ """
+
+ def __init__(self, config: Annotated[dict, "Configuration for the LLaMA model"]):
+ super().__init__(config)
+ model_name = config['model_name']
+ compute_type = config.get('compute_type')
+ torch.cuda.empty_cache()
+
+ print(f"Loading LLaMA model: {model_name}")
+ print(f"CUDA available: {torch.cuda.is_available()}")
+ if torch.cuda.is_available():
+ print(f"CUDA Version: {torch.version.cuda}")
+ print(f"GPU: {torch.cuda.get_device_name(0)}")
+ else:
+ print("GPU not available, using CPU.")
+
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
+ self.model = AutoModelForCausalLM.from_pretrained(
+ model_name,
+ device_map="auto",
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() and compute_type == "float16" else torch.float32,
+ low_cpu_mem_usage=True
+ )
+ self.pipe = pipeline(
+ "text-generation",
+ model=self.model,
+ tokenizer=self.tokenizer,
+ device_map="auto",
+ )
+
+ def generate(
+ self,
+ messages: Annotated[list, "List of message dictionaries"],
+ max_new_tokens: Annotated[int, "Maximum number of new tokens to generate"] = 10000,
+ truncation: Annotated[bool, "Whether to truncate the input"] = True,
+ batch_size: Annotated[int, "Batch size for generation"] = 1,
+ pad_token_id: Annotated[Optional[int], "Padding token ID"] = None
+ ) -> Annotated[str, "Generated text"]:
+ """
+ Generate text based on input messages using the LLaMA model.
+
+ Parameters
+ ----------
+ messages : list
+ List of message dictionaries with 'role' and 'content'.
+ max_new_tokens : int, optional
+ Maximum number of tokens to generate. Default is 10000.
+ truncation : bool, optional
+ Whether to truncate the input. Default is True.
+ batch_size : int, optional
+ Batch size for generation. Default is 1.
+ pad_token_id : int, optional
+ Padding token ID. Defaults to the tokenizer's EOS token ID.
+
+ Returns
+ -------
+ str
+ Generated text.
+ """
+ prompt = self._format_messages_llama(messages)
+ output = self.pipe(
+ prompt,
+ max_new_tokens=max_new_tokens,
+ truncation=truncation,
+ batch_size=batch_size,
+ pad_token_id=pad_token_id if pad_token_id is not None else self.tokenizer.eos_token_id
+ )
+ return output[0]['generated_text']
+
+ @staticmethod
+ def _format_messages_llama(messages: Annotated[list, "List of message dictionaries"]) -> Annotated[
+ str, "Formatted prompt"]:
+ """
+ Format messages into a single prompt for LLaMA.
+
+ Parameters
+ ----------
+ messages : list
+ List of message dictionaries with 'role' and 'content'.
+
+ Returns
+ -------
+ str
+ Formatted prompt.
+ """
+ prompt = ""
+ for message in messages:
+ role = message.get("role", "").lower()
+ content = message.get("content", "")
+ if role == "system":
+ prompt += f"System: {content}\n"
+ elif role == "user":
+ prompt += f"User: {content}\n"
+ elif role == "assistant":
+ prompt += f"Assistant: {content}\n"
+ prompt += "Assistant:"
+ return prompt
+
+ def unload(self) -> Annotated[None, "Unload the LLaMA model and release resources"]:
+ """
+ Unload the LLaMA model and release resources.
+ """
+ del self.pipe
+ del self.model
+ del self.tokenizer
+ torch.cuda.empty_cache()
+ print(f"LLaMA model '{self.config['model_name']}' unloaded.")
+
+
+class OpenAIModel(LanguageModel):
+ """
+ OpenAI GPT model integration.
+
+ Parameters
+ ----------
+ config : dict
+ Configuration for the OpenAI model.
+ """
+
+ def __init__(self, config: Annotated[dict, "Configuration for the OpenAI model"]):
+ super().__init__(config)
+ openai_api_key = config.get('openai_api_key')
+ if not openai_api_key:
+ raise ValueError("OpenAI API key must be provided.")
+ self.client = OpenAI(api_key=openai_api_key)
+ self.model_name = config.get('model_name', 'gpt-4')
+
+ def generate(
+ self,
+ messages: Annotated[list, "List of message dictionaries"],
+ max_length: Annotated[int, "Maximum number of tokens for the output"] = 10000,
+ return_as_json: bool = False,
+ **kwargs: Annotated[Any, "Additional keyword arguments"]
+ ) -> Annotated[str, "Generated text"]:
+ """
+ Generate text using OpenAI's API.
+
+ Parameters
+ ----------
+ messages : list
+ List of message dictionaries with 'role' and 'content'.
+ max_length : int, optional
+ Maximum number of tokens for the output. Default is 10000.
+ return_as_json : bool, optional
+ If True, response_format={"type": "json_object"} parametresi eklenir ve dönen içerik
+ json.loads ile dict'e dönüştürülür. Varsayılan False'dur.
+ **kwargs : Any
+ Additional keyword arguments.
+
+ Returns
+ -------
+ str or dict
+ Generated text as a string if return_as_json=False.
+ If return_as_json=True and the response is in valid JSON format,
+ returns a dict.
+ """
+
+ create_kwargs = {
+ "model": self.model_name,
+ "messages": messages,
+ "max_tokens": max_length,
+ "temperature": kwargs.get('temperature', 0.7)
+ }
+
+ if return_as_json is True:
+ create_kwargs["response_format"] = {"type": "json_object"}
+
+ completion = self.client.chat.completions.create(**create_kwargs)
+ response_text = completion.choices[0].message.content
+
+ if return_as_json:
+ try:
+ return json.loads(response_text)
+ except json.JSONDecodeError:
+ return response_text
+
+ return response_text
+
+ def unload(self) -> Annotated[None, "Placeholder for OpenAI model unload (no local resources to release)"]:
+ """
+ Placeholder for OpenAI model unload (no local resources to release).
+ """
+ print(f"OpenAI model '{self.model_name}' unloaded.")
+
+
+class AzureOpenAIModel(LanguageModel):
+ """
+ Azure OpenAI model integration.
+
+ Parameters
+ ----------
+ config : dict
+ Configuration for the Azure OpenAI model.
+ """
+
+ def __init__(self, config: Annotated[dict, "Configuration for the Azure OpenAI model"]):
+ super().__init__(config)
+ self.model_name = config.get('model_name', 'gpt-4o')
+ self.api_key = config.get('azure_openai_api_key')
+ self.api_base = config.get('azure_openai_api_base')
+ self.api_version = config.get('azure_openai_api_version')
+
+ if not all([self.api_key, self.api_base, self.api_version]):
+ raise ValueError("Azure OpenAI API key, base, and version must be provided.")
+
+ openai.api_type = "azure"
+ openai.api_base = self.api_base
+ openai.api_version = self.api_version
+ openai.api_key = self.api_key
+
+ def generate(
+ self,
+ messages: Annotated[list, "List of message dictionaries"],
+ max_length: Annotated[int, "Maximum number of tokens for the output"] = 10000,
+ **kwargs: Annotated[Any, "Additional keyword arguments"]
+ ) -> Annotated[str, "Generated text"]:
+ """
+ Generate text using Azure OpenAI's API.
+
+ Parameters
+ ----------
+ messages : list
+ List of message dictionaries with 'role' and 'content'.
+ max_length : int, optional
+ Maximum number of tokens for the output. Default is 10000.
+ **kwargs : Any
+ Additional keyword arguments.
+
+ Returns
+ -------
+ str
+ Generated text.
+ """
+ response = openai.ChatCompletion.create(
+ deployment_id=self.model_name,
+ messages=messages,
+ max_tokens=max_length,
+ temperature=kwargs.get('temperature', 0.7)
+ )
+ return response.choices[0].message['content']
+
+ def unload(self) -> Annotated[None, "Placeholder for Azure OpenAI model unload (no local resources to release)"]:
+ """
+ Placeholder for Azure OpenAI model unload (no local resources to release).
+ """
+ print(f"Azure OpenAI model '{self.model_name}' unloaded.")
+
+
+class ModelRegistry:
+ """
+ Registry to manage language model class registrations.
+
+ This class allows dynamic registration and retrieval of model classes.
+ """
+ _registry = {}
+
+ @classmethod
+ def register(
+ cls,
+ model_id: Annotated[str, "Unique identifier for the model"],
+ model_class: Annotated[type, "The class to register"]
+ ) -> Annotated[None, "Registration completed"]:
+ """
+ Register a model class with the registry.
+
+ Parameters
+ ----------
+ model_id : str
+ Unique identifier for the model class.
+ model_class : type
+ The class to register.
+ """
+ cls._registry[model_id.lower()] = model_class
+
+ @classmethod
+ def get_model_class(cls, model_id: Annotated[str, "Unique identifier for the model"]) -> Annotated[
+ type, "Model class"]:
+ """
+ Retrieve a model class by its unique identifier.
+
+ Parameters
+ ----------
+ model_id : str
+ Unique identifier for the model class.
+
+ Returns
+ -------
+ type
+ The model class corresponding to the identifier.
+
+ Raises
+ ------
+ ValueError
+ If the model ID is not registered.
+ """
+ model_class = cls._registry.get(model_id.lower())
+ if not model_class:
+ raise ValueError(f"No class found for model ID '{model_id}'.")
+ return model_class
+
+
+class ModelFactory:
+ """
+ Factory to create language model instances.
+
+ This class uses the `ModelRegistry` to create instances of registered model classes.
+ """
+
+ @staticmethod
+ def create_model(
+ model_id: Annotated[str, "Unique identifier for the model"],
+ config: Annotated[dict, "Configuration for the model"]
+ ) -> Annotated[LanguageModel, "Instance of the language model"]:
+ """
+ Create a language model instance based on its unique identifier.
+
+ Parameters
+ ----------
+ model_id : str
+ Unique identifier for the model.
+ config : dict
+ Configuration for the model.
+
+ Returns
+ -------
+ LanguageModel
+ An instance of the language model.
+ """
+ model_class = ModelRegistry.get_model_class(model_id)
+ return model_class(config)
+
+
+class LanguageModelManager:
+ """
+ Manages multiple language models with caching and async support.
+
+ Parameters
+ ----------
+ config_path : str
+ Path to the YAML configuration file.
+ cache_size : int, optional
+ Maximum number of models to cache. Default is 10.
+ """
+
+ def __init__(
+ self,
+ config_path: Annotated[str, "Path to the YAML configuration file"],
+ cache_size: Annotated[int, "Maximum number of models to cache"] = 10
+ ):
+ self.config_path = config_path
+ self.cache_size = cache_size
+ self.models = OrderedDict()
+ self.full_config = self._load_full_config(config_path)
+ self.runtime_config = self.full_config.get('runtime', {})
+ self.models_config = self.full_config.get('models', {})
+ self.lock = asyncio.Lock()
+
+ @staticmethod
+ def _load_full_config(config_path: Annotated[str, "Path to the YAML configuration file"]) -> Annotated[
+ dict, "Parsed configuration"]:
+ """
+ Load and parse the YAML configuration file.
+
+ Parameters
+ ----------
+ config_path : str
+ Path to the YAML file.
+
+ Returns
+ -------
+ dict
+ Parsed configuration.
+ """
+ with open(config_path, encoding='utf-8') as f:
+ config = yaml.safe_load(f)
+
+ for model_id, model_config in config.get('models', {}).items():
+ for key, value in model_config.items():
+ if isinstance(value, str) and value.startswith("${") and value.endswith("}"):
+ env_var = value[2:-1]
+ model_config[key] = os.getenv(env_var, "")
+ return config
+
+ async def get_model(
+ self,
+ model_id: Annotated[str, "Unique identifier for the model"]
+ ) -> Annotated[LanguageModel, "Instance of the language model"]:
+ """
+ Retrieve a language model instance from the cache or create a new one.
+
+ Parameters
+ ----------
+ model_id : str
+ Unique identifier for the model.
+
+ Returns
+ -------
+ LanguageModel
+ An instance of the language model.
+
+ Raises
+ ------
+ ValueError
+ If the model ID is not found in the configuration.
+ """
+ async with self.lock:
+ torch.cuda.empty_cache()
+ if model_id in self.models:
+ self.models.move_to_end(model_id)
+ return self.models[model_id]
+ else:
+ config = self.models_config.get(model_id)
+ if not config:
+ raise ValueError(f"Model ID '{model_id}' not found in configuration.")
+ config['compute_type'] = self.runtime_config.get('compute_type', 'float16')
+ model = ModelFactory.create_model(model_id, config)
+ self.models[model_id] = model
+ if len(self.models) > self.cache_size:
+ oldest_model_id, oldest_model = self.models.popitem(last=False)
+ oldest_model.unload()
+ return model
+
+ async def generate(
+ self,
+ model_id: Annotated[str, "Unique identifier for the model"],
+ messages: Annotated[list, "List of message dictionaries"],
+ **kwargs: Annotated[Any, "Additional keyword arguments"]
+ ) -> Annotated[Optional[str], "Generated text or None if an error occurs"]:
+ """
+ Generate text using a specific language model.
+
+ Parameters
+ ----------
+ model_id : str
+ Unique identifier for the model.
+ messages : list
+ List of message dictionaries with 'role' and 'content'.
+ **kwargs : Any
+ Additional keyword arguments.
+
+ Returns
+ -------
+ str or None
+ Generated text or None if an error occurs.
+ """
+ try:
+ model = await self.get_model(model_id)
+ return model.generate(messages, **kwargs)
+ except Exception as e:
+ print(f"Error with model ({model_id}): {e}")
+ return None
+
+ def unload_all(self) -> Annotated[None, "Unload all cached models and release resources"]:
+ """
+ Unload all cached models and release resources.
+ """
+ for model in self.models.values():
+ model.unload()
+ self.models.clear()
+ print("All models have been unloaded.")
+
+
+if __name__ == "__main__":
+ # noinspection PyMissingOrEmptyDocstring
+ async def main():
+ config_path = 'config/config.yaml'
+
+ manager = LanguageModelManager(config_path=config_path, cache_size=11)
+
+ llama_model_id = "llama"
+ llama_messages = [
+ {"role": "system", "content": "You are a pirate. Answer accordingly!"},
+ {"role": "user", "content": "Who are you?"}
+ ]
+ llama_output = await manager.generate(model_id=llama_model_id, messages=llama_messages)
+ print(f"LLaMA Model Output: {llama_output}")
+
+
+ asyncio.run(main())
diff --git a/src/text/prompt.py b/src/text/prompt.py
new file mode 100644
index 0000000000000000000000000000000000000000..025984867a84fbb603e1a3b3d011bc3214c97cd9
--- /dev/null
+++ b/src/text/prompt.py
@@ -0,0 +1,139 @@
+# Standard library imports
+import os
+from typing import Annotated, Dict, Any
+
+# Related third-party imports
+import yaml
+
+
+class PromptManager:
+ """
+ A class to manage prompts loaded from a YAML configuration file.
+
+ This class provides methods to load prompts from a YAML file and retrieve
+ them with optional formatting.
+
+ Parameters
+ ----------
+ config_path : str, optional
+ Path to the YAML configuration file. If not provided, the default
+ path is set to `../../config/prompts.yaml` relative to the script's
+ directory.
+
+ Attributes
+ ----------
+ config_path : str
+ Path to the YAML configuration file.
+ prompts : dict
+ Dictionary of prompts loaded from the YAML file.
+ """
+
+ def __init__(self, config_path: Annotated[str, "Path to the YAML configuration file"] = None):
+ """
+ Initializes the PromptManager with a specified or default configuration path.
+
+ Parameters
+ ----------
+ config_path : str, optional
+ Path to the YAML configuration file. Defaults to None.
+ """
+ self.config_path = config_path or os.path.join(
+ os.path.dirname(os.path.abspath(__file__)), "../../config/prompts.yaml"
+ )
+ self.prompts = self._load_prompts()
+
+ def _load_prompts(self) -> Annotated[Dict[str, Any], "Loaded prompts from YAML file"]:
+ """
+ Load prompts from the YAML file.
+
+ This method reads the YAML file specified by `self.config_path` and parses its contents.
+
+ Returns
+ -------
+ dict
+ Dictionary containing the prompts.
+
+ Raises
+ ------
+ FileNotFoundError
+ If the specified YAML file does not exist.
+
+ Examples
+ --------
+ >>> manager = PromptManager("config/prompts.yaml")
+ >>> prompts = manager._load_prompts()
+ """
+ if not os.path.exists(self.config_path):
+ raise FileNotFoundError(f"YAML file not found: {self.config_path}")
+
+ with open(self.config_path) as file:
+ loaded_prompts = yaml.safe_load(file)
+ if not isinstance(loaded_prompts, dict):
+ raise TypeError(f"Expected dictionary from YAML, got {type(loaded_prompts).__name__}.")
+ return loaded_prompts
+
+ def get_prompt(
+ self,
+ prompt_name: Annotated[str, "Name of the prompt to retrieve"],
+ **kwargs: Annotated[dict, "Keyword arguments for formatting the prompt"]
+ ) -> Annotated[Any, "Formatted prompt (str or dict)"]:
+ """
+ Retrieve and format a prompt by its name.
+
+ This method fetches the prompt template identified by `prompt_name` from the loaded prompts
+ and formats any placeholders within the prompt using the provided keyword arguments.
+
+ Parameters
+ ----------
+ prompt_name : str
+ Name of the prompt to retrieve.
+ **kwargs : Any
+ Keyword arguments to format the prompt strings.
+
+ Raises
+ ------
+ ValueError
+ If the specified prompt name does not exist in the loaded prompts.
+
+ Returns
+ -------
+ dict
+ Dictionary containing the formatted prompt with all placeholders replaced by provided values.
+ """
+ if not isinstance(self.prompts, dict):
+ raise TypeError(f"Internal error: self.prompts is not a dictionary but {type(self.prompts).__name__}.")
+
+ if prompt_name not in self.prompts:
+ raise ValueError(f"Prompt '{prompt_name}' not found.")
+
+ prompt = self.prompts[prompt_name]
+
+ if isinstance(prompt, dict):
+ formatted_prompt = {}
+ for key, value in prompt.items():
+ if isinstance(value, str):
+ formatted_prompt[key] = value.format(**kwargs)
+ else:
+ formatted_prompt[key] = value
+ return formatted_prompt
+
+ if isinstance(prompt, str):
+ return prompt.format(**kwargs)
+
+ raise TypeError(f"Unexpected prompt type: {type(prompt).__name__}")
+
+
+if __name__ == "__main__":
+ try:
+ prompt_manager = PromptManager()
+
+ formatted_prompt_ = prompt_manager.get_prompt(
+ "greeting",
+ name="Ahmet",
+ day="Pazartesi"
+ )
+
+ print("Formatted Prompt:", formatted_prompt_)
+
+ except Exception as e:
+ print(f"Error: {e}")
\ No newline at end of file
diff --git a/src/text/utils.py b/src/text/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..360a3e98f64b553b8f29a7b6faf7b7dabb21d63d
--- /dev/null
+++ b/src/text/utils.py
@@ -0,0 +1,239 @@
+# Standard library imports
+from typing import Annotated, Dict, Any, List
+
+
+class Annotator:
+ """
+ A class to annotate a structured sentiment model (SSM) with various
+ attributes such as sentiment, profanity, summary, conflict, and topic.
+
+ Parameters
+ ----------
+ ssm : list of dict
+ A list of dictionaries representing the structured sentiment model.
+
+ Attributes
+ ----------
+ ssm : list of dict
+ The structured sentiment model to be annotated.
+ global_summary : str
+ The global summary of the annotations.
+ global_conflict : bool
+ The global conflict status of the annotations.
+ global_topic : str
+ The global topic of the annotations.
+ """
+
+ def __init__(self, ssm: Annotated[List[Dict[str, Any]], "Structured Sentiment Model"]):
+ """
+ Initializes the Annotator class with the provided SSM.
+
+ Parameters
+ ----------
+ ssm : list of dict
+ A list of dictionaries representing the structured sentiment model.
+ """
+ self.ssm = ssm
+ self.global_summary = ""
+ self.global_conflict = False
+ self.global_topic = "Unknown"
+
+ def add_sentiment(
+ self,
+ sentiment_results: Annotated[Dict[str, Any], "Sentiment analysis results"]
+ ):
+ """
+ Adds sentiment data to the SSM.
+
+ Parameters
+ ----------
+ sentiment_results : dict
+ A dictionary containing sentiment analysis results, including
+ a "sentiments" key with a list of sentiment dictionaries.
+
+ Examples
+ --------
+ >>> annotator = Annotator([{"text": "example"}])
+ >>> results = {"sentiments": [{"index": 0, "sentiment": "Positive"}]}
+ >>> annotator.add_sentiment(sentiment_results)
+ """
+ if len(sentiment_results["sentiments"]) != len(self.ssm):
+ print(f"Mismatch: SSM Length = {len(self.ssm)}, "
+ f"Sentiments Length = {len(sentiment_results['sentiments'])}")
+ print("Adjusting to match lengths...")
+
+ if len(sentiment_results["sentiments"]) < len(self.ssm):
+ for idx in range(len(sentiment_results["sentiments"]), len(self.ssm)):
+ sentiment_results["sentiments"].append({"index": idx, "sentiment": "Neutral"})
+
+ elif len(sentiment_results["sentiments"]) > len(self.ssm):
+ sentiment_results["sentiments"] = sentiment_results["sentiments"][:len(self.ssm)]
+
+ for sentiment_data in sentiment_results["sentiments"]:
+ idx = sentiment_data["index"]
+ if idx < len(self.ssm):
+ self.ssm[idx]["sentiment"] = sentiment_data["sentiment"]
+ else:
+ print(f"Skipping sentiment data at index {idx}, out of range.")
+
+ def add_profanity(
+ self,
+ profane_results: Annotated[Dict[str, Any], "Profanity detection results"]
+ ) -> List[Dict[str, Any]]:
+ """
+ Adds profanity data to the SSM.
+
+ Parameters
+ ----------
+ profane_results : dict
+ A dictionary containing profanity detection results, including
+ a "profanity" key with a list of profanity dictionaries.
+
+ Returns
+ -------
+ list of dict
+ The updated SSM with profanity annotations.
+
+ Examples
+ --------
+ >>> annotator = Annotator([{"text": "example"}])
+ >>> results = {"profanity": [{"index": 0, "profane": True}]}
+ >>> annotator.add_profanity(profane_results)
+ """
+ if "profanity" not in profane_results:
+ print("Warning: 'profanity' key is missing in profane_results.")
+ return self.ssm
+
+ if len(profane_results["profanity"]) != len(self.ssm):
+ print(f"Mismatch: SSM Length = {len(self.ssm)}, "
+ f"Profanity Length = {len(profane_results['profanity'])}")
+ print("Adjusting to match lengths...")
+
+ if len(profane_results["profanity"]) < len(self.ssm):
+ for idx in range(len(profane_results["profanity"]), len(self.ssm)):
+ profane_results["profanity"].append({"index": idx, "profane": False})
+
+ elif len(profane_results["profanity"]) > len(self.ssm):
+ profane_results["profanity"] = profane_results["profanity"][:len(self.ssm)]
+
+ for profanity_data in profane_results["profanity"]:
+ idx = profanity_data["index"]
+ if idx < len(self.ssm):
+ self.ssm[idx]["profane"] = profanity_data["profane"]
+ else:
+ print(f"Skipping profanity data at index {idx}, out of range.")
+
+ return self.ssm
+
+ def add_summary(
+ self,
+ summary_result: Annotated[Dict[str, str], "Summary results"]
+ ) -> Dict[str, Any]:
+ """
+ Adds a global summary to the annotations.
+
+ Parameters
+ ----------
+ summary_result : dict
+ A dictionary containing a "summary" key with the summary text.
+
+ Returns
+ -------
+ dict
+ The updated SSM and global summary.
+
+ Examples
+ --------
+ >>> annotator = Annotator([{"text": "example"}])
+ >>> result = {"summary": "This is a summary."}
+ >>> annotator.add_summary(summary_result)
+ """
+ if not summary_result or "summary" not in summary_result:
+ print("Warning: 'summary' key is missing in summary_result.")
+ return {"ssm": self.ssm, "summary": self.global_summary}
+
+ self.global_summary = summary_result["summary"]
+ return {"ssm": self.ssm, "summary": self.global_summary}
+
+ def add_conflict(
+ self,
+ conflict_result: Annotated[Dict[str, bool], "Conflict detection results"]
+ ) -> Dict[str, Any]:
+ """
+ Adds a global conflict status to the annotations.
+
+ Parameters
+ ----------
+ conflict_result : dict
+ A dictionary containing a "conflict" key with a boolean value.
+
+ Returns
+ -------
+ dict
+ The updated SSM and global conflict status.
+
+ Examples
+ --------
+ >>> annotator = Annotator([{"text": "example"}])
+ >>> result = {"conflict": True}
+ >>> annotator.add_conflict(conflict_result)
+ """
+ if not conflict_result or "conflict" not in conflict_result:
+ print("Warning: 'conflict' key is missing in conflict_result.")
+ return {"ssm": self.ssm, "conflict": self.global_conflict}
+
+ self.global_conflict = conflict_result["conflict"]
+ return {"ssm": self.ssm, "conflict": self.global_conflict}
+
+ def add_topic(
+ self,
+ topic_result: Annotated[Dict[str, str], "Topic detection results"]
+ ) -> Dict[str, Any]:
+ """
+ Adds a global topic to the annotations.
+
+ Parameters
+ ----------
+ topic_result : dict
+ A dictionary containing a "topic" key with the topic name.
+
+ Returns
+ -------
+ dict
+ The updated SSM and global topic.
+
+ Examples
+ --------
+ >>> annotator = Annotator([{"text": "example"}])
+ >>> result = {"topic": "Technology"}
+ >>> annotator.add_topic(topic_result)
+ """
+ if not topic_result or "topic" not in topic_result:
+ print("Warning: 'topic' key is missing in topic_result.")
+ return {"ssm": self.ssm, "topic": self.global_topic}
+
+ self.global_topic = topic_result["topic"]
+ return {"ssm": self.ssm, "topic": self.global_topic}
+
+ def finalize(self) -> Dict[str, Any]:
+ """
+ Finalizes the annotations by returning the updated SSM along with
+ global annotations for summary, conflict, and topic.
+
+ Returns
+ -------
+ dict
+ A dictionary containing the updated SSM and global annotations.
+
+ Examples
+ --------
+ >>> annotator = Annotator([{"text": "example"}])
+ >>> annotator.finalize()
+ {'ssm': [{'text': 'example'}], 'summary': '', 'conflict': False, 'topic': 'Unknown'}
+ """
+ return {
+ "ssm": self.ssm,
+ "summary": self.global_summary,
+ "conflict": self.global_conflict,
+ "topic": self.global_topic
+ }
diff --git a/src/utils/__init__.py b/src/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/utils/utils.py b/src/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..df760f7ea62b89fcf0b93a9a382c961523ca4e9d
--- /dev/null
+++ b/src/utils/utils.py
@@ -0,0 +1,257 @@
+# Standard library imports
+import os
+import shutil
+import asyncio
+import logging
+from typing import Annotated
+
+# Related third-party imports
+from watchdog.observers import Observer
+from watchdog.events import FileSystemEventHandler
+
+
+class Logger:
+ """
+ Logger class to simplify logging setup and usage.
+
+ This class provides a reusable logging utility that supports customizable
+ log levels, message formatting, and optional console printing.
+
+ Parameters
+ ----------
+ name : str, optional
+ The name of the logger. Defaults to "CallyticsLogger".
+ level : int, optional
+ The logging level (e.g., logging.INFO, logging.DEBUG). Defaults to logging.INFO.
+
+ Attributes
+ ----------
+ logger : logging.Logger
+ The configured logger instance.
+ """
+
+ def __init__(
+ self,
+ name: Annotated[str, "The name of the logger"] = "CallyticsLogger",
+ level: Annotated[int, "The logging level (e.g., logging.INFO)"] = logging.INFO,
+ ) -> None:
+ """
+ Initialize the Logger instance with a specified name and logging level.
+
+ Parameters
+ ----------
+ name : str, optional
+ The name of the logger. Defaults to "CallyticsLogger".
+ level : int, optional
+ The logging level. Defaults to logging.INFO.
+
+ Returns
+ -------
+ None
+ """
+ self.logger = logging.getLogger(name)
+ self.logger.setLevel(level)
+ if not self.logger.hasHandlers():
+ handler = logging.StreamHandler()
+ formatter = logging.Formatter(
+ "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
+ )
+ handler.setFormatter(formatter)
+ self.logger.addHandler(handler)
+
+ def log(
+ self,
+ message: Annotated[str, "The message to log"],
+ level: Annotated[int, "The logging level (e.g., logging.INFO)"] = logging.INFO,
+ print_output: Annotated[bool, "Whether to print the message to console"] = True,
+ ) -> None:
+ """
+ Log a message at the specified logging level, with optional console output.
+
+ Parameters
+ ----------
+ message : str
+ The message to log.
+ level : int, optional
+ The logging level. Defaults to logging.INFO.
+ print_output : bool, optional
+ Whether to print the message to the console. Defaults to True.
+
+ Returns
+ -------
+ None
+
+ Examples
+ --------
+ >>> logger = Logger("ExampleLogger", logging.DEBUG)
+ >>> logger.log("This is a debug message", logging.DEBUG)
+ This is a debug message
+ """
+ if print_output:
+ print(message)
+ self.logger.log(level, message)
+
+
+class Cleaner:
+ """
+ A utility class for cleaning up files, directories, or symbolic links at one or multiple specified paths.
+ """
+
+ def __init__(self) -> None:
+ """
+ Initialize the Cleaner class. This method is present for completeness.
+
+ Returns
+ -------
+ None
+ """
+ pass
+
+ @staticmethod
+ def cleanup(*paths: str) -> None:
+ """
+ Deletes files, directories, or symbolic links at the specified paths.
+
+ Parameters
+ ----------
+ *paths : str
+ One or more paths to the files or directories to delete.
+
+ Returns
+ -------
+ None
+
+ Notes
+ -----
+ - Each path will be checked individually.
+ - If the path is a file or symbolic link, it will be deleted.
+ - If the path is a directory, the entire directory and its contents will be deleted.
+ - If the path does not exist or is neither a file nor a directory, a message will be printed.
+
+ Examples
+ --------
+ >>> Cleaner.cleanup("/path/to/file", "/path/to/directory")
+ File /path/to/file has been deleted.
+ Directory /path/to/directory has been deleted.
+ """
+ for path in paths:
+ if os.path.isfile(path) or os.path.islink(path):
+ os.remove(path)
+ print(f"File {path} has been deleted.")
+ elif os.path.isdir(path):
+ shutil.rmtree(path)
+ print(f"Directory {path} has been deleted.")
+ else:
+ print(f"Path {path} is not a file or directory.")
+
+
+class Watcher(FileSystemEventHandler):
+ """
+ A file system event handler that watches a directory for newly created audio files and triggers a callback.
+
+ The Watcher class extends FileSystemEventHandler to monitor a directory for new audio files with specific
+ extensions (.mp3, .wav, .flac). When a new file is detected, it invokes an asynchronous callback function,
+ allowing users to integrate custom processing logic (e.g., transcription, diarization) immediately after
+ the file is created.
+
+ Parameters
+ ----------
+ callback : callable
+ An asynchronous callback function that accepts a single argument (the path to the newly created audio file).
+ """
+
+ def __init__(self, callback) -> None:
+ """
+ Initialize the Watcher with a specified asynchronous callback.
+
+ Parameters
+ ----------
+ callback : callable
+ An async function that will be called with the path of the newly created audio file.
+
+ Returns
+ -------
+ None
+ """
+ super().__init__()
+ self.callback = callback
+
+ def on_created(self, event) -> None:
+ """
+ Handle the creation of a new file event.
+
+ If the newly created file is an audio file with supported extensions (.mp3, .wav, .flac),
+ this method triggers the asynchronous callback function to process the file.
+
+ Parameters
+ ----------
+ event : FileSystemEvent
+ The event object representing the file system change.
+
+ Returns
+ -------
+ None
+ """
+ if not event.is_directory and event.src_path.lower().endswith(('.mp3', '.wav', '.flac')):
+ print(f"New audio file detected: {event.src_path}")
+ asyncio.run(self.callback(event.src_path))
+
+ @classmethod
+ def start_watcher(cls, directory: str, callback) -> None:
+ """
+ Starts the file system watcher on the specified directory.
+
+ If the directory does not exist, it will be created. The Watcher will monitor the directory for newly
+ created audio files and trigger the provided callback function.
+
+ Parameters
+ ----------
+ directory : str
+ The path of the directory to watch.
+ callback : callable
+ An asynchronous callback function that accepts the path to a newly created audio file.
+
+ Returns
+ -------
+ None
+ """
+ if not os.path.exists(directory):
+ os.makedirs(directory, exist_ok=True)
+ print(f"Directory '{directory}' created.")
+
+ observer = Observer()
+ event_handler = cls(callback)
+ observer.schedule(event_handler, directory, recursive=False)
+ observer.start()
+ print(f"Watching directory: {directory}")
+
+ import time
+ try:
+ while True:
+ # Senkron bekleme
+ time.sleep(1)
+ except KeyboardInterrupt:
+ observer.stop()
+ observer.join()
+
+
+if __name__ == "__main__":
+ path_to_file = "sample_file.txt"
+ path_to_directory = "sample_directory"
+
+ with open(path_to_file, "w") as file:
+ file.write("This is a sample file for testing the Cleaner class.")
+
+ os.makedirs(path_to_directory, exist_ok=True)
+
+ print(f"Attempting to delete file: {path_to_file}")
+ Cleaner.cleanup(path_to_file)
+
+ print(f"Attempting to delete directory: {path_to_directory}")
+ Cleaner.cleanup(path_to_directory)
+
+ non_existent_path = "non_existent_path"
+ print(f"Attempting to delete non-existent path: {non_existent_path}")
+ Cleaner.cleanup(non_existent_path)
+
+ Cleaner.cleanup(path_to_file, path_to_directory)