File size: 8,603 Bytes
209e402
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba10a58
209e402
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba10a58
 
 
 
 
 
209e402
 
ba10a58
 
 
 
 
 
209e402
 
 
 
 
 
 
 
 
ba10a58
209e402
 
ba10a58
209e402
ba10a58
 
209e402
 
ba10a58
 
 
 
 
 
 
 
209e402
 
ba10a58
209e402
 
 
 
ba10a58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209e402
 
 
 
ba10a58
209e402
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import os
from typing import List, Dict, Any
import tempfile
import shutil
import logging
import time
import traceback
import asyncio

# Configure logging
logging.basicConfig(level=logging.INFO, 
                    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Make sure aimakerspace is in the path
import sys
sys.path.append(os.path.join(os.path.dirname(os.path.dirname(__file__)), ""))

# Import from local aimakerspace module
from aimakerspace.text_utils import CharacterTextSplitter, TextFileLoader, PDFLoader
from aimakerspace.vectordatabase import VectorDatabase
from aimakerspace.openai_utils.embedding import EmbeddingModel
from openai import OpenAI

# Initialize OpenAI client
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
logger.info(f"Initialized OpenAI client with API key: {'valid key' if os.getenv('OPENAI_API_KEY') else 'API KEY MISSING!'}")

class RetrievalAugmentedQAPipeline:
    def __init__(self, vector_db_retriever: VectorDatabase) -> None:
        self.vector_db_retriever = vector_db_retriever
    
    async def arun_pipeline(self, user_query: str):
        """
        Run the RAG pipeline with the given user query.
        Returns a stream of response chunks.
        """
        try:
            # 1. Retrieve relevant documents
            logger.info(f"RAG Pipeline: Retrieving documents for query: '{user_query}'")
            relevant_docs = self.vector_db_retriever.search_by_text(user_query, k=4)
            
            if not relevant_docs:
                logger.warning("No relevant documents found in vector database")
                documents_context = "No relevant information found in the document."
            else:
                logger.info(f"Found {len(relevant_docs)} relevant document chunks")
                # Format documents
                documents_context = "\n\n".join([doc[0] for doc in relevant_docs])
            
            # Debug similarity scores
            doc_scores = [f"{i+1}. Score: {doc[1]:.4f}" for i, doc in enumerate(relevant_docs)]
            logger.info(f"Document similarity scores: {', '.join(doc_scores) if doc_scores else 'No documents'}")
            
            # 2. Create messaging payload
            messages = [
                {"role": "system", "content": f"""You are a helpful AI assistant that answers questions based on the provided document context.
                If the answer is not in the context, say that you don't know based on the available information. 
                Use the following document extracts to answer the user's question:
                
                {documents_context}"""},
                {"role": "user", "content": user_query}
            ]
            
            # 3. Call LLM and stream the output
            async def generate_response():
                try:
                    logger.info("Initiating streaming completion from OpenAI")
                    stream = client.chat.completions.create(
                        model="gpt-3.5-turbo",
                        messages=messages,
                        temperature=0.2,
                        stream=True
                    )
                    
                    for chunk in stream:
                        if chunk.choices[0].delta.content:
                            yield chunk.choices[0].delta.content
                except Exception as e:
                    logger.error(f"Error generating stream: {str(e)}")
                    yield f"\n\nI apologize, but I encountered an error while generating a response: {str(e)}"
            
            return {
                "response": generate_response()
            }
        
        except Exception as e:
            logger.error(f"Error in RAG pipeline: {str(e)}")
            logger.error(traceback.format_exc())
            return {
                "response": (chunk for chunk in [f"I apologize, but an error occurred: {str(e)}"])
            }

def process_file(file_path: str, file_name: str) -> List[str]:
    """Process an uploaded file and convert it to text chunks - optimized for speed"""
    logger.info(f"Processing file: {file_name} at path: {file_path}")
    
    try:
        # Determine loader based on file extension
        if file_name.lower().endswith('.txt'):
            logger.info(f"Using TextFileLoader for {file_name}")
            loader = TextFileLoader(file_path)
            loader.load()
        elif file_name.lower().endswith('.pdf'):
            logger.info(f"Using PDFLoader for {file_name}")
            loader = PDFLoader(file_path)
            loader.load()
        else:
            logger.warning(f"Unsupported file type: {file_name}")
            return ["Unsupported file format. Please upload a .txt or .pdf file."]
        
        # Get documents from loader
        documents = loader.documents
        if documents and len(documents) > 0:
            logger.info(f"Loaded document with {len(documents[0])} characters")
        else:
            logger.warning("No document content loaded")
            return ["No content found in the document"]
        
        # Split text into chunks - use parallel processing
        logger.info("Splitting document with parallel processing")
        chunk_size = 1500  # Increased from 1000 for fewer chunks
        chunk_overlap = 150  # Increased from 100 for better context
        # Use 8 workers for parallel processing
        text_splitter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap, max_workers=8)
        text_chunks = text_splitter.split_texts(documents)
        
        # Limit chunks to avoid processing too many for speed
        max_chunks = 40  # Reduced from default
        if len(text_chunks) > max_chunks:
            logger.warning(f"Too many chunks ({len(text_chunks)}), limiting to {max_chunks} for faster processing")
            text_chunks = text_chunks[:max_chunks]
        
        logger.info(f"Split document into {len(text_chunks)} chunks")
        return text_chunks
        
    except Exception as e:
        logger.error(f"Error processing file: {str(e)}")
        logger.error(traceback.format_exc())
        return [f"Error processing file: {str(e)}"]

async def setup_vector_db(texts: List[str]) -> VectorDatabase:
    """Create vector database from text chunks - optimized with parallel processing"""
    logger.info(f"Setting up vector database with {len(texts)} text chunks")
    
    # Create embedding model to use with VectorDatabase
    embedding_model = EmbeddingModel()
    # Use batch size of 20 for better parallelization
    vector_db = VectorDatabase(embedding_model=embedding_model, batch_size=20)
    
    try:
        # Limit number of chunks for faster processing
        max_chunks = 40
        if len(texts) > max_chunks:
            logger.warning(f"Limiting {len(texts)} chunks to {max_chunks} for vector embedding")
            texts = texts[:max_chunks]
        
        # Build vector database with batch processing
        logger.info("Building vector database with batch processing")
        await vector_db.abuild_from_list(texts)
        
        # Add documents property for compatibility
        vector_db.documents = texts
        
        logger.info(f"Vector database built with {len(texts)} documents")
        return vector_db
    except asyncio.TimeoutError:
        logger.error(f"Vector database creation timed out after 300 seconds")
        # Create minimal fallback DB with just a few documents
        fallback_db = VectorDatabase(embedding_model=embedding_model)
        if texts:
            # Use just first few texts for minimal functionality
            minimal_texts = texts[:3] 
            for text in minimal_texts:
                fallback_db.insert(text, [0.0] * 1536)  # Use zero vectors for speed
            fallback_db.documents = minimal_texts
        else:
            error_text = "I'm sorry, but there was a timeout during document processing."
            fallback_db.insert(error_text, [0.0] * 1536)
            fallback_db.documents = [error_text]
        return fallback_db
    except Exception as e:
        logger.error(f"Error setting up vector database: {str(e)}")
        logger.error(traceback.format_exc())
        
        # Create fallback DB for this error case
        fallback_db = VectorDatabase(embedding_model=embedding_model)
        error_text = "I'm sorry, but there was an error processing the document."
        fallback_db.insert(error_text, [0.0] * 1536)
        fallback_db.documents = [error_text]
        return fallback_db