userisanillusion commited on
Commit
fa9395d
·
verified ·
1 Parent(s): 736f9b1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +464 -285
app.py CHANGED
@@ -1,5 +1,12 @@
 
 
 
 
 
 
1
  import os
2
  import re
 
3
  import numpy as np
4
  import gc
5
  import torch
@@ -10,31 +17,16 @@ import pickle
10
  import traceback
11
  from typing import List, Dict, Any, Tuple, Optional, Union, Generator
12
  from dataclasses import dataclass
13
- import gradio as gr
14
-
15
- # Import dependencies (no need for pip install commands)
16
- import fitz # PyMuPDF
17
  from langchain.text_splitter import RecursiveCharacterTextSplitter
18
  from langchain_community.vectorstores import FAISS
19
  from langchain_community.embeddings import HuggingFaceEmbeddings
20
  from llama_cpp import Llama
 
21
  from rank_bm25 import BM25Okapi
22
  import nltk
23
  from nltk.tokenize import word_tokenize
24
  from nltk.corpus import stopwords
25
- from huggingface_hub import hf_hub_download
26
-
27
- # Download nltk resources
28
- try:
29
- nltk.download('punkt', quiet=True)
30
- nltk.download('stopwords', quiet=True)
31
- except:
32
- print("Failed to download NLTK resources, continuing without them")
33
-
34
- # Setup directories for Spaces
35
- os.makedirs("pdfs", exist_ok=True)
36
- os.makedirs("models", exist_ok=True)
37
- os.makedirs("pdf_cache", exist_ok=True)
38
 
39
  # Download nltk resources
40
  try:
@@ -43,14 +35,6 @@ try:
43
  except:
44
  print("Failed to download NLTK resources, continuing without them")
45
 
46
- # Download model from Hugging Face Hub
47
- model_path = hf_hub_download(
48
- repo_id="TheBloke/phi-2-GGUF",
49
- filename="phi-2.Q8_0.gguf",
50
- repo_type="model",
51
- local_dir="models"
52
- )
53
-
54
  # === MEMORY MANAGEMENT UTILITIES ===
55
  def clear_memory():
56
  """Clear memory to prevent OOM errors"""
@@ -81,7 +65,7 @@ class PDFProcessor:
81
  length_function=len,
82
  is_separator_regex=False,
83
  )
84
-
85
  # Create cache directory
86
  self.cache_dir = os.path.join(os.getcwd(), "pdf_cache")
87
  os.makedirs(self.cache_dir, exist_ok=True)
@@ -96,12 +80,12 @@ class PDFProcessor:
96
  """Get the cache file path for a PDF"""
97
  pdf_hash = hashlib.md5(open(pdf_path, 'rb').read(8192)).hexdigest()
98
  return os.path.join(self.cache_dir, f"{os.path.basename(pdf_path)}_{pdf_hash}.pkl")
99
-
100
  def _is_cached(self, pdf_path: str) -> bool:
101
  """Check if a PDF is cached"""
102
  cache_path = self._get_cache_path(pdf_path)
103
  return os.path.exists(cache_path)
104
-
105
  def _load_from_cache(self, pdf_path: str) -> List[PDFChunk]:
106
  """Load chunks from cache"""
107
  cache_path = self._get_cache_path(pdf_path)
@@ -110,7 +94,7 @@ class PDFProcessor:
110
  return pickle.load(f)
111
  except:
112
  return None
113
-
114
  def _save_to_cache(self, pdf_path: str, chunks: List[PDFChunk]) -> None:
115
  """Save chunks to cache"""
116
  cache_path = self._get_cache_path(pdf_path)
@@ -143,15 +127,15 @@ class PDFProcessor:
143
  if cached_chunks:
144
  print(f"Loaded {len(cached_chunks)} chunks from cache for {os.path.basename(pdf_path)}")
145
  return cached_chunks
146
-
147
  try:
148
  doc = fitz.open(pdf_path)
149
  pdf_chunks = []
150
  pdf_name = os.path.basename(pdf_path)
151
-
152
  for page_num in range(len(doc)):
153
  page = doc.load_page(page_num)
154
-
155
  # Extract text with more options for better quality
156
  page_text = page.get_text("text", sort=True)
157
  # Try to extract text with alternative layout analysis if the text is too short
@@ -172,10 +156,10 @@ class PDFProcessor:
172
  except:
173
  # Fallback to default extraction
174
  page_text = page.get_text("text")
175
-
176
  # Clean the text
177
  page_text = self.clean_text(page_text)
178
-
179
  # Extract tables
180
  try:
181
  tables = page.find_tables()
@@ -191,17 +175,17 @@ class PDFProcessor:
191
  row_cells.append(cell_text)
192
  if row_cells:
193
  table_text += " | ".join(row_cells) + "\n"
194
-
195
  # Add table text to page text
196
  if table_text.strip():
197
  page_text += "\n\nTABLE:\n" + table_text
198
  except Exception as table_err:
199
  print(f"Warning: Skipping table extraction for page {page_num}: {str(table_err)}")
200
-
201
  # Split the page text into chunks
202
  if page_text.strip():
203
  page_chunks = self.text_splitter.split_text(page_text)
204
-
205
  # Create PDFChunk objects
206
  for i, chunk_text in enumerate(page_chunks):
207
  pdf_chunks.append(PDFChunk(
@@ -210,16 +194,16 @@ class PDFProcessor:
210
  page_num=page_num + 1, # 1-based page numbering for humans
211
  chunk_id=i
212
  ))
213
-
214
  # Clear memory periodically
215
  if page_num % 10 == 0:
216
  clear_memory()
217
-
218
  doc.close()
219
-
220
  # Cache the results
221
  self._save_to_cache(pdf_path, pdf_chunks)
222
-
223
  return pdf_chunks
224
  except Exception as e:
225
  print(f"Error extracting text from {pdf_path}: {str(e)}")
@@ -237,7 +221,7 @@ class PDFProcessor:
237
  pdf_path = os.path.join(self.pdf_dir, pdf_name)
238
  return self.extract_text_from_pdf(pdf_path)
239
 
240
- def process_all_pdfs(self, batch_size: int = 2) -> List[PDFChunk]:
241
  """Process all PDFs in batches to manage memory
242
 
243
  Args:
@@ -292,7 +276,14 @@ class VectorDBManager:
292
  model_kwargs={"device": "cpu"},
293
  encode_kwargs={"normalize_embeddings": True}
294
  )
295
-
 
 
 
 
 
 
 
296
  self.vectordb = None
297
  # BM25 index for hybrid search
298
  self.bm25_index = None
@@ -310,14 +301,15 @@ class VectorDBManager:
310
  stop_words = set(stopwords.words('english'))
311
  filtered_tokens = [w for w in tokens if w.isalnum() and w not in stop_words]
312
  tokenized_chunks.append(filtered_tokens)
313
-
314
  # Create BM25 index
315
  self.bm25_index = BM25Okapi(tokenized_chunks)
 
316
  except Exception as e:
317
  print(f"Error creating BM25 index: {str(e)}")
318
  print(traceback.format_exc())
319
  self.bm25_index = None
320
-
321
  def create_vector_db(self, chunks: List[PDFChunk]) -> None:
322
  """Create vector database from text chunks
