File size: 20,871 Bytes
9d8a29b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
# advanced_rag.py
import os
import tempfile
import shutil
import PyPDF2
import streamlit as st
import torch
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.llms import HuggingFaceHub
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain.chains import RetrievalQA, LLMChain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.docstore.document import Document
from langchain.prompts import PromptTemplate
import time
import psutil
import uuid
import atexit
from blockchain_utils_metamask import BlockchainManagerMetaMask


class AdvancedRAG:
    def __init__(self, 
                 llm_model_name="mistralai/Mistral-7B-Instruct-v0.2",
                 embedding_model_name="sentence-transformers/all-MiniLM-L6-v2",
                 chunk_size=1000,
                 chunk_overlap=200,
                 use_gpu=True,
                 use_blockchain=False,
                 contract_address=None):
        """
        Initialize the advanced RAG system with multiple retrieval methods.
        
        Args:
            llm_model_name: The HuggingFace model for text generation
            embedding_model_name: The HuggingFace model for embeddings
            chunk_size: Size of document chunks
            chunk_overlap: Overlap between chunks
            use_gpu: Whether to use GPU acceleration
            use_blockchain: Whether to enable blockchain verification
            contract_address: Address of the deployed RAG Document Verifier contract
        """
        self.llm_model_name = llm_model_name
        self.embedding_model_name = embedding_model_name
        self.use_gpu = use_gpu and torch.cuda.is_available()
        self.use_blockchain = use_blockchain
        
        # Device selection for embeddings
        self.device = "cuda" if self.use_gpu else "cpu"
        st.sidebar.info(f"Using device: {self.device}")
        
        # Initialize text splitter
        self.text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap,
            length_function=len,
        )
        
        # Initialize embeddings model
        self.embeddings = HuggingFaceEmbeddings(
            model_name=embedding_model_name,
            model_kwargs={"device": self.device}
        )
        
        # Initialize LLM using HuggingFaceHub
        try:
            # Use HF_TOKEN from environment variables
            hf_token = os.environ.get("HF_TOKEN")
            if not hf_token:
                st.warning("No HuggingFace token found. Using model without authentication.")
                
            self.llm = HuggingFaceHub(
                repo_id=llm_model_name,
                huggingfacehub_api_token=hf_token,
                model_kwargs={"temperature": 0.7, "max_length": 1024}
            )
        except Exception as e:
            st.error(f"Error initializing LLM: {str(e)}")
            st.info("Trying to initialize with default model...")
            # Fallback to a smaller model
            self.llm = HuggingFaceHub(
                repo_id="google/flan-t5-small",
                model_kwargs={"temperature": 0.7, "max_length": 512}
            )
        
        # Initialize vector store
        self.vector_store = None
        self.documents_processed = 0
        
        # Monitoring stats
        self.processing_times = {}
        
        # Initialize blockchain manager if enabled
        self.blockchain = None
        if use_blockchain:
            try:
                self.blockchain = BlockchainManagerMetaMask(
                    contract_address=contract_address
                )
                st.sidebar.success("Blockchain manager initialized. Please connect MetaMask to continue.")
            except Exception as e:
                st.sidebar.error(f"Failed to initialize blockchain manager: {str(e)}")
                self.use_blockchain = False

    def update_blockchain_connection(self, metamask_info):
        """Update blockchain connection with MetaMask info."""
        if self.blockchain and metamask_info:
            self.blockchain.update_connection(
                is_connected=metamask_info.get("connected", False),
                user_address=metamask_info.get("address"),
                network_id=metamask_info.get("network_id")
            )
            return self.blockchain.is_connected
        return False

    def process_pdfs(self, pdf_files):
        """Process PDF files, create a vector store, and verify documents on blockchain."""
        all_docs = []
        
        with st.status("Processing PDF files...") as status:
            # Create temporary directory for file storage
            temp_dir = tempfile.mkdtemp()
            st.session_state['temp_dir'] = temp_dir
            
            # Monitor processing time and memory usage
            start_time = time.time()
            
            # Track memory before processing
            mem_before = psutil.virtual_memory().used / (1024 * 1024 * 1024)  # GB
            
            # Process each PDF file
            for i, pdf_file in enumerate(pdf_files):
                try:
                    file_start_time = time.time()
                    
                    # Save uploaded file to temp directory
                    pdf_path = os.path.join(temp_dir, pdf_file.name)
                    with open(pdf_path, "wb") as f:
                        f.write(pdf_file.getbuffer())
                    
                    status.update(label=f"Processing {pdf_file.name} ({i+1}/{len(pdf_files)})...")
                    
                    # Extract text from PDF
                    text = ""
                    with open(pdf_path, "rb") as f:
                        pdf = PyPDF2.PdfReader(f)
                        for page_num in range(len(pdf.pages)):
                            page = pdf.pages[page_num]
                            page_text = page.extract_text()
                            if page_text:
                                text += page_text + "\n\n"
                    
                    # Create documents
                    docs = [Document(page_content=text, metadata={"source": pdf_file.name})]
                    
                    # Split documents into chunks
                    split_docs = self.text_splitter.split_documents(docs)
                    
                    all_docs.extend(split_docs)
                    
                    # Verify document on blockchain if enabled and connected
                    if self.use_blockchain and self.blockchain and self.blockchain.is_connected:
                        try:
                            # Create a unique document ID
                            document_id = f"{pdf_file.name}_{uuid.uuid4().hex[:8]}"
                            
                            # Verify document on blockchain
                            status.update(label=f"Verifying {pdf_file.name} on blockchain...")
                            verification = self.blockchain.verify_document(document_id, pdf_path)
                            
                            if verification.get('status'):  # Success
                                st.sidebar.success(f"✅ {pdf_file.name} verified on blockchain")
                                if 'tx_hash' in verification:
                                    st.sidebar.info(f"Transaction: {verification['tx_hash'][:10]}...")
                                
                                # Add blockchain metadata to documents
                                for doc in split_docs:
                                    doc.metadata["blockchain"] = {
                                        "verified": True,
                                        "document_id": document_id,
                                        "document_hash": verification.get("document_hash", ""),
                                        "tx_hash": verification.get("tx_hash", ""),
                                        "block_number": verification.get("block_number", 0)
                                    }
                            else:
                                st.sidebar.warning(f"❌ Failed to verify {pdf_file.name} on blockchain")
                                if 'error' in verification:
                                    st.sidebar.error(f"Error: {verification['error']}")
                        except Exception as e:
                            st.sidebar.error(f"Blockchain verification error: {str(e)}")
                    elif self.use_blockchain:
                        st.sidebar.warning("MetaMask not connected. Document not verified on blockchain.")
                    
                    file_end_time = time.time()
                    processing_time = file_end_time - file_start_time
                    
                    st.sidebar.success(f"Processed {pdf_file.name}: {len(split_docs)} chunks in {processing_time:.2f}s")
                    self.processing_times[pdf_file.name] = {
                        "chunks": len(split_docs),
                        "time": processing_time
                    }
                    
                except Exception as e:
                    st.sidebar.error(f"Error processing {pdf_file.name}: {str(e)}")
            
            # Create vector store if we have documents
            if all_docs:
                status.update(label="Building vector index...")
                try:
                    # Record the time taken to build the index
                    index_start_time = time.time()
                    
                    # Create the vector store using FAISS
                    self.vector_store = FAISS.from_documents(all_docs, self.embeddings)
                    
                    index_end_time = time.time()
                    index_time = index_end_time - index_start_time
                    
                    # Track memory after processing
                    mem_after = psutil.virtual_memory().used / (1024 * 1024 * 1024)  # GB
                    mem_used = mem_after - mem_before
                    
                    total_time = time.time() - start_time
                    
                    status.update(label=f"Completed processing {len(all_docs)} chunks in {total_time:.2f}s", state="complete")
                    
                    # Save performance metrics
                    self.processing_times["index_building"] = index_time
                    self.processing_times["total_time"] = total_time
                    self.processing_times["memory_used_gb"] = mem_used
                    self.documents_processed = len(all_docs)
                    
                    return True
                except Exception as e:
                    st.error(f"Error creating vector store: {str(e)}")
                    status.update(label="Error creating vector store", state="error")
                    return False
            else:
                status.update(label="No content extracted from PDFs", state="error")
                return False

    def direct_retrieval(self, query):
        """
        Direct retrieval method: simply returns the most relevant document chunks without LLM processing.
        
        Args:
            query: User's question
            
        Returns:
            dict: Results with raw document chunks
        """
        if not self.vector_store:
            return "Please upload and process PDF files first."
            
        try:
            # Start timing the query
            query_start_time = time.time()
            
            # Retrieve the most relevant documents
            with st.status("Searching documents..."):
                retriever = self.vector_store.as_retriever(search_kwargs={"k": 5})
                docs = retriever.get_relevant_documents(query)
            
            # Calculate query time
            query_time = time.time() - query_start_time
            
            # Format sources and create answer from sources directly
            sources = []
            answer = f"Here are the most relevant passages for your query:\n\n"
            
            for i, doc in enumerate(docs):
                # Extract blockchain verification info if available
                blockchain_info = None
                if "blockchain" in doc.metadata:
                    blockchain_info = {
                        "verified": doc.metadata["blockchain"]["verified"],
                        "document_id": doc.metadata["blockchain"]["document_id"],
                        "tx_hash": doc.metadata["blockchain"]["tx_hash"]
                    }
                
                source_text = doc.page_content
                answer += f"**Passage {i+1}** (from {doc.metadata.get('source', 'Unknown')}):\n{source_text}\n\n"
                
                sources.append({
                    "content": source_text,
                    "source": doc.metadata.get("source", "Unknown"),
                    "blockchain": blockchain_info
                })
            
            # Log query to blockchain if enabled and connected
            blockchain_log = None
            if self.use_blockchain and self.blockchain and self.blockchain.is_connected:
                try:
                    with st.status("Logging query to blockchain..."):
                        log_result = self.blockchain.log_query(query, answer)
                        
                        if log_result.get("status"):  # Success
                            blockchain_log = {
                                "logged": True,
                                "query_id": log_result.get("query_id", ""),
                                "tx_hash": log_result.get("tx_hash", ""),
                                "block_number": log_result.get("block_number", 0)
                            }
                        else:
                            st.error(f"Error logging to blockchain: {log_result.get('error', 'Unknown error')}")
                except Exception as e:
                    st.error(f"Error logging to blockchain: {str(e)}")
                
            return {
                "answer": answer,
                "sources": sources,
                "query_time": query_time,
                "blockchain_log": blockchain_log,
                "method": "direct"
            }
                
        except Exception as e:
            st.error(f"Error in direct retrieval: {str(e)}")
            return f"Error: {str(e)}"

    def enhanced_retrieval(self, query):
        """
        Enhanced retrieval method: uses an LLM to process the retrieved documents and generate a comprehensive answer.
        
        Args:
            query: User's question
            
        Returns:
            dict: Results with LLM-enhanced answer
        """
        if not self.vector_store:
            return "Please upload and process PDF files first."
            
        try:
            # Custom prompt for advanced processing
            prompt_template = """
            You are an AI research assistant with expertise in analyzing and synthesizing information from documents.
            
            Below are relevant passages from documents that might answer the user's question.
            
            USER QUESTION: {question}
            
            RELEVANT PASSAGES:
            {context}
            
            Based on ONLY these passages, provide a comprehensive, accurate and well-structured answer to the question.
            
            Your answer should:
            1. Directly address the user's question
            2. Synthesize information from multiple passages when applicable
            3. Be detailed, precise and factual
            4. Include specific examples or evidence from the passages
            5. Acknowledge any limitations or gaps in the provided information
            
            If the information to answer the question is not present in the passages, clearly state: "I don't have enough information to answer this question based on the available documents."
            
            ANSWER:
            """
            
            PROMPT = PromptTemplate(
                template=prompt_template, 
                input_variables=["context", "question"]
            )
            
            # Start timing the query
            query_start_time = time.time()
            
            # Create QA chain
            retriever = self.vector_store.as_retriever(search_kwargs={"k": 5})
            
            # Get documents first to track sources
            with st.status("Retrieving relevant documents..."):
                docs = retriever.get_relevant_documents(query)
            
            # Format sources
            sources = []
            for i, doc in enumerate(docs):
                # Extract blockchain verification info if available
                blockchain_info = None
                if "blockchain" in doc.metadata:
                    blockchain_info = {
                        "verified": doc.metadata["blockchain"]["verified"],
                        "document_id": doc.metadata["blockchain"]["document_id"],
                        "tx_hash": doc.metadata["blockchain"]["tx_hash"]
                    }
                
                sources.append({
                    "content": doc.page_content[:300] + "..." if len(doc.page_content) > 300 else doc.page_content,
                    "source": doc.metadata.get("source", "Unknown"),
                    "blockchain": blockchain_info
                })
            
            # Create document chain
            document_chain = create_stuff_documents_chain(self.llm, PROMPT)
            
            # Generate answer
            with st.status("Generating enhanced answer..."):
                answer = document_chain.invoke({
                    "question": query,
                    "context": docs
                })
            
            # Calculate query time
            query_time = time.time() - query_start_time
            
            # Log query to blockchain if enabled and connected
            blockchain_log = None
            if self.use_blockchain and self.blockchain and self.blockchain.is_connected:
                try:
                    with st.status("Logging query to blockchain..."):
                        log_result = self.blockchain.log_query(query, answer)
                        
                        if log_result.get("status"):  # Success
                            blockchain_log = {
                                "logged": True,
                                "query_id": log_result.get("query_id", ""),
                                "tx_hash": log_result.get("tx_hash", ""),
                                "block_number": log_result.get("block_number", 0)
                            }
                        else:
                            st.error(f"Error logging to blockchain: {log_result.get('error', 'Unknown error')}")
                except Exception as e:
                    st.error(f"Error logging to blockchain: {str(e)}")
                
            return {
                "answer": answer,
                "sources": sources,
                "query_time": query_time,
                "blockchain_log": blockchain_log,
                "method": "enhanced"
            }
                
        except Exception as e:
            st.error(f"Error in enhanced retrieval: {str(e)}")
            return f"Error: {str(e)}"

    def ask(self, query, method="enhanced"):
        """
        Ask a question using the specified retrieval method.
        
        Args:
            query: User's question
            method: Retrieval method ("direct" or "enhanced")
            
        Returns:
            dict: Results from the specified retrieval method
        """
        if method == "direct":
            return self.direct_retrieval(query)
        else:
            return self.enhanced_retrieval(query)

    def get_performance_metrics(self):
        """Return performance metrics for the RAG system."""
        if not self.processing_times:
            return None
            
        return {
            "documents_processed": self.documents_processed,
            "index_building_time": self.processing_times.get("index_building", 0),
            "total_processing_time": self.processing_times.get("total_time", 0),
            "memory_used_gb": self.processing_times.get("memory_used_gb", 0),
            "device": self.device,
            "embedding_model": self.embedding_model_name,
            "blockchain_enabled": self.use_blockchain,
            "blockchain_connected": self.blockchain.is_connected if self.blockchain else False
        }