insuranceai / app.py
prithvirajpawar's picture
addition of intro_msg
140b902
raw
history blame contribute delete
5.35 kB
from fastapi import FastAPI, Request, Depends, HTTPException, Header, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Optional
from helpmate_ai import get_system_msg, retreive_results, rerank_with_cross_encoder, generate_response, intro_message
import google.generativeai as genai
import os
from dotenv import load_dotenv
import re
import speech_recognition as sr
from io import BytesIO
import wave
import google.generativeai as genai
# Load environment variables
load_dotenv()
gemini_api_key = os.getenv("GEMINI_API_KEY")
genai.configure(api_key=gemini_api_key)
# Define a secret API key (use environment variables in production)
API_KEY = os.getenv("API_KEY")
# Initialize FastAPI app
app = FastAPI()
# # Enable CORS
# app.add_middleware(
# CORSMiddleware,
# allow_origins=["*"],
# allow_credentials=True,
# allow_methods=["*"],
# allow_headers=["*"],
# )
# Pydantic models for request/response validation
class Message(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
message: str
class ChatResponse(BaseModel):
response: str
conversation: List[Message]
class Report(BaseModel):
response: str
message: str
timestamp: str
# Initialize conversation and model
conversation_bot = []
conversation = get_system_msg()
model = genai.GenerativeModel("gemini-1.5-flash", system_instruction=conversation)
# Initialize speech recognizer
recognizer = sr.Recognizer()
# Dependency to check the API key
async def verify_api_key(x_api_key: str = Header(...)):
if x_api_key != API_KEY:
raise HTTPException(status_code=403, detail="Unauthorized")
def get_gemini_completions(conversation: str) -> str:
response = model.generate_content(conversation)
return response.text
# @app.get("/secure-endpoint", dependencies=[Depends(verify_api_key)])
# async def secure_endpoint():
# return {"message": "Access granted!"}
# Initialize conversation endpoint
@app.get("/init", response_model=ChatResponse, dependencies=[Depends(verify_api_key)])
async def initialize_chat():
global conversation_bot
# conversation = "Hi"
# introduction = get_gemini_completions(conversation)
conversation_bot = [Message(role="bot", content=intro_message)]
return ChatResponse(
response=intro_message,
conversation=conversation_bot
)
# Chat endpoint
@app.post("/chat", response_model=ChatResponse, dependencies=[Depends(verify_api_key)])
async def chat(request: ChatRequest):
global conversation_bot
# Add user message to conversation
user_message = Message(role="user", content=request.message)
conversation_bot.append(user_message)
# Generate response
results_df = retreive_results(request.message)
top_docs = rerank_with_cross_encoder(request.message, results_df)
messages = generate_response(request.message, top_docs)
response_assistant = get_gemini_completions(messages)
# formatted_response = format_rag_response(response_assistant)
# Add bot response to conversation
bot_message = Message(role="bot", content=response_assistant)
conversation_bot.append(bot_message)
return ChatResponse(
response=response_assistant,
conversation=conversation_bot
)
# Voice processing endpoint
@app.post("/process-voice")
async def process_voice(audio_file: UploadFile = File(...), dependencies=[Depends(verify_api_key)]):
# async def process_voice(name: str):
try:
# Read the audio file
contents = await audio_file.read()
audio_data = BytesIO(contents)
# Convert audio to wav format for speech recognition
with sr.AudioFile(audio_data) as source:
audio = recognizer.record(source)
# Perform speech recognition
text = recognizer.recognize_google(audio)
# print(text)
# Process the text through the chat pipeline
results_df = retreive_results(text)
top_docs = rerank_with_cross_encoder(text, results_df)
messages = generate_response(text, top_docs)
response_assistant = get_gemini_completions(messages)
return {
"transcribed_text": text,
"response": response_assistant
}
except Exception as e:
return {"error": f"Error processing voice input: {str(e)}"}
@app.post("/report")
async def handle_feedback(
request: Report,
dependencies=[Depends(verify_api_key)]
):
# if x_api_key != VALID_API_KEY:
# raise HTTPException(status_code=403, detail="Invalid API key")
# Here you can store the feedback in your database
# For example:
# await db.store_feedback(message, is_positive)
return {"status": "success"}
# Reset conversation endpoint
@app.post("/reset", dependencies=[Depends(verify_api_key)])
async def reset_conversation():
global conversation_bot, conversation
conversation_bot = []
# conversation = "Hi"
# introduction = get_gemini_completions(conversation)
conversation_bot.append(Message(role="bot", content=intro_message))
return {"status": "success", "message": "Conversation reset"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)