gradio-transcript-mcp / transcription_tool.py
Bismay
Initial commit
83a4e82
from smolagents import Tool
import os
import tempfile
import shutil
import torch
import subprocess
from transcription import run_whisper_transcription
from logging_config import logger
from ffmpeg_setup import ensure_ffmpeg_in_path
class TranscriptTool(Tool):
name = "TranscriptTool"
description = """
A smolagent tool for transcribing audio and video files into text. This tool utilises Whisper for transcription
and ffmpeg for media conversion, enabling agents to process multimedia inputs into text. The tool supports robust
file handling, including format conversion to WAV and dynamic device selection for optimal performance.
"""
inputs = {
"file_path": {
"type": "string",
"description": "Path to the audio or video file for transcription."
}
}
output_type = "string"
def __init__(self, audio_directory=None):
super().__init__()
ensure_ffmpeg_in_path()
self.audio_directory = audio_directory or os.getcwd()
def locate_audio_file(self, file_name):
for root, _, files in os.walk(self.audio_directory):
if file_name in files:
return os.path.join(root, file_name)
return None
def convert_audio_to_wav(self, input_file: str, output_file: str, ffmpeg_path: str) -> str:
logger.info(f"Converting {input_file} to WAV format: {output_file}")
cmd = [
ffmpeg_path,
"-y", # Overwrite output files without asking
"-i", input_file,
"-ar", "16000", # Set audio sampling rate to 16kHz
"-ac", "1", # Set number of audio channels to mono
output_file
]
try:
subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
logger.info("Audio conversion to WAV completed successfully.")
return output_file
except subprocess.CalledProcessError as e:
ffmpeg_error = e.stderr.decode()
logger.error(f"ffmpeg error: {ffmpeg_error}")
raise RuntimeError("Failed to convert audio to WAV.") from e
def forward(self, file_path: str) -> str:
try:
# Locate the file if it does not exist
logger.info(f"Attempting to transcribe file: {file_path}")
if not os.path.exists(file_path):
file_name = os.path.basename(file_path)
file_path = self.locate_audio_file(file_name)
if not file_path:
logger.error(f"File '{file_name}' not found in '{self.audio_directory}'.")
return f"Error: File '{file_name}' not found in '{self.audio_directory}'."
with tempfile.TemporaryDirectory() as tmpdir:
# Copy file to temp dir
filename = os.path.basename(file_path)
input_file_path = os.path.join(tmpdir, filename)
shutil.copy(file_path, input_file_path)
logger.info(f"Copied input file to temporary directory: {input_file_path}")
# Check if already WAV, otherwise convert
if input_file_path.lower().endswith(".wav"):
logger.info(f"File {filename} is already in WAV format. Skipping conversion.")
transcription_input_path = input_file_path
else:
# Convert to wav
wav_file_path = os.path.join(tmpdir, "converted_audio.wav")
ffmpeg_path = shutil.which("ffmpeg")
if not ffmpeg_path:
logger.error("ffmpeg is not accessible in PATH.")
raise RuntimeError("ffmpeg is not accessible in PATH.")
self.convert_audio_to_wav(input_file_path, wav_file_path, ffmpeg_path)
transcription_input_path = wav_file_path
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device for transcription: {device}")
# Transcribe audio
transcription_generator = run_whisper_transcription(transcription_input_path, device)
# The generator now yields only the result string (transcription or error)
for result_string in transcription_generator:
return result_string # Return the first (and only) result
except Exception as e:
logger.error(f"An unexpected error occurred in TranscriptTool: {str(e)}")
return f"An error occurred: {str(e)}"