|
|
|
import os |
|
import tempfile |
|
import shutil |
|
import PyPDF2 |
|
import streamlit as st |
|
import torch |
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
from langchain_community.llms import HuggingFaceHub |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain_community.vectorstores import FAISS |
|
from langchain.chains import RetrievalQA, LLMChain |
|
from langchain.chains.combine_documents import create_stuff_documents_chain |
|
from langchain.docstore.document import Document |
|
from langchain.prompts import PromptTemplate |
|
import time |
|
import psutil |
|
import uuid |
|
import atexit |
|
from blockchain_utils_metamask import BlockchainManagerMetaMask |
|
|
|
|
|
class AdvancedRAG: |
|
def __init__(self, |
|
llm_model_name="mistralai/Mistral-7B-Instruct-v0.2", |
|
embedding_model_name="sentence-transformers/all-MiniLM-L6-v2", |
|
chunk_size=1000, |
|
chunk_overlap=200, |
|
use_gpu=True, |
|
use_blockchain=False, |
|
contract_address=None): |
|
""" |
|
Initialize the advanced RAG system with multiple retrieval methods. |
|
|
|
Args: |
|
llm_model_name: The HuggingFace model for text generation |
|
embedding_model_name: The HuggingFace model for embeddings |
|
chunk_size: Size of document chunks |
|
chunk_overlap: Overlap between chunks |
|
use_gpu: Whether to use GPU acceleration |
|
use_blockchain: Whether to enable blockchain verification |
|
contract_address: Address of the deployed RAG Document Verifier contract |
|
""" |
|
self.llm_model_name = llm_model_name |
|
self.embedding_model_name = embedding_model_name |
|
self.use_gpu = use_gpu and torch.cuda.is_available() |
|
self.use_blockchain = use_blockchain |
|
|
|
|
|
self.device = "cuda" if self.use_gpu else "cpu" |
|
st.sidebar.info(f"Using device: {self.device}") |
|
|
|
|
|
self.text_splitter = RecursiveCharacterTextSplitter( |
|
chunk_size=chunk_size, |
|
chunk_overlap=chunk_overlap, |
|
length_function=len, |
|
) |
|
|
|
|
|
self.embeddings = HuggingFaceEmbeddings( |
|
model_name=embedding_model_name, |
|
model_kwargs={"device": self.device} |
|
) |
|
|
|
|
|
try: |
|
|
|
hf_token = os.environ.get("HF_TOKEN") |
|
if not hf_token: |
|
st.warning("No HuggingFace token found. Using model without authentication.") |
|
|
|
self.llm = HuggingFaceHub( |
|
repo_id=llm_model_name, |
|
huggingfacehub_api_token=hf_token, |
|
model_kwargs={"temperature": 0.7, "max_length": 1024} |
|
) |
|
except Exception as e: |
|
st.error(f"Error initializing LLM: {str(e)}") |
|
st.info("Trying to initialize with default model...") |
|
|
|
self.llm = HuggingFaceHub( |
|
repo_id="google/flan-t5-small", |
|
model_kwargs={"temperature": 0.7, "max_length": 512} |
|
) |
|
|
|
|
|
self.vector_store = None |
|
self.documents_processed = 0 |
|
|
|
|
|
self.processing_times = {} |
|
|
|
|
|
self.blockchain = None |
|
if use_blockchain: |
|
try: |
|
self.blockchain = BlockchainManagerMetaMask( |
|
contract_address=contract_address |
|
) |
|
st.sidebar.success("Blockchain manager initialized. Please connect MetaMask to continue.") |
|
except Exception as e: |
|
st.sidebar.error(f"Failed to initialize blockchain manager: {str(e)}") |
|
self.use_blockchain = False |
|
|
|
def update_blockchain_connection(self, metamask_info): |
|
"""Update blockchain connection with MetaMask info.""" |
|
if self.blockchain and metamask_info: |
|
self.blockchain.update_connection( |
|
is_connected=metamask_info.get("connected", False), |
|
user_address=metamask_info.get("address"), |
|
network_id=metamask_info.get("network_id") |
|
) |
|
return self.blockchain.is_connected |
|
return False |
|
|
|
def process_pdfs(self, pdf_files): |
|
"""Process PDF files, create a vector store, and verify documents on blockchain.""" |
|
all_docs = [] |
|
|
|
with st.status("Processing PDF files...") as status: |
|
|
|
temp_dir = tempfile.mkdtemp() |
|
st.session_state['temp_dir'] = temp_dir |
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
mem_before = psutil.virtual_memory().used / (1024 * 1024 * 1024) |
|
|
|
|
|
for i, pdf_file in enumerate(pdf_files): |
|
try: |
|
file_start_time = time.time() |
|
|
|
|
|
pdf_path = os.path.join(temp_dir, pdf_file.name) |
|
with open(pdf_path, "wb") as f: |
|
f.write(pdf_file.getbuffer()) |
|
|
|
status.update(label=f"Processing {pdf_file.name} ({i+1}/{len(pdf_files)})...") |
|
|
|
|
|
text = "" |
|
with open(pdf_path, "rb") as f: |
|
pdf = PyPDF2.PdfReader(f) |
|
for page_num in range(len(pdf.pages)): |
|
page = pdf.pages[page_num] |
|
page_text = page.extract_text() |
|
if page_text: |
|
text += page_text + "\n\n" |
|
|
|
|
|
docs = [Document(page_content=text, metadata={"source": pdf_file.name})] |
|
|
|
|
|
split_docs = self.text_splitter.split_documents(docs) |
|
|
|
all_docs.extend(split_docs) |
|
|
|
|
|
if self.use_blockchain and self.blockchain and self.blockchain.is_connected: |
|
try: |
|
|
|
document_id = f"{pdf_file.name}_{uuid.uuid4().hex[:8]}" |
|
|
|
|
|
status.update(label=f"Verifying {pdf_file.name} on blockchain...") |
|
verification = self.blockchain.verify_document(document_id, pdf_path) |
|
|
|
if verification.get('status'): |
|
st.sidebar.success(f"✅ {pdf_file.name} verified on blockchain") |
|
if 'tx_hash' in verification: |
|
st.sidebar.info(f"Transaction: {verification['tx_hash'][:10]}...") |
|
|
|
|
|
for doc in split_docs: |
|
doc.metadata["blockchain"] = { |
|
"verified": True, |
|
"document_id": document_id, |
|
"document_hash": verification.get("document_hash", ""), |
|
"tx_hash": verification.get("tx_hash", ""), |
|
"block_number": verification.get("block_number", 0) |
|
} |
|
else: |
|
st.sidebar.warning(f"❌ Failed to verify {pdf_file.name} on blockchain") |
|
if 'error' in verification: |
|
st.sidebar.error(f"Error: {verification['error']}") |
|
except Exception as e: |
|
st.sidebar.error(f"Blockchain verification error: {str(e)}") |
|
elif self.use_blockchain: |
|
st.sidebar.warning("MetaMask not connected. Document not verified on blockchain.") |
|
|
|
file_end_time = time.time() |
|
processing_time = file_end_time - file_start_time |
|
|
|
st.sidebar.success(f"Processed {pdf_file.name}: {len(split_docs)} chunks in {processing_time:.2f}s") |
|
self.processing_times[pdf_file.name] = { |
|
"chunks": len(split_docs), |
|
"time": processing_time |
|
} |
|
|
|
except Exception as e: |
|
st.sidebar.error(f"Error processing {pdf_file.name}: {str(e)}") |
|
|
|
|
|
if all_docs: |
|
status.update(label="Building vector index...") |
|
try: |
|
|
|
index_start_time = time.time() |
|
|
|
|
|
self.vector_store = FAISS.from_documents(all_docs, self.embeddings) |
|
|
|
index_end_time = time.time() |
|
index_time = index_end_time - index_start_time |
|
|
|
|
|
mem_after = psutil.virtual_memory().used / (1024 * 1024 * 1024) |
|
mem_used = mem_after - mem_before |
|
|
|
total_time = time.time() - start_time |
|
|
|
status.update(label=f"Completed processing {len(all_docs)} chunks in {total_time:.2f}s", state="complete") |
|
|
|
|
|
self.processing_times["index_building"] = index_time |
|
self.processing_times["total_time"] = total_time |
|
self.processing_times["memory_used_gb"] = mem_used |
|
self.documents_processed = len(all_docs) |
|
|
|
return True |
|
except Exception as e: |
|
st.error(f"Error creating vector store: {str(e)}") |
|
status.update(label="Error creating vector store", state="error") |
|
return False |
|
else: |
|
status.update(label="No content extracted from PDFs", state="error") |
|
return False |
|
|
|
def direct_retrieval(self, query): |
|
""" |
|
Direct retrieval method: simply returns the most relevant document chunks without LLM processing. |
|
|
|
Args: |
|
query: User's question |
|
|
|
Returns: |
|
dict: Results with raw document chunks |
|
""" |
|
if not self.vector_store: |
|
return "Please upload and process PDF files first." |
|
|
|
try: |
|
|
|
query_start_time = time.time() |
|
|
|
|
|
with st.status("Searching documents..."): |
|
retriever = self.vector_store.as_retriever(search_kwargs={"k": 5}) |
|
docs = retriever.get_relevant_documents(query) |
|
|
|
|
|
query_time = time.time() - query_start_time |
|
|
|
|
|
sources = [] |
|
answer = f"Here are the most relevant passages for your query:\n\n" |
|
|
|
for i, doc in enumerate(docs): |
|
|
|
blockchain_info = None |
|
if "blockchain" in doc.metadata: |
|
blockchain_info = { |
|
"verified": doc.metadata["blockchain"]["verified"], |
|
"document_id": doc.metadata["blockchain"]["document_id"], |
|
"tx_hash": doc.metadata["blockchain"]["tx_hash"] |
|
} |
|
|
|
source_text = doc.page_content |
|
answer += f"**Passage {i+1}** (from {doc.metadata.get('source', 'Unknown')}):\n{source_text}\n\n" |
|
|
|
sources.append({ |
|
"content": source_text, |
|
"source": doc.metadata.get("source", "Unknown"), |
|
"blockchain": blockchain_info |
|
}) |
|
|
|
|
|
blockchain_log = None |
|
if self.use_blockchain and self.blockchain and self.blockchain.is_connected: |
|
try: |
|
with st.status("Logging query to blockchain..."): |
|
log_result = self.blockchain.log_query(query, answer) |
|
|
|
if log_result.get("status"): |
|
blockchain_log = { |
|
"logged": True, |
|
"query_id": log_result.get("query_id", ""), |
|
"tx_hash": log_result.get("tx_hash", ""), |
|
"block_number": log_result.get("block_number", 0) |
|
} |
|
else: |
|
st.error(f"Error logging to blockchain: {log_result.get('error', 'Unknown error')}") |
|
except Exception as e: |
|
st.error(f"Error logging to blockchain: {str(e)}") |
|
|
|
return { |
|
"answer": answer, |
|
"sources": sources, |
|
"query_time": query_time, |
|
"blockchain_log": blockchain_log, |
|
"method": "direct" |
|
} |
|
|
|
except Exception as e: |
|
st.error(f"Error in direct retrieval: {str(e)}") |
|
return f"Error: {str(e)}" |
|
|
|
def enhanced_retrieval(self, query): |
|
""" |
|
Enhanced retrieval method: uses an LLM to process the retrieved documents and generate a comprehensive answer. |
|
|
|
Args: |
|
query: User's question |
|
|
|
Returns: |
|
dict: Results with LLM-enhanced answer |
|
""" |
|
if not self.vector_store: |
|
return "Please upload and process PDF files first." |
|
|
|
try: |
|
|
|
prompt_template = """ |
|
You are an AI research assistant with expertise in analyzing and synthesizing information from documents. |
|
|
|
Below are relevant passages from documents that might answer the user's question. |
|
|
|
USER QUESTION: {question} |
|
|
|
RELEVANT PASSAGES: |
|
{context} |
|
|
|
Based on ONLY these passages, provide a comprehensive, accurate and well-structured answer to the question. |
|
|
|
Your answer should: |
|
1. Directly address the user's question |
|
2. Synthesize information from multiple passages when applicable |
|
3. Be detailed, precise and factual |
|
4. Include specific examples or evidence from the passages |
|
5. Acknowledge any limitations or gaps in the provided information |
|
|
|
If the information to answer the question is not present in the passages, clearly state: "I don't have enough information to answer this question based on the available documents." |
|
|
|
ANSWER: |
|
""" |
|
|
|
PROMPT = PromptTemplate( |
|
template=prompt_template, |
|
input_variables=["context", "question"] |
|
) |
|
|
|
|
|
query_start_time = time.time() |
|
|
|
|
|
retriever = self.vector_store.as_retriever(search_kwargs={"k": 5}) |
|
|
|
|
|
with st.status("Retrieving relevant documents..."): |
|
docs = retriever.get_relevant_documents(query) |
|
|
|
|
|
sources = [] |
|
for i, doc in enumerate(docs): |
|
|
|
blockchain_info = None |
|
if "blockchain" in doc.metadata: |
|
blockchain_info = { |
|
"verified": doc.metadata["blockchain"]["verified"], |
|
"document_id": doc.metadata["blockchain"]["document_id"], |
|
"tx_hash": doc.metadata["blockchain"]["tx_hash"] |
|
} |
|
|
|
sources.append({ |
|
"content": doc.page_content[:300] + "..." if len(doc.page_content) > 300 else doc.page_content, |
|
"source": doc.metadata.get("source", "Unknown"), |
|
"blockchain": blockchain_info |
|
}) |
|
|
|
|
|
document_chain = create_stuff_documents_chain(self.llm, PROMPT) |
|
|
|
|
|
with st.status("Generating enhanced answer..."): |
|
answer = document_chain.invoke({ |
|
"question": query, |
|
"context": docs |
|
}) |
|
|
|
|
|
query_time = time.time() - query_start_time |
|
|
|
|
|
blockchain_log = None |
|
if self.use_blockchain and self.blockchain and self.blockchain.is_connected: |
|
try: |
|
with st.status("Logging query to blockchain..."): |
|
log_result = self.blockchain.log_query(query, answer) |
|
|
|
if log_result.get("status"): |
|
blockchain_log = { |
|
"logged": True, |
|
"query_id": log_result.get("query_id", ""), |
|
"tx_hash": log_result.get("tx_hash", ""), |
|
"block_number": log_result.get("block_number", 0) |
|
} |
|
else: |
|
st.error(f"Error logging to blockchain: {log_result.get('error', 'Unknown error')}") |
|
except Exception as e: |
|
st.error(f"Error logging to blockchain: {str(e)}") |
|
|
|
return { |
|
"answer": answer, |
|
"sources": sources, |
|
"query_time": query_time, |
|
"blockchain_log": blockchain_log, |
|
"method": "enhanced" |
|
} |
|
|
|
except Exception as e: |
|
st.error(f"Error in enhanced retrieval: {str(e)}") |
|
return f"Error: {str(e)}" |
|
|
|
def ask(self, query, method="enhanced"): |
|
""" |
|
Ask a question using the specified retrieval method. |
|
|
|
Args: |
|
query: User's question |
|
method: Retrieval method ("direct" or "enhanced") |
|
|
|
Returns: |
|
dict: Results from the specified retrieval method |
|
""" |
|
if method == "direct": |
|
return self.direct_retrieval(query) |
|
else: |
|
return self.enhanced_retrieval(query) |
|
|
|
def get_performance_metrics(self): |
|
"""Return performance metrics for the RAG system.""" |
|
if not self.processing_times: |
|
return None |
|
|
|
return { |
|
"documents_processed": self.documents_processed, |
|
"index_building_time": self.processing_times.get("index_building", 0), |
|
"total_processing_time": self.processing_times.get("total_time", 0), |
|
"memory_used_gb": self.processing_times.get("memory_used_gb", 0), |
|
"device": self.device, |
|
"embedding_model": self.embedding_model_name, |
|
"blockchain_enabled": self.use_blockchain, |
|
"blockchain_connected": self.blockchain.is_connected if self.blockchain else False |
|
} |