userisanillusion commited on
Commit
7c711a3
·
verified ·
1 Parent(s): 0a6fdfd

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +926 -0
app.py ADDED
@@ -0,0 +1,926 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import numpy as np
4
+ import gc
5
+ import torch
6
+ import time
7
+ import shutil
8
+ import hashlib
9
+ import pickle
10
+ import traceback
11
+ from typing import List, Dict, Any, Tuple, Optional, Union, Generator
12
+ from dataclasses import dataclass
13
+ import gradio as gr
14
+
15
+ # Install dependencies in the app.py file for Spaces
16
+ os.system("pip install -q pymupdf langchain langchain_community sentence-transformers faiss-cpu huggingface_hub")
17
+ os.system("pip install -q llama-cpp-python transformers rank_bm25 nltk")
18
+ os.system("pip install -q git+https://github.com/chroma-core/chroma.git")
19
+
20
+ # Import dependencies after installation
21
+ import fitz # PyMuPDF
22
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
23
+ from langchain_community.vectorstores import FAISS
24
+ from langchain_community.embeddings import HuggingFaceEmbeddings
25
+ from llama_cpp import Llama
26
+ from rank_bm25 import BM25Okapi
27
+ import nltk
28
+ from nltk.tokenize import word_tokenize
29
+ from nltk.corpus import stopwords
30
+ from huggingface_hub import hf_hub_download
31
+
32
+ # Setup directories for Spaces
33
+ os.makedirs("pdfs", exist_ok=True)
34
+ os.makedirs("models", exist_ok=True)
35
+ os.makedirs("pdf_cache", exist_ok=True)
36
+
37
+ # Download nltk resources
38
+ try:
39
+ nltk.download('punkt', quiet=True)
40
+ nltk.download('stopwords', quiet=True)
41
+ except:
42
+ print("Failed to download NLTK resources, continuing without them")
43
+
44
+ # Download model from Hugging Face Hub
45
+ model_path = hf_hub_download(
46
+ repo_id="TheBloke/phi-2-GGUF",
47
+ filename="phi-2.Q8_0.gguf",
48
+ repo_type="model",
49
+ local_dir="models"
50
+ )
51
+
52
+ # === MEMORY MANAGEMENT UTILITIES ===
53
+ def clear_memory():
54
+ """Clear memory to prevent OOM errors"""
55
+ gc.collect()
56
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
57
+
58
+ # === PDF PROCESSING ===
59
+ @dataclass
60
+ class PDFChunk:
61
+ """Class to represent a chunk of text extracted from a PDF"""
62
+ text: str
63
+ source: str
64
+ page_num: int
65
+ chunk_id: int
66
+
67
+ class PDFProcessor:
68
+ def __init__(self, pdf_dir: str = "pdfs"):
69
+ """Initialize PDF processor
70
+
71
+ Args:
72
+ pdf_dir: Directory containing PDF files
73
+ """
74
+ self.pdf_dir = pdf_dir
75
+ # Smaller chunk size with more overlap for better retrieval
76
+ self.text_splitter = RecursiveCharacterTextSplitter(
77
+ chunk_size=384,
78
+ chunk_overlap=288, # 75% overlap for better context preservation
79
+ length_function=len,
80
+ is_separator_regex=False,
81
+ )
82
+
83
+ # Create cache directory
84
+ self.cache_dir = os.path.join(os.getcwd(), "pdf_cache")
85
+ os.makedirs(self.cache_dir, exist_ok=True)
86
+
87
+ def list_pdfs(self) -> List[str]:
88
+ """List all PDF files in the directory"""
89
+ if not os.path.exists(self.pdf_dir):
90
+ return []
91
+ return [f for f in os.listdir(self.pdf_dir) if f.lower().endswith('.pdf')]
92
+
93
+ def _get_cache_path(self, pdf_path: str) -> str:
94
+ """Get the cache file path for a PDF"""
95
+ pdf_hash = hashlib.md5(open(pdf_path, 'rb').read(8192)).hexdigest()
96
+ return os.path.join(self.cache_dir, f"{os.path.basename(pdf_path)}_{pdf_hash}.pkl")
97
+
98
+ def _is_cached(self, pdf_path: str) -> bool:
99
+ """Check if a PDF is cached"""
100
+ cache_path = self._get_cache_path(pdf_path)
101
+ return os.path.exists(cache_path)
102
+
103
+ def _load_from_cache(self, pdf_path: str) -> List[PDFChunk]:
104
+ """Load chunks from cache"""
105
+ cache_path = self._get_cache_path(pdf_path)
106
+ try:
107
+ with open(cache_path, 'rb') as f:
108
+ return pickle.load(f)
109
+ except:
110
+ return None
111
+
112
+ def _save_to_cache(self, pdf_path: str, chunks: List[PDFChunk]) -> None:
113
+ """Save chunks to cache"""
114
+ cache_path = self._get_cache_path(pdf_path)
115
+ try:
116
+ with open(cache_path, 'wb') as f:
117
+ pickle.dump(chunks, f)
118
+ except Exception as e:
119
+ print(f"Warning: Failed to cache PDF {pdf_path}: {str(e)}")
120
+
121
+ def clean_text(self, text: str) -> str:
122
+ """Clean extracted text"""
123
+ # Remove excessive whitespace
124
+ text = re.sub(r'\s+', ' ', text).strip()
125
+ # Remove header/footer patterns (common in PDFs)
126
+ text = re.sub(r'(?<!\w)page \d+(?!\w)', '', text, flags=re.IGNORECASE)
127
+ return text
128
+
129
+ def extract_text_from_pdf(self, pdf_path: str) -> List[PDFChunk]:
130
+ """Extract text content from a PDF file with improved extraction
131
+
132
+ Args:
133
+ pdf_path: Path to the PDF file
134
+
135
+ Returns:
136
+ List of PDFChunk objects extracted from the PDF
137
+ """
138
+ # Check cache first
139
+ if self._is_cached(pdf_path):
140
+ cached_chunks = self._load_from_cache(pdf_path)
141
+ if cached_chunks:
142
+ print(f"Loaded {len(cached_chunks)} chunks from cache for {os.path.basename(pdf_path)}")
143
+ return cached_chunks
144
+
145
+ try:
146
+ doc = fitz.open(pdf_path)
147
+ pdf_chunks = []
148
+ pdf_name = os.path.basename(pdf_path)
149
+
150
+ for page_num in range(len(doc)):
151
+ page = doc.load_page(page_num)
152
+
153
+ # Extract text with more options for better quality
154
+ page_text = page.get_text("text", sort=True)
155
+ # Try to extract text with alternative layout analysis if the text is too short
156
+ if len(page_text) < 100:
157
+ try:
158
+ page_text = page.get_text("dict", sort=True)
159
+ # Convert dict to text
160
+ if isinstance(page_text, dict) and "blocks" in page_text:
161
+ extracted_text = ""
162
+ for block in page_text["blocks"]:
163
+ if "lines" in block:
164
+ for line in block["lines"]:
165
+ if "spans" in line:
166
+ for span in line["spans"]:
167
+ if "text" in span:
168
+ extracted_text += span["text"] + " "
169
+ page_text = extracted_text
170
+ except:
171
+ # Fallback to default extraction
172
+ page_text = page.get_text("text")
173
+
174
+ # Clean the text
175
+ page_text = self.clean_text(page_text)
176
+
177
+ # Extract tables
178
+ try:
179
+ tables = page.find_tables()
180
+ if tables and hasattr(tables, "tables"):
181
+ for table in tables.tables:
182
+ table_text = ""
183
+ for i, row in enumerate(table.rows):
184
+ row_cells = []
185
+ for cell in row.cells:
186
+ if hasattr(cell, "rect"):
187
+ cell_text = page.get_text("text", clip=cell.rect)
188
+ cell_text = self.clean_text(cell_text)
189
+ row_cells.append(cell_text)
190
+ if row_cells:
191
+ table_text += " | ".join(row_cells) + "\n"
192
+
193
+ # Add table text to page text
194
+ if table_text.strip():
195
+ page_text += "\n\nTABLE:\n" + table_text
196
+ except Exception as table_err:
197
+ print(f"Warning: Skipping table extraction for page {page_num}: {str(table_err)}")
198
+
199
+ # Split the page text into chunks
200
+ if page_text.strip():
201
+ page_chunks = self.text_splitter.split_text(page_text)
202
+
203
+ # Create PDFChunk objects
204
+ for i, chunk_text in enumerate(page_chunks):
205
+ pdf_chunks.append(PDFChunk(
206
+ text=chunk_text,
207
+ source=pdf_name,
208
+ page_num=page_num + 1, # 1-based page numbering for humans
209
+ chunk_id=i
210
+ ))
211
+
212
+ # Clear memory periodically
213
+ if page_num % 10 == 0:
214
+ clear_memory()
215
+
216
+ doc.close()
217
+
218
+ # Cache the results
219
+ self._save_to_cache(pdf_path, pdf_chunks)
220
+
221
+ return pdf_chunks
222
+ except Exception as e:
223
+ print(f"Error extracting text from {pdf_path}: {str(e)}")
224
+ return []
225
+
226
+ def process_pdf(self, pdf_name: str) -> List[PDFChunk]:
227
+ """Process a single PDF file and extract chunks
228
+
229
+ Args:
230
+ pdf_name: Name of the PDF file in the pdf_dir
231
+
232
+ Returns:
233
+ List of PDFChunk objects from the PDF
234
+ """
235
+ pdf_path = os.path.join(self.pdf_dir, pdf_name)
236
+ return self.extract_text_from_pdf(pdf_path)
237
+
238
+ def process_all_pdfs(self, batch_size: int = 2) -> List[PDFChunk]:
239
+ """Process all PDFs in batches to manage memory
240
+
241
+ Args:
242
+ batch_size: Number of PDFs to process in each batch
243
+
244
+ Returns:
245
+ List of all PDFChunk objects from all PDFs
246
+ """
247
+ all_chunks = []
248
+ pdf_files = self.list_pdfs()
249
+
250
+ if not pdf_files:
251
+ print("No PDF files found in the directory.")
252
+ return []
253
+
254
+ # Process PDFs in batches
255
+ for i in range(0, len(pdf_files), batch_size):
256
+ batch = pdf_files[i:i+batch_size]
257
+ print(f"Processing batch {i//batch_size + 1}/{(len(pdf_files)-1)//batch_size + 1}")
258
+
259
+ for pdf_name in batch:
260
+ print(f"Processing {pdf_name}")
261
+ chunks = self.process_pdf(pdf_name)
262
+ all_chunks.extend(chunks)
263
+ print(f"Extracted {len(chunks)} chunks from {pdf_name}")
264
+
265
+ # Clear memory after each batch
266
+ clear_memory()
267
+
268
+ return all_chunks
269
+
270
+ # === VECTOR DATABASE SETUP ===
271
+ class VectorDBManager:
272
+ def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
273
+ """Initialize vector database manager
274
+
275
+ Args:
276
+ model_name: Name of the embedding model
277
+ """
278
+ # Initialize embedding model with normalization
279
+ try:
280
+ self.embedding_model = HuggingFaceEmbeddings(
281
+ model_name=model_name,
282
+ model_kwargs={"device": "cpu"},
283
+ encode_kwargs={"normalize_embeddings": True}
284
+ )
285
+ except Exception as e:
286
+ print(f"Error initializing embedding model {model_name}: {str(e)}")
287
+ print("Falling back to all-MiniLM-L6-v2 model")
288
+ self.embedding_model = HuggingFaceEmbeddings(
289
+ model_name="sentence-transformers/all-MiniLM-L6-v2",
290
+ model_kwargs={"device": "cpu"},
291
+ encode_kwargs={"normalize_embeddings": True}
292
+ )
293
+
294
+ self.vectordb = None
295
+ # BM25 index for hybrid search
296
+ self.bm25_index = None
297
+ self.chunks = []
298
+ self.tokenized_chunks = []
299
+
300
+ def _prepare_bm25(self, chunks: List[PDFChunk]):
301
+ """Prepare BM25 index for hybrid search"""
302
+ # Tokenize chunks for BM25
303
+ try:
304
+ tokenized_chunks = []
305
+ for chunk in chunks:
306
+ # Tokenize and remove stopwords
307
+ tokens = word_tokenize(chunk.text.lower())
308
+ stop_words = set(stopwords.words('english'))
309
+ filtered_tokens = [w for w in tokens if w.isalnum() and w not in stop_words]
310
+ tokenized_chunks.append(filtered_tokens)
311
+
312
+ # Create BM25 index
313
+ self.bm25_index = BM25Okapi(tokenized_chunks)
314
+ except Exception as e:
315
+ print(f"Error creating BM25 index: {str(e)}")
316
+ print(traceback.format_exc())
317
+ self.bm25_index = None
318
+
319
+ def create_vector_db(self, chunks: List[PDFChunk]) -> None:
320
+ """Create vector database from text chunks
321
+
322
+ Args:
323
+ chunks: List of PDFChunk objects
324
+ """
325
+ try:
326
+ if not chunks or len(chunks) == 0:
327
+ print("ERROR: No chunks provided to create vector database")
328
+ return
329
+
330
+ print(f"Creating vector DB with {len(chunks)} chunks")
331
+
332
+ # Store chunks for hybrid search
333
+ self.chunks = chunks
334
+
335
+ # Prepare data for vector DB
336
+ chunk_texts = [chunk.text for chunk in chunks]
337
+
338
+ # Create BM25 index for hybrid search
339
+ print("Creating BM25 index for hybrid search")
340
+ self._prepare_bm25(chunks)
341
+
342
+ # Process in smaller batches to manage memory
343
+ batch_size = 16 # Reduced for Spaces
344
+ all_embeddings = []
345
+
346
+ for i in range(0, len(chunk_texts), batch_size):
347
+ batch = chunk_texts[i:i+batch_size]
348
+ print(f"Embedding batch {i//batch_size + 1}/{(len(chunk_texts)-1)//batch_size + 1}")
349
+
350
+ # Generate embeddings for the batch
351
+ batch_embeddings = self.embedding_model.embed_documents(batch)
352
+ all_embeddings.extend(batch_embeddings)
353
+
354
+ # Clear memory after each batch
355
+ clear_memory()
356
+
357
+ # Create FAISS index
358
+ print(f"Creating FAISS index with {len(all_embeddings)} embeddings")
359
+ self.vectordb = FAISS.from_embeddings(
360
+ text_embeddings=list(zip(chunk_texts, all_embeddings)),
361
+ embedding=self.embedding_model
362
+ )
363
+
364
+ print(f"Vector database created with {len(chunks)} documents")
365
+
366
+ except Exception as e:
367
+ print(f"Error creating vector database: {str(e)}")
368
+ print(traceback.format_exc())
369
+ raise
370
+
371
+ def _format_chunk_with_metadata(self, chunk: PDFChunk) -> str:
372
+ """Format a chunk with its metadata for better context"""
373
+ return f"Source: {chunk.source} | Page: {chunk.page_num}\n\n{chunk.text}"
374
+
375
+ def hybrid_search(self, query: str, k: int = 5, alpha: float = 0.7) -> List[str]:
376
+ """Hybrid search combining vector search and BM25
377
+
378
+ Args:
379
+ query: Query text
380
+ k: Number of results to return
381
+ alpha: Weight for vector search (1-alpha for BM25)
382
+
383
+ Returns:
384
+ List of formatted documents
385
+ """
386
+ if self.vectordb is None:
387
+ print("Vector database not initialized")
388
+ return []
389
+
390
+ try:
391
+ # Get vector search results
392
+ vector_results = self.vectordb.similarity_search(query, k=k*2)
393
+ vector_texts = [doc.page_content for doc in vector_results]
394
+
395
+ final_results = []
396
+
397
+ # Combine with BM25 if available
398
+ if self.bm25_index is not None:
399
+ try:
400
+ # Tokenize query for BM25
401
+ query_tokens = word_tokenize(query.lower())
402
+ stop_words = set(stopwords.words('english'))
403
+ filtered_query = [w for w in query_tokens if w.isalnum() and w not in stop_words]
404
+
405
+ # Get BM25 scores
406
+ bm25_scores = self.bm25_index.get_scores(filtered_query)
407
+
408
+ # Combine scores (normalized)
409
+ combined_results = []
410
+ seen_texts = set()
411
+
412
+ # First add vector results with their positions as scores
413
+ for i, text in enumerate(vector_texts):
414
+ if text not in seen_texts:
415
+ seen_texts.add(text)
416
+ # Find corresponding chunk
417
+ for j, chunk in enumerate(self.chunks):
418
+ if chunk.text == text:
419
+ # Combine scores: alpha * vector_score + (1-alpha) * bm25_score
420
+ # For vector, use inverse of position as score (normalized)
421
+ vector_score = 1.0 - (i / len(vector_texts))
422
+ # Normalize BM25 score
423
+ bm25_score = bm25_scores[j] / max(bm25_scores) if max(bm25_scores) > 0 else 0
424
+ combined_score = alpha * vector_score + (1-alpha) * bm25_score
425
+
426
+ combined_results.append((chunk, combined_score))
427
+ break
428
+
429
+ # Sort by combined score
430
+ combined_results.sort(key=lambda x: x[1], reverse=True)
431
+
432
+ # Get top k results
433
+ top_chunks = [item[0] for item in combined_results[:k]]
434
+
435
+ # Format results with metadata
436
+ final_results = [self._format_chunk_with_metadata(chunk) for chunk in top_chunks]
437
+ except Exception as e:
438
+ print(f"Error in BM25 scoring: {str(e)}")
439
+ # Fallback to vector search results
440
+ final_results = vector_texts[:k]
441
+ else:
442
+ # Just use vector search results if BM25 is not available
443
+ final_results = vector_texts[:k]
444
+
445
+ return final_results
446
+ except Exception as e:
447
+ print(f"Error during hybrid search: {str(e)}")
448
+ return []
449
+
450
+ # === QUERY EXPANSION ===
451
+ class QueryExpander:
452
+ def __init__(self, llm_model):
453
+ """Initialize query expander
454
+
455
+ Args:
456
+ llm_model: LLM model for query expansion
457
+ """
458
+ self.llm = llm_model
459
+
460
+ def expand_query(self, query: str) -> str:
461
+ """Expand the query using the LLM to improve retrieval
462
+
463
+ Args:
464
+ query: Original query
465
+
466
+ Returns:
467
+ Expanded query
468
+ """
469
+ try:
470
+ prompt = f"""I need to search for documents related to this question: "{query}"
471
+
472
+ Please help me expand this query by identifying key concepts, synonyms, and related terms that might be used in the documents.
473
+ Return only the expanded search query, without any explanations or additional text.
474
+
475
+ Expanded query:"""
476
+
477
+ expanded = self.llm.generate(prompt, max_tokens=100, temperature=0.3)
478
+
479
+ # Combine original and expanded
480
+ combined = f"{query} {expanded}"
481
+
482
+ # Limit length
483
+ if len(combined) > 300:
484
+ combined = combined[:300]
485
+
486
+ return combined
487
+ except:
488
+ # Return original query if expansion fails
489
+ return query
490
+
491
+ # === LLM SETUP ===
492
+ class Phi2Model:
493
+ def __init__(self, model_path: str = model_path):
494
+ """Initialize Phi-2 model
495
+
496
+ Args:
497
+ model_path: Path to the model file
498
+ """
499
+ try:
500
+ # Initialize Phi-2 with llama.cpp - optimized for Spaces
501
+ self.llm = Llama(
502
+ model_path=model_path,
503
+ n_ctx=1024, # Reduced context window for Spaces
504
+ n_batch=64, # Reduced batch size
505
+ n_gpu_layers=0, # Run on CPU for compatibility
506
+ verbose=False
507
+ )
508
+ except Exception as e:
509
+ print(f"Error initializing Phi-2 model: {str(e)}")
510
+ raise
511
+
512
+ def generate(self, prompt: str,
513
+ max_tokens: int = 512,
514
+ temperature: float = 0.7,
515
+ top_p: float = 0.9,
516
+ stream: bool = False) -> Union[str, Generator[str, None, None]]:
517
+ """Generate text using Phi-2
518
+
519
+ Args:
520
+ prompt: Input prompt
521
+ max_tokens: Maximum number of tokens to generate
522
+ temperature: Sampling temperature
523
+ top_p: Top-p sampling parameter
524
+ stream: Whether to stream the output
525
+
526
+ Returns:
527
+ Generated text or generator if streaming
528
+ """
529
+ try:
530
+ if stream:
531
+ return self._generate_stream(prompt, max_tokens, temperature, top_p)
532
+ else:
533
+ output = self.llm(
534
+ prompt,
535
+ max_tokens=max_tokens,
536
+ temperature=temperature,
537
+ top_p=top_p,
538
+ echo=False
539
+ )
540
+ return output["choices"][0]["text"]
541
+ except Exception as e:
542
+ print(f"Error generating text: {str(e)}")
543
+ return "Error: Could not generate response."
544
+
545
+ def _generate_stream(self, prompt: str,
546
+ max_tokens: int = 512,
547
+ temperature: float = 0.7,
548
+ top_p: float = 0.9) -> Generator[str, None, None]:
549
+ """Stream text generation using Phi-2
550
+
551
+ Args:
552
+ prompt: Input prompt
553
+ max_tokens: Maximum number of tokens to generate
554
+ temperature: Sampling temperature
555
+ top_p: Top-p sampling parameter
556
+
557
+ Yields:
558
+ Generated text tokens
559
+ """
560
+ response = ""
561
+ for output in self.llm(
562
+ prompt,
563
+ max_tokens=max_tokens,
564
+ temperature=temperature,
565
+ top_p=top_p,
566
+ echo=False,
567
+ stream=True
568
+ ):
569
+ token = output["choices"][0]["text"]
570
+ response += token
571
+ yield response
572
+
573
+ # === RAG SYSTEM ===
574
+ class RAGSystem:
575
+ def __init__(self, pdf_processor: PDFProcessor,
576
+ vector_db: VectorDBManager,
577
+ model: Phi2Model):
578
+ """Initialize RAG system
579
+
580
+ Args:
581
+ pdf_processor: PDF processor instance
582
+ vector_db: Vector database manager instance
583
+ model: LLM model instance
584
+ """
585
+ self.pdf_processor = pdf_processor
586
+ self.vector_db = vector_db
587
+ self.model = model
588
+ self.query_expander = QueryExpander(model)
589
+ self.is_initialized = False
590
+
591
+ def process_documents(self) -> bool:
592
+ """Process all documents and create vector database
593
+
594
+ Returns:
595
+ True if successful, False otherwise
596
+ """
597
+ try:
598
+ # Process PDFs
599
+ chunks = self.pdf_processor.process_all_pdfs()
600
+ if not chunks:
601
+ print("No chunks were extracted from PDFs")
602
+ return False
603
+
604
+ print(f"Total chunks extracted: {len(chunks)}")
605
+
606
+ # Create vector database
607
+ print("Creating vector database...")
608
+ self.vector_db.create_vector_db(chunks)
609
+
610
+ # Verify success
611
+ if self.vector_db.vectordb is None:
612
+ print("Failed to create vector database")
613
+ return False
614
+
615
+ # Set initialization flag
616
+ self.is_initialized = True
617
+ return True
618
+
619
+ except Exception as e:
620
+ print(f"Error processing documents: {str(e)}")
621
+ print(traceback.format_exc())
622
+ return False
623
+
624
+ def generate_prompt(self, query: str, contexts: List[str]) -> str:
625
+ """Generate prompt for the LLM with better instructions
626
+
627
+ Args:
628
+ query: User query
629
+ contexts: Retrieved contexts
630
+
631
+ Returns:
632
+ Formatted prompt
633
+ """
634
+ # Format contexts with numbering for better reference
635
+ formatted_contexts = ""
636
+ for i, context in enumerate(contexts):
637
+ formatted_contexts += f"[CONTEXT {i+1}]\n{context}\n\n"
638
+
639
+ # Create prompt with better instructions
640
+ prompt = f"""You are an AI assistant that answers questions based on the provided context information.
641
+
642
+ User Query: {query}
643
+
644
+ Below are relevant passages from documents that might help answer the query:
645
+
646
+ {formatted_contexts}
647
+
648
+ Using ONLY the information provided in the context above, provide a comprehensive answer to the user's query.
649
+ If the provided context doesn't contain relevant information to answer the query, clearly state: "I don't have enough information in the provided context to answer this question."
650
+
651
+ Do not use any prior knowledge that is not contained in the provided context.
652
+ If quoting from the context, mention the source document and page number.
653
+ Organize your answer in a clear, coherent manner.
654
+
655
+ Answer:"""
656
+ return prompt
657
+
658
+ def answer_query(self, query: str, k: int = 5, max_tokens: int = 512,
659
+ temperature: float = 0.7, stream: bool = False) -> Union[str, Generator[str, None, None]]:
660
+ """Answer a query using RAG with query expansion
661
+
662
+ Args:
663
+ query: User query
664
+ k: Number of contexts to retrieve
665
+ max_tokens: Maximum number of tokens to generate
666
+ temperature: Temperature for generation
667
+ stream: Whether to stream the output
668
+
669
+ Returns:
670
+ Answer text or generator if streaming
671
+ """
672
+ # Check if system is initialized
673
+ if not self.is_initialized or self.vector_db.vectordb is None:
674
+ return "Error: Documents have not been processed yet. Please process documents first."
675
+
676
+ try:
677
+ # Expand query for better retrieval
678
+ expanded_query = self.query_expander.expand_query(query)
679
+ print(f"Expanded query: {expanded_query}")
680
+
681
+ # Retrieve relevant contexts using hybrid search
682
+ contexts = self.vector_db.hybrid_search(expanded_query, k=k)
683
+
684
+ if not contexts:
685
+ return "No relevant information found in the documents. Please try a different query or check if documents were processed correctly."
686
+
687
+ # Generate prompt with improved instructions
688
+ prompt = self.generate_prompt(query, contexts)
689
+
690
+ # Generate answer
691
+ return self.model.generate(
692
+ prompt,
693
+ max_tokens=max_tokens,
694
+ temperature=temperature,
695
+ stream=stream
696
+ )
697
+ except Exception as e:
698
+ print(f"Error answering query: {str(e)}")
699
+ print(traceback.format_exc())
700
+ return f"Error processing your query: {str(e)}"
701
+
702
+ # === GRADIO INTERFACE ===
703
+ class RAGInterface:
704
+ def __init__(self, rag_system: RAGSystem):
705
+ """Initialize Gradio interface
706
+
707
+ Args:
708
+ rag_system: RAG system instance
709
+ """
710
+ self.rag_system = rag_system
711
+ self.interface = None
712
+ self.is_processing = False
713
+
714
+ def upload_file(self, files):
715
+ """Upload PDF files"""
716
+ try:
717
+ os.makedirs("pdfs", exist_ok=True)
718
+ uploaded_files = []
719
+
720
+ for file in files:
721
+ destination = os.path.join("pdfs", os.path.basename(file.name))
722
+ shutil.copy(file.name, destination)
723
+ uploaded_files.append(os.path.basename(file.name))
724
+
725
+ # Verify files exist in the directory
726
+ pdf_files = [f for f in os.listdir("pdfs") if f.lower().endswith('.pdf')]
727
+
728
+ if not pdf_files:
729
+ return "No PDF files were uploaded successfully."
730
+
731
+ return f"Successfully uploaded {len(uploaded_files)} files: {', '.join(uploaded_files)}"
732
+ except Exception as e:
733
+ return f"Error uploading files: {str(e)}"
734
+
735
+ def process_documents(self):
736
+ """Process all documents
737
+
738
+ Returns:
739
+ Status message
740
+ """
741
+ if self.is_processing:
742
+ return "Document processing is already in progress. Please wait."
743
+
744
+ try:
745
+ self.is_processing = True
746
+ start_time = time.time()
747
+
748
+ success = self.rag_system.process_documents()
749
+
750
+ elapsed = time.time() - start_time
751
+ self.is_processing = False
752
+
753
+ if success:
754
+ return f"Documents processed successfully in {elapsed:.2f} seconds."
755
+ else:
756
+ return "Failed to process documents. Check the logs for more information."
757
+ except Exception as e:
758
+ self.is_processing = False
759
+ return f"Error processing documents: {str(e)}"
760
+
761
+ def answer_query(self, query, k, max_tokens, temperature):
762
+ """Answer a query
763
+
764
+ Args:
765
+ query: User query
766
+ k: Number of contexts to retrieve
767
+ max_tokens: Maximum number of tokens to generate
768
+ temperature: Sampling temperature
769
+
770
+ Returns:
771
+ Answer
772
+ """
773
+ if not query.strip():
774
+ return "Please enter a question."
775
+
776
+ try:
777
+ return self.rag_system.answer_query(
778
+ query,
779
+ k=k,
780
+ max_tokens=max_tokens,
781
+ temperature=temperature,
782
+ stream=False
783
+ )
784
+ except Exception as e:
785
+ return f"Error answering query: {str(e)}"
786
+
787
+ def answer_query_stream(self, query, k, max_tokens, temperature):
788
+ """Stream answer to a query
789
+
790
+ Args:
791
+ query: User query
792
+ k: Number of contexts to retrieve
793
+ max_tokens: Maximum number of tokens to generate
794
+ temperature: Sampling temperature
795
+
796
+ Yields:
797
+ Generated text
798
+ """
799
+ if not query.strip():
800
+ yield "Please enter a question."
801
+ return
802
+
803
+ try:
804
+ yield from self.rag_system.answer_query(
805
+ query,
806
+ k=k,
807
+ max_tokens=max_tokens,
808
+ temperature=temperature,
809
+ stream=True
810
+ )
811
+ except Exception as e:
812
+ yield f"Error answering query: {str(e)}"
813
+
814
+ def create_interface(self):
815
+ """Create Gradio interface"""
816
+ with gr.Blocks(title="PDF RAG System") as interface:
817
+ gr.Markdown("# PDF RAG System with Phi-2")
818
+ gr.Markdown("Upload your PDF documents, process them, and ask questions to get answers based on the content.")
819
+
820
+ with gr.Tab("Upload & Process"):
821
+ with gr.Row():
822
+ pdf_files = gr.File(
823
+ file_count="multiple",
824
+ label="Upload PDF Files",
825
+ file_types=[".pdf"]
826
+ )
827
+ upload_button = gr.Button("Upload", variant="primary")
828
+
829
+ upload_output = gr.Textbox(label="Upload Status", lines=2)
830
+ upload_button.click(self.upload_file, inputs=[pdf_files], outputs=upload_output)
831
+
832
+ process_button = gr.Button("Process Documents", variant="primary")
833
+ process_output = gr.Textbox(label="Processing Status", lines=2)
834
+ process_button.click(self.process_documents, inputs=[], outputs=process_output)
835
+
836
+ with gr.Tab("Query"):
837
+ with gr.Row():
838
+ with gr.Column():
839
+ query_input = gr.Textbox(
840
+ label="Question",
841
+ lines=3,
842
+ placeholder="Ask a question about your documents..."
843
+ )
844
+ with gr.Row():
845
+ k_slider = gr.Slider(
846
+ minimum=1,
847
+ maximum=10,
848
+ value=3,
849
+ step=1,
850
+ label="Number of Contexts"
851
+ )
852
+ max_tokens_slider = gr.Slider(
853
+ minimum=100,
854
+ maximum=800,
855
+ value=400,
856
+ step=50,
857
+ label="Max Tokens"
858
+ )
859
+ temperature_slider = gr.Slider(
860
+ minimum=0.1,
861
+ maximum=1.0,value=0.7,
862
+ step=0.1,
863
+ label="Temperature"
864
+ )
865
+ submit_button = gr.Button("Submit", variant="primary")
866
+
867
+ answer_output = gr.Textbox(label="Answer", lines=10)
868
+
869
+ submit_button.click(
870
+ self.answer_query,
871
+ inputs=[query_input, k_slider, max_tokens_slider, temperature_slider],
872
+ outputs=answer_output
873
+ )
874
+
875
+ # Add streaming capability
876
+ stream_button = gr.Button("Submit (Streaming)", variant="secondary")
877
+ stream_button.click(
878
+ self.answer_query_stream,
879
+ inputs=[query_input, k_slider, max_tokens_slider, temperature_slider],
880
+ outputs=answer_output
881
+ )
882
+
883
+ gr.Markdown("""
884
+ ## Instructions
885
+ 1. Upload PDF files in the 'Upload & Process' tab.
886
+ 2. Click the 'Process Documents' button to extract and index content.
887
+ 3. Switch to the 'Query' tab to ask questions about your documents.
888
+ 4. Adjust parameters as needed:
889
+ - Number of Contexts: More contexts provide more information but may be less focused.
890
+ - Max Tokens: Controls the length of the response.
891
+ - Temperature: Lower values (0.1-0.5) give more focused answers, higher values (0.6-1.0) give more creative answers.
892
+ """)
893
+
894
+ self.interface = interface
895
+ return interface
896
+
897
+ def launch(self, **kwargs):
898
+ """Launch the Gradio interface"""
899
+ if self.interface is None:
900
+ self.create_interface()
901
+ self.interface.launch(**kwargs)
902
+
903
+ # === MAIN APPLICATION ===
904
+ def main():
905
+ """Main function to set up and launch the application"""
906
+ try:
907
+ # Initialize components
908
+ pdf_processor = PDFProcessor(pdf_dir="pdfs")
909
+ vector_db = VectorDBManager()
910
+ phi2_model = Phi2Model()
911
+
912
+ # Initialize RAG system
913
+ rag_system = RAGSystem(pdf_processor, vector_db, phi2_model)
914
+
915
+ # Create interface
916
+ interface = RAGInterface(rag_system)
917
+
918
+ # Launch application
919
+ interface.launch(share=True)
920
+
921
+ except Exception as e:
922
+ print(f"Error initializing application: {str(e)}")
923
+ print(traceback.format_exc())
924
+
925
+ if __name__ == "__main__":
926
+ main()