323
 
@@ -330,19 +322,19 @@ class VectorDBManager:
330
  return
331
 
332
  print(f"Creating vector DB with {len(chunks)} chunks")
333
-
334
  # Store chunks for hybrid search
335
  self.chunks = chunks
336
-
337
  # Prepare data for vector DB
338
  chunk_texts = [chunk.text for chunk in chunks]
339
-
340
  # Create BM25 index for hybrid search
341
  print("Creating BM25 index for hybrid search")
342
  self._prepare_bm25(chunks)
343
-
344
  # Process in smaller batches to manage memory
345
- batch_size = 16 # Reduced for Spaces
346
  all_embeddings = []
347
 
348
  for i in range(0, len(chunk_texts), batch_size):
@@ -362,7 +354,7 @@ class VectorDBManager:
362
  text_embeddings=list(zip(chunk_texts, all_embeddings)),
363
  embedding=self.embedding_model
364
  )
365
-
366
  print(f"Vector database created with {len(chunks)} documents")
367
 
368
  except Exception as e:
@@ -374,8 +366,51 @@ class VectorDBManager:
374
  """Format a chunk with its metadata for better context"""
375
  return f"Source: {chunk.source} | Page: {chunk.page_num}\n\n{chunk.text}"
376
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
377
  def hybrid_search(self, query: str, k: int = 5, alpha: float = 0.7) -> List[str]:
378
- """Hybrid search combining vector search and BM25
379
 
380
  Args:
381
  query: Query text
@@ -391,11 +426,11 @@ class VectorDBManager:
391
 
392
  try:
393
  # Get vector search results
394
- vector_results = self.vectordb.similarity_search(query, k=k*2)
395
  vector_texts = [doc.page_content for doc in vector_results]
396
-
397
- final_results = []
398
-
399
  # Combine with BM25 if available
400
  if self.bm25_index is not None:
401
  try:
@@ -403,14 +438,14 @@ class VectorDBManager:
403
  query_tokens = word_tokenize(query.lower())
404
  stop_words = set(stopwords.words('english'))
405
  filtered_query = [w for w in query_tokens if w.isalnum() and w not in stop_words]
406
-
407
  # Get BM25 scores
408
  bm25_scores = self.bm25_index.get_scores(filtered_query)
409
-
410
  # Combine scores (normalized)
411
  combined_results = []
412
  seen_texts = set()
413
-
414
  # First add vector results with their positions as scores
415
  for i, text in enumerate(vector_texts):
416
  if text not in seen_texts:
@@ -424,26 +459,34 @@ class VectorDBManager:
424
  # Normalize BM25 score
425
  bm25_score = bm25_scores[j] / max(bm25_scores) if max(bm25_scores) > 0 else 0
426
  combined_score = alpha * vector_score + (1-alpha) * bm25_score
427
-
428
  combined_results.append((chunk, combined_score))
429
  break
430
-
431
  # Sort by combined score
432
  combined_results.sort(key=lambda x: x[1], reverse=True)
433
-
434
- # Get top k results
435
- top_chunks = [item[0] for item in combined_results[:k]]
436
-
437
- # Format results with metadata
438
- final_results = [self._format_chunk_with_metadata(chunk) for chunk in top_chunks]
439
  except Exception as e:
440
  print(f"Error in BM25 scoring: {str(e)}")
441
  # Fallback to vector search results
442
- final_results = vector_texts[:k]
 
443
  else:
444
  # Just use vector search results if BM25 is not available
 
 
 
 
 
 
 
 
 
 
445
  final_results = vector_texts[:k]
446
-
447
  return final_results
448
  except Exception as e:
449
  print(f"Error during hybrid search: {str(e)}")
@@ -458,7 +501,7 @@ class QueryExpander:
458
  llm_model: LLM model for query expansion
459
  """
460
  self.llm = llm_model
461
-
462
  def expand_query(self, query: str) -> str:
463
  """Expand the query using the LLM to improve retrieval
464
 
@@ -469,46 +512,44 @@ class QueryExpander:
469
  Expanded query
470
  """
471
  try:
472
- prompt = f"""I need to search for documents related to this question: "{query}"
473
-
474
  Please help me expand this query by identifying key concepts, synonyms, and related terms that might be used in the documents.
475
- Return only the expanded search query, without any explanations or additional text.
476
 
477
- Expanded query:"""
478
-
479
  expanded = self.llm.generate(prompt, max_tokens=100, temperature=0.3)
480
-
481
  # Combine original and expanded
482
  combined = f"{query} {expanded}"
483
-
484
  # Limit length
485
  if len(combined) > 300:
486
  combined = combined[:300]
487
-
488
  return combined
489
  except:
490
  # Return original query if expansion fails
491
  return query
492
 
493
  # === LLM SETUP ===
494
- class Phi2Model:
495
- def __init__(self, model_path: str = model_path):
496
- """Initialize Phi-2 model
497
 
498
  Args:
499
  model_path: Path to the model file
500
  """
501
  try:
502
- # Initialize Phi-2 with llama.cpp - optimized for Spaces
503
  self.llm = Llama(
504
  model_path=model_path,
505
- n_ctx=1024, # Reduced context window for Spaces
506
- n_batch=64, # Reduced batch size
507
- n_gpu_layers=0, # Run on CPU for compatibility
508
  verbose=False
509
  )
510
  except Exception as e:
511
- print(f"Error initializing Phi-2 model: {str(e)}")
512
  raise
513
 
514
  def generate(self, prompt: str,
@@ -516,7 +557,7 @@ class Phi2Model:
516
  temperature: float = 0.7,
517
  top_p: float = 0.9,
518
  stream: bool = False) -> Union[str, Generator[str, None, None]]:
519
- """Generate text using Phi-2
520
 
521
  Args:
522
  prompt: Input prompt
@@ -539,7 +580,7 @@ class Phi2Model:
539
  top_p=top_p,
540
  echo=False
541
  )
542
- return output["choices"][0]["text"]
543
  except Exception as e:
544
  print(f"Error generating text: {str(e)}")
545
  return "Error: Could not generate response."
@@ -548,7 +589,7 @@ class Phi2Model:
548
  max_tokens: int = 512,
549
  temperature: float = 0.7,
550
  top_p: float = 0.9) -> Generator[str, None, None]:
551
- """Stream text generation using Phi-2
552
 
553
  Args:
554
  prompt: Input prompt
@@ -572,11 +613,96 @@ class Phi2Model:
572
  response += token
573
  yield response
574
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
575
  # === RAG SYSTEM ===
576
  class RAGSystem:
