from smolagents import Tool import os import tempfile import shutil import torch import subprocess from transcription import run_whisper_transcription from logging_config import logger from ffmpeg_setup import ensure_ffmpeg_in_path class TranscriptTool(Tool): name = "TranscriptTool" description = """ A smolagent tool for transcribing audio and video files into text. This tool utilises Whisper for transcription and ffmpeg for media conversion, enabling agents to process multimedia inputs into text. The tool supports robust file handling, including format conversion to WAV and dynamic device selection for optimal performance. """ inputs = { "file_path": { "type": "string", "description": "Path to the audio or video file for transcription." } } output_type = "string" def __init__(self, audio_directory=None): super().__init__() ensure_ffmpeg_in_path() self.audio_directory = audio_directory or os.getcwd() def locate_audio_file(self, file_name): for root, _, files in os.walk(self.audio_directory): if file_name in files: return os.path.join(root, file_name) return None def convert_audio_to_wav(self, input_file: str, output_file: str, ffmpeg_path: str) -> str: logger.info(f"Converting {input_file} to WAV format: {output_file}") cmd = [ ffmpeg_path, "-y", # Overwrite output files without asking "-i", input_file, "-ar", "16000", # Set audio sampling rate to 16kHz "-ac", "1", # Set number of audio channels to mono output_file ] try: subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) logger.info("Audio conversion to WAV completed successfully.") return output_file except subprocess.CalledProcessError as e: ffmpeg_error = e.stderr.decode() logger.error(f"ffmpeg error: {ffmpeg_error}") raise RuntimeError("Failed to convert audio to WAV.") from e def forward(self, file_path: str) -> str: try: # Locate the file if it does not exist logger.info(f"Attempting to transcribe file: {file_path}") if not os.path.exists(file_path): file_name = os.path.basename(file_path) file_path = self.locate_audio_file(file_name) if not file_path: logger.error(f"File '{file_name}' not found in '{self.audio_directory}'.") return f"Error: File '{file_name}' not found in '{self.audio_directory}'." with tempfile.TemporaryDirectory() as tmpdir: # Copy file to temp dir filename = os.path.basename(file_path) input_file_path = os.path.join(tmpdir, filename) shutil.copy(file_path, input_file_path) logger.info(f"Copied input file to temporary directory: {input_file_path}") # Check if already WAV, otherwise convert if input_file_path.lower().endswith(".wav"): logger.info(f"File {filename} is already in WAV format. Skipping conversion.") transcription_input_path = input_file_path else: # Convert to wav wav_file_path = os.path.join(tmpdir, "converted_audio.wav") ffmpeg_path = shutil.which("ffmpeg") if not ffmpeg_path: logger.error("ffmpeg is not accessible in PATH.") raise RuntimeError("ffmpeg is not accessible in PATH.") self.convert_audio_to_wav(input_file_path, wav_file_path, ffmpeg_path) transcription_input_path = wav_file_path device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Using device for transcription: {device}") # Transcribe audio transcription_generator = run_whisper_transcription(transcription_input_path, device) # The generator now yields only the result string (transcription or error) for result_string in transcription_generator: return result_string # Return the first (and only) result except Exception as e: logger.error(f"An unexpected error occurred in TranscriptTool: {str(e)}") return f"An error occurred: {str(e)}"