Spaces:
Running
Running
File size: 5,776 Bytes
698c6c9 a2c10b6 698c6c9 a2c10b6 698c6c9 a2c10b6 698c6c9 a2c10b6 698c6c9 a2c10b6 698c6c9 a2c10b6 698c6c9 a2c10b6 698c6c9 72801e6 698c6c9 a2c10b6 698c6c9 72801e6 698c6c9 a2c10b6 698c6c9 a2c10b6 698c6c9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
from fastapi import FastAPI, Request, Form, File, UploadFile
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from main import process_query
from voice.speech_to_text import SpeechToText
import os
import logging
import tempfile
import shutil
import wave
from pydub import AudioSegment
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
# Mount static files (for CSS, JS, etc.)
app.mount("/static", StaticFiles(directory="static"), name="static")
# Set up templates
templates = Jinja2Templates(directory="templates")
# Vosk model path
vosk_model_path = "./vosk-model-small-en-us-0.15"
# Initialize SpeechToText
stt = SpeechToText(model_path=vosk_model_path)
def is_valid_wav(file_path):
"""Check if the file is a valid WAV file."""
try:
with wave.open(file_path, 'rb') as wf:
logger.info(f"WAV file validated: channels={wf.getnchannels()}, rate={wf.getframerate()}, frames={wf.getnframes()}")
return True
except wave.Error as e:
logger.error(f"Invalid WAV file: {str(e)}")
return False
def convert_to_wav(input_path, output_path):
"""Convert input audio to WAV format (16kHz, mono, 16-bit PCM)."""
try:
audio = AudioSegment.from_file(input_path)
audio = audio.set_channels(1).set_frame_rate(16000).set_sample_width(2) # 16-bit PCM
audio.export(output_path, format="wav")
logger.info(f"Converted audio to WAV: {output_path}, size: {os.path.getsize(output_path)} bytes")
return True
except Exception as e:
logger.error(f"Failed to convert audio to WAV: {str(e)}")
return False
@app.get("/", response_class=HTMLResponse)
async def get_index(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.post("/upload_audio", response_class=HTMLResponse)
async def upload_audio(request: Request, audio_file: UploadFile = File(...)):
# Use a temporary directory to store the uploaded audio file
try:
with tempfile.TemporaryDirectory() as temp_dir:
# Save the uploaded file (could be WebM, OGG, etc.)
input_file_path = os.path.join(temp_dir, "input_audio.webm")
with open(input_file_path, "wb") as buffer:
shutil.copyfileobj(audio_file.file, buffer)
file_size = os.path.getsize(input_file_path) if os.path.exists(input_file_path) else 0
logger.info(f"Uploaded audio saved to {input_file_path}, size: {file_size} bytes")
# Verify file was saved
if not os.path.exists(input_file_path) or file_size == 0:
logger.error("Uploaded audio file was not saved correctly")
return templates.TemplateResponse("index.html", {
"request": request,
"error": "Failed to save uploaded audio file."
})
# Convert to WAV
wav_file_path = os.path.join(temp_dir, "temp_audio.wav")
if not convert_to_wav(input_file_path, wav_file_path):
logger.error("Audio conversion to WAV failed")
return templates.TemplateResponse("index.html", {
"request": request,
"error": "Failed to convert audio to WAV format."
})
# Validate WAV file
if not is_valid_wav(wav_file_path):
logger.error("Converted file is not a valid WAV file")
return templates.TemplateResponse("index.html", {
"request": request,
"error": "Converted audio is not a valid WAV file."
})
# Transcribe the WAV file
try:
text = stt.transcribe_audio(wav_file_path)
logger.info(f"Transcription result: '{text}'")
if not text:
logger.warning("Transcription returned no text")
return templates.TemplateResponse("index.html", {
"request": request,
"error": "Could not understand the audio. Please try speaking clearly."
})
return templates.TemplateResponse("index.html", {
"request": request,
"transcribed_text": text
})
except Exception as e:
logger.error(f"Transcription error: {str(e)}")
return templates.TemplateResponse("index.html", {
"request": request,
"error": f"Transcription error: {str(e)}"
})
except Exception as e:
logger.error(f"Error processing uploaded audio: {str(e)}")
return templates.TemplateResponse("index.html", {
"request": request,
"error": f"Error processing audio: {str(e)}"
})
@app.post("/query", response_class=HTMLResponse)
async def handle_query(request: Request, query_text: str = Form(...), use_retriever: str = Form("no")):
use_retriever = use_retriever.lower() in ["yes", "y"]
result = await process_query(vosk_model_path, query_text=query_text, use_retriever=use_retriever)
return templates.TemplateResponse("index.html", {
"request": request,
"User_Query": query_text,
"Intent": result["intent"],
"Entities": result["entities"],
"API_Response": result["base_response"],
"RAG_Response": result["retriever_response"],
"Web_Search_Response": result["web_search_response"],
"Final_Response": result["final_response"],
"Error": result["error"]
})
|