577
  def __init__(self, pdf_processor: PDFProcessor,
578
  vector_db: VectorDBManager,
579
- model: Phi2Model):
580
  """Initialize RAG system
581
 
582
  Args:
@@ -588,11 +714,12 @@ class RAGSystem:
588
  self.vector_db = vector_db
589
  self.model = model
590
  self.query_expander = QueryExpander(model)
 
591
  self.is_initialized = False
592
 
593
  def process_documents(self) -> bool:
594
  """Process all documents and create vector database
595
-
596
  Returns:
597
  True if successful, False otherwise
598
  """
@@ -602,22 +729,22 @@ class RAGSystem:
602
  if not chunks:
603
  print("No chunks were extracted from PDFs")
604
  return False
605
-
606
  print(f"Total chunks extracted: {len(chunks)}")
607
 
608
  # Create vector database
609
  print("Creating vector database...")
610
  self.vector_db.create_vector_db(chunks)
611
-
612
  # Verify success
613
  if self.vector_db.vectordb is None:
614
  print("Failed to create vector database")
615
  return False
616
-
617
  # Set initialization flag
618
  self.is_initialized = True
619
  return True
620
-
621
  except Exception as e:
622
  print(f"Error processing documents: {str(e)}")
623
  print(traceback.format_exc())
@@ -638,8 +765,8 @@ class RAGSystem:
638
  for i, context in enumerate(contexts):
639
  formatted_contexts += f"[CONTEXT {i+1}]\n{context}\n\n"
640
 
641
- # Create prompt with better instructions
642
- prompt = f"""You are an AI assistant that answers questions based on the provided context information.
643
 
644
  User Query: {query}
645
 
@@ -647,27 +774,26 @@ Below are relevant passages from documents that might help answer the query:
647
 
648
  {formatted_contexts}
649
 
650
- Using ONLY the information provided in the context above, provide a comprehensive answer to the user's query.
651
  If the provided context doesn't contain relevant information to answer the query, clearly state: "I don't have enough information in the provided context to answer this question."
652
 
653
  Do not use any prior knowledge that is not contained in the provided context.
654
  If quoting from the context, mention the source document and page number.
655
- Organize your answer in a clear, coherent manner.
656
-
657
- Answer:"""
658
  return prompt
659
 
660
  def answer_query(self, query: str, k: int = 5, max_tokens: int = 512,
661
- temperature: float = 0.7, stream: bool = False) -> Union[str, Generator[str, None, None]]:
662
- """Answer a query using RAG with query expansion
663
-
664
  Args:
665
  query: User query
666
  k: Number of contexts to retrieve
667
  max_tokens: Maximum number of tokens to generate
668
  temperature: Temperature for generation
669
  stream: Whether to stream the output
670
-
 
671
  Returns:
672
  Answer text or generator if streaming
673
  """
@@ -679,7 +805,7 @@ Answer:"""
679
  # Expand query for better retrieval
680
  expanded_query = self.query_expander.expand_query(query)
681
  print(f"Expanded query: {expanded_query}")
682
-
683
  # Retrieve relevant contexts using hybrid search
684
  contexts = self.vector_db.hybrid_search(expanded_query, k=k)
685
 
@@ -689,240 +815,293 @@ Answer:"""
689
  # Generate prompt with improved instructions
690
  prompt = self.generate_prompt(query, contexts)
691
 
692
- # Generate answer
693
- return self.model.generate(
 
 
 
 
 
 
 
 
 
694
  prompt,
695
  max_tokens=max_tokens,
696
  temperature=temperature,
697
- stream=stream
698
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
699
  except Exception as e:
700
  print(f"Error answering query: {str(e)}")
701
  print(traceback.format_exc())
702
  return f"Error processing your query: {str(e)}"
703
 
704
- # === GRADIO INTERFACE ===
705
- class RAGInterface:
706
  def __init__(self, rag_system: RAGSystem):
707
- """Initialize Gradio interface
708
 
709
  Args:
710
  rag_system: RAG system instance
711
  """
712
  self.rag_system = rag_system
 
713
  self.interface = None
714
- self.is_processing = False
715
 
716
- def upload_file(self, files):
717
- """Upload PDF files"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
718
  try:
719
- os.makedirs("pdfs", exist_ok=True)
720
- uploaded_files = []
721
 
 
722
  for file in files:
723
- destination = os.path.join("pdfs", os.path.basename(file.name))
724
- shutil.copy(file.name, destination)
725
- uploaded_files.append(os.path.basename(file.name))
726
 
727
- # Verify files exist in the directory
728
- pdf_files = [f for f in os.listdir("pdfs") if f.lower().endswith('.pdf')]
729
-
730
- if not pdf_files:
731
- return "No PDF files were uploaded successfully."
732
-
733
- return f"Successfully uploaded {len(uploaded_files)} files: {', '.join(uploaded_files)}"
734
  except Exception as e:
735
  return f"Error uploading files: {str(e)}"
736
 
737
- def process_documents(self):
738
- """Process all documents
739
 
740
  Returns:
741
  Status message
742
  """
743
- if self.is_processing:
744
- return "Document processing is already in progress. Please wait."
745
-
746
  try:
747
- self.is_processing = True
 
 
 
 
 
748
  start_time = time.time()
749
-
750
  success = self.rag_system.process_documents()
751
-
752
- elapsed = time.time() - start_time
753
- self.is_processing = False
754
-
755
  if success:
756
- return f"Documents processed successfully in {elapsed:.2f} seconds."
757
  else:
758
- return "Failed to process documents. Check the logs for more information."
759
  except Exception as e:
760
- self.is_processing = False
761
  return f"Error processing documents: {str(e)}"
762
 
763
- def answer_query(self, query, k, max_tokens, temperature):
764
- """Answer a query
 
 
765
 
766
  Args:
767
  query: User query
 
768
  k: Number of contexts to retrieve
769
- max_tokens: Maximum number of tokens to generate
770
- temperature: Sampling temperature
771
 
772
  Returns:
773
- Answer
774
  """
775
- if not query.strip():
776
- return "Please enter a question."
777
-
778
- try:
779
- return self.rag_system.answer_query(
780
- query,
781
- k=k,
782
- max_tokens=max_tokens,
783
- temperature=temperature,
784
- stream=False
785
- )
786
- except Exception as e:
787
- return f"Error answering query: {str(e)}"
788
 
789
- def answer_query_stream(self, query, k, max_tokens, temperature):
790
- """Stream answer to a query
791
-
792
- Args:
793
- query: User query
794
- k: Number of contexts to retrieve
795
- max_tokens: Maximum number of tokens to generate
796
- temperature: Sampling temperature
797
 
798
- Yields:
799
- Generated text
800
- """
801
- if not query.strip():
802
- yield "Please enter a question."
803
- return
804
-
805
  try:
