Spaces:
Sleeping
Sleeping
File size: 8,603 Bytes
209e402 ba10a58 209e402 ba10a58 209e402 ba10a58 209e402 ba10a58 209e402 ba10a58 209e402 ba10a58 209e402 ba10a58 209e402 ba10a58 209e402 ba10a58 209e402 ba10a58 209e402 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
import os
from typing import List, Dict, Any
import tempfile
import shutil
import logging
import time
import traceback
import asyncio
# Configure logging
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Make sure aimakerspace is in the path
import sys
sys.path.append(os.path.join(os.path.dirname(os.path.dirname(__file__)), ""))
# Import from local aimakerspace module
from aimakerspace.text_utils import CharacterTextSplitter, TextFileLoader, PDFLoader
from aimakerspace.vectordatabase import VectorDatabase
from aimakerspace.openai_utils.embedding import EmbeddingModel
from openai import OpenAI
# Initialize OpenAI client
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
logger.info(f"Initialized OpenAI client with API key: {'valid key' if os.getenv('OPENAI_API_KEY') else 'API KEY MISSING!'}")
class RetrievalAugmentedQAPipeline:
def __init__(self, vector_db_retriever: VectorDatabase) -> None:
self.vector_db_retriever = vector_db_retriever
async def arun_pipeline(self, user_query: str):
"""
Run the RAG pipeline with the given user query.
Returns a stream of response chunks.
"""
try:
# 1. Retrieve relevant documents
logger.info(f"RAG Pipeline: Retrieving documents for query: '{user_query}'")
relevant_docs = self.vector_db_retriever.search_by_text(user_query, k=4)
if not relevant_docs:
logger.warning("No relevant documents found in vector database")
documents_context = "No relevant information found in the document."
else:
logger.info(f"Found {len(relevant_docs)} relevant document chunks")
# Format documents
documents_context = "\n\n".join([doc[0] for doc in relevant_docs])
# Debug similarity scores
doc_scores = [f"{i+1}. Score: {doc[1]:.4f}" for i, doc in enumerate(relevant_docs)]
logger.info(f"Document similarity scores: {', '.join(doc_scores) if doc_scores else 'No documents'}")
# 2. Create messaging payload
messages = [
{"role": "system", "content": f"""You are a helpful AI assistant that answers questions based on the provided document context.
If the answer is not in the context, say that you don't know based on the available information.
Use the following document extracts to answer the user's question:
{documents_context}"""},
{"role": "user", "content": user_query}
]
# 3. Call LLM and stream the output
async def generate_response():
try:
logger.info("Initiating streaming completion from OpenAI")
stream = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=messages,
temperature=0.2,
stream=True
)
for chunk in stream:
if chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
except Exception as e:
logger.error(f"Error generating stream: {str(e)}")
yield f"\n\nI apologize, but I encountered an error while generating a response: {str(e)}"
return {
"response": generate_response()
}
except Exception as e:
logger.error(f"Error in RAG pipeline: {str(e)}")
logger.error(traceback.format_exc())
return {
"response": (chunk for chunk in [f"I apologize, but an error occurred: {str(e)}"])
}
def process_file(file_path: str, file_name: str) -> List[str]:
"""Process an uploaded file and convert it to text chunks - optimized for speed"""
logger.info(f"Processing file: {file_name} at path: {file_path}")
try:
# Determine loader based on file extension
if file_name.lower().endswith('.txt'):
logger.info(f"Using TextFileLoader for {file_name}")
loader = TextFileLoader(file_path)
loader.load()
elif file_name.lower().endswith('.pdf'):
logger.info(f"Using PDFLoader for {file_name}")
loader = PDFLoader(file_path)
loader.load()
else:
logger.warning(f"Unsupported file type: {file_name}")
return ["Unsupported file format. Please upload a .txt or .pdf file."]
# Get documents from loader
documents = loader.documents
if documents and len(documents) > 0:
logger.info(f"Loaded document with {len(documents[0])} characters")
else:
logger.warning("No document content loaded")
return ["No content found in the document"]
# Split text into chunks - use parallel processing
logger.info("Splitting document with parallel processing")
chunk_size = 1500 # Increased from 1000 for fewer chunks
chunk_overlap = 150 # Increased from 100 for better context
# Use 8 workers for parallel processing
text_splitter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap, max_workers=8)
text_chunks = text_splitter.split_texts(documents)
# Limit chunks to avoid processing too many for speed
max_chunks = 40 # Reduced from default
if len(text_chunks) > max_chunks:
logger.warning(f"Too many chunks ({len(text_chunks)}), limiting to {max_chunks} for faster processing")
text_chunks = text_chunks[:max_chunks]
logger.info(f"Split document into {len(text_chunks)} chunks")
return text_chunks
except Exception as e:
logger.error(f"Error processing file: {str(e)}")
logger.error(traceback.format_exc())
return [f"Error processing file: {str(e)}"]
async def setup_vector_db(texts: List[str]) -> VectorDatabase:
"""Create vector database from text chunks - optimized with parallel processing"""
logger.info(f"Setting up vector database with {len(texts)} text chunks")
# Create embedding model to use with VectorDatabase
embedding_model = EmbeddingModel()
# Use batch size of 20 for better parallelization
vector_db = VectorDatabase(embedding_model=embedding_model, batch_size=20)
try:
# Limit number of chunks for faster processing
max_chunks = 40
if len(texts) > max_chunks:
logger.warning(f"Limiting {len(texts)} chunks to {max_chunks} for vector embedding")
texts = texts[:max_chunks]
# Build vector database with batch processing
logger.info("Building vector database with batch processing")
await vector_db.abuild_from_list(texts)
# Add documents property for compatibility
vector_db.documents = texts
logger.info(f"Vector database built with {len(texts)} documents")
return vector_db
except asyncio.TimeoutError:
logger.error(f"Vector database creation timed out after 300 seconds")
# Create minimal fallback DB with just a few documents
fallback_db = VectorDatabase(embedding_model=embedding_model)
if texts:
# Use just first few texts for minimal functionality
minimal_texts = texts[:3]
for text in minimal_texts:
fallback_db.insert(text, [0.0] * 1536) # Use zero vectors for speed
fallback_db.documents = minimal_texts
else:
error_text = "I'm sorry, but there was a timeout during document processing."
fallback_db.insert(error_text, [0.0] * 1536)
fallback_db.documents = [error_text]
return fallback_db
except Exception as e:
logger.error(f"Error setting up vector database: {str(e)}")
logger.error(traceback.format_exc())
# Create fallback DB for this error case
fallback_db = VectorDatabase(embedding_model=embedding_model)
error_text = "I'm sorry, but there was an error processing the document."
fallback_db.insert(error_text, [0.0] * 1536)
fallback_db.documents = [error_text]
return fallback_db |