Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, File, UploadFile, HTTPException, BackgroundTasks | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.responses import FileResponse, StreamingResponse | |
from pydantic import BaseModel | |
import uvicorn | |
import os | |
import tempfile | |
import shutil | |
from typing import List, Optional, Dict, Any, Iterator | |
import pathlib | |
import asyncio | |
import logging | |
import time | |
import traceback | |
import uuid | |
import json | |
# Configure logging | |
logging.basicConfig(level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
# Import our RAG components | |
from rag import RetrievalAugmentedQAPipeline, process_file, setup_vector_db | |
# Add local aimakerspace module to the path | |
import sys | |
sys.path.append(os.path.join(os.path.dirname(os.path.dirname(__file__)), "")) | |
# Import from local aimakerspace module | |
from aimakerspace.utils.session_manager import SessionManager | |
# Load environment variables | |
from dotenv import load_dotenv | |
load_dotenv() | |
app = FastAPI( | |
title="RAG Application", | |
description="Retrieval Augmented Generation with FastAPI and React", | |
version="0.1.0", | |
root_path="", # Important for proxy environments | |
) | |
# More robust middleware for handling HTTPS | |
from starlette.middleware.base import BaseHTTPMiddleware | |
from starlette.responses import RedirectResponse, JSONResponse | |
class ProxyMiddleware(BaseHTTPMiddleware): | |
async def dispatch(self, request, call_next): | |
# Log request details for debugging | |
logger.info(f"Request path: {request.url.path}") | |
logger.info(f"Request headers: {request.headers}") | |
# Validate request before processing | |
try: | |
start_time = time.time() | |
response = await call_next(request) | |
process_time = time.time() - start_time | |
response.headers["X-Process-Time"] = str(process_time) | |
return response | |
except Exception as e: | |
logger.error(f"Request failed: {str(e)}") | |
logger.error(traceback.format_exc()) | |
return JSONResponse( | |
status_code=500, | |
content={"detail": f"Internal server error: {str(e)}"} | |
) | |
# Add custom middleware | |
app.add_middleware(ProxyMiddleware) | |
# Configure CORS - more specific configuration for Hugging Face | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # In production, you should restrict this | |
allow_credentials=True, | |
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], | |
allow_headers=["*"], | |
expose_headers=["Content-Length", "X-Process-Time"], | |
max_age=600, # 10 minutes cache for preflight requests | |
) | |
# Initialize session manager | |
session_manager = SessionManager() | |
class QueryRequest(BaseModel): | |
session_id: str | |
query: str | |
class QueryResponse(BaseModel): | |
response: str | |
session_id: str | |
# Set file size limit to 10MB - adjust as needed | |
FILE_SIZE_LIMIT = 10 * 1024 * 1024 # 10MB | |
async def process_file_background(temp_path: str, filename: str, session_id: str): | |
"""Process file in background and set up the RAG pipeline""" | |
try: | |
start_time = time.time() | |
logger.info(f"Background processing started for file: {filename} (session: {session_id})") | |
# Set max processing time (5 minutes) | |
max_processing_time = 300 # seconds | |
# Process the file | |
logger.info(f"Starting text extraction for file: {filename}") | |
try: | |
texts = process_file(temp_path, filename) | |
logger.info(f"Processed file into {len(texts)} text chunks (took {time.time() - start_time:.2f}s)") | |
# Check if processing is taking too long already | |
if time.time() - start_time > max_processing_time / 2: | |
logger.warning(f"Text extraction took more than half the allowed time. Limiting chunks...") | |
# Limit to a smaller number if extraction took a long time | |
max_chunks = 50 | |
if len(texts) > max_chunks: | |
logger.warning(f"Limiting text chunks from {len(texts)} to {max_chunks}") | |
texts = texts[:max_chunks] | |
except Exception as e: | |
logger.error(f"Error during text extraction: {str(e)}") | |
logger.error(traceback.format_exc()) | |
session_manager.update_session(session_id, "failed") | |
os.unlink(temp_path) | |
return | |
# Setup vector database - This is the part that might be hanging | |
logger.info(f"Starting vector DB creation for {len(texts)} chunks") | |
embedding_start = time.time() | |
# Create a task with overall timeout | |
try: | |
async def setup_with_timeout(): | |
return await setup_vector_db(texts) | |
# Wait for vector DB setup with timeout | |
vector_db = await asyncio.wait_for( | |
setup_with_timeout(), | |
timeout=max_processing_time - (time.time() - start_time) | |
) | |
# Get document count - check if documents property is available | |
if hasattr(vector_db, 'documents'): | |
doc_count = len(vector_db.documents) | |
else: | |
# If using the original VectorDatabase implementation that uses vectors dict | |
doc_count = len(vector_db.vectors) if hasattr(vector_db, 'vectors') else 0 | |
logger.info(f"Created vector database with {doc_count} documents (took {time.time() - embedding_start:.2f}s)") | |
# Create RAG pipeline | |
logger.info(f"Creating RAG pipeline for session {session_id}") | |
rag_pipeline = RetrievalAugmentedQAPipeline(vector_db_retriever=vector_db) | |
# Store pipeline in session manager | |
session_manager.update_session(session_id, rag_pipeline) | |
logger.info(f"Updated session {session_id} with processed pipeline (total time: {time.time() - start_time:.2f}s)") | |
except asyncio.TimeoutError: | |
logger.error(f"Vector database creation timed out after {time.time() - embedding_start:.2f}s") | |
session_manager.update_session(session_id, "failed") | |
except Exception as e: | |
logger.error(f"Error in vector database creation: {str(e)}") | |
logger.error(traceback.format_exc()) | |
session_manager.update_session(session_id, "failed") | |
# Clean up temp file | |
os.unlink(temp_path) | |
logger.info(f"Removed temporary file: {temp_path}") | |
except Exception as e: | |
logger.error(f"Error in background processing for session {session_id}: {str(e)}") | |
logger.error(traceback.format_exc()) # Log the full error traceback | |
# Mark the session as failed rather than removing it | |
session_manager.update_session(session_id, "failed") | |
# Try to clean up temp file if it exists | |
try: | |
if os.path.exists(temp_path): | |
os.unlink(temp_path) | |
logger.info(f"Cleaned up temporary file after error: {temp_path}") | |
except Exception as cleanup_error: | |
logger.error(f"Error cleaning up temp file: {str(cleanup_error)}") | |
async def upload_file(background_tasks: BackgroundTasks, file: UploadFile = File(...)): | |
try: | |
logger.info(f"Received upload request for file: {file.filename}") | |
logger.info(f"Content type: {file.content_type}") | |
# Create a unique ID for this upload | |
upload_id = str(uuid.uuid4()) | |
logger.info(f"Assigned upload ID: {upload_id}") | |
# Check file size first | |
file_size = 0 | |
chunk_size = 1024 * 1024 # 1MB chunks for reading | |
contents = bytearray() | |
# Read file in chunks to avoid memory issues | |
try: | |
while True: | |
chunk = await asyncio.wait_for(file.read(chunk_size), timeout=60.0) | |
if not chunk: | |
break | |
file_size += len(chunk) | |
contents.extend(chunk) | |
# Check size limit | |
if file_size > FILE_SIZE_LIMIT: | |
logger.warning(f"File too large: {file_size/1024/1024:.2f}MB exceeds limit of {FILE_SIZE_LIMIT/1024/1024}MB") | |
return HTTPException( | |
status_code=413, | |
detail=f"File too large. Maximum size is {FILE_SIZE_LIMIT/1024/1024}MB" | |
) | |
# Log progress for large files | |
if file_size % (5 * 1024 * 1024) == 0: # Log every 5MB | |
logger.info(f"Upload progress: {file_size/1024/1024:.2f}MB read so far...") | |
except asyncio.TimeoutError: | |
logger.error(f"Timeout reading file: {file.filename}") | |
raise HTTPException( | |
status_code=408, | |
detail="Request timeout while reading file. Please try again." | |
) | |
logger.info(f"File size: {file_size/1024/1024:.2f}MB") | |
# Reset file stream for processing | |
file_content = bytes(contents) | |
# Create a temporary file | |
suffix = f".{file.filename.split('.')[-1]}" | |
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file: | |
# Write file content to temp file | |
temp_file.write(file_content) | |
temp_path = temp_file.name | |
logger.info(f"Created temporary file at: {temp_path}") | |
# Generate session ID and create session | |
session_id = session_manager.create_session("processing") | |
logger.info(f"Created session ID: {session_id}") | |
# Process file in background | |
background_tasks.add_task( | |
process_file_background, | |
temp_path, | |
file.filename, | |
session_id | |
) | |
return {"session_id": session_id, "message": "File uploaded and processing started", "upload_id": upload_id} | |
except Exception as e: | |
logger.error(f"Error processing upload: {str(e)}") | |
logger.error(traceback.format_exc()) # Log the full error traceback | |
raise HTTPException(status_code=500, detail=f"Error processing file: {str(e)}") | |
async def process_query(request: QueryRequest): | |
logger.info(f"Received query request for session: {request.session_id}") | |
# Check if session exists | |
if not session_manager.session_exists(request.session_id): | |
logger.warning(f"Session not found: {request.session_id}") | |
raise HTTPException(status_code=404, detail="Session not found. Please upload a document first.") | |
# Get session data | |
session_data = session_manager.get_session(request.session_id) | |
# Check if processing is still ongoing | |
if session_data == "processing": | |
logger.info(f"Document still processing for session: {request.session_id}") | |
raise HTTPException(status_code=409, detail="Document is still being processed. Please try again in a moment.") | |
# Check if processing failed | |
if session_data == "failed": | |
logger.error(f"Processing failed for session: {request.session_id}") | |
raise HTTPException(status_code=500, detail="Document processing failed. Please try uploading again.") | |
try: | |
logger.info(f"Processing query: '{request.query}' for session: {request.session_id}") | |
# Get response from RAG pipeline | |
start_time = time.time() | |
result = await session_data.arun_pipeline(request.query) | |
# Stream the response - this is key for the Star Wars effect | |
async def stream_response(): | |
try: | |
async for chunk in result["response"]: | |
# Add a small delay between chunks for dramatic effect | |
await asyncio.sleep(0.01) | |
# Stream each chunk as JSON with proper encoding | |
yield chunk | |
logger.info(f"Completed streaming response (took {time.time() - start_time:.2f}s)") | |
except Exception as e: | |
logger.error(f"Error in streaming: {str(e)}") | |
yield f"Error during streaming: {str(e)}" | |
# Return streaming response | |
return StreamingResponse( | |
stream_response(), | |
media_type="text/plain", | |
) | |
except Exception as e: | |
logger.error(f"Error processing query for session {request.session_id}: {str(e)}") | |
logger.error(traceback.format_exc()) # Log the full error traceback | |
raise HTTPException(status_code=500, detail=f"Error processing query: {str(e)}") | |
def health_check(): | |
return {"status": "healthy"} | |
def test_endpoint(): | |
return {"message": "Backend is accessible"} | |
async def session_status(session_id: str): | |
"""Check if a session exists and its processing status""" | |
logger.info(f"Checking status for session: {session_id}") | |
if not session_manager.session_exists(session_id): | |
logger.warning(f"Session not found: {session_id}") | |
return {"exists": False, "status": "not_found"} | |
session_data = session_manager.get_session(session_id) | |
if session_data == "processing": | |
logger.info(f"Session {session_id} is still processing") | |
return {"exists": True, "status": "processing"} | |
if session_data == "failed": | |
logger.error(f"Session {session_id} processing failed") | |
return {"exists": True, "status": "failed"} | |
logger.info(f"Session {session_id} is ready") | |
return {"exists": True, "status": "ready"} | |
async def debug_sessions(): | |
"""Return debug information about all sessions - for diagnostic use only""" | |
logger.info("Accessed debug sessions endpoint") | |
# Get summary of all sessions | |
sessions_summary = session_manager.get_sessions_summary() | |
return sessions_summary | |
# For Hugging Face Spaces deployment, serve the static files from the React build | |
frontend_path = pathlib.Path(__file__).parent.parent / "frontend" / "build" | |
if frontend_path.exists(): | |
app.mount("/", StaticFiles(directory=str(frontend_path), html=True), name="frontend") | |
async def serve_frontend(): | |
return FileResponse(str(frontend_path / "index.html")) | |
if __name__ == "__main__": | |
# Get the port from environment variable or use default | |
port = int(os.environ.get("PORT", 8000)) | |
# For Hugging Face Spaces deployment | |
uvicorn.run( | |
"main:app", | |
host="0.0.0.0", | |
port=port, | |
proxy_headers=True, # This tells uvicorn to trust the X-Forwarded-* headers | |
forwarded_allow_ips="*" # Allow forwarded requests from any IP | |
) |