806
- yield from self.rag_system.answer_query(
807
- query,
808
- k=k,
809
- max_tokens=max_tokens,
810
- temperature=temperature,
811
- stream=True
812
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
813
  except Exception as e:
814
- yield f"Error answering query: {str(e)}"
815
-
816
- def create_interface(self):
817
- """Create Gradio interface"""
818
- with gr.Blocks(title="PDF RAG System") as interface:
819
- gr.Markdown("# PDF RAG System with Phi-2")
820
- gr.Markdown("Upload your PDF documents, process them, and ask questions to get answers based on the content.")
821
-
822
- with gr.Tab("Upload & Process"):
823
- with gr.Row():
824
- pdf_files = gr.File(
825
- file_count="multiple",
826
- label="Upload PDF Files",
827
- file_types=[".pdf"]
828
- )
829
- upload_button = gr.Button("Upload", variant="primary")
830
-
831
- upload_output = gr.Textbox(label="Upload Status", lines=2)
832
- upload_button.click(self.upload_file, inputs=[pdf_files], outputs=upload_output)
833
-
834
- process_button = gr.Button("Process Documents", variant="primary")
835
- process_output = gr.Textbox(label="Processing Status", lines=2)
836
- process_button.click(self.process_documents, inputs=[], outputs=process_output)
837
-
838
- with gr.Tab("Query"):
839
- with gr.Row():
840
- with gr.Column():
841
- query_input = gr.Textbox(
842
- label="Question",
843
- lines=3,
844
- placeholder="Ask a question about your documents..."
845
- )
846
- with gr.Row():
847
- k_slider = gr.Slider(
848
- minimum=1,
849
- maximum=10,
850
- value=3,
851
- step=1,
852
- label="Number of Contexts"
853
  )
854
- max_tokens_slider = gr.Slider(
855
- minimum=100,
856
- maximum=800,
857
- value=400,
858
- step=50,
859
- label="Max Tokens"
860
  )
861
- temperature_slider = gr.Slider(
862
- minimum=0.1,
863
- maximum=1.0,value=0.7,
864
- step=0.1,
865
- label="Temperature"
 
 
866
  )
867
- submit_button = gr.Button("Submit", variant="primary")
868
-
869
- answer_output = gr.Textbox(label="Answer", lines=10)
870
-
871
- submit_button.click(
872
- self.answer_query,
873
- inputs=[query_input, k_slider, max_tokens_slider, temperature_slider],
874
- outputs=answer_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
875
  )
876
-
877
- # Add streaming capability
878
- stream_button = gr.Button("Submit (Streaming)", variant="secondary")
879
- stream_button.click(
880
- self.answer_query_stream,
881
- inputs=[query_input, k_slider, max_tokens_slider, temperature_slider],
882
- outputs=answer_output
 
 
 
 
883
  )
884
-
885
- gr.Markdown("""
886
- ## Instructions
887
- 1. Upload PDF files in the 'Upload & Process' tab.
888
- 2. Click the 'Process Documents' button to extract and index content.
889
- 3. Switch to the 'Query' tab to ask questions about your documents.
890
- 4. Adjust parameters as needed:
891
- - Number of Contexts: More contexts provide more information but may be less focused.
892
- - Max Tokens: Controls the length of the response.
893
- - Temperature: Lower values (0.1-0.5) give more focused answers, higher values (0.6-1.0) give more creative answers.
894
- """)
895
-
896
- self.interface = interface
897
- return interface
898
-
899
- def launch(self, **kwargs):
900
- """Launch the Gradio interface"""
901
- if self.interface is None:
902
- self.create_interface()
903
- self.interface.launch(**kwargs)
 
 
 
904
 
905
  # === MAIN APPLICATION ===
906
  def main():
907
- """Main function to set up and launch the application"""
908
- try:
909
- # Initialize components
910
- pdf_processor = PDFProcessor(pdf_dir="pdfs")
911
- vector_db = VectorDBManager()
912
- phi2_model = Phi2Model()
913
-
914
- # Initialize RAG system
915
- rag_system = RAGSystem(pdf_processor, vector_db, phi2_model)
916
-
917
- # Create interface
918
- interface = RAGInterface(rag_system)
919
-
920
- # Launch application
921
- interface.launch(share=True)
922
-
923
- except Exception as e:
924
- print(f"Error initializing application: {str(e)}")
925
- print(traceback.format_exc())
926
 
927
  if __name__ == "__main__":
928
  main()
 
1
+ !pip install -q pymupdf langchain langchain_community sentence-transformers faiss-cpu llama-cpp-python gradio transformers rank_bm25
2
+ !pip install -q git+https://github.com/chroma-core/chroma.git
3
+ !mkdir -p pdfs
4
+ !mkdir -p models
5
+ !wget -q -O models/mistral-7b-instruct-v0.3.Q8_0.gguf https://huggingface.co/MaziyarPanahi/Mistral-7B-Instruct-v0.3-GGUF/resolve/main/Mistral-7B-Instruct-v0.3.Q8_0.gguf
6
+
7
  import os
8
  import re
9
+ import fitz # PyMuPDF
10
  import numpy as np
11
  import gc
12
  import torch
 
17
  import traceback
18
  from typing import List, Dict, Any, Tuple, Optional, Union, Generator
19
  from dataclasses import dataclass
 
 
 
 
20
  from langchain.text_splitter import RecursiveCharacterTextSplitter
21
  from langchain_community.vectorstores import FAISS
22
  from langchain_community.embeddings import HuggingFaceEmbeddings
23
  from llama_cpp import Llama
24
+ import gradio as gr
25
  from rank_bm25 import BM25Okapi
26
  import nltk
27
  from nltk.tokenize import word_tokenize
28
  from nltk.corpus import stopwords
29
+ from sentence_transformers import CrossEncoder
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  # Download nltk resources
32
  try:
 
35
  except:
36
  print("Failed to download NLTK resources, continuing without them")
37
 
 
 
 
 
 
 
 
 
38
  # === MEMORY MANAGEMENT UTILITIES ===
39
  def clear_memory():
40
  """Clear memory to prevent OOM errors"""
 
65
  length_function=len,
66
  is_separator_regex=False,
67
  )
68
+
69
  # Create cache directory
70
  self.cache_dir = os.path.join(os.getcwd(), "pdf_cache")
71
  os.makedirs(self.cache_dir, exist_ok=True)
 
80
  """Get the cache file path for a PDF"""
81
  pdf_hash = hashlib.md5(open(pdf_path, 'rb').read(8192)).hexdigest()
82
  return os.path.join(self.cache_dir, f"{os.path.basename(pdf_path)}_{pdf_hash}.pkl")
83
+
84
  def _is_cached(self, pdf_path: str) -> bool:
85
  """Check if a PDF is cached"""
86
  cache_path = self._get_cache_path(pdf_path)
87
  return os.path.exists(cache_path)
88
+
89
  def _load_from_cache(self, pdf_path: str) -> List[PDFChunk]:
90
  """Load chunks from cache"""
91
  cache_path = self._get_cache_path(pdf_path)
 
94
  return pickle.load(f)
95
  except:
96
  return None
97
+
98
  def _save_to_cache(self, pdf_path: str, chunks: List[PDFChunk]) -> None:
99
  """Save chunks to cache"""
100
  cache_path = self._get_cache_path(pdf_path)
 
127
  if cached_chunks:
128
  print(f"Loaded {len(cached_chunks)} chunks from cache for {os.path.basename(pdf_path)}")
