File size: 5,776 Bytes
698c6c9
a2c10b6
 
 
 
 
 
 
698c6c9
 
 
 
a2c10b6
 
 
 
 
 
 
 
 
 
 
 
 
 
698c6c9
 
a2c10b6
 
 
 
698c6c9
 
a2c10b6
698c6c9
 
 
 
 
 
 
 
 
a2c10b6
698c6c9
 
 
 
 
a2c10b6
698c6c9
 
a2c10b6
 
 
 
 
698c6c9
 
 
 
 
 
72801e6
698c6c9
 
 
 
 
 
 
 
a2c10b6
 
698c6c9
 
 
 
72801e6
698c6c9
 
 
 
 
a2c10b6
698c6c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2c10b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
698c6c9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
from fastapi import FastAPI, Request, Form, File, UploadFile
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from main import process_query
from voice.speech_to_text import SpeechToText
import os
import logging
import tempfile
import shutil
import wave
from pydub import AudioSegment

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI()

# Mount static files (for CSS, JS, etc.)
app.mount("/static", StaticFiles(directory="static"), name="static")

# Set up templates
templates = Jinja2Templates(directory="templates")


# Vosk model path
vosk_model_path = "./vosk-model-small-en-us-0.15"

# Initialize SpeechToText
stt = SpeechToText(model_path=vosk_model_path)

def is_valid_wav(file_path):
    """Check if the file is a valid WAV file."""
    try:
        with wave.open(file_path, 'rb') as wf:
            logger.info(f"WAV file validated: channels={wf.getnchannels()}, rate={wf.getframerate()}, frames={wf.getnframes()}")
            return True
    except wave.Error as e:
        logger.error(f"Invalid WAV file: {str(e)}")
        return False

def convert_to_wav(input_path, output_path):
    """Convert input audio to WAV format (16kHz, mono, 16-bit PCM)."""
    try:
        audio = AudioSegment.from_file(input_path)
        audio = audio.set_channels(1).set_frame_rate(16000).set_sample_width(2)  # 16-bit PCM
        audio.export(output_path, format="wav")
        logger.info(f"Converted audio to WAV: {output_path}, size: {os.path.getsize(output_path)} bytes")
        return True
    except Exception as e:
        logger.error(f"Failed to convert audio to WAV: {str(e)}")
        return False

@app.get("/", response_class=HTMLResponse)
async def get_index(request: Request):
    return templates.TemplateResponse("index.html", {"request": request})

@app.post("/upload_audio", response_class=HTMLResponse)
async def upload_audio(request: Request, audio_file: UploadFile = File(...)):
    # Use a temporary directory to store the uploaded audio file
    try:
        with tempfile.TemporaryDirectory() as temp_dir:
            # Save the uploaded file (could be WebM, OGG, etc.)
            input_file_path = os.path.join(temp_dir, "input_audio.webm")
            with open(input_file_path, "wb") as buffer:
                shutil.copyfileobj(audio_file.file, buffer)
                file_size = os.path.getsize(input_file_path) if os.path.exists(input_file_path) else 0
                logger.info(f"Uploaded audio saved to {input_file_path}, size: {file_size} bytes")

            # Verify file was saved
            if not os.path.exists(input_file_path) or file_size == 0:
                logger.error("Uploaded audio file was not saved correctly")
                return templates.TemplateResponse("index.html", {
                    "request": request,
                    "error": "Failed to save uploaded audio file."
                })

            # Convert to WAV
            wav_file_path = os.path.join(temp_dir, "temp_audio.wav")
            if not convert_to_wav(input_file_path, wav_file_path):
                logger.error("Audio conversion to WAV failed")
                return templates.TemplateResponse("index.html", {
                    "request": request,
                    "error": "Failed to convert audio to WAV format."
                })

            # Validate WAV file
            if not is_valid_wav(wav_file_path):
                logger.error("Converted file is not a valid WAV file")
                return templates.TemplateResponse("index.html", {
                    "request": request,
                    "error": "Converted audio is not a valid WAV file."
                })

            # Transcribe the WAV file
            try:
                text = stt.transcribe_audio(wav_file_path)
                logger.info(f"Transcription result: '{text}'")
                if not text:
                    logger.warning("Transcription returned no text")
                    return templates.TemplateResponse("index.html", {
                        "request": request,
                        "error": "Could not understand the audio. Please try speaking clearly."
                    })
                return templates.TemplateResponse("index.html", {
                    "request": request,
                    "transcribed_text": text
                })
            except Exception as e:
                logger.error(f"Transcription error: {str(e)}")
                return templates.TemplateResponse("index.html", {
                    "request": request,
                    "error": f"Transcription error: {str(e)}"
                })

    except Exception as e:
        logger.error(f"Error processing uploaded audio: {str(e)}")
        return templates.TemplateResponse("index.html", {
            "request": request,
            "error": f"Error processing audio: {str(e)}"
        })

@app.post("/query", response_class=HTMLResponse)
async def handle_query(request: Request, query_text: str = Form(...), use_retriever: str = Form("no")):
    use_retriever = use_retriever.lower() in ["yes", "y"]
    result = await process_query(vosk_model_path, query_text=query_text, use_retriever=use_retriever)

    return templates.TemplateResponse("index.html", {
        "request": request,
        "User_Query": query_text,
        "Intent": result["intent"],
        "Entities": result["entities"],
        "API_Response": result["base_response"],
        "RAG_Response": result["retriever_response"],
        "Web_Search_Response": result["web_search_response"],
        "Final_Response": result["final_response"],
        "Error": result["error"]
    })