Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
import re
|
|
|
3 |
import numpy as np
|
4 |
import gc
|
5 |
import torch
|
@@ -10,31 +17,16 @@ import pickle
|
|
10 |
import traceback
|
11 |
from typing import List, Dict, Any, Tuple, Optional, Union, Generator
|
12 |
from dataclasses import dataclass
|
13 |
-
import gradio as gr
|
14 |
-
|
15 |
-
# Import dependencies (no need for pip install commands)
|
16 |
-
import fitz # PyMuPDF
|
17 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
18 |
from langchain_community.vectorstores import FAISS
|
19 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
20 |
from llama_cpp import Llama
|
|
|
21 |
from rank_bm25 import BM25Okapi
|
22 |
import nltk
|
23 |
from nltk.tokenize import word_tokenize
|
24 |
from nltk.corpus import stopwords
|
25 |
-
from
|
26 |
-
|
27 |
-
# Download nltk resources
|
28 |
-
try:
|
29 |
-
nltk.download('punkt', quiet=True)
|
30 |
-
nltk.download('stopwords', quiet=True)
|
31 |
-
except:
|
32 |
-
print("Failed to download NLTK resources, continuing without them")
|
33 |
-
|
34 |
-
# Setup directories for Spaces
|
35 |
-
os.makedirs("pdfs", exist_ok=True)
|
36 |
-
os.makedirs("models", exist_ok=True)
|
37 |
-
os.makedirs("pdf_cache", exist_ok=True)
|
38 |
|
39 |
# Download nltk resources
|
40 |
try:
|
@@ -43,14 +35,6 @@ try:
|
|
43 |
except:
|
44 |
print("Failed to download NLTK resources, continuing without them")
|
45 |
|
46 |
-
# Download model from Hugging Face Hub
|
47 |
-
model_path = hf_hub_download(
|
48 |
-
repo_id="TheBloke/phi-2-GGUF",
|
49 |
-
filename="phi-2.Q8_0.gguf",
|
50 |
-
repo_type="model",
|
51 |
-
local_dir="models"
|
52 |
-
)
|
53 |
-
|
54 |
# === MEMORY MANAGEMENT UTILITIES ===
|
55 |
def clear_memory():
|
56 |
"""Clear memory to prevent OOM errors"""
|
@@ -81,7 +65,7 @@ class PDFProcessor:
|
|
81 |
length_function=len,
|
82 |
is_separator_regex=False,
|
83 |
)
|
84 |
-
|
85 |
# Create cache directory
|
86 |
self.cache_dir = os.path.join(os.getcwd(), "pdf_cache")
|
87 |
os.makedirs(self.cache_dir, exist_ok=True)
|
@@ -96,12 +80,12 @@ class PDFProcessor:
|
|
96 |
"""Get the cache file path for a PDF"""
|
97 |
pdf_hash = hashlib.md5(open(pdf_path, 'rb').read(8192)).hexdigest()
|
98 |
return os.path.join(self.cache_dir, f"{os.path.basename(pdf_path)}_{pdf_hash}.pkl")
|
99 |
-
|
100 |
def _is_cached(self, pdf_path: str) -> bool:
|
101 |
"""Check if a PDF is cached"""
|
102 |
cache_path = self._get_cache_path(pdf_path)
|
103 |
return os.path.exists(cache_path)
|
104 |
-
|
105 |
def _load_from_cache(self, pdf_path: str) -> List[PDFChunk]:
|
106 |
"""Load chunks from cache"""
|
107 |
cache_path = self._get_cache_path(pdf_path)
|
@@ -110,7 +94,7 @@ class PDFProcessor:
|
|
110 |
return pickle.load(f)
|
111 |
except:
|
112 |
return None
|
113 |
-
|
114 |
def _save_to_cache(self, pdf_path: str, chunks: List[PDFChunk]) -> None:
|
115 |
"""Save chunks to cache"""
|
116 |
cache_path = self._get_cache_path(pdf_path)
|
@@ -143,15 +127,15 @@ class PDFProcessor:
|
|
143 |
if cached_chunks:
|
144 |
print(f"Loaded {len(cached_chunks)} chunks from cache for {os.path.basename(pdf_path)}")
|
145 |
return cached_chunks
|
146 |
-
|
147 |
try:
|
148 |
doc = fitz.open(pdf_path)
|
149 |
pdf_chunks = []
|
150 |
pdf_name = os.path.basename(pdf_path)
|
151 |
-
|
152 |
for page_num in range(len(doc)):
|
153 |
page = doc.load_page(page_num)
|
154 |
-
|
155 |
# Extract text with more options for better quality
|
156 |
page_text = page.get_text("text", sort=True)
|
157 |
# Try to extract text with alternative layout analysis if the text is too short
|
@@ -172,10 +156,10 @@ class PDFProcessor:
|
|
172 |
except:
|
173 |
# Fallback to default extraction
|
174 |
page_text = page.get_text("text")
|
175 |
-
|
176 |
# Clean the text
|
177 |
page_text = self.clean_text(page_text)
|
178 |
-
|
179 |
# Extract tables
|
180 |
try:
|
181 |
tables = page.find_tables()
|
@@ -191,17 +175,17 @@ class PDFProcessor:
|
|
191 |
row_cells.append(cell_text)
|
192 |
if row_cells:
|
193 |
table_text += " | ".join(row_cells) + "\n"
|
194 |
-
|
195 |
# Add table text to page text
|
196 |
if table_text.strip():
|
197 |
page_text += "\n\nTABLE:\n" + table_text
|
198 |
except Exception as table_err:
|
199 |
print(f"Warning: Skipping table extraction for page {page_num}: {str(table_err)}")
|
200 |
-
|
201 |
# Split the page text into chunks
|
202 |
if page_text.strip():
|
203 |
page_chunks = self.text_splitter.split_text(page_text)
|
204 |
-
|
205 |
# Create PDFChunk objects
|
206 |
for i, chunk_text in enumerate(page_chunks):
|
207 |
pdf_chunks.append(PDFChunk(
|
@@ -210,16 +194,16 @@ class PDFProcessor:
|
|
210 |
page_num=page_num + 1, # 1-based page numbering for humans
|
211 |
chunk_id=i
|
212 |
))
|
213 |
-
|
214 |
# Clear memory periodically
|
215 |
if page_num % 10 == 0:
|
216 |
clear_memory()
|
217 |
-
|
218 |
doc.close()
|
219 |
-
|
220 |
# Cache the results
|
221 |
self._save_to_cache(pdf_path, pdf_chunks)
|
222 |
-
|
223 |
return pdf_chunks
|
224 |
except Exception as e:
|
225 |
print(f"Error extracting text from {pdf_path}: {str(e)}")
|
@@ -237,7 +221,7 @@ class PDFProcessor:
|
|
237 |
pdf_path = os.path.join(self.pdf_dir, pdf_name)
|
238 |
return self.extract_text_from_pdf(pdf_path)
|
239 |
|
240 |
-
def process_all_pdfs(self, batch_size: int =
|
241 |
"""Process all PDFs in batches to manage memory
|
242 |
|
243 |
Args:
|
@@ -292,7 +276,14 @@ class VectorDBManager:
|
|
292 |
model_kwargs={"device": "cpu"},
|
293 |
encode_kwargs={"normalize_embeddings": True}
|
294 |
)
|
295 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
296 |
self.vectordb = None
|
297 |
# BM25 index for hybrid search
|
298 |
self.bm25_index = None
|
@@ -310,14 +301,15 @@ class VectorDBManager:
|
|
310 |
stop_words = set(stopwords.words('english'))
|
311 |
filtered_tokens = [w for w in tokens if w.isalnum() and w not in stop_words]
|
312 |
tokenized_chunks.append(filtered_tokens)
|
313 |
-
|
314 |
# Create BM25 index
|
315 |
self.bm25_index = BM25Okapi(tokenized_chunks)
|
|
|
316 |
except Exception as e:
|
317 |
print(f"Error creating BM25 index: {str(e)}")
|
318 |
print(traceback.format_exc())
|
319 |
self.bm25_index = None
|
320 |
-
|
321 |
def create_vector_db(self, chunks: List[PDFChunk]) -> None:
|
322 |
"""Create vector database from text chunks
|
323 |
|
@@ -330,19 +322,19 @@ class VectorDBManager:
|
|
330 |
return
|
331 |
|
332 |
print(f"Creating vector DB with {len(chunks)} chunks")
|
333 |
-
|
334 |
# Store chunks for hybrid search
|
335 |
self.chunks = chunks
|
336 |
-
|
337 |
# Prepare data for vector DB
|
338 |
chunk_texts = [chunk.text for chunk in chunks]
|
339 |
-
|
340 |
# Create BM25 index for hybrid search
|
341 |
print("Creating BM25 index for hybrid search")
|
342 |
self._prepare_bm25(chunks)
|
343 |
-
|
344 |
# Process in smaller batches to manage memory
|
345 |
-
batch_size =
|
346 |
all_embeddings = []
|
347 |
|
348 |
for i in range(0, len(chunk_texts), batch_size):
|
@@ -362,7 +354,7 @@ class VectorDBManager:
|
|
362 |
text_embeddings=list(zip(chunk_texts, all_embeddings)),
|
363 |
embedding=self.embedding_model
|
364 |
)
|
365 |
-
|
366 |
print(f"Vector database created with {len(chunks)} documents")
|
367 |
|
368 |
except Exception as e:
|
@@ -374,8 +366,51 @@ class VectorDBManager:
|
|
374 |
"""Format a chunk with its metadata for better context"""
|
375 |
return f"Source: {chunk.source} | Page: {chunk.page_num}\n\n{chunk.text}"
|
376 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
377 |
def hybrid_search(self, query: str, k: int = 5, alpha: float = 0.7) -> List[str]:
|
378 |
-
"""Hybrid search combining vector search and BM25
|
379 |
|
380 |
Args:
|
381 |
query: Query text
|
@@ -391,11 +426,11 @@ class VectorDBManager:
|
|
391 |
|
392 |
try:
|
393 |
# Get vector search results
|
394 |
-
vector_results = self.vectordb.similarity_search(query, k=k*
|
395 |
vector_texts = [doc.page_content for doc in vector_results]
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
# Combine with BM25 if available
|
400 |
if self.bm25_index is not None:
|
401 |
try:
|
@@ -403,14 +438,14 @@ class VectorDBManager:
|
|
403 |
query_tokens = word_tokenize(query.lower())
|
404 |
stop_words = set(stopwords.words('english'))
|
405 |
filtered_query = [w for w in query_tokens if w.isalnum() and w not in stop_words]
|
406 |
-
|
407 |
# Get BM25 scores
|
408 |
bm25_scores = self.bm25_index.get_scores(filtered_query)
|
409 |
-
|
410 |
# Combine scores (normalized)
|
411 |
combined_results = []
|
412 |
seen_texts = set()
|
413 |
-
|
414 |
# First add vector results with their positions as scores
|
415 |
for i, text in enumerate(vector_texts):
|
416 |
if text not in seen_texts:
|
@@ -424,26 +459,34 @@ class VectorDBManager:
|
|
424 |
# Normalize BM25 score
|
425 |
bm25_score = bm25_scores[j] / max(bm25_scores) if max(bm25_scores) > 0 else 0
|
426 |
combined_score = alpha * vector_score + (1-alpha) * bm25_score
|
427 |
-
|
428 |
combined_results.append((chunk, combined_score))
|
429 |
break
|
430 |
-
|
431 |
# Sort by combined score
|
432 |
combined_results.sort(key=lambda x: x[1], reverse=True)
|
433 |
-
|
434 |
-
# Get top k results
|
435 |
-
|
436 |
-
|
437 |
-
# Format results with metadata
|
438 |
-
final_results = [self._format_chunk_with_metadata(chunk) for chunk in top_chunks]
|
439 |
except Exception as e:
|
440 |
print(f"Error in BM25 scoring: {str(e)}")
|
441 |
# Fallback to vector search results
|
442 |
-
|
|
|
443 |
else:
|
444 |
# Just use vector search results if BM25 is not available
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
445 |
final_results = vector_texts[:k]
|
446 |
-
|
447 |
return final_results
|
448 |
except Exception as e:
|
449 |
print(f"Error during hybrid search: {str(e)}")
|
@@ -458,7 +501,7 @@ class QueryExpander:
|
|
458 |
llm_model: LLM model for query expansion
|
459 |
"""
|
460 |
self.llm = llm_model
|
461 |
-
|
462 |
def expand_query(self, query: str) -> str:
|
463 |
"""Expand the query using the LLM to improve retrieval
|
464 |
|
@@ -469,46 +512,44 @@ class QueryExpander:
|
|
469 |
Expanded query
|
470 |
"""
|
471 |
try:
|
472 |
-
prompt = f"""I need to search for documents related to this question: "{query}"
|
473 |
-
|
474 |
Please help me expand this query by identifying key concepts, synonyms, and related terms that might be used in the documents.
|
475 |
-
Return only the expanded search query, without any explanations or additional text.
|
476 |
|
477 |
-
Expanded query:"""
|
478 |
-
|
479 |
expanded = self.llm.generate(prompt, max_tokens=100, temperature=0.3)
|
480 |
-
|
481 |
# Combine original and expanded
|
482 |
combined = f"{query} {expanded}"
|
483 |
-
|
484 |
# Limit length
|
485 |
if len(combined) > 300:
|
486 |
combined = combined[:300]
|
487 |
-
|
488 |
return combined
|
489 |
except:
|
490 |
# Return original query if expansion fails
|
491 |
return query
|
492 |
|
493 |
# === LLM SETUP ===
|
494 |
-
class
|
495 |
-
def __init__(self, model_path: str =
|
496 |
-
"""Initialize
|
497 |
|
498 |
Args:
|
499 |
model_path: Path to the model file
|
500 |
"""
|
501 |
try:
|
502 |
-
# Initialize
|
503 |
self.llm = Llama(
|
504 |
model_path=model_path,
|
505 |
-
n_ctx=
|
506 |
-
n_batch=
|
507 |
-
n_gpu_layers=0, # Run on CPU for
|
508 |
verbose=False
|
509 |
)
|
510 |
except Exception as e:
|
511 |
-
print(f"Error initializing
|
512 |
raise
|
513 |
|
514 |
def generate(self, prompt: str,
|
@@ -516,7 +557,7 @@ class Phi2Model:
|
|
516 |
temperature: float = 0.7,
|
517 |
top_p: float = 0.9,
|
518 |
stream: bool = False) -> Union[str, Generator[str, None, None]]:
|
519 |
-
"""Generate text using
|
520 |
|
521 |
Args:
|
522 |
prompt: Input prompt
|
@@ -539,7 +580,7 @@ class Phi2Model:
|
|
539 |
top_p=top_p,
|
540 |
echo=False
|
541 |
)
|
542 |
-
return output["choices"][0]["text"]
|
543 |
except Exception as e:
|
544 |
print(f"Error generating text: {str(e)}")
|
545 |
return "Error: Could not generate response."
|
@@ -548,7 +589,7 @@ class Phi2Model:
|
|
548 |
max_tokens: int = 512,
|
549 |
temperature: float = 0.7,
|
550 |
top_p: float = 0.9) -> Generator[str, None, None]:
|
551 |
-
"""Stream text generation using
|
552 |
|
553 |
Args:
|
554 |
prompt: Input prompt
|
@@ -572,11 +613,96 @@ class Phi2Model:
|
|
572 |
response += token
|
573 |
yield response
|
574 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
575 |
# === RAG SYSTEM ===
|
576 |
class RAGSystem:
|
577 |
def __init__(self, pdf_processor: PDFProcessor,
|
578 |
vector_db: VectorDBManager,
|
579 |
-
model:
|
580 |
"""Initialize RAG system
|
581 |
|
582 |
Args:
|
@@ -588,11 +714,12 @@ class RAGSystem:
|
|
588 |
self.vector_db = vector_db
|
589 |
self.model = model
|
590 |
self.query_expander = QueryExpander(model)
|
|
|
591 |
self.is_initialized = False
|
592 |
|
593 |
def process_documents(self) -> bool:
|
594 |
"""Process all documents and create vector database
|
595 |
-
|
596 |
Returns:
|
597 |
True if successful, False otherwise
|
598 |
"""
|
@@ -602,22 +729,22 @@ class RAGSystem:
|
|
602 |
if not chunks:
|
603 |
print("No chunks were extracted from PDFs")
|
604 |
return False
|
605 |
-
|
606 |
print(f"Total chunks extracted: {len(chunks)}")
|
607 |
|
608 |
# Create vector database
|
609 |
print("Creating vector database...")
|
610 |
self.vector_db.create_vector_db(chunks)
|
611 |
-
|
612 |
# Verify success
|
613 |
if self.vector_db.vectordb is None:
|
614 |
print("Failed to create vector database")
|
615 |
return False
|
616 |
-
|
617 |
# Set initialization flag
|
618 |
self.is_initialized = True
|
619 |
return True
|
620 |
-
|
621 |
except Exception as e:
|
622 |
print(f"Error processing documents: {str(e)}")
|
623 |
print(traceback.format_exc())
|
@@ -638,8 +765,8 @@ class RAGSystem:
|
|
638 |
for i, context in enumerate(contexts):
|
639 |
formatted_contexts += f"[CONTEXT {i+1}]\n{context}\n\n"
|
640 |
|
641 |
-
# Create prompt with
|
642 |
-
prompt = f"""You are an AI assistant that answers questions based on the provided context information.
|
643 |
|
644 |
User Query: {query}
|
645 |
|
@@ -647,27 +774,26 @@ Below are relevant passages from documents that might help answer the query:
|
|
647 |
|
648 |
{formatted_contexts}
|
649 |
|
650 |
-
Using ONLY the information provided in the context above, provide a comprehensive answer to the user's query.
|
651 |
If the provided context doesn't contain relevant information to answer the query, clearly state: "I don't have enough information in the provided context to answer this question."
|
652 |
|
653 |
Do not use any prior knowledge that is not contained in the provided context.
|
654 |
If quoting from the context, mention the source document and page number.
|
655 |
-
Organize your answer in a clear, coherent manner.
|
656 |
-
|
657 |
-
Answer:"""
|
658 |
return prompt
|
659 |
|
660 |
def answer_query(self, query: str, k: int = 5, max_tokens: int = 512,
|
661 |
-
temperature: float = 0.7, stream: bool = False) -> Union[str, Generator[str, None, None]]:
|
662 |
-
"""Answer a query using RAG with query expansion
|
663 |
-
|
664 |
Args:
|
665 |
query: User query
|
666 |
k: Number of contexts to retrieve
|
667 |
max_tokens: Maximum number of tokens to generate
|
668 |
temperature: Temperature for generation
|
669 |
stream: Whether to stream the output
|
670 |
-
|
|
|
671 |
Returns:
|
672 |
Answer text or generator if streaming
|
673 |
"""
|
@@ -679,7 +805,7 @@ Answer:"""
|
|
679 |
# Expand query for better retrieval
|
680 |
expanded_query = self.query_expander.expand_query(query)
|
681 |
print(f"Expanded query: {expanded_query}")
|
682 |
-
|
683 |
# Retrieve relevant contexts using hybrid search
|
684 |
contexts = self.vector_db.hybrid_search(expanded_query, k=k)
|
685 |
|
@@ -689,240 +815,293 @@ Answer:"""
|
|
689 |
# Generate prompt with improved instructions
|
690 |
prompt = self.generate_prompt(query, contexts)
|
691 |
|
692 |
-
#
|
693 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
694 |
prompt,
|
695 |
max_tokens=max_tokens,
|
696 |
temperature=temperature,
|
697 |
-
stream=
|
698 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
699 |
except Exception as e:
|
700 |
print(f"Error answering query: {str(e)}")
|
701 |
print(traceback.format_exc())
|
702 |
return f"Error processing your query: {str(e)}"
|
703 |
|
704 |
-
# === GRADIO
|
705 |
-
class
|
706 |
def __init__(self, rag_system: RAGSystem):
|
707 |
-
"""Initialize
|
708 |
|
709 |
Args:
|
710 |
rag_system: RAG system instance
|
711 |
"""
|
712 |
self.rag_system = rag_system
|
|
|
713 |
self.interface = None
|
714 |
-
self.is_processing = False
|
715 |
|
716 |
-
def
|
717 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
718 |
try:
|
719 |
-
|
720 |
-
|
721 |
|
|
|
722 |
for file in files:
|
723 |
-
|
724 |
-
shutil.copy(file.name, destination)
|
725 |
-
uploaded_files.append(os.path.basename(file.name))
|
726 |
|
727 |
-
|
728 |
-
pdf_files = [f for f in os.listdir("pdfs") if f.lower().endswith('.pdf')]
|
729 |
-
|
730 |
-
if not pdf_files:
|
731 |
-
return "No PDF files were uploaded successfully."
|
732 |
-
|
733 |
-
return f"Successfully uploaded {len(uploaded_files)} files: {', '.join(uploaded_files)}"
|
734 |
except Exception as e:
|
735 |
return f"Error uploading files: {str(e)}"
|
736 |
|
737 |
-
def process_documents(self):
|
738 |
-
"""Process
|
739 |
|
740 |
Returns:
|
741 |
Status message
|
742 |
"""
|
743 |
-
if self.is_processing:
|
744 |
-
return "Document processing is already in progress. Please wait."
|
745 |
-
|
746 |
try:
|
747 |
-
|
|
|
|
|
|
|
|
|
|
|
748 |
start_time = time.time()
|
749 |
-
|
750 |
success = self.rag_system.process_documents()
|
751 |
-
|
752 |
-
|
753 |
-
self.is_processing = False
|
754 |
-
|
755 |
if success:
|
756 |
-
return f"
|
757 |
else:
|
758 |
-
return "Failed to process documents. Check the logs for
|
759 |
except Exception as e:
|
760 |
-
self.is_processing = False
|
761 |
return f"Error processing documents: {str(e)}"
|
762 |
|
763 |
-
def answer_query(self, query
|
764 |
-
|
|
|
|
|
765 |
|
766 |
Args:
|
767 |
query: User query
|
|
|
768 |
k: Number of contexts to retrieve
|
769 |
-
|
770 |
-
|
771 |
|
772 |
Returns:
|
773 |
-
Answer
|
774 |
"""
|
775 |
-
if not query.strip():
|
776 |
-
return "Please enter a
|
777 |
-
|
778 |
-
try:
|
779 |
-
return self.rag_system.answer_query(
|
780 |
-
query,
|
781 |
-
k=k,
|
782 |
-
max_tokens=max_tokens,
|
783 |
-
temperature=temperature,
|
784 |
-
stream=False
|
785 |
-
)
|
786 |
-
except Exception as e:
|
787 |
-
return f"Error answering query: {str(e)}"
|
788 |
|
789 |
-
|
790 |
-
|
791 |
-
|
792 |
-
Args:
|
793 |
-
query: User query
|
794 |
-
k: Number of contexts to retrieve
|
795 |
-
max_tokens: Maximum number of tokens to generate
|
796 |
-
temperature: Sampling temperature
|
797 |
|
798 |
-
Yields:
|
799 |
-
Generated text
|
800 |
-
"""
|
801 |
-
if not query.strip():
|
802 |
-
yield "Please enter a question."
|
803 |
-
return
|
804 |
-
|
805 |
try:
|
806 |
-
|
807 |
-
|
808 |
-
|
809 |
-
|
810 |
-
|
811 |
-
|
812 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
813 |
except Exception as e:
|
814 |
-
|
815 |
-
|
816 |
-
|
817 |
-
|
818 |
-
|
819 |
-
|
820 |
-
|
821 |
-
|
822 |
-
|
823 |
-
|
824 |
-
|
825 |
-
|
826 |
-
|
827 |
-
|
828 |
-
|
829 |
-
|
830 |
-
|
831 |
-
|
832 |
-
upload_button.click(self.upload_file, inputs=[pdf_files], outputs=upload_output)
|
833 |
-
|
834 |
-
process_button = gr.Button("Process Documents", variant="primary")
|
835 |
-
process_output = gr.Textbox(label="Processing Status", lines=2)
|
836 |
-
process_button.click(self.process_documents, inputs=[], outputs=process_output)
|
837 |
-
|
838 |
-
with gr.Tab("Query"):
|
839 |
-
with gr.Row():
|
840 |
-
with gr.Column():
|
841 |
-
query_input = gr.Textbox(
|
842 |
-
label="Question",
|
843 |
-
lines=3,
|
844 |
-
placeholder="Ask a question about your documents..."
|
845 |
-
)
|
846 |
-
with gr.Row():
|
847 |
-
k_slider = gr.Slider(
|
848 |
-
minimum=1,
|
849 |
-
maximum=10,
|
850 |
-
value=3,
|
851 |
-
step=1,
|
852 |
-
label="Number of Contexts"
|
853 |
)
|
854 |
-
|
855 |
-
|
856 |
-
|
857 |
-
value=400,
|
858 |
-
step=50,
|
859 |
-
label="Max Tokens"
|
860 |
)
|
861 |
-
|
862 |
-
|
863 |
-
|
864 |
-
|
865 |
-
label="
|
|
|
|
|
866 |
)
|
867 |
-
|
868 |
-
|
869 |
-
|
870 |
-
|
871 |
-
|
872 |
-
|
873 |
-
|
874 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
875 |
)
|
876 |
-
|
877 |
-
|
878 |
-
|
879 |
-
|
880 |
-
|
881 |
-
|
882 |
-
|
|
|
|
|
|
|
|
|
883 |
)
|
884 |
-
|
885 |
-
|
886 |
-
|
887 |
-
|
888 |
-
|
889 |
-
|
890 |
-
|
891 |
-
|
892 |
-
|
893 |
-
|
894 |
-
|
895 |
-
|
896 |
-
|
897 |
-
|
898 |
-
|
899 |
-
|
900 |
-
|
901 |
-
|
902 |
-
self.
|
903 |
-
|
|
|
|
|
|
|
904 |
|
905 |
# === MAIN APPLICATION ===
|
906 |
def main():
|
907 |
-
|
908 |
-
|
909 |
-
|
910 |
-
|
911 |
-
|
912 |
-
|
913 |
-
|
914 |
-
|
915 |
-
|
916 |
-
|
917 |
-
|
918 |
-
|
919 |
-
|
920 |
-
|
921 |
-
|
922 |
-
|
923 |
-
|
924 |
-
|
925 |
-
print(traceback.format_exc())
|
926 |
|
927 |
if __name__ == "__main__":
|
928 |
main()
|
|
|
1 |
+
!pip install -q pymupdf langchain langchain_community sentence-transformers faiss-cpu llama-cpp-python gradio transformers rank_bm25
|
2 |
+
!pip install -q git+https://github.com/chroma-core/chroma.git
|
3 |
+
!mkdir -p pdfs
|
4 |
+
!mkdir -p models
|
5 |
+
!wget -q -O models/mistral-7b-instruct-v0.3.Q8_0.gguf https://huggingface.co/MaziyarPanahi/Mistral-7B-Instruct-v0.3-GGUF/resolve/main/Mistral-7B-Instruct-v0.3.Q8_0.gguf
|
6 |
+
|
7 |
import os
|
8 |
import re
|
9 |
+
import fitz # PyMuPDF
|
10 |
import numpy as np
|
11 |
import gc
|
12 |
import torch
|
|
|
17 |
import traceback
|
18 |
from typing import List, Dict, Any, Tuple, Optional, Union, Generator
|
19 |
from dataclasses import dataclass
|
|
|
|
|
|
|
|
|
20 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
21 |
from langchain_community.vectorstores import FAISS
|
22 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
23 |
from llama_cpp import Llama
|
24 |
+
import gradio as gr
|
25 |
from rank_bm25 import BM25Okapi
|
26 |
import nltk
|
27 |
from nltk.tokenize import word_tokenize
|
28 |
from nltk.corpus import stopwords
|
29 |
+
from sentence_transformers import CrossEncoder
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
# Download nltk resources
|
32 |
try:
|
|
|
35 |
except:
|
36 |
print("Failed to download NLTK resources, continuing without them")
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
# === MEMORY MANAGEMENT UTILITIES ===
|
39 |
def clear_memory():
|
40 |
"""Clear memory to prevent OOM errors"""
|
|
|
65 |
length_function=len,
|
66 |
is_separator_regex=False,
|
67 |
)
|
68 |
+
|
69 |
# Create cache directory
|
70 |
self.cache_dir = os.path.join(os.getcwd(), "pdf_cache")
|
71 |
os.makedirs(self.cache_dir, exist_ok=True)
|
|
|
80 |
"""Get the cache file path for a PDF"""
|
81 |
pdf_hash = hashlib.md5(open(pdf_path, 'rb').read(8192)).hexdigest()
|
82 |
return os.path.join(self.cache_dir, f"{os.path.basename(pdf_path)}_{pdf_hash}.pkl")
|
83 |
+
|
84 |
def _is_cached(self, pdf_path: str) -> bool:
|
85 |
"""Check if a PDF is cached"""
|
86 |
cache_path = self._get_cache_path(pdf_path)
|
87 |
return os.path.exists(cache_path)
|
88 |
+
|
89 |
def _load_from_cache(self, pdf_path: str) -> List[PDFChunk]:
|
90 |
"""Load chunks from cache"""
|
91 |
cache_path = self._get_cache_path(pdf_path)
|
|
|
94 |
return pickle.load(f)
|
95 |
except:
|
96 |
return None
|
97 |
+
|
98 |
def _save_to_cache(self, pdf_path: str, chunks: List[PDFChunk]) -> None:
|
99 |
"""Save chunks to cache"""
|
100 |
cache_path = self._get_cache_path(pdf_path)
|
|
|
127 |
if cached_chunks:
|
128 |
print(f"Loaded {len(cached_chunks)} chunks from cache for {os.path.basename(pdf_path)}")
|
129 |
return cached_chunks
|
130 |
+
|
131 |
try:
|
132 |
doc = fitz.open(pdf_path)
|
133 |
pdf_chunks = []
|
134 |
pdf_name = os.path.basename(pdf_path)
|
135 |
+
|
136 |
for page_num in range(len(doc)):
|
137 |
page = doc.load_page(page_num)
|
138 |
+
|
139 |
# Extract text with more options for better quality
|
140 |
page_text = page.get_text("text", sort=True)
|
141 |
# Try to extract text with alternative layout analysis if the text is too short
|
|
|
156 |
except:
|
157 |
# Fallback to default extraction
|
158 |
page_text = page.get_text("text")
|
159 |
+
|
160 |
# Clean the text
|
161 |
page_text = self.clean_text(page_text)
|
162 |
+
|
163 |
# Extract tables
|
164 |
try:
|
165 |
tables = page.find_tables()
|
|
|
175 |
row_cells.append(cell_text)
|
176 |
if row_cells:
|
177 |
table_text += " | ".join(row_cells) + "\n"
|
178 |
+
|
179 |
# Add table text to page text
|
180 |
if table_text.strip():
|
181 |
page_text += "\n\nTABLE:\n" + table_text
|
182 |
except Exception as table_err:
|
183 |
print(f"Warning: Skipping table extraction for page {page_num}: {str(table_err)}")
|
184 |
+
|
185 |
# Split the page text into chunks
|
186 |
if page_text.strip():
|
187 |
page_chunks = self.text_splitter.split_text(page_text)
|
188 |
+
|
189 |
# Create PDFChunk objects
|
190 |
for i, chunk_text in enumerate(page_chunks):
|
191 |
pdf_chunks.append(PDFChunk(
|
|
|
194 |
page_num=page_num + 1, # 1-based page numbering for humans
|
195 |
chunk_id=i
|
196 |
))
|
197 |
+
|
198 |
# Clear memory periodically
|
199 |
if page_num % 10 == 0:
|
200 |
clear_memory()
|
201 |
+
|
202 |
doc.close()
|
203 |
+
|
204 |
# Cache the results
|
205 |
self._save_to_cache(pdf_path, pdf_chunks)
|
206 |
+
|
207 |
return pdf_chunks
|
208 |
except Exception as e:
|
209 |
print(f"Error extracting text from {pdf_path}: {str(e)}")
|
|
|
221 |
pdf_path = os.path.join(self.pdf_dir, pdf_name)
|
222 |
return self.extract_text_from_pdf(pdf_path)
|
223 |
|
224 |
+
def process_all_pdfs(self, batch_size: int = 3) -> List[PDFChunk]:
|
225 |
"""Process all PDFs in batches to manage memory
|
226 |
|
227 |
Args:
|
|
|
276 |
model_kwargs={"device": "cpu"},
|
277 |
encode_kwargs={"normalize_embeddings": True}
|
278 |
)
|
279 |
+
|
280 |
+
# Initialize cross-encoder for re-ranking
|
281 |
+
try:
|
282 |
+
self.cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
|
283 |
+
except Exception as e:
|
284 |
+
print(f"Error initializing cross-encoder: {str(e)}")
|
285 |
+
self.cross_encoder = None
|
286 |
+
|
287 |
self.vectordb = None
|
288 |
# BM25 index for hybrid search
|
289 |
self.bm25_index = None
|
|
|
301 |
stop_words = set(stopwords.words('english'))
|
302 |
filtered_tokens = [w for w in tokens if w.isalnum() and w not in stop_words]
|
303 |
tokenized_chunks.append(filtered_tokens)
|
304 |
+
|
305 |
# Create BM25 index
|
306 |
self.bm25_index = BM25Okapi(tokenized_chunks)
|
307 |
+
self.tokenized_chunks = tokenized_chunks
|
308 |
except Exception as e:
|
309 |
print(f"Error creating BM25 index: {str(e)}")
|
310 |
print(traceback.format_exc())
|
311 |
self.bm25_index = None
|
312 |
+
|
313 |
def create_vector_db(self, chunks: List[PDFChunk]) -> None:
|
314 |
"""Create vector database from text chunks
|
315 |
|
|
|
322 |
return
|
323 |
|
324 |
print(f"Creating vector DB with {len(chunks)} chunks")
|
325 |
+
|
326 |
# Store chunks for hybrid search
|
327 |
self.chunks = chunks
|
328 |
+
|
329 |
# Prepare data for vector DB
|
330 |
chunk_texts = [chunk.text for chunk in chunks]
|
331 |
+
|
332 |
# Create BM25 index for hybrid search
|
333 |
print("Creating BM25 index for hybrid search")
|
334 |
self._prepare_bm25(chunks)
|
335 |
+
|
336 |
# Process in smaller batches to manage memory
|
337 |
+
batch_size = 32
|
338 |
all_embeddings = []
|
339 |
|
340 |
for i in range(0, len(chunk_texts), batch_size):
|
|
|
354 |
text_embeddings=list(zip(chunk_texts, all_embeddings)),
|
355 |
embedding=self.embedding_model
|
356 |
)
|
357 |
+
|
358 |
print(f"Vector database created with {len(chunks)} documents")
|
359 |
|
360 |
except Exception as e:
|
|
|
366 |
"""Format a chunk with its metadata for better context"""
|
367 |
return f"Source: {chunk.source} | Page: {chunk.page_num}\n\n{chunk.text}"
|
368 |
|
369 |
+
def _rerank_with_cross_encoder(self, query: str, chunks: List[PDFChunk], k: int = 5) -> List[PDFChunk]:
|
370 |
+
"""Re-rank chunks using cross-encoder
|
371 |
+
|
372 |
+
Args:
|
373 |
+
query: User query
|
374 |
+
chunks: List of retrieved chunks
|
375 |
+
k: Number of top chunks to return
|
376 |
+
|
377 |
+
Returns:
|
378 |
+
Re-ranked chunks
|
379 |
+
"""
|
380 |
+
if not self.cross_encoder or not chunks:
|
381 |
+
return chunks[:k] if len(chunks) > k else chunks
|
382 |
+
|
383 |
+
try:
|
384 |
+
# Prepare passage pairs for re-ranking
|
385 |
+
pairs = [[query, chunk.text] for chunk in chunks]
|
386 |
+
|
387 |
+
# Score passages in smaller batches to prevent OOM
|
388 |
+
batch_size = 16
|
389 |
+
all_scores = []
|
390 |
+
|
391 |
+
for i in range(0, len(pairs), batch_size):
|
392 |
+
batch_pairs = pairs[i:i+batch_size]
|
393 |
+
batch_scores = self.cross_encoder.predict(batch_pairs)
|
394 |
+
all_scores.extend(batch_scores)
|
395 |
+
|
396 |
+
# Clear memory
|
397 |
+
clear_memory()
|
398 |
+
|
399 |
+
# Create chunk-score pairs
|
400 |
+
scored_chunks = list(zip(chunks, all_scores))
|
401 |
+
|
402 |
+
# Sort by score
|
403 |
+
scored_chunks.sort(key=lambda x: x[1], reverse=True)
|
404 |
+
|
405 |
+
# Return top k chunks
|
406 |
+
return [chunk for chunk, score in scored_chunks[:k]]
|
407 |
+
except Exception as e:
|
408 |
+
print(f"Error during cross-encoder re-ranking: {str(e)}")
|
409 |
+
# Fallback to original chunks
|
410 |
+
return chunks[:k] if len(chunks) > k else chunks
|
411 |
+
|
412 |
def hybrid_search(self, query: str, k: int = 5, alpha: float = 0.7) -> List[str]:
|
413 |
+
"""Hybrid search combining vector search and BM25 with cross-encoder re-ranking
|
414 |
|
415 |
Args:
|
416 |
query: Query text
|
|
|
426 |
|
427 |
try:
|
428 |
# Get vector search results
|
429 |
+
vector_results = self.vectordb.similarity_search(query, k=k*3) # Get more for re-ranking
|
430 |
vector_texts = [doc.page_content for doc in vector_results]
|
431 |
+
|
432 |
+
retrieved_chunks = []
|
433 |
+
|
434 |
# Combine with BM25 if available
|
435 |
if self.bm25_index is not None:
|
436 |
try:
|
|
|
438 |
query_tokens = word_tokenize(query.lower())
|
439 |
stop_words = set(stopwords.words('english'))
|
440 |
filtered_query = [w for w in query_tokens if w.isalnum() and w not in stop_words]
|
441 |
+
|
442 |
# Get BM25 scores
|
443 |
bm25_scores = self.bm25_index.get_scores(filtered_query)
|
444 |
+
|
445 |
# Combine scores (normalized)
|
446 |
combined_results = []
|
447 |
seen_texts = set()
|
448 |
+
|
449 |
# First add vector results with their positions as scores
|
450 |
for i, text in enumerate(vector_texts):
|
451 |
if text not in seen_texts:
|
|
|
459 |
# Normalize BM25 score
|
460 |
bm25_score = bm25_scores[j] / max(bm25_scores) if max(bm25_scores) > 0 else 0
|
461 |
combined_score = alpha * vector_score + (1-alpha) * bm25_score
|
462 |
+
|
463 |
combined_results.append((chunk, combined_score))
|
464 |
break
|
465 |
+
|
466 |
# Sort by combined score
|
467 |
combined_results.sort(key=lambda x: x[1], reverse=True)
|
468 |
+
|
469 |
+
# Get top k*2 results for re-ranking
|
470 |
+
retrieved_chunks = [item[0] for item in combined_results[:k*2]]
|
|
|
|
|
|
|
471 |
except Exception as e:
|
472 |
print(f"Error in BM25 scoring: {str(e)}")
|
473 |
# Fallback to vector search results
|
474 |
+
retrieved_chunks = [self.chunks[i] for i, text in enumerate(self.chunks)
|
475 |
+
if text.text in vector_texts[:k*2]]
|
476 |
else:
|
477 |
# Just use vector search results if BM25 is not available
|
478 |
+
retrieved_chunks = [self.chunks[i] for i, chunk in enumerate(self.chunks)
|
479 |
+
if chunk.text in vector_texts[:k*2]]
|
480 |
+
|
481 |
+
# Re-rank with cross-encoder
|
482 |
+
if retrieved_chunks:
|
483 |
+
reranked_chunks = self._rerank_with_cross_encoder(query, retrieved_chunks, k)
|
484 |
+
# Format results with metadata
|
485 |
+
final_results = [self._format_chunk_with_metadata(chunk) for chunk in reranked_chunks]
|
486 |
+
else:
|
487 |
+
# Fallback to basic results
|
488 |
final_results = vector_texts[:k]
|
489 |
+
|
490 |
return final_results
|
491 |
except Exception as e:
|
492 |
print(f"Error during hybrid search: {str(e)}")
|
|
|
501 |
llm_model: LLM model for query expansion
|
502 |
"""
|
503 |
self.llm = llm_model
|
504 |
+
|
505 |
def expand_query(self, query: str) -> str:
|
506 |
"""Expand the query using the LLM to improve retrieval
|
507 |
|
|
|
512 |
Expanded query
|
513 |
"""
|
514 |
try:
|
515 |
+
prompt = f"""<s>[INST] I need to search for documents related to this question: "{query}"
|
516 |
+
|
517 |
Please help me expand this query by identifying key concepts, synonyms, and related terms that might be used in the documents.
|
518 |
+
Return only the expanded search query, without any explanations or additional text. [/INST]"""
|
519 |
|
|
|
|
|
520 |
expanded = self.llm.generate(prompt, max_tokens=100, temperature=0.3)
|
521 |
+
|
522 |
# Combine original and expanded
|
523 |
combined = f"{query} {expanded}"
|
524 |
+
|
525 |
# Limit length
|
526 |
if len(combined) > 300:
|
527 |
combined = combined[:300]
|
528 |
+
|
529 |
return combined
|
530 |
except:
|
531 |
# Return original query if expansion fails
|
532 |
return query
|
533 |
|
534 |
# === LLM SETUP ===
|
535 |
+
class MistralModel:
|
536 |
+
def __init__(self, model_path: str = "models/mistral-7b-instruct-v0.3.Q8_0.gguf"):
|
537 |
+
"""Initialize Mistral model
|
538 |
|
539 |
Args:
|
540 |
model_path: Path to the model file
|
541 |
"""
|
542 |
try:
|
543 |
+
# Initialize Mistral with llama.cpp
|
544 |
self.llm = Llama(
|
545 |
model_path=model_path,
|
546 |
+
n_ctx=4096, # Increased context window for better reasoning
|
547 |
+
n_batch=256, # Batch size to save memory
|
548 |
+
n_gpu_layers=0, # Run on CPU only for Colab free tier
|
549 |
verbose=False
|
550 |
)
|
551 |
except Exception as e:
|
552 |
+
print(f"Error initializing Mistral model: {str(e)}")
|
553 |
raise
|
554 |
|
555 |
def generate(self, prompt: str,
|
|
|
557 |
temperature: float = 0.7,
|
558 |
top_p: float = 0.9,
|
559 |
stream: bool = False) -> Union[str, Generator[str, None, None]]:
|
560 |
+
"""Generate text using Mistral
|
561 |
|
562 |
Args:
|
563 |
prompt: Input prompt
|
|
|
580 |
top_p=top_p,
|
581 |
echo=False
|
582 |
)
|
583 |
+
return output["choices"][0]["text"].strip()
|
584 |
except Exception as e:
|
585 |
print(f"Error generating text: {str(e)}")
|
586 |
return "Error: Could not generate response."
|
|
|
589 |
max_tokens: int = 512,
|
590 |
temperature: float = 0.7,
|
591 |
top_p: float = 0.9) -> Generator[str, None, None]:
|
592 |
+
"""Stream text generation using Mistral
|
593 |
|
594 |
Args:
|
595 |
prompt: Input prompt
|
|
|
613 |
response += token
|
614 |
yield response
|
615 |
|
616 |
+
# === SELF-CHECKING ===
|
617 |
+
class SelfChecker:
|
618 |
+
def __init__(self, llm_model):
|
619 |
+
"""Initialize self-checker for improved response quality
|
620 |
+
|
621 |
+
Args:
|
622 |
+
llm_model: LLM model to use for checking
|
623 |
+
"""
|
624 |
+
self.llm = llm_model
|
625 |
+
|
626 |
+
def check_answer(self, query: str, initial_answer: str, contexts: List[str]) -> str:
|
627 |
+
"""Check if answer is correct and complete based on the contexts
|
628 |
+
|
629 |
+
Args:
|
630 |
+
query: User query
|
631 |
+
initial_answer: Initial generated answer
|
632 |
+
contexts: Retrieved contexts used to generate the answer
|
633 |
+
|
634 |
+
Returns:
|
635 |
+
Improved answer after reflection
|
636 |
+
"""
|
637 |
+
# Guard against very long inputs that could cause runtime disconnection
|
638 |
+
# Limit contexts to prevent excessive token usage
|
639 |
+
max_contexts_len = 4000
|
640 |
+
contexts_text = "\n\n".join(contexts)
|
641 |
+
if len(contexts_text) > max_contexts_len:
|
642 |
+
# Truncate while keeping as many complete contexts as possible
|
643 |
+
truncated_contexts = []
|
644 |
+
current_len = 0
|
645 |
+
for ctx in contexts:
|
646 |
+
if current_len + len(ctx) + 2 <= max_contexts_len:
|
647 |
+
truncated_contexts.append(ctx)
|
648 |
+
current_len += len(ctx) + 2
|
649 |
+
else:
|
650 |
+
break
|
651 |
+
contexts_text = "\n\n".join(truncated_contexts)
|
652 |
+
|
653 |
+
# Check if we should skip reflection to prevent disconnection
|
654 |
+
if len(initial_answer) + len(contexts_text) + len(query) > 6000:
|
655 |
+
print("Skipping reflection due to excessive input length")
|
656 |
+
return initial_answer
|
657 |
+
|
658 |
+
try:
|
659 |
+
prompt = f"""<s>[INST] You're an AI assistant tasked with evaluating and improving an answer to a user query.
|
660 |
+
|
661 |
+
QUERY: {query}
|
662 |
+
|
663 |
+
INITIAL ANSWER: {initial_answer}
|
664 |
+
|
665 |
+
AVAILABLE CONTEXTS:
|
666 |
+
{contexts_text}
|
667 |
+
|
668 |
+
First, carefully check if the initial answer:
|
669 |
+
1. Is factually accurate based on the provided contexts
|
670 |
+
2. Addresses all aspects of the user's query
|
671 |
+
3. Contains any information not supported by the contexts
|
672 |
+
4. Misses important information from the contexts
|
673 |
+
|
674 |
+
Then improve the answer to fix any issues identified. The final answer should:
|
675 |
+
- Be comprehensive and accurate based ONLY on the contexts
|
676 |
+
- Not include any unsupported information
|
677 |
+
- Be well-structured and clear
|
678 |
+
- Cite specific sources when appropriate (e.g., "According to [Source, Page X]...")
|
679 |
+
|
680 |
+
Provide ONLY the improved answer without explanations about your reasoning process. [/INST]"""
|
681 |
+
|
682 |
+
# We use slightly lower temperature for more focused reflection
|
683 |
+
improved_answer = self.llm.generate(
|
684 |
+
prompt,
|
685 |
+
max_tokens=1024,
|
686 |
+
temperature=0.3,
|
687 |
+
stream=False
|
688 |
+
)
|
689 |
+
|
690 |
+
# If reflection produced nothing useful, return original answer
|
691 |
+
if not improved_answer or len(improved_answer) < 10:
|
692 |
+
return initial_answer
|
693 |
+
|
694 |
+
return improved_answer
|
695 |
+
|
696 |
+
except Exception as e:
|
697 |
+
# On any error, return the original answer to ensure robustness
|
698 |
+
print(f"Self-check error: {str(e)}")
|
699 |
+
return initial_answer
|
700 |
+
|
701 |
# === RAG SYSTEM ===
|
702 |
class RAGSystem:
|
703 |
def __init__(self, pdf_processor: PDFProcessor,
|
704 |
vector_db: VectorDBManager,
|
705 |
+
model: MistralModel):
|
706 |
"""Initialize RAG system
|
707 |
|
708 |
Args:
|
|
|
714 |
self.vector_db = vector_db
|
715 |
self.model = model
|
716 |
self.query_expander = QueryExpander(model)
|
717 |
+
self.self_checker = SelfChecker(model)
|
718 |
self.is_initialized = False
|
719 |
|
720 |
def process_documents(self) -> bool:
|
721 |
"""Process all documents and create vector database
|
722 |
+
|
723 |
Returns:
|
724 |
True if successful, False otherwise
|
725 |
"""
|
|
|
729 |
if not chunks:
|
730 |
print("No chunks were extracted from PDFs")
|
731 |
return False
|
732 |
+
|
733 |
print(f"Total chunks extracted: {len(chunks)}")
|
734 |
|
735 |
# Create vector database
|
736 |
print("Creating vector database...")
|
737 |
self.vector_db.create_vector_db(chunks)
|
738 |
+
|
739 |
# Verify success
|
740 |
if self.vector_db.vectordb is None:
|
741 |
print("Failed to create vector database")
|
742 |
return False
|
743 |
+
|
744 |
# Set initialization flag
|
745 |
self.is_initialized = True
|
746 |
return True
|
747 |
+
|
748 |
except Exception as e:
|
749 |
print(f"Error processing documents: {str(e)}")
|
750 |
print(traceback.format_exc())
|
|
|
765 |
for i, context in enumerate(contexts):
|
766 |
formatted_contexts += f"[CONTEXT {i+1}]\n{context}\n\n"
|
767 |
|
768 |
+
# Create prompt with Mistral's chat format
|
769 |
+
prompt = f"""<s>[INST] You are an AI assistant that answers questions based on the provided context information.
|
770 |
|
771 |
User Query: {query}
|
772 |
|
|
|
774 |
|
775 |
{formatted_contexts}
|
776 |
|
777 |
+
Using ONLY the information provided in the context above, provide a comprehensive answer to the user's query.
|
778 |
If the provided context doesn't contain relevant information to answer the query, clearly state: "I don't have enough information in the provided context to answer this question."
|
779 |
|
780 |
Do not use any prior knowledge that is not contained in the provided context.
|
781 |
If quoting from the context, mention the source document and page number.
|
782 |
+
Organize your answer in a clear, coherent manner. [/INST]"""
|
|
|
|
|
783 |
return prompt
|
784 |
|
785 |
def answer_query(self, query: str, k: int = 5, max_tokens: int = 512,
|
786 |
+
temperature: float = 0.7, stream: bool = False, enable_reflection: bool = True) -> Union[str, Generator[str, None, None]]:
|
787 |
+
"""Answer a query using RAG with query expansion and self-checking
|
788 |
+
|
789 |
Args:
|
790 |
query: User query
|
791 |
k: Number of contexts to retrieve
|
792 |
max_tokens: Maximum number of tokens to generate
|
793 |
temperature: Temperature for generation
|
794 |
stream: Whether to stream the output
|
795 |
+
enable_reflection: Whether to enable self-reflection for better answers
|
796 |
+
|
797 |
Returns:
|
798 |
Answer text or generator if streaming
|
799 |
"""
|
|
|
805 |
# Expand query for better retrieval
|
806 |
expanded_query = self.query_expander.expand_query(query)
|
807 |
print(f"Expanded query: {expanded_query}")
|
808 |
+
|
809 |
# Retrieve relevant contexts using hybrid search
|
810 |
contexts = self.vector_db.hybrid_search(expanded_query, k=k)
|
811 |
|
|
|
815 |
# Generate prompt with improved instructions
|
816 |
prompt = self.generate_prompt(query, contexts)
|
817 |
|
818 |
+
# For streaming, we can't do self-checking
|
819 |
+
if stream:
|
820 |
+
return self.model.generate(
|
821 |
+
prompt,
|
822 |
+
max_tokens=max_tokens,
|
823 |
+
temperature=temperature,
|
824 |
+
stream=True
|
825 |
+
)
|
826 |
+
|
827 |
+
# Generate initial answer
|
828 |
+
initial_answer = self.model.generate(
|
829 |
prompt,
|
830 |
max_tokens=max_tokens,
|
831 |
temperature=temperature,
|
832 |
+
stream=False
|
833 |
)
|
834 |
+
|
835 |
+
# Perform self-checking if enabled and initial answer exists
|
836 |
+
if enable_reflection and initial_answer and len(initial_answer) > 10:
|
837 |
+
try:
|
838 |
+
print("Performing self-checking to improve answer quality...")
|
839 |
+
improved_answer = self.self_checker.check_answer(query, initial_answer, contexts)
|
840 |
+
return improved_answer
|
841 |
+
except Exception as e:
|
842 |
+
print(f"Error during self-checking: {str(e)}")
|
843 |
+
# Fallback to initial answer if self-checking fails
|
844 |
+
return initial_answer
|
845 |
+
else:
|
846 |
+
return initial_answer
|
847 |
+
|
848 |
except Exception as e:
|
849 |
print(f"Error answering query: {str(e)}")
|
850 |
print(traceback.format_exc())
|
851 |
return f"Error processing your query: {str(e)}"
|
852 |
|
853 |
+
# === GRADIO UI ===
|
854 |
+
class RAGUI:
|
855 |
def __init__(self, rag_system: RAGSystem):
|
856 |
+
"""Initialize RAG UI
|
857 |
|
858 |
Args:
|
859 |
rag_system: RAG system instance
|
860 |
"""
|
861 |
self.rag_system = rag_system
|
862 |
+
self.pdf_dir = rag_system.pdf_processor.pdf_dir
|
863 |
self.interface = None
|
|
|
864 |
|
865 |
+
def _list_uploaded_pdfs(self) -> str:
|
866 |
+
"""List all uploaded PDFs"""
|
867 |
+
pdfs = self.rag_system.pdf_processor.list_pdfs()
|
868 |
+
if not pdfs:
|
869 |
+
return "No PDFs uploaded yet."
|
870 |
+
return "\n".join([f"- {pdf}" for pdf in pdfs])
|
871 |
+
|
872 |
+
def upload_pdf(self, files) -> str:
|
873 |
+
"""Upload PDF files
|
874 |
+
|
875 |
+
Args:
|
876 |
+
files: File objects
|
877 |
+
|
878 |
+
Returns:
|
879 |
+
Status message
|
880 |
+
"""
|
881 |
try:
|
882 |
+
# Create directory if it doesn't exist
|
883 |
+
os.makedirs(self.pdf_dir, exist_ok=True)
|
884 |
|
885 |
+
# Copy files to pdf directory
|
886 |
for file in files:
|
887 |
+
shutil.copy(file.name, os.path.join(self.pdf_dir, os.path.basename(file.name)))
|
|
|
|
|
888 |
|
889 |
+
return f"Successfully uploaded {len(files)} file(s). Please process documents to make them searchable."
|
|
|
|
|
|
|
|
|
|
|
|
|
890 |
except Exception as e:
|
891 |
return f"Error uploading files: {str(e)}"
|
892 |
|
893 |
+
def process_documents(self) -> str:
|
894 |
+
"""Process documents and create vector database
|
895 |
|
896 |
Returns:
|
897 |
Status message
|
898 |
"""
|
|
|
|
|
|
|
899 |
try:
|
900 |
+
# Check if there are PDFs
|
901 |
+
pdf_files = self.rag_system.pdf_processor.list_pdfs()
|
902 |
+
if not pdf_files:
|
903 |
+
return "No PDF files uploaded. Please upload PDFs first."
|
904 |
+
|
905 |
+
# Process PDFs
|
906 |
start_time = time.time()
|
|
|
907 |
success = self.rag_system.process_documents()
|
908 |
+
process_time = time.time() - start_time
|
909 |
+
|
|
|
|
|
910 |
if success:
|
911 |
+
return f"Successfully processed {len(pdf_files)} PDF file(s) in {process_time:.2f} seconds. You can now ask questions."
|
912 |
else:
|
913 |
+
return "Failed to process documents. Check the logs for details."
|
914 |
except Exception as e:
|
|
|
915 |
return f"Error processing documents: {str(e)}"
|
916 |
|
917 |
+
def answer_query(self, query: str, stream_output: bool = True,
|
918 |
+
k: int = 4, temperature: float = 0.7,
|
919 |
+
enable_reflection: bool = True) -> str:
|
920 |
+
"""Answer a query using RAG
|
921 |
|
922 |
Args:
|
923 |
query: User query
|
924 |
+
stream_output: Whether to stream the output
|
925 |
k: Number of contexts to retrieve
|
926 |
+
temperature: Temperature for text generation
|
927 |
+
enable_reflection: Whether to use reflection to improve answers
|
928 |
|
929 |
Returns:
|
930 |
+
Answer text
|
931 |
"""
|
932 |
+
if not query or query.strip() == "":
|
933 |
+
return "Please enter a query."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
934 |
|
935 |
+
# Check if system is initialized
|
936 |
+
if not self.rag_system.is_initialized:
|
937 |
+
return "Documents have not been processed yet. Please process documents first."
|
|
|
|
|
|
|
|
|
|
|
938 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
939 |
try:
|
940 |
+
# For streaming, we need to handle gradio uniqueness
|
941 |
+
if stream_output:
|
942 |
+
# We can't stream with reflection
|
943 |
+
return self.rag_system.answer_query(
|
944 |
+
query,
|
945 |
+
k=k,
|
946 |
+
max_tokens=1024,
|
947 |
+
temperature=temperature,
|
948 |
+
stream=True,
|
949 |
+
enable_reflection=False
|
950 |
+
)
|
951 |
+
else:
|
952 |
+
return self.rag_system.answer_query(
|
953 |
+
query,
|
954 |
+
k=k,
|
955 |
+
max_tokens=1024,
|
956 |
+
temperature=temperature,
|
957 |
+
stream=False,
|
958 |
+
enable_reflection=enable_reflection
|
959 |
+
)
|
960 |
except Exception as e:
|
961 |
+
print(f"Error in answer_query: {str(e)}")
|
962 |
+
print(traceback.format_exc())
|
963 |
+
return f"Error processing your query: {str(e)}"
|
964 |
+
|
965 |
+
def launch(self):
|
966 |
+
"""Launch Gradio UI"""
|
967 |
+
try:
|
968 |
+
with gr.Blocks(title="Document Q&A System") as self.interface:
|
969 |
+
gr.Markdown("# PDF Question Answering System")
|
970 |
+
gr.Markdown("Upload PDF documents and ask questions about their content.")
|
971 |
+
|
972 |
+
with gr.Tab("Upload & Process"):
|
973 |
+
with gr.Row():
|
974 |
+
with gr.Column():
|
975 |
+
upload_button = gr.File(
|
976 |
+
label="Upload PDF Files",
|
977 |
+
file_count="multiple",
|
978 |
+
file_types=[".pdf"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
979 |
)
|
980 |
+
upload_output = gr.Textbox(
|
981 |
+
label="Upload Status",
|
982 |
+
interactive=False
|
|
|
|
|
|
|
983 |
)
|
984 |
+
upload_btn = gr.Button("Upload Files")
|
985 |
+
|
986 |
+
with gr.Column():
|
987 |
+
pdf_list = gr.Textbox(
|
988 |
+
label="Uploaded PDFs",
|
989 |
+
value=self._list_uploaded_pdfs(),
|
990 |
+
interactive=False
|
991 |
)
|
992 |
+
refresh_btn = gr.Button("Refresh List")
|
993 |
+
|
994 |
+
process_btn = gr.Button("Process Documents")
|
995 |
+
process_output = gr.Textbox(
|
996 |
+
label="Processing Status",
|
997 |
+
interactive=False
|
998 |
+
)
|
999 |
+
|
1000 |
+
with gr.Tab("Ask Questions"):
|
1001 |
+
with gr.Row():
|
1002 |
+
with gr.Column():
|
1003 |
+
query_input = gr.Textbox(
|
1004 |
+
label="Enter your question",
|
1005 |
+
placeholder="What are the main findings of the report?",
|
1006 |
+
lines=2
|
1007 |
+
)
|
1008 |
+
|
1009 |
+
with gr.Row():
|
1010 |
+
k_slider = gr.Slider(
|
1011 |
+
minimum=1,
|
1012 |
+
maximum=10,
|
1013 |
+
value=4,
|
1014 |
+
step=1,
|
1015 |
+
label="Number of contexts to retrieve"
|
1016 |
+
)
|
1017 |
+
temp_slider = gr.Slider(
|
1018 |
+
minimum=0.1,
|
1019 |
+
maximum=1.0,
|
1020 |
+
value=0.7,
|
1021 |
+
step=0.1,
|
1022 |
+
label="Temperature"
|
1023 |
+
)
|
1024 |
+
|
1025 |
+
with gr.Row():
|
1026 |
+
stream_checkbox = gr.Checkbox(
|
1027 |
+
label="Stream output",
|
1028 |
+
value=True
|
1029 |
+
)
|
1030 |
+
reflection_checkbox = gr.Checkbox(
|
1031 |
+
label="Use self-reflection (disables streaming)",
|
1032 |
+
value=True
|
1033 |
+
)
|
1034 |
+
|
1035 |
+
query_btn = gr.Button("Submit Question")
|
1036 |
+
|
1037 |
+
answer_output = gr.Textbox(
|
1038 |
+
label="Answer",
|
1039 |
+
interactive=False,
|
1040 |
+
lines=15
|
1041 |
+
)
|
1042 |
+
|
1043 |
+
# Event handlers
|
1044 |
+
upload_btn.click(
|
1045 |
+
fn=self.upload_pdf,
|
1046 |
+
inputs=[upload_button],
|
1047 |
+
outputs=[upload_output]
|
1048 |
)
|
1049 |
+
|
1050 |
+
refresh_btn.click(
|
1051 |
+
fn=lambda: self._list_uploaded_pdfs(),
|
1052 |
+
inputs=[],
|
1053 |
+
outputs=[pdf_list]
|
1054 |
+
)
|
1055 |
+
|
1056 |
+
process_btn.click(
|
1057 |
+
fn=self.process_documents,
|
1058 |
+
inputs=[],
|
1059 |
+
outputs=[process_output]
|
1060 |
)
|
1061 |
+
|
1062 |
+
query_btn.click(
|
1063 |
+
fn=self.answer_query,
|
1064 |
+
inputs=[query_input, stream_checkbox, k_slider, temp_slider, reflection_checkbox],
|
1065 |
+
outputs=[answer_output]
|
1066 |
+
)
|
1067 |
+
|
1068 |
+
# Checkbox dependency
|
1069 |
+
def update_stream_state(reflection_enabled):
|
1070 |
+
return not reflection_enabled if reflection_enabled else gr.update()
|
1071 |
+
|
1072 |
+
reflection_checkbox.change(
|
1073 |
+
fn=update_stream_state,
|
1074 |
+
inputs=[reflection_checkbox],
|
1075 |
+
outputs=[stream_checkbox]
|
1076 |
+
)
|
1077 |
+
|
1078 |
+
# Launch UI
|
1079 |
+
self.interface.launch(share=True)
|
1080 |
+
|
1081 |
+
except Exception as e:
|
1082 |
+
print(f"Error launching UI: {str(e)}")
|
1083 |
+
print(traceback.format_exc())
|
1084 |
|
1085 |
# === MAIN APPLICATION ===
|
1086 |
def main():
|
1087 |
+
# Initialize components
|
1088 |
+
print("Initializing PDF processor...")
|
1089 |
+
pdf_processor = PDFProcessor()
|
1090 |
+
|
1091 |
+
print("Initializing vector database manager...")
|
1092 |
+
vector_db = VectorDBManager()
|
1093 |
+
|
1094 |
+
print("Initializing Mistral model...")
|
1095 |
+
model = MistralModel()
|
1096 |
+
|
1097 |
+
print("Initializing RAG system...")
|
1098 |
+
rag_system = RAGSystem(pdf_processor, vector_db, model)
|
1099 |
+
|
1100 |
+
print("Initializing UI...")
|
1101 |
+
ui = RAGUI(rag_system)
|
1102 |
+
|
1103 |
+
print("Launching UI...")
|
1104 |
+
ui.launch()
|
|
|
1105 |
|
1106 |
if __name__ == "__main__":
|
1107 |
main()
|