129
  return cached_chunks
130
+
131
  try:
132
  doc = fitz.open(pdf_path)
133
  pdf_chunks = []
134
  pdf_name = os.path.basename(pdf_path)
135
+
136
  for page_num in range(len(doc)):
137
  page = doc.load_page(page_num)
138
+
139
  # Extract text with more options for better quality
140
  page_text = page.get_text("text", sort=True)
141
  # Try to extract text with alternative layout analysis if the text is too short
 
156
  except:
157
  # Fallback to default extraction
158
  page_text = page.get_text("text")
159
+
160
  # Clean the text
161
  page_text = self.clean_text(page_text)
162
+
163
  # Extract tables
164
  try:
165
  tables = page.find_tables()
 
175
  row_cells.append(cell_text)
176
  if row_cells:
177
  table_text += " | ".join(row_cells) + "\n"
178
+
179
  # Add table text to page text
180
  if table_text.strip():
181
  page_text += "\n\nTABLE:\n" + table_text
182
  except Exception as table_err:
183
  print(f"Warning: Skipping table extraction for page {page_num}: {str(table_err)}")
184
+
185
  # Split the page text into chunks
186
  if page_text.strip():
187
  page_chunks = self.text_splitter.split_text(page_text)
188
+
189
  # Create PDFChunk objects
190
  for i, chunk_text in enumerate(page_chunks):
191
  pdf_chunks.append(PDFChunk(
 
194
  page_num=page_num + 1, # 1-based page numbering for humans
195
  chunk_id=i
196
  ))
197
+
198
  # Clear memory periodically
199
  if page_num % 10 == 0:
200
  clear_memory()
201
+
202
  doc.close()
203
+
204
  # Cache the results
205
  self._save_to_cache(pdf_path, pdf_chunks)
206
+
207
  return pdf_chunks
208
  except Exception as e:
209
  print(f"Error extracting text from {pdf_path}: {str(e)}")
 
221
  pdf_path = os.path.join(self.pdf_dir, pdf_name)
222
  return self.extract_text_from_pdf(pdf_path)
223
 
224
+ def process_all_pdfs(self, batch_size: int = 3) -> List[PDFChunk]:
225
  """Process all PDFs in batches to manage memory
226
 
227
  Args:
 
276
  model_kwargs={"device": "cpu"},
277
  encode_kwargs={"normalize_embeddings": True}
278
  )
279
+
280
+ # Initialize cross-encoder for re-ranking
281
+ try:
282
+ self.cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
283
+ except Exception as e:
284
+ print(f"Error initializing cross-encoder: {str(e)}")
285
+ self.cross_encoder = None
286
+
287
  self.vectordb = None
288
  # BM25 index for hybrid search
289
  self.bm25_index = None
 
301
  stop_words = set(stopwords.words('english'))
302
  filtered_tokens = [w for w in tokens if w.isalnum() and w not in stop_words]
303
  tokenized_chunks.append(filtered_tokens)
304
+
305
  # Create BM25 index
306
  self.bm25_index = BM25Okapi(tokenized_chunks)
307
+ self.tokenized_chunks = tokenized_chunks
308
  except Exception as e:
309
  print(f"Error creating BM25 index: {str(e)}")
310
  print(traceback.format_exc())
311
  self.bm25_index = None
312
+
313
  def create_vector_db(self, chunks: List[PDFChunk]) -> None:
314
  """Create vector database from text chunks
315
 
 
322
  return
323
 
324
  print(f"Creating vector DB with {len(chunks)} chunks")
325
+
326
  # Store chunks for hybrid search
327
  self.chunks = chunks
328
+
329
  # Prepare data for vector DB
330
  chunk_texts = [chunk.text for chunk in chunks]
331
+
332
  # Create BM25 index for hybrid search
333
  print("Creating BM25 index for hybrid search")
334
  self._prepare_bm25(chunks)
335
+
336
  # Process in smaller batches to manage memory
337
+ batch_size = 32
338
  all_embeddings = []
339
 
340
  for i in range(0, len(chunk_texts), batch_size):
 
354
  text_embeddings=list(zip(chunk_texts, all_embeddings)),
355
  embedding=self.embedding_model
356
  )
357
+
358
  print(f"Vector database created with {len(chunks)} documents")
359
 
360
  except Exception as e:
 
366
  """Format a chunk with its metadata for better context"""
367
  return f"Source: {chunk.source} | Page: {chunk.page_num}\n\n{chunk.text}"
368
 
369
+ def _rerank_with_cross_encoder(self, query: str, chunks: List[PDFChunk], k: int = 5) -> List[PDFChunk]:
370
+ """Re-rank chunks using cross-encoder
371
+
372
+ Args:
373
+ query: User query
374
+ chunks: List of retrieved chunks
375
+ k: Number of top chunks to return
376
+
377
+ Returns:
378
+ Re-ranked chunks
379
+ """
380
+ if not self.cross_encoder or not chunks:
381
+ return chunks[:k] if len(chunks) > k else chunks
382
+
383
+ try:
384
+ # Prepare passage pairs for re-ranking
385
+ pairs = [[query, chunk.text] for chunk in chunks]
386
+
387
+ # Score passages in smaller batches to prevent OOM
388
+ batch_size = 16
389
+ all_scores = []
390
+
391
+ for i in range(0, len(pairs), batch_size):
392
+ batch_pairs = pairs[i:i+batch_size]
393
+ batch_scores = self.cross_encoder.predict(batch_pairs)
394
+ all_scores.extend(batch_scores)
395
+
396
+ # Clear memory
397
+ clear_memory()
398
+
399
+ # Create chunk-score pairs
400
+ scored_chunks = list(zip(chunks, all_scores))
401
+
402
+ # Sort by score
403
+ scored_chunks.sort(key=lambda x: x[1], reverse=True)
404
+
405
+ # Return top k chunks
406
+ return [chunk for chunk, score in scored_chunks[:k]]
407
+ except Exception as e:
408
+ print(f"Error during cross-encoder re-ranking: {str(e)}")
409
+ # Fallback to original chunks
410
+ return chunks[:k] if len(chunks) > k else chunks
411
+
412
  def hybrid_search(self, query: str, k: int = 5, alpha: float = 0.7) -> List[str]:
