File size: 15,113 Bytes
209e402
 
 
fb8177a
209e402
 
 
 
 
fb8177a
209e402
 
 
 
 
 
fb8177a
209e402
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b9a663
 
 
 
 
 
209e402
4b9a663
75e2ca4
4b9a663
75e2ca4
4b9a663
75e2ca4
4b9a663
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209e402
 
4b9a663
209e402
4b9a663
 
 
 
209e402
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b9a663
 
 
 
 
209e402
 
 
 
 
 
 
4b9a663
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209e402
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b9a663
209e402
 
 
 
 
 
fb8177a
209e402
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb8177a
 
 
 
 
 
 
 
 
 
 
 
 
209e402
fb8177a
 
 
 
 
209e402
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75e2ca4
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
from fastapi import FastAPI, File, UploadFile, HTTPException, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse, StreamingResponse
from pydantic import BaseModel
import uvicorn
import os
import tempfile
import shutil
from typing import List, Optional, Dict, Any, Iterator
import pathlib
import asyncio
import logging
import time
import traceback
import uuid
import json

# Configure logging
logging.basicConfig(level=logging.INFO, 
                    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Import our RAG components
from rag import RetrievalAugmentedQAPipeline, process_file, setup_vector_db

# Add local aimakerspace module to the path
import sys
sys.path.append(os.path.join(os.path.dirname(os.path.dirname(__file__)), ""))

# Import from local aimakerspace module
from aimakerspace.utils.session_manager import SessionManager

# Load environment variables
from dotenv import load_dotenv
load_dotenv()

app = FastAPI(
    title="RAG Application",
    description="Retrieval Augmented Generation with FastAPI and React",
    version="0.1.0",
    root_path="",  # Important for proxy environments
)

# More robust middleware for handling HTTPS
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import RedirectResponse, JSONResponse

class ProxyMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request, call_next):
        # Log request details for debugging
        logger.info(f"Request path: {request.url.path}")
        logger.info(f"Request headers: {request.headers}")
        
        # Validate request before processing
        try:
            start_time = time.time()
            response = await call_next(request)
            process_time = time.time() - start_time
            response.headers["X-Process-Time"] = str(process_time)
            return response
        except Exception as e:
            logger.error(f"Request failed: {str(e)}")
            logger.error(traceback.format_exc())
            return JSONResponse(
                status_code=500,
                content={"detail": f"Internal server error: {str(e)}"}
            )

# Add custom middleware
app.add_middleware(ProxyMiddleware)

# Configure CORS - more specific configuration for Hugging Face
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # In production, you should restrict this
    allow_credentials=True,
    allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
    allow_headers=["*"],
    expose_headers=["Content-Length", "X-Process-Time"],
    max_age=600,  # 10 minutes cache for preflight requests
)

# Initialize session manager
session_manager = SessionManager()

class QueryRequest(BaseModel):
    session_id: str
    query: str

class QueryResponse(BaseModel):
    response: str
    session_id: str

# Set file size limit to 10MB - adjust as needed
FILE_SIZE_LIMIT = 10 * 1024 * 1024  # 10MB

