Spaces:
Sleeping
Sleeping
import os | |
from typing import List, Dict, Any | |
import tempfile | |
import shutil | |
import logging | |
import time | |
import traceback | |
import asyncio | |
# Configure logging | |
logging.basicConfig(level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
# Make sure aimakerspace is in the path | |
import sys | |
sys.path.append(os.path.join(os.path.dirname(os.path.dirname(__file__)), "")) | |
# Import from local aimakerspace module | |
from aimakerspace.text_utils import CharacterTextSplitter, TextFileLoader, PDFLoader | |
from aimakerspace.vectordatabase import VectorDatabase | |
from aimakerspace.openai_utils.embedding import EmbeddingModel | |
from openai import OpenAI | |
# Initialize OpenAI client | |
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) | |
logger.info(f"Initialized OpenAI client with API key: {'valid key' if os.getenv('OPENAI_API_KEY') else 'API KEY MISSING!'}") | |
class RetrievalAugmentedQAPipeline: | |
def __init__(self, vector_db_retriever: VectorDatabase) -> None: | |
self.vector_db_retriever = vector_db_retriever | |
async def arun_pipeline(self, user_query: str): | |
""" | |
Run the RAG pipeline with the given user query. | |
Returns a stream of response chunks. | |
""" | |
try: | |
# 1. Retrieve relevant documents | |
logger.info(f"RAG Pipeline: Retrieving documents for query: '{user_query}'") | |
relevant_docs = self.vector_db_retriever.search_by_text(user_query, k=4) | |
if not relevant_docs: | |
logger.warning("No relevant documents found in vector database") | |
documents_context = "No relevant information found in the document." | |
else: | |
logger.info(f"Found {len(relevant_docs)} relevant document chunks") | |
# Format documents | |
documents_context = "\n\n".join([doc[0] for doc in relevant_docs]) | |
# Debug similarity scores | |
doc_scores = [f"{i+1}. Score: {doc[1]:.4f}" for i, doc in enumerate(relevant_docs)] | |
logger.info(f"Document similarity scores: {', '.join(doc_scores) if doc_scores else 'No documents'}") | |
# 2. Create messaging payload | |
messages = [ | |
{"role": "system", "content": f"""You are a helpful AI assistant that answers questions based on the provided document context. | |
If the answer is not in the context, say that you don't know based on the available information. | |
Use the following document extracts to answer the user's question: | |
{documents_context}"""}, | |
{"role": "user", "content": user_query} | |
] | |
# 3. Call LLM and stream the output | |
async def generate_response(): | |
try: | |
logger.info("Initiating streaming completion from OpenAI") | |
stream = client.chat.completions.create( | |
model="gpt-3.5-turbo", | |
messages=messages, | |
temperature=0.2, | |
stream=True | |
) | |
for chunk in stream: | |
if chunk.choices[0].delta.content: | |
yield chunk.choices[0].delta.content | |
except Exception as e: | |
logger.error(f"Error generating stream: {str(e)}") | |
yield f"\n\nI apologize, but I encountered an error while generating a response: {str(e)}" | |
return { | |
"response": generate_response() | |
} | |
except Exception as e: | |
logger.error(f"Error in RAG pipeline: {str(e)}") | |
logger.error(traceback.format_exc()) | |
return { | |
"response": (chunk for chunk in [f"I apologize, but an error occurred: {str(e)}"]) | |
} | |
def process_file(file_path: str, file_name: str) -> List[str]: | |
"""Process an uploaded file and convert it to text chunks - optimized for speed""" | |
logger.info(f"Processing file: {file_name} at path: {file_path}") | |
try: | |
# Determine loader based on file extension | |
if file_name.lower().endswith('.txt'): | |
logger.info(f"Using TextFileLoader for {file_name}") | |
loader = TextFileLoader(file_path) | |
loader.load() | |
elif file_name.lower().endswith('.pdf'): | |
logger.info(f"Using PDFLoader for {file_name}") | |
loader = PDFLoader(file_path) | |
loader.load() | |
else: | |
logger.warning(f"Unsupported file type: {file_name}") | |
return ["Unsupported file format. Please upload a .txt or .pdf file."] | |
# Get documents from loader | |
documents = loader.documents | |
if documents and len(documents) > 0: | |
logger.info(f"Loaded document with {len(documents[0])} characters") | |
else: | |
logger.warning("No document content loaded") | |
return ["No content found in the document"] | |
# Split text into chunks - use parallel processing | |
logger.info("Splitting document with parallel processing") | |
chunk_size = 1500 # Increased from 1000 for fewer chunks | |
chunk_overlap = 150 # Increased from 100 for better context | |
# Use 8 workers for parallel processing | |
text_splitter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap, max_workers=8) | |
text_chunks = text_splitter.split_texts(documents) | |
# Limit chunks to avoid processing too many for speed | |
max_chunks = 40 # Reduced from default | |
if len(text_chunks) > max_chunks: | |
logger.warning(f"Too many chunks ({len(text_chunks)}), limiting to {max_chunks} for faster processing") | |
text_chunks = text_chunks[:max_chunks] | |
logger.info(f"Split document into {len(text_chunks)} chunks") | |
return text_chunks | |
except Exception as e: | |
logger.error(f"Error processing file: {str(e)}") | |
logger.error(traceback.format_exc()) | |
return [f"Error processing file: {str(e)}"] | |
async def setup_vector_db(texts: List[str]) -> VectorDatabase: | |
"""Create vector database from text chunks - optimized with parallel processing""" | |
logger.info(f"Setting up vector database with {len(texts)} text chunks") | |
# Create embedding model to use with VectorDatabase | |
embedding_model = EmbeddingModel() | |
# Use batch size of 20 for better parallelization | |
vector_db = VectorDatabase(embedding_model=embedding_model, batch_size=20) | |
try: | |
# Limit number of chunks for faster processing | |
max_chunks = 40 | |
if len(texts) > max_chunks: | |
logger.warning(f"Limiting {len(texts)} chunks to {max_chunks} for vector embedding") | |
texts = texts[:max_chunks] | |
# Build vector database with batch processing | |
logger.info("Building vector database with batch processing") | |
await vector_db.abuild_from_list(texts) | |
# Add documents property for compatibility | |
vector_db.documents = texts | |
logger.info(f"Vector database built with {len(texts)} documents") | |
return vector_db | |
except asyncio.TimeoutError: | |
logger.error(f"Vector database creation timed out after 300 seconds") | |
# Create minimal fallback DB with just a few documents | |
fallback_db = VectorDatabase(embedding_model=embedding_model) | |
if texts: | |
# Use just first few texts for minimal functionality | |
minimal_texts = texts[:3] | |
for text in minimal_texts: | |
fallback_db.insert(text, [0.0] * 1536) # Use zero vectors for speed | |
fallback_db.documents = minimal_texts | |
else: | |
error_text = "I'm sorry, but there was a timeout during document processing." | |
fallback_db.insert(error_text, [0.0] * 1536) | |
fallback_db.documents = [error_text] | |
return fallback_db | |
except Exception as e: | |
logger.error(f"Error setting up vector database: {str(e)}") | |
logger.error(traceback.format_exc()) | |
# Create fallback DB for this error case | |
fallback_db = VectorDatabase(embedding_model=embedding_model) | |
error_text = "I'm sorry, but there was an error processing the document." | |
fallback_db.insert(error_text, [0.0] * 1536) | |
fallback_db.documents = [error_text] | |
return fallback_db |