413
+ """Hybrid search combining vector search and BM25 with cross-encoder re-ranking
414
 
415
  Args:
416
  query: Query text
 
426
 
427
  try:
428
  # Get vector search results
429
+ vector_results = self.vectordb.similarity_search(query, k=k*3) # Get more for re-ranking
430
  vector_texts = [doc.page_content for doc in vector_results]
431
+
432
+ retrieved_chunks = []
433
+
434
  # Combine with BM25 if available
435
  if self.bm25_index is not None:
436
  try:
 
438
  query_tokens = word_tokenize(query.lower())
439
  stop_words = set(stopwords.words('english'))
440
  filtered_query = [w for w in query_tokens if w.isalnum() and w not in stop_words]
441
+
442
  # Get BM25 scores
443
  bm25_scores = self.bm25_index.get_scores(filtered_query)
444
+
445
  # Combine scores (normalized)
446
  combined_results = []
447
  seen_texts = set()
448
+
449
  # First add vector results with their positions as scores
450
  for i, text in enumerate(vector_texts):
451
  if text not in seen_texts:
 
459
  # Normalize BM25 score
460
  bm25_score = bm25_scores[j] / max(bm25_scores) if max(bm25_scores) > 0 else 0
461
  combined_score = alpha * vector_score + (1-alpha) * bm25_score
462
+
463
  combined_results.append((chunk, combined_score))
464
  break
465
+
466
  # Sort by combined score
467
  combined_results.sort(key=lambda x: x[1], reverse=True)
468
+
469
+ # Get top k*2 results for re-ranking
470
+ retrieved_chunks = [item[0] for item in combined_results[:k*2]]
 
 
 
471
  except Exception as e:
472
  print(f"Error in BM25 scoring: {str(e)}")
473
  # Fallback to vector search results
474
+ retrieved_chunks = [self.chunks[i] for i, text in enumerate(self.chunks)
475
+ if text.text in vector_texts[:k*2]]
476
  else:
477
  # Just use vector search results if BM25 is not available
478
+ retrieved_chunks = [self.chunks[i] for i, chunk in enumerate(self.chunks)
479
+ if chunk.text in vector_texts[:k*2]]
480
+
481
+ # Re-rank with cross-encoder
482
+ if retrieved_chunks:
483
+ reranked_chunks = self._rerank_with_cross_encoder(query, retrieved_chunks, k)
484
+ # Format results with metadata
485
+ final_results = [self._format_chunk_with_metadata(chunk) for chunk in reranked_chunks]
486
+ else:
487
+ # Fallback to basic results
488
  final_results = vector_texts[:k]
489
+
490
  return final_results
491
  except Exception as e:
492
  print(f"Error during hybrid search: {str(e)}")
 
501
  llm_model: LLM model for query expansion
502
  """
503
  self.llm = llm_model
504
+
505
  def expand_query(self, query: str) -> str:
506
  """Expand the query using the LLM to improve retrieval
507
 
 
512
  Expanded query
513
  """
514
  try:
515
+ prompt = f"""<s>[INST] I need to search for documents related to this question: "{query}"
516
+
517
  Please help me expand this query by identifying key concepts, synonyms, and related terms that might be used in the documents.
518
+ Return only the expanded search query, without any explanations or additional text. [/INST]"""
519
 
 
 
520
  expanded = self.llm.generate(prompt, max_tokens=100, temperature=0.3)
521
+
522
  # Combine original and expanded
523
  combined = f"{query} {expanded}"
524
+
525
  # Limit length
526
  if len(combined) > 300:
527
  combined = combined[:300]
528
+
529
  return combined
530
  except:
531
  # Return original query if expansion fails
532
  return query
533
 
534
  # === LLM SETUP ===
535
+ class MistralModel:
536
+ def __init__(self, model_path: str = "models/mistral-7b-instruct-v0.3.Q8_0.gguf"):
537
+ """Initialize Mistral model
538
 
539
  Args:
540
  model_path: Path to the model file
541
  """
542
  try:
543
+ # Initialize Mistral with llama.cpp
544
  self.llm = Llama(
545
  model_path=model_path,
546
+ n_ctx=4096, # Increased context window for better reasoning
547
+ n_batch=256, # Batch size to save memory
548
+ n_gpu_layers=0, # Run on CPU only for Colab free tier
549
  verbose=False
550
  )
551
  except Exception as e:
552
+ print(f"Error initializing Mistral model: {str(e)}")
553
  raise
554
 
555
  def generate(self, prompt: str,
 
557
  temperature: float = 0.7,
558
  top_p: float = 0.9,
559
  stream: bool = False) -> Union[str, Generator[str, None, None]]:
560
+ """Generate text using Mistral
561
 
562
  Args:
563
  prompt: Input prompt
 
580
  top_p=top_p,
581
  echo=False
582
  )
583
+ return output["choices"][0]["text"].strip()
584
  except Exception as e:
585
  print(f"Error generating text: {str(e)}")
586
  return "Error: Could not generate response."
 
589
  max_tokens: int = 512,
590
  temperature: float = 0.7,
591
  top_p: float = 0.9) -> Generator[str, None, None]:
592
+ """Stream text generation using Mistral
593
 
594
  Args:
595
  prompt: Input prompt
 
613
  response += token
614
  yield response
615
 
616
+ # === SELF-CHECKING ===
617
+ class SelfChecker:
618
+ def __init__(self, llm_model):
619
+ """Initialize self-checker for improved response quality
620
+
621
+ Args:
622
+ llm_model: LLM model to use for checking
623
+ """
624
+ self.llm = llm_model
625
+
626
+ def check_answer(self, query: str, initial_answer: str, contexts: List[str]) -> str:
627
+ """Check if answer is correct and complete based on the contexts
628
+
629
+ Args:
630
+ query: User query
631
+ initial_answer: Initial generated answer
632
+ contexts: Retrieved contexts used to generate the answer
633
+
634
+ Returns:
635
+ Improved answer after reflection
636
+ """
637
+ # Guard against very long inputs that could cause runtime disconnection
638
+ # Limit contexts to prevent excessive token usage
639
+ max_contexts_len = 4000
640
+ contexts_text = "\n\n".join(contexts)
641
+ if len(contexts_text) > max_contexts_len:
642
+ # Truncate while keeping as many complete contexts as possible
643
+ truncated_contexts = []
644
+ current_len = 0
645
+ for ctx in contexts:
646
+ if current_len + len(ctx) + 2 <= max_contexts_len:
647
+ truncated_contexts.append(ctx)
648
+ current_len += len(ctx) + 2
649
+ else:
650
+ break
651
+ contexts_text = "\n\n".join(truncated_contexts)
652
+
653
+ # Check if we should skip reflection to prevent disconnection
654
+ if len(initial_answer) + len(contexts_text) + len(query) > 6000:
655
+ print("Skipping reflection due to excessive input length")
656
+ return initial_answer
657
+
658
+ try:
659
+ prompt = f"""<s>[INST] You're an AI assistant tasked with evaluating and improving an answer to a user query.
660
+
661
+ QUERY: {query}
662
+
663
+ INITIAL ANSWER: {initial_answer}
664
+
665
+ AVAILABLE CONTEXTS:
666
+ {contexts_text}
667
+
668
+ First, carefully check if the initial answer:
669
+ 1. Is factually accurate based on the provided contexts
670
+ 2. Addresses all aspects of the user's query
671
+ 3. Contains any information not supported by the contexts
672
+ 4. Misses important information from the contexts
673
+
674
+ Then improve the answer to fix any issues identified. The final answer should:
675
+ - Be comprehensive and accurate based ONLY on the contexts
676
+ - Not include any unsupported information
677
+ - Be well-structured and clear
678
+ - Cite specific sources when appropriate (e.g., "According to [Source, Page X]...")
679
+
680
+ Provide ONLY the improved answer without explanations about your reasoning process. [/INST]"""
681
+
682
+ # We use slightly lower temperature for more focused reflection
683
+ improved_answer = self.llm.generate(
684
+ prompt,
685
+ max_tokens=1024,
686
+ temperature=0.3,
687
+ stream=False
688
+ )
689
+
690
+ # If reflection produced nothing useful, return original answer
691
+ if not improved_answer or len(improved_answer) < 10:
692
+ return initial_answer
693
+
694
+ return improved_answer
695
+
696
+ except Exception as e:
697
+ # On any error, return the original answer to ensure robustness
698
+ print(f"Self-check error: {str(e)}")
699
+ return initial_answer
700
+
701
  # === RAG SYSTEM ===
