Spaces:
Running
Running
# Standard library imports | |
import os | |
import asyncio | |
# Related third-party imports | |
import gradio as gr | |
from omegaconf import OmegaConf | |
from nemo.collections.asr.models.msdd_models import NeuralDiarizer | |
from huggingface_hub import login | |
# Local imports | |
from src.audio.utils import Formatter | |
from src.audio.metrics import SilenceStats | |
from src.audio.error import DialogueDetecting | |
from src.audio.alignment import ForcedAligner | |
from src.audio.effect import DemucsVocalSeparator | |
from src.audio.preprocessing import SpeechEnhancement | |
from src.audio.io import SpeakerTimestampReader, TranscriptWriter | |
from src.audio.analysis import WordSpeakerMapper, SentenceSpeakerMapper, Audio | |
from src.audio.processing import AudioProcessor, Transcriber, PunctuationRestorer | |
from src.text.utils import Annotator | |
from src.text.llm import LLMOrchestrator, LLMResultHandler | |
from src.utils.utils import Cleaner | |
async def main(audio_file_path: str): | |
""" | |
Process an audio file to perform diarization, transcription, punctuation restoration, | |
and speaker role classification. | |
Parameters | |
---------- | |
audio_file_path : str | |
The path to the input audio file to be processed. | |
Returns | |
------- | |
dict | |
final_output | |
""" | |
# Paths | |
config_nemo = "config/nemo/diar_infer_telephonic.yaml" | |
manifest_path = ".temp/manifest.json" | |
temp_dir = ".temp" | |
rttm_file_path = os.path.join(temp_dir, "pred_rttms", "mono_file.rttm") | |
transcript_output_path = ".temp/output.txt" | |
srt_output_path = ".temp/output.srt" | |
config_path = "config/config.yaml" | |
prompt_path = "config/prompt.yaml" | |
# Configuration | |
config = OmegaConf.load(config_path) | |
device = config.runtime.device | |
compute_type = config.runtime.compute_type | |
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = config.runtime.cuda_alloc_conf | |
hf_token = os.getenv("HF_TOKEN") | |
login(token=hf_token) | |
# Initialize Classes | |
dialogue_detector = DialogueDetecting(delete_original=True) | |
enhancer = SpeechEnhancement(config_path=config_path, output_dir=temp_dir) | |
separator = DemucsVocalSeparator() | |
processor = AudioProcessor(audio_path=audio_file_path, temp_dir=temp_dir) | |
transcriber = Transcriber(device=device, compute_type=compute_type) | |
aligner = ForcedAligner(device=device) | |
llm_handler = LLMOrchestrator(config_path=config_path, prompt_config_path=prompt_path, model_id="openai") | |
llm_result_handler = LLMResultHandler() | |
cleaner = Cleaner() | |
formatter = Formatter() | |
# Step 1: Detect Dialogue | |
has_dialogue = dialogue_detector.process(audio_file_path) | |
if not has_dialogue: | |
return {"error": "No dialogue detected in this audio."} | |
# Step 2: Speech Enhancement | |
audio_path = enhancer.enhance_audio( | |
input_path=audio_file_path, | |
output_path=os.path.join(temp_dir, "enhanced.wav"), | |
noise_threshold=0.0001, | |
verbose=True | |
) | |
# Step 3: Vocal Separation | |
vocal_path = separator.separate_vocals(audio_file=audio_path, output_dir=temp_dir) | |
# Step 4: Transcription | |
transcript, info = transcriber.transcribe(audio_path=vocal_path) | |
detected_language = info["language"] | |
# Step 5: Forced Alignment | |
word_timestamps = aligner.align( | |
audio_path=vocal_path, | |
transcript=transcript, | |
language=detected_language | |
) | |
# Step 6: Diarization | |
processor.audio_path = vocal_path | |
mono_audio_path = processor.convert_to_mono() | |
processor.audio_path = mono_audio_path | |
processor.create_manifest(manifest_path) | |
cfg = OmegaConf.load(config_nemo) | |
cfg.diarizer.manifest_filepath = manifest_path | |
cfg.diarizer.out_dir = temp_dir | |
msdd_model = NeuralDiarizer(cfg=cfg) | |
msdd_model.diarize() | |
# Step 7: Processing Transcript | |
# Step 7.1: Speaker Timestamps | |
speaker_reader = SpeakerTimestampReader(rttm_path=rttm_file_path) | |
speaker_ts = speaker_reader.read_speaker_timestamps() | |
# Step 7.2: Mapping Words | |
word_speaker_mapper = WordSpeakerMapper(word_timestamps, speaker_ts) | |
wsm = word_speaker_mapper.get_words_speaker_mapping() | |
# Step 7.3: Punctuation Restoration | |
punct_restorer = PunctuationRestorer(language=detected_language) | |
wsm = punct_restorer.restore_punctuation(wsm) | |
word_speaker_mapper.word_speaker_mapping = wsm | |
word_speaker_mapper.realign_with_punctuation() | |
wsm = word_speaker_mapper.word_speaker_mapping | |
# Step 7.4: Mapping Sentences | |
sentence_mapper = SentenceSpeakerMapper() | |
ssm = sentence_mapper.get_sentences_speaker_mapping(wsm) | |
# Step 8 (Optional): Write Transcript and SRT Files | |
writer = TranscriptWriter() | |
writer.write_transcript(ssm, transcript_output_path) | |
writer.write_srt(ssm, srt_output_path) | |
# Step 9: Classify Speaker Roles | |
speaker_roles = await llm_handler.generate("Classification", ssm) | |
# Step 9.1: LLM results validate and fallback | |
ssm = llm_result_handler.validate_and_fallback(speaker_roles, ssm) | |
llm_result_handler.log_result(ssm, speaker_roles) | |
# Step 10: Sentiment Analysis | |
ssm_with_indices = formatter.add_indices_to_ssm(ssm) | |
annotator = Annotator(ssm_with_indices) | |
sentiment_results = await llm_handler.generate("SentimentAnalysis", user_input=ssm) | |
annotator.add_sentiment(sentiment_results) | |
# Step 11: Profanity Word Detection | |
profane_results = await llm_handler.generate("ProfanityWordDetection", user_input=ssm) | |
annotator.add_profanity(profane_results) | |
# Step 12: Summary | |
summary_result = await llm_handler.generate("Summary", user_input=ssm) | |
annotator.add_summary(summary_result) | |
# Step 13: Conflict Detection | |
conflict_result = await llm_handler.generate("ConflictDetection", user_input=ssm) | |
annotator.add_conflict(conflict_result) | |
# Step 14: Topic Detection | |
topics = [ | |
"Complaint", | |
"Technical Support", | |
"Billing", | |
"Order Status", | |
] | |
topic_result = await llm_handler.generate( | |
"TopicDetection", | |
user_input=ssm, | |
system_input=topics | |
) | |
annotator.add_topic(topic_result) | |
final_output = annotator.finalize() | |
# Step 15: Total Silence Calculation | |
stats = SilenceStats.from_segments(final_output["ssm"]) | |
t_std = stats.threshold_std(factor=0.99) | |
final_output["silence"] = t_std | |
print("Final_Output:", final_output) | |
# Step 16: Clean Up | |
cleaner.cleanup(temp_dir, audio_file_path) | |
return final_output | |
def process_audio(uploaded_audio): | |
""" | |
Synchronous wrapper for Gradio. | |
1. Save the incoming audio to a temporary file. | |
2. Run the `main` pipeline (async) via `asyncio.run`. | |
3. Return the result so Gradio can display it. | |
""" | |
if uploaded_audio is None: | |
return {"error": "No audio provided."} | |
in_file_path = uploaded_audio | |
try: | |
result = asyncio.run(main(in_file_path)) | |
return result | |
except Exception as e: | |
return {"error": str(e)} | |
def transform_output_to_tables(final_output: dict): | |
""" | |
Helper function to convert data into a table view. | |
Transforms data inside `final_output` into two separate tables. | |
Parameters | |
---------- | |
final_output : dict | |
Dictionary containing processed results. | |
Returns | |
------- | |
tuple | |
Returns two lists as `(ssm_data, file_data)`. | |
""" | |
if "error" in final_output: | |
return [], [] | |
# Utterance Table | |
ssm_data = [] | |
if "ssm" in final_output: | |
for item in final_output["ssm"]: | |
ssm_data.append([ | |
item.get("speaker", ""), | |
item.get("start_time", ""), | |
item.get("end_time", ""), | |
item.get("text", ""), | |
item.get("index", ""), | |
item.get("sentiment", ""), | |
item.get("profane", "") | |
]) | |
# File Table | |
file_data = [] | |
for key in ["summary", "conflict", "topic", "silence"]: | |
file_data.append([key, final_output.get(key, "")]) | |
return ssm_data, file_data | |
with gr.Blocks() as demo: | |
gr.Markdown("Callytics Demo") | |
with gr.Row(): | |
audio_input = gr.Audio(type="filepath", label="Upload your audio") | |
submit_btn = gr.Button("Process") | |
with gr.Row(): | |
utterance_table = gr.Dataframe( | |
headers=["Speaker", "Start Time", "End Time", "Text", "Index", "Sentiment", "Profane"], | |
label="Utterance Table" | |
) | |
with gr.Row(): | |
file_table = gr.Dataframe( | |
headers=["Key", "Value"], | |
label="File Table" | |
) | |
output_display = gr.JSON(label="Final Output (JSON)") | |
gr.Examples( | |
examples=[ | |
[".data/example/tr.mp3"], | |
[".data/example/en.mp3"], | |
[".data/example/jp.mp3"], | |
[".data/example/fr.mp3"], | |
[".data/example/de.mp3"], | |
], | |
inputs=audio_input, | |
outputs=[utterance_table, file_table, output_display], | |
label="Example Call Center Call" | |
) | |
def process_and_show_tables(uploaded_audio): | |
""" | |
Calls the main processing function `process_audio` and returns data suitable for the table. | |
""" | |
final_output = process_audio(uploaded_audio) | |
ssm_data, file_data = transform_output_to_tables(final_output) | |
return ssm_data, file_data, final_output | |
submit_btn.click( | |
fn=process_and_show_tables, | |
inputs=audio_input, | |
outputs=[utterance_table, file_table, output_display] | |
) | |
if __name__ == "__main__": | |
demo.launch() | |