File size: 1,901 Bytes
9333f04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import os
from typing import Optional
from langchain.tools import BaseTool
import requests
from pydub import AudioSegment
import tempfile

class AudioToTextTool(BaseTool):
    name = "audio_to_text"
    description = "Convertit un fichier audio en texte en utilisant l'API Hugging Face"
    
    def _run(self, audio_path: str) -> str:
        """Convertit un fichier audio en texte"""
        try:
            # Vérifier si le fichier existe
            if not os.path.exists(audio_path):
                return f"Erreur: Le fichier {audio_path} n'existe pas"
            
            # Convertir le fichier audio en format WAV si nécessaire
            audio = AudioSegment.from_file(audio_path)
            with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
                audio.export(temp_file.name, format="wav")
                temp_path = temp_file.name
            
            # Appeler l'API Hugging Face pour la transcription
            API_URL = "https://api-inference.huggingface.co/models/facebook/wav2vec2-large-960h-lv60-self"
            headers = {"Authorization": f"Bearer {os.getenv('HUGGINGFACE_API_KEY')}"}
            
            with open(temp_path, "rb") as f:
                data = f.read()
            
            response = requests.post(API_URL, headers=headers, data=data)
            
            # Nettoyer le fichier temporaire
            os.unlink(temp_path)
            
            if response.status_code != 200:
                return f"Erreur lors de la transcription: {response.text}"
            
            return response.json().get("text", "Aucun texte transcrit")
            
        except Exception as e:
            return f"Erreur lors de la conversion audio en texte: {str(e)}"
    
    async def _arun(self, audio_path: str) -> str:
        """Version asynchrone de l'outil"""
        return self._run(audio_path)