702
  class RAGSystem:
703
  def __init__(self, pdf_processor: PDFProcessor,
704
  vector_db: VectorDBManager,
705
+ model: MistralModel):
706
  """Initialize RAG system
707
 
708
  Args:
 
714
  self.vector_db = vector_db
715
  self.model = model
716
  self.query_expander = QueryExpander(model)
717
+ self.self_checker = SelfChecker(model)
718
  self.is_initialized = False
719
 
720
  def process_documents(self) -> bool:
721
  """Process all documents and create vector database
722
+
723
  Returns:
724
  True if successful, False otherwise
725
  """
 
729
  if not chunks:
730
  print("No chunks were extracted from PDFs")
731
  return False
732
+
733
  print(f"Total chunks extracted: {len(chunks)}")
734
 
735
  # Create vector database
736
  print("Creating vector database...")
737
  self.vector_db.create_vector_db(chunks)
738
+
739
  # Verify success
740
  if self.vector_db.vectordb is None:
741
  print("Failed to create vector database")
742
  return False
743
+
744
  # Set initialization flag
745
  self.is_initialized = True
746
  return True
747
+
748
  except Exception as e:
749
  print(f"Error processing documents: {str(e)}")
750
  print(traceback.format_exc())
 
765
  for i, context in enumerate(contexts):
766
  formatted_contexts += f"[CONTEXT {i+1}]\n{context}\n\n"
767
 
768
+ # Create prompt with Mistral's chat format
769
+ prompt = f"""<s>[INST] You are an AI assistant that answers questions based on the provided context information.
770
 
771
  User Query: {query}
772
 
 
774
 
775
  {formatted_contexts}
776
 
777
+ Using ONLY the information provided in the context above, provide a comprehensive answer to the user's query.
778
  If the provided context doesn't contain relevant information to answer the query, clearly state: "I don't have enough information in the provided context to answer this question."
779
 
780
  Do not use any prior knowledge that is not contained in the provided context.
781
  If quoting from the context, mention the source document and page number.
782
+ Organize your answer in a clear, coherent manner. [/INST]"""
 
 
783
  return prompt
784
 
785
  def answer_query(self, query: str, k: int = 5, max_tokens: int = 512,
786
+ temperature: float = 0.7, stream: bool = False, enable_reflection: bool = True) -> Union[str, Generator[str, None, None]]:
787
+ """Answer a query using RAG with query expansion and self-checking
788
+
789
  Args:
790
  query: User query
791
  k: Number of contexts to retrieve
792
  max_tokens: Maximum number of tokens to generate
793
  temperature: Temperature for generation
794
  stream: Whether to stream the output
795
+ enable_reflection: Whether to enable self-reflection for better answers
796
+
797
  Returns:
798
  Answer text or generator if streaming
799
  """
 
805
  # Expand query for better retrieval
806
  expanded_query = self.query_expander.expand_query(query)
807
  print(f"Expanded query: {expanded_query}")
808
+
809
  # Retrieve relevant contexts using hybrid search
810
  contexts = self.vector_db.hybrid_search(expanded_query, k=k)
811
 
 
815
  # Generate prompt with improved instructions
816
  prompt = self.generate_prompt(query, contexts)
817
 
818
+ # For streaming, we can't do self-checking
819
+ if stream:
820
+ return self.model.generate(
821
+ prompt,
822
+ max_tokens=max_tokens,
823
+ temperature=temperature,
824
+ stream=True
825
+ )
826
+
827
+ # Generate initial answer
828
+ initial_answer = self.model.generate(
829
  prompt,
830
  max_tokens=max_tokens,
831
  temperature=temperature,
832
+ stream=False
833
  )
834
+
835
+ # Perform self-checking if enabled and initial answer exists
836
+ if enable_reflection and initial_answer and len(initial_answer) > 10:
837
+ try:
838
+ print("Performing self-checking to improve answer quality...")
839
+ improved_answer = self.self_checker.check_answer(query, initial_answer, contexts)
840
+ return improved_answer
841
+ except Exception as e:
842
+ print(f"Error during self-checking: {str(e)}")
843
+ # Fallback to initial answer if self-checking fails
844
+ return initial_answer
845
+ else:
846
+ return initial_answer
847
+
848
  except Exception as e:
849
  print(f"Error answering query: {str(e)}")
850
  print(traceback.format_exc())
851
  return f"Error processing your query: {str(e)}"
852
 
853
+ # === GRADIO UI ===
854
+ class RAGUI:
855
  def __init__(self, rag_system: RAGSystem):
856
+ """Initialize RAG UI
857
 
858
  Args:
859
  rag_system: RAG system instance
860
  """
861
  self.rag_system = rag_system
862
+ self.pdf_dir = rag_system.pdf_processor.pdf_dir
863
  self.interface = None
 
864
 
865
+ def _list_uploaded_pdfs(self) -> str:
866
+ """List all uploaded PDFs"""
867
+ pdfs = self.rag_system.pdf_processor.list_pdfs()
868
+ if not pdfs:
869
+ return "No PDFs uploaded yet."
870
+ return "\n".join([f"- {pdf}" for pdf in pdfs])
871
+
872
+ def upload_pdf(self, files) -> str:
873
+ """Upload PDF files
874
+
875
+ Args:
876
+ files: File objects
877
+
878
+ Returns:
879
+ Status message
880
+ """
881
  try:
882
+ # Create directory if it doesn't exist
883
+ os.makedirs(self.pdf_dir, exist_ok=True)
884
 
885
+ # Copy files to pdf directory
886
  for file in files:
887
+ shutil.copy(file.name, os.path.join(self.pdf_dir, os.path.basename(file.name)))
 
 
888
 
889
+ return f"Successfully uploaded {len(files)} file(s). Please process documents to make them searchable."
 
 
 
 
 
 
890
  except Exception as e:
891
  return f"Error uploading files: {str(e)}"
892
 
893
+ def process_documents(self) -> str:
894
+ """Process documents and create vector database
895
 
896
  Returns:
897
  Status message