async def process_file_background(temp_path: str, filename: str, session_id: str):
    """Process file in background and set up the RAG pipeline"""
    try:
        start_time = time.time()
        logger.info(f"Background processing started for file: {filename} (session: {session_id})")
        
        # Set max processing time (5 minutes)
        max_processing_time = 300  # seconds
        
        # Process the file
        logger.info(f"Starting text extraction for file: {filename}")
        try:
            texts = process_file(temp_path, filename)
            logger.info(f"Processed file into {len(texts)} text chunks (took {time.time() - start_time:.2f}s)")
            
            # Check if processing is taking too long already
            if time.time() - start_time > max_processing_time / 2:
                logger.warning(f"Text extraction took more than half the allowed time. Limiting chunks...")
                # Limit to a smaller number if extraction took a long time
                max_chunks = 50
                if len(texts) > max_chunks:
                    logger.warning(f"Limiting text chunks from {len(texts)} to {max_chunks}")
                    texts = texts[:max_chunks]
        except Exception as e:
            logger.error(f"Error during text extraction: {str(e)}")
            logger.error(traceback.format_exc())
            session_manager.update_session(session_id, "failed")
            os.unlink(temp_path)
            return
        
        # Setup vector database - This is the part that might be hanging
        logger.info(f"Starting vector DB creation for {len(texts)} chunks")
        embedding_start = time.time()
        
        # Create a task with overall timeout
        try:
            async def setup_with_timeout():
                return await setup_vector_db(texts)
                
            # Wait for vector DB setup with timeout
            vector_db = await asyncio.wait_for(
                setup_with_timeout(), 
                timeout=max_processing_time - (time.time() - start_time)
            )
            
            # Get document count - check if documents property is available
            if hasattr(vector_db, 'documents'):
                doc_count = len(vector_db.documents)
            else:
                # If using the original VectorDatabase implementation that uses vectors dict
                doc_count = len(vector_db.vectors) if hasattr(vector_db, 'vectors') else 0
                
            logger.info(f"Created vector database with {doc_count} documents (took {time.time() - embedding_start:.2f}s)")
            
            # Create RAG pipeline
            logger.info(f"Creating RAG pipeline for session {session_id}")
            rag_pipeline = RetrievalAugmentedQAPipeline(vector_db_retriever=vector_db)
            
            # Store pipeline in session manager
            session_manager.update_session(session_id, rag_pipeline)
            logger.info(f"Updated session {session_id} with processed pipeline (total time: {time.time() - start_time:.2f}s)")
            
        except asyncio.TimeoutError:
            logger.error(f"Vector database creation timed out after {time.time() - embedding_start:.2f}s")
            session_manager.update_session(session_id, "failed")
        except Exception as e:
            logger.error(f"Error in vector database creation: {str(e)}")
            logger.error(traceback.format_exc())
            session_manager.update_session(session_id, "failed")
        
        # Clean up temp file
        os.unlink(temp_path)
        logger.info(f"Removed temporary file: {temp_path}")
        
    except Exception as e:
        logger.error(f"Error in background processing for session {session_id}: {str(e)}")
        logger.error(traceback.format_exc())  # Log the full error traceback
        # Mark the session as failed rather than removing it
        session_manager.update_session(session_id, "failed")
        # Try to clean up temp file if it exists
        try:
            if os.path.exists(temp_path):
                os.unlink(temp_path)
                logger.info(f"Cleaned up temporary file after error: {temp_path}")
        except Exception as cleanup_error:
            logger.error(f"Error cleaning up temp file: {str(cleanup_error)}")

@app.post("/upload/")
async def upload_file(background_tasks: BackgroundTasks, file: UploadFile = File(...)):
    try:
        logger.info(f"Received upload request for file: {file.filename}")
        logger.info(f"Content type: {file.content_type}")
        
        # Create a unique ID for this upload
        upload_id = str(uuid.uuid4())
        logger.info(f"Assigned upload ID: {upload_id}")
        
        # Check file size first
        file_size = 0
        chunk_size = 1024 * 1024  # 1MB chunks for reading
        contents = bytearray()
        
        # Read file in chunks to avoid memory issues
        try:
            while True:
                chunk = await asyncio.wait_for(file.read(chunk_size), timeout=60.0)
                if not chunk:
                    break
                file_size += len(chunk)
                contents.extend(chunk)
                
                # Check size limit
                if file_size > FILE_SIZE_LIMIT:
                    logger.warning(f"File too large: {file_size/1024/1024:.2f}MB exceeds limit of {FILE_SIZE_LIMIT/1024/1024}MB")
                    return HTTPException(
                        status_code=413, 
                        detail=f"File too large. Maximum size is {FILE_SIZE_LIMIT/1024/1024}MB"
                    )
                
                # Log progress for large files
                if file_size % (5 * 1024 * 1024) == 0:  # Log every 5MB
                    logger.info(f"Upload progress: {file_size/1024/1024:.2f}MB read so far...")
                    
        except asyncio.TimeoutError:
            logger.error(f"Timeout reading file: {file.filename}")
            raise HTTPException(
                status_code=408,
                detail="Request timeout while reading file. Please try again."
            )
        
        logger.info(f"File size: {file_size/1024/1024:.2f}MB")
        
        # Reset file stream for processing
        file_content = bytes(contents)
        
        # Create a temporary file
        suffix = f".{file.filename.split('.')[-1]}"
        with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
            # Write file content to temp file
            temp_file.write(file_content)
            temp_path = temp_file.name
            logger.info(f"Created temporary file at: {temp_path}")
        
        # Generate session ID and create session
        session_id = session_manager.create_session("processing")
        logger.info(f"Created session ID: {session_id}")
        
        # Process file in background
        background_tasks.add_task(
            process_file_background, 
            temp_path, 
            file.filename, 
            session_id
        )
        
        return {"session_id": session_id, "message": "File uploaded and processing started", "upload_id": upload_id}
    
    except Exception as e:
        logger.error(f"Error processing upload: {str(e)}")
        logger.error(traceback.format_exc())  # Log the full error traceback
        raise HTTPException(status_code=500, detail=f"Error processing file: {str(e)}")

