Spaces:
Build error
Build error
import os | |
import subprocess | |
subprocess.run(["pip", "install", "-q", "pymupdf", "langchain", "langchain_community", "sentence-transformers", "faiss-cpu", "llama-cpp-python", "gradio", "transformers", "rank_bm25"], check=True) | |
subprocess.run(["curl", "--proto", "=https", "--tlsv1.2", "-sSf", "https://sh.rustup.rs | sh"], check=True) | |
subprocess.run("source $HOME/.cargo/env", shell=True, check=True) | |
subprocess.run(["pip", "install", "-q", "git+https://github.com/chroma-core/chroma.git"], check=True) | |
subprocess.run(["wget", "-q", "-O", "models/mistral-7b-instruct-v0.3.Q8_0.gguf", "https://huggingface.co/MaziyarPanahi/Mistral-7B-Instruct-v0.3-GGUF/resolve/main/Mistral-7B-Instruct-v0.3.Q8_0.gguf"]) | |
os.makedirs("pdfs", exist_ok=True) | |
os.makedirs("models", exist_ok=True) | |
import re | |
import fitz # PyMuPDF | |
import numpy as np | |
import gc | |
import torch | |
import time | |
import shutil | |
import hashlib | |
import pickle | |
import traceback | |
from typing import List, Dict, Any, Tuple, Optional, Union, Generator | |
from dataclasses import dataclass | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.vectorstores import FAISS | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from llama_cpp import Llama | |
import gradio as gr | |
from rank_bm25 import BM25Okapi | |
import nltk | |
from nltk.tokenize import word_tokenize | |
from nltk.corpus import stopwords | |
from sentence_transformers import CrossEncoder | |
# Download nltk resources | |
try: | |
nltk.download('punkt', quiet=True) | |
nltk.download('stopwords', quiet=True) | |
except: | |
print("Failed to download NLTK resources, continuing without them") | |
# === MEMORY MANAGEMENT UTILITIES === | |
def clear_memory(): | |
"""Clear memory to prevent OOM errors""" | |
gc.collect() | |
torch.cuda.empty_cache() if torch.cuda.is_available() else None | |
# === PDF PROCESSING === | |
class PDFChunk: | |
"""Class to represent a chunk of text extracted from a PDF""" | |
text: str | |
source: str | |
page_num: int | |
chunk_id: int | |
class PDFProcessor: | |
def __init__(self, pdf_dir: str = "pdfs"): | |
"""Initialize PDF processor | |
Args: | |
pdf_dir: Directory containing PDF files | |
""" | |
self.pdf_dir = pdf_dir | |
# Smaller chunk size with more overlap for better retrieval | |
self.text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=384, | |
chunk_overlap=288, # 75% overlap for better context preservation | |
length_function=len, | |
is_separator_regex=False, | |
) | |
# Create cache directory | |
self.cache_dir = os.path.join(os.getcwd(), "pdf_cache") | |
os.makedirs(self.cache_dir, exist_ok=True) | |
def list_pdfs(self) -> List[str]: | |
"""List all PDF files in the directory""" | |
if not os.path.exists(self.pdf_dir): | |
return [] | |
return [f for f in os.listdir(self.pdf_dir) if f.lower().endswith('.pdf')] | |
def _get_cache_path(self, pdf_path: str) -> str: | |
"""Get the cache file path for a PDF""" | |
pdf_hash = hashlib.md5(open(pdf_path, 'rb').read(8192)).hexdigest() | |
return os.path.join(self.cache_dir, f"{os.path.basename(pdf_path)}_{pdf_hash}.pkl") | |
def _is_cached(self, pdf_path: str) -> bool: | |
"""Check if a PDF is cached""" | |
cache_path = self._get_cache_path(pdf_path) | |
return os.path.exists(cache_path) | |
def _load_from_cache(self, pdf_path: str) -> List[PDFChunk]: | |
"""Load chunks from cache""" | |
cache_path = self._get_cache_path(pdf_path) | |
try: | |
with open(cache_path, 'rb') as f: | |
return pickle.load(f) | |
except: | |
return None | |
def _save_to_cache(self, pdf_path: str, chunks: List[PDFChunk]) -> None: | |
"""Save chunks to cache""" | |
cache_path = self._get_cache_path(pdf_path) | |
try: | |
with open(cache_path, 'wb') as f: | |
pickle.dump(chunks, f) | |
except Exception as e: | |
print(f"Warning: Failed to cache PDF {pdf_path}: {str(e)}") | |
def clean_text(self, text: str) -> str: | |
"""Clean extracted text""" | |
# Remove excessive whitespace | |
text = re.sub(r'\s+', ' ', text).strip() | |
# Remove header/footer patterns (common in PDFs) | |
text = re.sub(r'(?<!\w)page \d+(?!\w)', '', text, flags=re.IGNORECASE) | |
return text | |
def extract_text_from_pdf(self, pdf_path: str) -> List[PDFChunk]: | |
"""Extract text content from a PDF file with improved extraction | |
Args: | |
pdf_path: Path to the PDF file | |
Returns: | |
List of PDFChunk objects extracted from the PDF | |
""" | |
# Check cache first | |
if self._is_cached(pdf_path): | |
cached_chunks = self._load_from_cache(pdf_path) | |
if cached_chunks: | |
print(f"Loaded {len(cached_chunks)} chunks from cache for {os.path.basename(pdf_path)}") | |
return cached_chunks | |
try: | |
doc = fitz.open(pdf_path) | |
pdf_chunks = [] | |
pdf_name = os.path.basename(pdf_path) | |
for page_num in range(len(doc)): | |
page = doc.load_page(page_num) | |
# Extract text with more options for better quality | |
page_text = page.get_text("text", sort=True) | |
# Try to extract text with alternative layout analysis if the text is too short | |
if len(page_text) < 100: | |
try: | |
page_text = page.get_text("dict", sort=True) | |
# Convert dict to text | |
if isinstance(page_text, dict) and "blocks" in page_text: | |
extracted_text = "" | |
for block in page_text["blocks"]: | |
if "lines" in block: | |
for line in block["lines"]: | |
if "spans" in line: | |
for span in line["spans"]: | |
if "text" in span: | |
extracted_text += span["text"] + " " | |
page_text = extracted_text | |
except: | |
# Fallback to default extraction | |
page_text = page.get_text("text") | |
# Clean the text | |
page_text = self.clean_text(page_text) | |
# Extract tables | |
try: | |
tables = page.find_tables() | |
if tables and hasattr(tables, "tables"): | |
for table in tables.tables: | |
table_text = "" | |
for i, row in enumerate(table.rows): | |
row_cells = [] | |
for cell in row.cells: | |
if hasattr(cell, "rect"): | |
cell_text = page.get_text("text", clip=cell.rect) | |
cell_text = self.clean_text(cell_text) | |
row_cells.append(cell_text) | |
if row_cells: | |
table_text += " | ".join(row_cells) + "\n" | |
# Add table text to page text | |
if table_text.strip(): | |
page_text += "\n\nTABLE:\n" + table_text | |
except Exception as table_err: | |
print(f"Warning: Skipping table extraction for page {page_num}: {str(table_err)}") | |
# Split the page text into chunks | |
if page_text.strip(): | |
page_chunks = self.text_splitter.split_text(page_text) | |
# Create PDFChunk objects | |
for i, chunk_text in enumerate(page_chunks): | |
pdf_chunks.append(PDFChunk( | |
text=chunk_text, | |
source=pdf_name, | |
page_num=page_num + 1, # 1-based page numbering for humans | |
chunk_id=i | |
)) | |
# Clear memory periodically | |
if page_num % 10 == 0: | |
clear_memory() | |
doc.close() | |
# Cache the results | |
self._save_to_cache(pdf_path, pdf_chunks) | |
return pdf_chunks | |
except Exception as e: | |
print(f"Error extracting text from {pdf_path}: {str(e)}") | |
return [] | |
def process_pdf(self, pdf_name: str) -> List[PDFChunk]: | |
"""Process a single PDF file and extract chunks | |
Args: | |
pdf_name: Name of the PDF file in the pdf_dir | |
Returns: | |
List of PDFChunk objects from the PDF | |
""" | |
pdf_path = os.path.join(self.pdf_dir, pdf_name) | |
return self.extract_text_from_pdf(pdf_path) | |
def process_all_pdfs(self, batch_size: int = 3) -> List[PDFChunk]: | |
"""Process all PDFs in batches to manage memory | |
Args: | |
batch_size: Number of PDFs to process in each batch | |
Returns: | |
List of all PDFChunk objects from all PDFs | |
""" | |
all_chunks = [] | |
pdf_files = self.list_pdfs() | |
if not pdf_files: | |
print("No PDF files found in the directory.") | |
return [] | |
# Process PDFs in batches | |
for i in range(0, len(pdf_files), batch_size): | |
batch = pdf_files[i:i+batch_size] | |
print(f"Processing batch {i//batch_size + 1}/{(len(pdf_files)-1)//batch_size + 1}") | |
for pdf_name in batch: | |
print(f"Processing {pdf_name}") | |
chunks = self.process_pdf(pdf_name) | |
all_chunks.extend(chunks) | |
print(f"Extracted {len(chunks)} chunks from {pdf_name}") | |
# Clear memory after each batch | |
clear_memory() | |
return all_chunks | |
# === VECTOR DATABASE SETUP === | |
class VectorDBManager: | |
def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2"): | |
"""Initialize vector database manager | |
Args: | |
model_name: Name of the embedding model | |
""" | |
# Initialize embedding model with normalization | |
try: | |
self.embedding_model = HuggingFaceEmbeddings( | |
model_name=model_name, | |
model_kwargs={"device": "cpu"}, | |
encode_kwargs={"normalize_embeddings": True} | |
) | |
except Exception as e: | |
print(f"Error initializing embedding model {model_name}: {str(e)}") | |
print("Falling back to all-MiniLM-L6-v2 model") | |
self.embedding_model = HuggingFaceEmbeddings( | |
model_name="sentence-transformers/all-MiniLM-L6-v2", | |
model_kwargs={"device": "cpu"}, | |
encode_kwargs={"normalize_embeddings": True} | |
) | |
# Initialize cross-encoder for re-ranking | |
try: | |
self.cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') | |
except Exception as e: | |
print(f"Error initializing cross-encoder: {str(e)}") | |
self.cross_encoder = None | |
self.vectordb = None | |
# BM25 index for hybrid search | |
self.bm25_index = None | |
self.chunks = [] | |
self.tokenized_chunks = [] | |
def _prepare_bm25(self, chunks: List[PDFChunk]): | |
"""Prepare BM25 index for hybrid search""" | |
# Tokenize chunks for BM25 | |
try: | |
tokenized_chunks = [] | |
for chunk in chunks: | |
# Tokenize and remove stopwords | |
tokens = word_tokenize(chunk.text.lower()) | |
stop_words = set(stopwords.words('english')) | |
filtered_tokens = [w for w in tokens if w.isalnum() and w not in stop_words] | |
tokenized_chunks.append(filtered_tokens) | |
# Create BM25 index | |
self.bm25_index = BM25Okapi(tokenized_chunks) | |
self.tokenized_chunks = tokenized_chunks | |
except Exception as e: | |
print(f"Error creating BM25 index: {str(e)}") | |
print(traceback.format_exc()) | |
self.bm25_index = None | |
def create_vector_db(self, chunks: List[PDFChunk]) -> None: | |
"""Create vector database from text chunks | |
Args: | |
chunks: List of PDFChunk objects | |
""" | |
try: | |
if not chunks or len(chunks) == 0: | |
print("ERROR: No chunks provided to create vector database") | |
return | |
print(f"Creating vector DB with {len(chunks)} chunks") | |
# Store chunks for hybrid search | |
self.chunks = chunks | |
# Prepare data for vector DB | |
chunk_texts = [chunk.text for chunk in chunks] | |
# Create BM25 index for hybrid search | |
print("Creating BM25 index for hybrid search") | |
self._prepare_bm25(chunks) | |
# Process in smaller batches to manage memory | |
batch_size = 32 | |
all_embeddings = [] | |
for i in range(0, len(chunk_texts), batch_size): | |
batch = chunk_texts[i:i+batch_size] | |
print(f"Embedding batch {i//batch_size + 1}/{(len(chunk_texts)-1)//batch_size + 1}") | |
# Generate embeddings for the batch | |
batch_embeddings = self.embedding_model.embed_documents(batch) | |
all_embeddings.extend(batch_embeddings) | |
# Clear memory after each batch | |
clear_memory() | |
# Create FAISS index | |
print(f"Creating FAISS index with {len(all_embeddings)} embeddings") | |
self.vectordb = FAISS.from_embeddings( | |
text_embeddings=list(zip(chunk_texts, all_embeddings)), | |
embedding=self.embedding_model | |
) | |
print(f"Vector database created with {len(chunks)} documents") | |
except Exception as e: | |
print(f"Error creating vector database: {str(e)}") | |
print(traceback.format_exc()) | |
raise | |
def _format_chunk_with_metadata(self, chunk: PDFChunk) -> str: | |
"""Format a chunk with its metadata for better context""" | |
return f"Source: {chunk.source} | Page: {chunk.page_num}\n\n{chunk.text}" | |
def _rerank_with_cross_encoder(self, query: str, chunks: List[PDFChunk], k: int = 5) -> List[PDFChunk]: | |
"""Re-rank chunks using cross-encoder | |
Args: | |
query: User query | |
chunks: List of retrieved chunks | |
k: Number of top chunks to return | |
Returns: | |
Re-ranked chunks | |
""" | |
if not self.cross_encoder or not chunks: | |
return chunks[:k] if len(chunks) > k else chunks | |
try: | |
# Prepare passage pairs for re-ranking | |
pairs = [[query, chunk.text] for chunk in chunks] | |
# Score passages in smaller batches to prevent OOM | |
batch_size = 16 | |
all_scores = [] | |
for i in range(0, len(pairs), batch_size): | |
batch_pairs = pairs[i:i+batch_size] | |
batch_scores = self.cross_encoder.predict(batch_pairs) | |
all_scores.extend(batch_scores) | |
# Clear memory | |
clear_memory() | |
# Create chunk-score pairs | |
scored_chunks = list(zip(chunks, all_scores)) | |
# Sort by score | |
scored_chunks.sort(key=lambda x: x[1], reverse=True) | |
# Return top k chunks | |
return [chunk for chunk, score in scored_chunks[:k]] | |
except Exception as e: | |
print(f"Error during cross-encoder re-ranking: {str(e)}") | |
# Fallback to original chunks | |
return chunks[:k] if len(chunks) > k else chunks | |
def hybrid_search(self, query: str, k: int = 5, alpha: float = 0.7) -> List[str]: | |
"""Hybrid search combining vector search and BM25 with cross-encoder re-ranking | |
Args: | |
query: Query text | |
k: Number of results to return | |
alpha: Weight for vector search (1-alpha for BM25) | |
Returns: | |
List of formatted documents | |
""" | |
if self.vectordb is None: | |
print("Vector database not initialized") | |
return [] | |
try: | |
# Get vector search results | |
vector_results = self.vectordb.similarity_search(query, k=k*3) # Get more for re-ranking | |
vector_texts = [doc.page_content for doc in vector_results] | |
retrieved_chunks = [] | |
# Combine with BM25 if available | |
if self.bm25_index is not None: | |
try: | |
# Tokenize query for BM25 | |
query_tokens = word_tokenize(query.lower()) | |
stop_words = set(stopwords.words('english')) | |
filtered_query = [w for w in query_tokens if w.isalnum() and w not in stop_words] | |
# Get BM25 scores | |
bm25_scores = self.bm25_index.get_scores(filtered_query) | |
# Combine scores (normalized) | |
combined_results = [] | |
seen_texts = set() | |
# First add vector results with their positions as scores | |
for i, text in enumerate(vector_texts): | |
if text not in seen_texts: | |
seen_texts.add(text) | |
# Find corresponding chunk | |
for j, chunk in enumerate(self.chunks): | |
if chunk.text == text: | |
# Combine scores: alpha * vector_score + (1-alpha) * bm25_score | |
# For vector, use inverse of position as score (normalized) | |
vector_score = 1.0 - (i / len(vector_texts)) | |
# Normalize BM25 score | |
bm25_score = bm25_scores[j] / max(bm25_scores) if max(bm25_scores) > 0 else 0 | |
combined_score = alpha * vector_score + (1-alpha) * bm25_score | |
combined_results.append((chunk, combined_score)) | |
break | |
# Sort by combined score | |
combined_results.sort(key=lambda x: x[1], reverse=True) | |
# Get top k*2 results for re-ranking | |
retrieved_chunks = [item[0] for item in combined_results[:k*2]] | |
except Exception as e: | |
print(f"Error in BM25 scoring: {str(e)}") | |
# Fallback to vector search results | |
retrieved_chunks = [self.chunks[i] for i, text in enumerate(self.chunks) | |
if text.text in vector_texts[:k*2]] | |
else: | |
# Just use vector search results if BM25 is not available | |
retrieved_chunks = [self.chunks[i] for i, chunk in enumerate(self.chunks) | |
if chunk.text in vector_texts[:k*2]] | |
# Re-rank with cross-encoder | |
if retrieved_chunks: | |
reranked_chunks = self._rerank_with_cross_encoder(query, retrieved_chunks, k) | |
# Format results with metadata | |
final_results = [self._format_chunk_with_metadata(chunk) for chunk in reranked_chunks] | |
else: | |
# Fallback to basic results | |
final_results = vector_texts[:k] | |
return final_results | |
except Exception as e: | |
print(f"Error during hybrid search: {str(e)}") | |
return [] | |
# === QUERY EXPANSION === | |
class QueryExpander: | |
def __init__(self, llm_model): | |
"""Initialize query expander | |
Args: | |
llm_model: LLM model for query expansion | |
""" | |
self.llm = llm_model | |
def expand_query(self, query: str) -> str: | |
"""Expand the query using the LLM to improve retrieval | |
Args: | |
query: Original query | |
Returns: | |
Expanded query | |
""" | |
try: | |
prompt = f"""<s>[INST] I need to search for documents related to this question: "{query}" | |
Please help me expand this query by identifying key concepts, synonyms, and related terms that might be used in the documents. | |
Return only the expanded search query, without any explanations or additional text. [/INST]""" | |
expanded = self.llm.generate(prompt, max_tokens=100, temperature=0.3) | |
# Combine original and expanded | |
combined = f"{query} {expanded}" | |
# Limit length | |
if len(combined) > 300: | |
combined = combined[:300] | |
return combined | |
except: | |
# Return original query if expansion fails | |
return query | |
# === LLM SETUP === | |
class MistralModel: | |
def __init__(self, model_path: str = "models/mistral-7b-instruct-v0.3.Q8_0.gguf"): | |
"""Initialize Mistral model | |
Args: | |
model_path: Path to the model file | |
""" | |
try: | |
# Initialize Mistral with llama.cpp | |
self.llm = Llama( | |
model_path=model_path, | |
n_ctx=4096, # Increased context window for better reasoning | |
n_batch=256, # Batch size to save memory | |
n_gpu_layers=0, # Run on CPU only for Colab free tier | |
verbose=False | |
) | |
except Exception as e: | |
print(f"Error initializing Mistral model: {str(e)}") | |
raise | |
def generate(self, prompt: str, | |
max_tokens: int = 512, | |
temperature: float = 0.7, | |
top_p: float = 0.9, | |
stream: bool = False) -> Union[str, Generator[str, None, None]]: | |
"""Generate text using Mistral | |
Args: | |
prompt: Input prompt | |
max_tokens: Maximum number of tokens to generate | |
temperature: Sampling temperature | |
top_p: Top-p sampling parameter | |
stream: Whether to stream the output | |
Returns: | |
Generated text or generator if streaming | |
""" | |
try: | |
if stream: | |
return self._generate_stream(prompt, max_tokens, temperature, top_p) | |
else: | |
output = self.llm( | |
prompt, | |
max_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
echo=False | |
) | |
return output["choices"][0]["text"].strip() | |
except Exception as e: | |
print(f"Error generating text: {str(e)}") | |
return "Error: Could not generate response." | |
def _generate_stream(self, prompt: str, | |
max_tokens: int = 512, | |
temperature: float = 0.7, | |
top_p: float = 0.9) -> Generator[str, None, None]: | |
"""Stream text generation using Mistral | |
Args: | |
prompt: Input prompt | |
max_tokens: Maximum number of tokens to generate | |
temperature: Sampling temperature | |
top_p: Top-p sampling parameter | |
Yields: | |
Generated text tokens | |
""" | |
response = "" | |
for output in self.llm( | |
prompt, | |
max_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
echo=False, | |
stream=True | |
): | |
token = output["choices"][0]["text"] | |
response += token | |
yield response | |
# === SELF-CHECKING === | |
class SelfChecker: | |
def __init__(self, llm_model): | |
"""Initialize self-checker for improved response quality | |
Args: | |
llm_model: LLM model to use for checking | |
""" | |
self.llm = llm_model | |
def check_answer(self, query: str, initial_answer: str, contexts: List[str]) -> str: | |
"""Check if answer is correct and complete based on the contexts | |
Args: | |
query: User query | |
initial_answer: Initial generated answer | |
contexts: Retrieved contexts used to generate the answer | |
Returns: | |
Improved answer after reflection | |
""" | |
# Guard against very long inputs that could cause runtime disconnection | |
# Limit contexts to prevent excessive token usage | |
max_contexts_len = 4000 | |
contexts_text = "\n\n".join(contexts) | |
if len(contexts_text) > max_contexts_len: | |
# Truncate while keeping as many complete contexts as possible | |
truncated_contexts = [] | |
current_len = 0 | |
for ctx in contexts: | |
if current_len + len(ctx) + 2 <= max_contexts_len: | |
truncated_contexts.append(ctx) | |
current_len += len(ctx) + 2 | |
else: | |
break | |
contexts_text = "\n\n".join(truncated_contexts) | |
# Check if we should skip reflection to prevent disconnection | |
if len(initial_answer) + len(contexts_text) + len(query) > 6000: | |
print("Skipping reflection due to excessive input length") | |
return initial_answer | |
try: | |
prompt = f"""<s>[INST] You're an AI assistant tasked with evaluating and improving an answer to a user query. | |
QUERY: {query} | |
INITIAL ANSWER: {initial_answer} | |
AVAILABLE CONTEXTS: | |
{contexts_text} | |
First, carefully check if the initial answer: | |
1. Is factually accurate based on the provided contexts | |
2. Addresses all aspects of the user's query | |
3. Contains any information not supported by the contexts | |
4. Misses important information from the contexts | |
Then improve the answer to fix any issues identified. The final answer should: | |
- Be comprehensive and accurate based ONLY on the contexts | |
- Not include any unsupported information | |
- Be well-structured and clear | |
- Cite specific sources when appropriate (e.g., "According to [Source, Page X]...") | |
Provide ONLY the improved answer without explanations about your reasoning process. [/INST]""" | |
# We use slightly lower temperature for more focused reflection | |
improved_answer = self.llm.generate( | |
prompt, | |
max_tokens=1024, | |
temperature=0.3, | |
stream=False | |
) | |
# If reflection produced nothing useful, return original answer | |
if not improved_answer or len(improved_answer) < 10: | |
return initial_answer | |
return improved_answer | |
except Exception as e: | |
# On any error, return the original answer to ensure robustness | |
print(f"Self-check error: {str(e)}") | |
return initial_answer | |
# === RAG SYSTEM === | |
class RAGSystem: | |
def __init__(self, pdf_processor: PDFProcessor, | |
vector_db: VectorDBManager, | |
model: MistralModel): | |
"""Initialize RAG system | |
Args: | |
pdf_processor: PDF processor instance | |
vector_db: Vector database manager instance | |
model: LLM model instance | |
""" | |
self.pdf_processor = pdf_processor | |
self.vector_db = vector_db | |
self.model = model | |
self.query_expander = QueryExpander(model) | |
self.self_checker = SelfChecker(model) | |
self.is_initialized = False | |
def process_documents(self) -> bool: | |
"""Process all documents and create vector database | |
Returns: | |
True if successful, False otherwise | |
""" | |
try: | |
# Process PDFs | |
chunks = self.pdf_processor.process_all_pdfs() | |
if not chunks: | |
print("No chunks were extracted from PDFs") | |
return False | |
print(f"Total chunks extracted: {len(chunks)}") | |
# Create vector database | |
print("Creating vector database...") | |
self.vector_db.create_vector_db(chunks) | |
# Verify success | |
if self.vector_db.vectordb is None: | |
print("Failed to create vector database") | |
return False | |
# Set initialization flag | |
self.is_initialized = True | |
return True | |
except Exception as e: | |
print(f"Error processing documents: {str(e)}") | |
print(traceback.format_exc()) | |
return False | |
def generate_prompt(self, query: str, contexts: List[str]) -> str: | |
"""Generate prompt for the LLM with better instructions | |
Args: | |
query: User query | |
contexts: Retrieved contexts | |
Returns: | |
Formatted prompt | |
""" | |
# Format contexts with numbering for better reference | |
formatted_contexts = "" | |
for i, context in enumerate(contexts): | |
formatted_contexts += f"[CONTEXT {i+1}]\n{context}\n\n" | |
# Create prompt with Mistral's chat format | |
prompt = f"""<s>[INST] You are an AI assistant that answers questions based on the provided context information. | |
User Query: {query} | |
Below are relevant passages from documents that might help answer the query: | |
{formatted_contexts} | |
Using ONLY the information provided in the context above, provide a comprehensive answer to the user's query. | |
If the provided context doesn't contain relevant information to answer the query, clearly state: "I don't have enough information in the provided context to answer this question." | |
Do not use any prior knowledge that is not contained in the provided context. | |
If quoting from the context, mention the source document and page number. | |
Organize your answer in a clear, coherent manner. [/INST]""" | |
return prompt | |
def answer_query(self, query: str, k: int = 5, max_tokens: int = 512, | |
temperature: float = 0.7, stream: bool = False, enable_reflection: bool = True) -> Union[str, Generator[str, None, None]]: | |
"""Answer a query using RAG with query expansion and self-checking | |
Args: | |
query: User query | |
k: Number of contexts to retrieve | |
max_tokens: Maximum number of tokens to generate | |
temperature: Temperature for generation | |
stream: Whether to stream the output | |
enable_reflection: Whether to enable self-reflection for better answers | |
Returns: | |
Answer text or generator if streaming | |
""" | |
# Check if system is initialized | |
if not self.is_initialized or self.vector_db.vectordb is None: | |
return "Error: Documents have not been processed yet. Please process documents first." | |
try: | |
# Expand query for better retrieval | |
expanded_query = self.query_expander.expand_query(query) | |
print(f"Expanded query: {expanded_query}") | |
# Retrieve relevant contexts using hybrid search | |
contexts = self.vector_db.hybrid_search(expanded_query, k=k) | |
if not contexts: | |
return "No relevant information found in the documents. Please try a different query or check if documents were processed correctly." | |
# Generate prompt with improved instructions | |
prompt = self.generate_prompt(query, contexts) | |
# For streaming, we can't do self-checking | |
if stream: | |
return self.model.generate( | |
prompt, | |
max_tokens=max_tokens, | |
temperature=temperature, | |
stream=True | |
) | |
# Generate initial answer | |
initial_answer = self.model.generate( | |
prompt, | |
max_tokens=max_tokens, | |
temperature=temperature, | |
stream=False | |
) | |
# Perform self-checking if enabled and initial answer exists | |
if enable_reflection and initial_answer and len(initial_answer) > 10: | |
try: | |
print("Performing self-checking to improve answer quality...") | |
improved_answer = self.self_checker.check_answer(query, initial_answer, contexts) | |
return improved_answer | |
except Exception as e: | |
print(f"Error during self-checking: {str(e)}") | |
# Fallback to initial answer if self-checking fails | |
return initial_answer | |
else: | |
return initial_answer | |
except Exception as e: | |
print(f"Error answering query: {str(e)}") | |
print(traceback.format_exc()) | |
return f"Error processing your query: {str(e)}" | |
# === GRADIO UI === | |
class RAGUI: | |
def __init__(self, rag_system: RAGSystem): | |
"""Initialize RAG UI | |
Args: | |
rag_system: RAG system instance | |
""" | |
self.rag_system = rag_system | |
self.pdf_dir = rag_system.pdf_processor.pdf_dir | |
self.interface = None | |
def _list_uploaded_pdfs(self) -> str: | |
"""List all uploaded PDFs""" | |
pdfs = self.rag_system.pdf_processor.list_pdfs() | |
if not pdfs: | |
return "No PDFs uploaded yet." | |
return "\n".join([f"- {pdf}" for pdf in pdfs]) | |
def upload_pdf(self, files) -> str: | |
"""Upload PDF files | |
Args: | |
files: File objects | |
Returns: | |
Status message | |
""" | |
try: | |
# Create directory if it doesn't exist | |
os.makedirs(self.pdf_dir, exist_ok=True) | |
# Copy files to pdf directory | |
for file in files: | |
shutil.copy(file.name, os.path.join(self.pdf_dir, os.path.basename(file.name))) | |
return f"Successfully uploaded {len(files)} file(s). Please process documents to make them searchable." | |
except Exception as e: | |
return f"Error uploading files: {str(e)}" | |
def process_documents(self) -> str: | |
"""Process documents and create vector database | |
Returns: | |
Status message | |
""" | |
try: | |
# Check if there are PDFs | |
pdf_files = self.rag_system.pdf_processor.list_pdfs() | |
if not pdf_files: | |
return "No PDF files uploaded. Please upload PDFs first." | |
# Process PDFs | |
start_time = time.time() | |
success = self.rag_system.process_documents() | |
process_time = time.time() - start_time | |
if success: | |
return f"Successfully processed {len(pdf_files)} PDF file(s) in {process_time:.2f} seconds. You can now ask questions." | |
else: | |
return "Failed to process documents. Check the logs for details." | |
except Exception as e: | |
return f"Error processing documents: {str(e)}" | |
def answer_query(self, query: str, stream_output: bool = True, | |
k: int = 4, temperature: float = 0.7, | |
enable_reflection: bool = True) -> str: | |
"""Answer a query using RAG | |
Args: | |
query: User query | |
stream_output: Whether to stream the output | |
k: Number of contexts to retrieve | |
temperature: Temperature for text generation | |
enable_reflection: Whether to use reflection to improve answers | |
Returns: | |
Answer text | |
""" | |
if not query or query.strip() == "": | |
return "Please enter a query." | |
# Check if system is initialized | |
if not self.rag_system.is_initialized: | |
return "Documents have not been processed yet. Please process documents first." | |
try: | |
# For streaming, we need to handle gradio uniqueness | |
if stream_output: | |
# We can't stream with reflection | |
return self.rag_system.answer_query( | |
query, | |
k=k, | |
max_tokens=1024, | |
temperature=temperature, | |
stream=True, | |
enable_reflection=False | |
) | |
else: | |
return self.rag_system.answer_query( | |
query, | |
k=k, | |
max_tokens=1024, | |
temperature=temperature, | |
stream=False, | |
enable_reflection=enable_reflection | |
) | |
except Exception as e: | |
print(f"Error in answer_query: {str(e)}") | |
print(traceback.format_exc()) | |
return f"Error processing your query: {str(e)}" | |
def launch(self): | |
"""Launch Gradio UI""" | |
try: | |
with gr.Blocks(title="Document Q&A System") as self.interface: | |
gr.Markdown("# PDF Question Answering System") | |
gr.Markdown("Upload PDF documents and ask questions about their content.") | |
with gr.Tab("Upload & Process"): | |
with gr.Row(): | |
with gr.Column(): | |
upload_button = gr.File( | |
label="Upload PDF Files", | |
file_count="multiple", | |
file_types=[".pdf"] | |
) | |
upload_output = gr.Textbox( | |
label="Upload Status", | |
interactive=False | |
) | |
upload_btn = gr.Button("Upload Files") | |
with gr.Column(): | |
pdf_list = gr.Textbox( | |
label="Uploaded PDFs", | |
value=self._list_uploaded_pdfs(), | |
interactive=False | |
) | |
refresh_btn = gr.Button("Refresh List") | |
process_btn = gr.Button("Process Documents") | |
process_output = gr.Textbox( | |
label="Processing Status", | |
interactive=False | |
) | |
with gr.Tab("Ask Questions"): | |
with gr.Row(): | |
with gr.Column(): | |
query_input = gr.Textbox( | |
label="Enter your question", | |
placeholder="What are the main findings of the report?", | |
lines=2 | |
) | |
with gr.Row(): | |
k_slider = gr.Slider( | |
minimum=1, | |
maximum=10, | |
value=4, | |
step=1, | |
label="Number of contexts to retrieve" | |
) | |
temp_slider = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.7, | |
step=0.1, | |
label="Temperature" | |
) | |
with gr.Row(): | |
stream_checkbox = gr.Checkbox( | |
label="Stream output", | |
value=True | |
) | |
reflection_checkbox = gr.Checkbox( | |
label="Use self-reflection (disables streaming)", | |
value=True | |
) | |
query_btn = gr.Button("Submit Question") | |
answer_output = gr.Textbox( | |
label="Answer", | |
interactive=False, | |
lines=15 | |
) | |
# Event handlers | |
upload_btn.click( | |
fn=self.upload_pdf, | |
inputs=[upload_button], | |
outputs=[upload_output] | |
) | |
refresh_btn.click( | |
fn=lambda: self._list_uploaded_pdfs(), | |
inputs=[], | |
outputs=[pdf_list] | |
) | |
process_btn.click( | |
fn=self.process_documents, | |
inputs=[], | |
outputs=[process_output] | |
) | |
query_btn.click( | |
fn=self.answer_query, | |
inputs=[query_input, stream_checkbox, k_slider, temp_slider, reflection_checkbox], | |
outputs=[answer_output] | |
) | |
# Checkbox dependency | |
def update_stream_state(reflection_enabled): | |
return not reflection_enabled if reflection_enabled else gr.update() | |
reflection_checkbox.change( | |
fn=update_stream_state, | |
inputs=[reflection_checkbox], | |
outputs=[stream_checkbox] | |
) | |
# Launch UI | |
self.interface.launch(share=True) | |
except Exception as e: | |
print(f"Error launching UI: {str(e)}") | |
print(traceback.format_exc()) | |
# === MAIN APPLICATION === | |
def main(): | |
# Initialize components | |
print("Initializing PDF processor...") | |
pdf_processor = PDFProcessor() | |
print("Initializing vector database manager...") | |
vector_db = VectorDBManager() | |
print("Initializing Mistral model...") | |
model = MistralModel() | |
print("Initializing RAG system...") | |
rag_system = RAGSystem(pdf_processor, vector_db, model) | |
print("Initializing UI...") | |
ui = RAGUI(rag_system) | |
print("Launching UI...") | |
ui.launch() | |
if __name__ == "__main__": | |
main() |