898
  """
 
 
 
899
  try:
900
+ # Check if there are PDFs
901
+ pdf_files = self.rag_system.pdf_processor.list_pdfs()
902
+ if not pdf_files:
903
+ return "No PDF files uploaded. Please upload PDFs first."
904
+
905
+ # Process PDFs
906
  start_time = time.time()
 
907
  success = self.rag_system.process_documents()
908
+ process_time = time.time() - start_time
909
+
 
 
910
  if success:
911
+ return f"Successfully processed {len(pdf_files)} PDF file(s) in {process_time:.2f} seconds. You can now ask questions."
912
  else:
913
+ return "Failed to process documents. Check the logs for details."
914
  except Exception as e:
 
915
  return f"Error processing documents: {str(e)}"
916
 
917
+ def answer_query(self, query: str, stream_output: bool = True,
918
+ k: int = 4, temperature: float = 0.7,
919
+ enable_reflection: bool = True) -> str:
920
+ """Answer a query using RAG
921
 
922
  Args:
923
  query: User query
924
+ stream_output: Whether to stream the output
925
  k: Number of contexts to retrieve
926
+ temperature: Temperature for text generation
927
+ enable_reflection: Whether to use reflection to improve answers
928
 
929
  Returns:
930
+ Answer text
931
  """
932
+ if not query or query.strip() == "":
933
+ return "Please enter a query."
 
 
 
 
 
 
 
 
 
 
 
934
 
935
+ # Check if system is initialized
936
+ if not self.rag_system.is_initialized:
937
+ return "Documents have not been processed yet. Please process documents first."
 
 
 
 
 
938
 
 
 
 
 
 
 
 
939
  try:
940
+ # For streaming, we need to handle gradio uniqueness
941
+ if stream_output:
942
+ # We can't stream with reflection
943
+ return self.rag_system.answer_query(
944
+ query,
945
+ k=k,
946
+ max_tokens=1024,
947
+ temperature=temperature,
948
+ stream=True,
949
+ enable_reflection=False
950
+ )
951
+ else:
952
+ return self.rag_system.answer_query(
953
+ query,
954
+ k=k,
955
+ max_tokens=1024,
956
+ temperature=temperature,
957
+ stream=False,
958
+ enable_reflection=enable_reflection
959
+ )
960
  except Exception as e:
961
+ print(f"Error in answer_query: {str(e)}")
962
+ print(traceback.format_exc())
963
+ return f"Error processing your query: {str(e)}"
964
+
965
+ def launch(self):
966
+ """Launch Gradio UI"""
967
+ try:
968
+ with gr.Blocks(title="Document Q&A System") as self.interface:
969
+ gr.Markdown("# PDF Question Answering System")
970
+ gr.Markdown("Upload PDF documents and ask questions about their content.")
971
+
972
+ with gr.Tab("Upload & Process"):
973
+ with gr.Row():
974
+ with gr.Column():
975
+ upload_button = gr.File(
976
+ label="Upload PDF Files",
977
+ file_count="multiple",
978
+ file_types=[".pdf"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
979
  )
980
+ upload_output = gr.Textbox(
981
+ label="Upload Status",
982
+ interactive=False
 
 
 
983
  )
984
+ upload_btn = gr.Button("Upload Files")
985
+
986
+ with gr.Column():
987
+ pdf_list = gr.Textbox(
988
+ label="Uploaded PDFs",
989
+ value=self._list_uploaded_pdfs(),
990
+ interactive=False
991
  )
992
+ refresh_btn = gr.Button("Refresh List")
993
+
994
+ process_btn = gr.Button("Process Documents")
995
+ process_output = gr.Textbox(
996
+ label="Processing Status",
997
+ interactive=False
998
+ )
999
+
1000
+ with gr.Tab("Ask Questions"):
1001
+ with gr.Row():
1002
+ with gr.Column():
1003
+ query_input = gr.Textbox(
1004
+ label="Enter your question",
1005
+ placeholder="What are the main findings of the report?",
1006
+ lines=2
1007
+ )
1008
+
1009
+ with gr.Row():
1010
+ k_slider = gr.Slider(
1011
+ minimum=1,
1012
+ maximum=10,
1013
+ value=4,
1014
+ step=1,
1015
+ label="Number of contexts to retrieve"
1016
+ )
1017
+ temp_slider = gr.Slider(
1018
+ minimum=0.1,
1019
+ maximum=1.0,
1020
+ value=0.7,
1021
+ step=0.1,
1022
+ label="Temperature"
1023
+ )
1024
+
1025
+ with gr.Row():
1026
+ stream_checkbox = gr.Checkbox(
1027
+ label="Stream output",
1028
+ value=True
1029
+ )
1030
+ reflection_checkbox = gr.Checkbox(
1031
+ label="Use self-reflection (disables streaming)",
1032
+ value=True
1033
+ )
1034
+
1035
+ query_btn = gr.Button("Submit Question")
1036
+
1037
+ answer_output = gr.Textbox(
1038
+ label="Answer",
1039
+ interactive=False,
1040
+ lines=15
1041
+ )
1042
+
1043
+ # Event handlers
1044
+ upload_btn.click(
1045
+ fn=self.upload_pdf,
1046
+ inputs=[upload_button],
1047
+ outputs=[upload_output]
1048
  )
1049
+
1050
+ refresh_btn.click(
1051
+ fn=lambda: self._list_uploaded_pdfs(),
1052
+ inputs=[],
1053
+ outputs=[pdf_list]
1054
+ )
1055
+
1056
+ process_btn.click(
1057
+ fn=self.process_documents,
1058
+ inputs=[],
1059
+ outputs=[process_output]
1060
  )
1061
+
1062
+ query_btn.click(
1063
+ fn=self.answer_query,
1064
+ inputs=[query_input, stream_checkbox, k_slider, temp_slider, reflection_checkbox],
1065
+ outputs=[answer_output]
1066
+ )
1067
+
1068
+ # Checkbox dependency
1069
+ def update_stream_state(reflection_enabled):
1070
+ return not reflection_enabled if reflection_enabled else gr.update()
1071
+
1072
+ reflection_checkbox.change(
1073
+ fn=update_stream_state,
1074
+ inputs=[reflection_checkbox],
1075
+ outputs=[stream_checkbox]
1076
+ )
1077
+
1078
+ # Launch UI
1079
+ self.interface.launch(share=True)
1080
+
1081
+ except Exception as e:
1082
+ print(f"Error launching UI: {str(e)}")
1083
+ print(traceback.format_exc())
1084
 
1085
  # === MAIN APPLICATION ===
1086
  def main():
1087
+ # Initialize components
1088
+ print("Initializing PDF processor...")
1089
+ pdf_processor = PDFProcessor()
1090
+
1091
+ print("Initializing vector database manager...")
1092
+ vector_db = VectorDBManager()
1093
+
1094
+ print("Initializing Mistral model...")
1095
+ model = MistralModel()
1096
+
1097
+ print("Initializing RAG system...")
1098
+ rag_system = RAGSystem(pdf_processor, vector_db, model)
1099
+
1100
+ print("Initializing UI...")
1101
+ ui = RAGUI(rag_system)
1102
+
1103
+ print("Launching UI...")
1104
+ ui.launch()
 
1105
 
1106
  if __name__ == "__main__":
1107
  main()