Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, Depends, HTTPException, Header, File, UploadFile | |
from pydantic import BaseModel | |
from typing import List, Dict, Optional | |
from email_ai import initialize_conversation, intro_message, get_chat_model_completions | |
import google.generativeai as genai | |
import os | |
from dotenv import load_dotenv | |
import speech_recognition as sr | |
from io import BytesIO | |
import wave | |
# Load environment variables | |
load_dotenv() | |
# gemini_api_key = os.getenv("GEMINI_API_KEY") | |
# genai.configure(api_key=gemini_api_key) | |
# Define a secret API key (use environment variables in production) | |
API_KEY = os.getenv("API_KEY") | |
app = FastAPI() | |
# Initialize speech recognizer | |
recognizer = sr.Recognizer() | |
# Pydantic models for request/response validation | |
class Message(BaseModel): | |
role: str | |
content: str | |
class ChatRequest(BaseModel): | |
message: str | |
class ChatResponse(BaseModel): | |
response: str | |
conversation: List[Message] | |
class Report(BaseModel): | |
response: str | |
message: str | |
timestamp: str | |
# Dependency to check the API key | |
async def verify_api_key(x_api_key: str = Header(...)): | |
if x_api_key != API_KEY: | |
raise HTTPException(status_code=403, detail="Unauthorized") | |
async def get_conversation(): | |
global llm, chroma_retriever, conversation_bot | |
conversation_bot = [Message(role="bot", content=intro_message)] | |
llm, chroma_retriever = initialize_conversation() | |
return ChatResponse( | |
response=intro_message, | |
conversation=conversation_bot | |
) | |
async def send_message(request: ChatRequest): | |
global conversation_bot | |
conversation_bot.append(Message(role="user", content=request.message)) | |
response_assistant = get_chat_model_completions(llm, chroma_retriever, request.message) | |
conversation_bot.append(Message(role="bot", content=response_assistant.content)) | |
return ChatResponse( | |
response=response_assistant.content, | |
conversation=conversation_bot | |
) | |
# Voice processing endpoint | |
async def process_voice(audio_file: UploadFile = File(...), dependencies=[Depends(verify_api_key)]): | |
# async def process_voice(name: str): | |
global conversation_bot | |
try: | |
# Read the audio file | |
contents = await audio_file.read() | |
audio_data = BytesIO(contents) | |
# Convert audio to wav format for speech recognition | |
with sr.AudioFile(audio_data) as source: | |
audio = recognizer.record(source) | |
# Perform speech recognition | |
text = recognizer.recognize_google(audio) | |
print(text) | |
conversation_bot.append(Message(role="user", content=text)) | |
response_assistant = get_chat_model_completions(llm, chroma_retriever, text) | |
conversation_bot.append(Message(role="bot", content=response_assistant.content)) | |
# print('response_assistant.content') | |
return { | |
"transcribed_text": text, | |
"response": response_assistant.content | |
} | |
except Exception as e: | |
return {"error": f"Error processing voice input: {str(e)}"} | |
async def handle_feedback( | |
request: Report, | |
dependencies=[Depends(verify_api_key)] | |
): | |
# if x_api_key != VALID_API_KEY: | |
# raise HTTPException(status_code=403, detail="Invalid API key") | |
# Here you can store the feedback in your database | |
# For example: | |
# await db.store_feedback(message, is_positive) | |
return {"status": "success"} | |
async def reset_conversation(): | |
global conversation_bot, llm, chroma_retriever | |
conversation_bot = [{'bot': intro_message}] | |
llm, chroma_retriever = initialize_conversation() | |
return {"status": "conversation reset"} | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) |