@app.post("/query/")
async def process_query(request: QueryRequest):
    logger.info(f"Received query request for session: {request.session_id}")
    
    # Check if session exists
    if not session_manager.session_exists(request.session_id):
        logger.warning(f"Session not found: {request.session_id}")
        raise HTTPException(status_code=404, detail="Session not found. Please upload a document first.")
    
    # Get session data
    session_data = session_manager.get_session(request.session_id)
    
    # Check if processing is still ongoing
    if session_data == "processing":
        logger.info(f"Document still processing for session: {request.session_id}")
        raise HTTPException(status_code=409, detail="Document is still being processed. Please try again in a moment.")
    
    # Check if processing failed
    if session_data == "failed":
        logger.error(f"Processing failed for session: {request.session_id}")
        raise HTTPException(status_code=500, detail="Document processing failed. Please try uploading again.")
    
    try:
        logger.info(f"Processing query: '{request.query}' for session: {request.session_id}")
        
        # Get response from RAG pipeline
        start_time = time.time()
        result = await session_data.arun_pipeline(request.query)
        
        # Stream the response - this is key for the Star Wars effect
        async def stream_response():
            try:
                async for chunk in result["response"]:
                    # Add a small delay between chunks for dramatic effect
                    await asyncio.sleep(0.01)
                    # Stream each chunk as JSON with proper encoding
                    yield chunk
                
                logger.info(f"Completed streaming response (took {time.time() - start_time:.2f}s)")
            except Exception as e:
                logger.error(f"Error in streaming: {str(e)}")
                yield f"Error during streaming: {str(e)}"
        
        # Return streaming response
        return StreamingResponse(
            stream_response(),
            media_type="text/plain",
        )
    
    except Exception as e:
        logger.error(f"Error processing query for session {request.session_id}: {str(e)}")
        logger.error(traceback.format_exc())  # Log the full error traceback
        raise HTTPException(status_code=500, detail=f"Error processing query: {str(e)}")

@app.get("/health")
def health_check():
    return {"status": "healthy"}

@app.get("/test")
def test_endpoint():
    return {"message": "Backend is accessible"}

@app.get("/session/{session_id}/status")
async def session_status(session_id: str):
    """Check if a session exists and its processing status"""
    logger.info(f"Checking status for session: {session_id}")
    
    if not session_manager.session_exists(session_id):
        logger.warning(f"Session not found: {session_id}")
        return {"exists": False, "status": "not_found"}
    
    session_data = session_manager.get_session(session_id)
    
    if session_data == "processing":
        logger.info(f"Session {session_id} is still processing")
        return {"exists": True, "status": "processing"}
    
    if session_data == "failed":
        logger.error(f"Session {session_id} processing failed")
        return {"exists": True, "status": "failed"}
    
    logger.info(f"Session {session_id} is ready")
    return {"exists": True, "status": "ready"}

@app.get("/debug/sessions")
async def debug_sessions():
    """Return debug information about all sessions - for diagnostic use only"""
    logger.info("Accessed debug sessions endpoint")
    
    # Get summary of all sessions
    sessions_summary = session_manager.get_sessions_summary()
    
    return sessions_summary

# For Hugging Face Spaces deployment, serve the static files from the React build
frontend_path = pathlib.Path(__file__).parent.parent / "frontend" / "build"
if frontend_path.exists():
    app.mount("/", StaticFiles(directory=str(frontend_path), html=True), name="frontend")

    @app.get("/", include_in_schema=False)
    async def serve_frontend():
        return FileResponse(str(frontend_path / "index.html"))

if __name__ == "__main__":
    # Get the port from environment variable or use default
    port = int(os.environ.get("PORT", 8000))
    
    # For Hugging Face Spaces deployment
    uvicorn.run(
        "main:app", 
        host="0.0.0.0", 
        port=port,
        proxy_headers=True,  # This tells uvicorn to trust the X-Forwarded-* headers
        forwarded_allow_ips="*"  # Allow forwarded requests from any IP
    )