chat / app /main.py
ariansyahdedy's picture
Add memory
7b2511b
from fastapi import FastAPI, Request, status
from fastapi.responses import JSONResponse
from fastapi.responses import Response
from fastapi.exceptions import HTTPException
from fastapi.background import BackgroundTasks
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded
from slowapi.util import get_remote_address
from slowapi.middleware import SlowAPIMiddleware
from typing import Dict, List
from prometheus_client import Counter, Histogram, start_http_server
from pydantic import BaseModel, ValidationError
from app.services.message import generate_reply, send_reply
import logging
import httpx
from datetime import datetime
from sentence_transformers import SentenceTransformer
from app.search.rag_pipeline import RAGSystem
from contextlib import asynccontextmanager
# from app.db.database import create_indexes, init_db
# from app.services.webhook_handler import verify_webhook
from app.handlers.message_handler import MessageHandler
from app.handlers.webhook_handler import WebhookHandler
from app.handlers.media_handler import WhatsAppMediaHandler
from app.services.cache import MessageCache
from app.services.chat_manager import ChatManager
from app.api.api_prompt import prompt_router
from app.api.api_file import file_router, load_file_with_markdown_function
from app.utils.load_env import ACCESS_TOKEN, WHATSAPP_API_URL, GEMINI_API
from fastapi.staticfiles import StaticFiles
from vidavox.core import RAG_Engine
from app.memory import AgentMemory
from app.settings import settings
from markitdown import MarkItDown
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Initialize handlers at startup
message_handler = None
webhook_handler = None
indexed_links = ["https://sswalfa.surabaya.go.id/info/detail/izin-pengumpulan-sumbangan-bencana",
"https://sswalfa.surabaya.go.id/info/detail/izin-pemakaian-ruang-terbuka-hijau",
"https://sswalfa.surabaya.go.id/info/detail/pengganti-ipt",
"https://sswalfa.surabaya.go.id/info/detail/arahan-sistem-drainase",
"https://sswalfa.surabaya.go.id/info/detail/rangkaian-pelayanan-surat-pernyataan-belum-menikah-lagi-bagi-jandaduda"
]
async def setup_message_handler():
logger = logging.getLogger(__name__)
message_cache = MessageCache()
chat_manager = ChatManager()
media_handler = WhatsAppMediaHandler()
return MessageHandler(
message_cache=message_cache,
chat_manager=chat_manager,
media_handler=media_handler,
logger=logger
)
# async def setup_rag_system():
# embedding_model = SentenceTransformer('all-MiniLM-L6-v2') # Replace with your model if different
# rag_system = RAGSystem(embedding_model)
# return rag_system
# Initialize FastAPI app
@asynccontextmanager
async def lifespan(app: FastAPI):
try:
agentMemory = AgentMemory(db_url=settings.POSTGRES_DB_URL)
memory = await agentMemory.initialize()
# await init_db()
file_paths = ['./docs/coretax_telegram.csv']
logger.info("Connected to the MongoDB database!")
# rag_system = await setup_rag_system()
engine= RAG_Engine(embedding_model='Snowflake/snowflake-arctic-embed-l-v2.0').from_paths(file_paths, load_csv_as_pandas_dataframe=True, text_col='answer', metadata_cols=['question','images_path'])
app.state.rag_system = engine
app.state.agentMemory = agentMemory
app.state.memory = memory
global message_handler, webhook_handler
message_handler = await setup_message_handler()
webhook_handler = WebhookHandler(message_handler)
# collections = app.database.list_collection_names()
# print(f"Collections in {db_name}: {collections}")
# await load_file_with_markdown_function(rag_system=rag_system, filepaths=indexed_links)
yield
except Exception as e:
logger.error(e)
# Initialize Limiter and Prometheus Metrics
limiter = Limiter(key_func=get_remote_address)
app = FastAPI(lifespan=lifespan)
# Mount the 'images' directory so its files are available under the /images URL path
app.mount("/images", StaticFiles(directory="images"), name="images")
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
# Add SlowAPI Middleware
app.add_middleware(SlowAPIMiddleware)
# app.include_router(users.router, prefix="/users", tags=["Users"])
app.include_router(prompt_router, prefix="/prompts", tags=["Prompts"])
app.include_router(file_router, prefix="/file_load", tags=["File Load"])
# Prometheus metrics
webhook_requests = Counter('webhook_requests_total', 'Total webhook requests')
webhook_processing_time = Histogram('webhook_processing_seconds', 'Time spent processing webhook')
def get_image_links(image_paths: List[str]) -> List[str]:
links = []
for path in image_paths:
# Remove the surrounding brackets and any extra whitespace
cleaned = path.strip("[]").strip()
# Split by comma to get individual image paths
parts = [part.strip() for part in cleaned.split(",") if part.strip()]
for part in parts:
# Assuming the part starts with "images/", extract the filename
if part.startswith("images/"):
filename = part.split("/", 1)[1]
links.append(f"/images/{filename}")
else:
links.append(part) # Fallback if the format is unexpected
return links
# @app.get("/image-links")
# async def image_links_endpoint():
# image_paths = ['[images/photo_3.jpg, images/photo_16.jpg]']
# links = get_image_links(image_paths)
# return {"links": links}
# Start Prometheus metrics server on port 8002
# start_http_server(8002)
# Register webhook routes
# app.post("/webhook")(webhook)
# Define Pydantic schema for request validation
class WebhookPayload(BaseModel):
entry: List[Dict]
@app.post("/webhook")
# @limiter.limit("20/minute")
async def webhook(request: Request, background_tasks: BackgroundTasks):
try:
payload = await request.json()
rag_system = request.app.state.rag_system
agentMemory = request.app.state.agentMemory
memory = request.app.state.memory
# validated_payload = WebhookPayload(**payload) # Validate payload
# logger.info(f"Validated Payload: {validated_payload}")
# Process the webhook payload here
# For example:
# results = process_webhook_entries(validated_payload.entry)
# e.g., whatsapp_token, verify_token, llm_api_key, llm_model
whatsapp_token = request.query_params.get("whatsapp_token")
whatsapp_url = request.query_params.get("whatsapp_url")
gemini_api = request.query_params.get("gemini_api")
llm_model = request.query_params.get("cx_code")
# Return HTTP 200 immediately
# response = JSONResponse(
# content={"status": "received"},
# status_code=200
# )
print(f"payload: {payload}")
# response = await webhook_handler.process_webhook(
# payload=payload,
# whatsapp_token=ACCESS_TOKEN,
# whatsapp_url=WHATSAPP_API_URL,
# gemini_api=GEMINI_API,
# rag_system=rag_system,
# )
# Add the processing to background tasks
background_tasks.add_task(
webhook_handler.process_webhook,
payload=payload,
whatsapp_token=ACCESS_TOKEN,
whatsapp_url=WHATSAPP_API_URL,
gemini_api=GEMINI_API,
rag_system=rag_system,
agentMemory = agentMemory,
memory = memory
)
# Return HTTP 200 immediately
return JSONResponse(
content={"status": "received"},
status_code=status.HTTP_200_OK
)
# return JSONResponse(
# content=response.__dict__,
# status_code=status.HTTP_200_OK
# )
except ValidationError as ve:
logger.error(f"Validation error: {ve}")
return JSONResponse(
content={"status": "error", "detail": ve.errors()},
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY
)
except Exception as e:
logger.error(f"Unexpected error: {str(e)}")
return JSONResponse(
content={"status": "error", "detail": str(e)},
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
)
@app.get("/webhook")
async def verify_webhook(request: Request):
mode = request.query_params.get('hub.mode')
token = request.query_params.get('hub.verify_token')
challenge = request.query_params.get('hub.challenge')
if mode == 'subscribe' and token == 'test':
return Response(content=challenge, media_type="text/plain")
else:
raise HTTPException(status_code=403, detail="Verification failed")
@app.post("/load_file")
async def load_file_with_markitdown(file_path:str, llm_client:str=None, model:str=None):
if llm_client and model:
markitdown = MarkItDown(llm_client, model)
documents = markitdown.convert(file_path)
else:
markitdown = MarkItDown()
documents = markitdown.convert(file_path)
print(f"documents: {documents}")
return documents
# Add a route for Prometheus metrics (optional, if not using a separate Prometheus server)
@app.get("/metrics")
async def metrics():
from prometheus_client import generate_latest
return Response(content=generate_latest(), media_type="text/plain")
# In-memory cache with timestamp cleanup
# class MessageCache:
# def __init__(self, max_age_hours: int = 24):
# self.messages: Dict[str, float] = {}
# self.max_age_seconds = max_age_hours * 3600
# def add(self, message_id: str) -> None:
# self.cleanup()
# self.messages[message_id] = time.time()
# def exists(self, message_id: str) -> bool:
# self.cleanup()
# return message_id in self.messages
# def cleanup(self) -> None:
# current_time = time.time()
# self.messages = {
# msg_id: timestamp
# for msg_id, timestamp in self.messages.items()
# if current_time - timestamp < self.max_age_seconds
# }
# message_cache = MessageCache()
# user_chats = {}