sehatech-demo / core /rag_engine.py
larawehbe's picture
Upload folder using huggingface_hub
965ac15 verified
raw
history blame contribute delete
9.32 kB
import sys
import os
import boto3
import hashlib
import json
import threading
# Add the project root directory to Python path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from typing import List
from langchain.text_splitter import RecursiveCharacterTextSplitter
from concurrent.futures import ThreadPoolExecutor, as_completed
from langchain_openai import OpenAIEmbeddings
import pinecone
from tqdm.auto import tqdm
from langchain.schema import Document
from config import get_settings
from dotenv import load_dotenv
from io import BytesIO
from PyPDF2 import PdfReader
load_dotenv()
class RAGPrep:
def __init__(self, processed_hashes_file="processed_hashes.json"):
self.settings = get_settings()
self.index_name = self.settings.INDEX_NAME
self.pc = self.init_pinecone()
self.embeddings = OpenAIEmbeddings(openai_api_key=self.settings.OPENAI_API_KEY)
self.processed_hashes_file = processed_hashes_file
self.processed_hashes = self.load_processed_hashes()
def init_pinecone(self):
"""Initialize Pinecone client"""
pc = pinecone.Pinecone(self.settings.PINECONE_API_KEY)
return pc
# Define function to create or connect to an existing index
def create_or_connect_index(self,index_name, dimension):
"""Create or connect to existing Pinecone index"""
spec = pinecone.ServerlessSpec(
cloud=self.settings.CLOUD,
region=self.settings.REGION
)
print(f'all indexes: {self.pc.list_indexes()}')
if index_name not in self.pc.list_indexes().names():
self.pc.create_index(
name=index_name,
dimension=dimension,
metric='cosine', # You can use 'dotproduct' or other metrics if needed
spec=spec
)
return self.pc.Index(index_name)
def load_processed_hashes(self):
"""Load previously processed hashes from a file."""
if os.path.exists(self.processed_hashes_file):
with open(self.processed_hashes_file, "r") as f:
return set(json.load(f))
return set()
def save_processed_hashes(self):
"""Save processed hashes to a file."""
with open(self.processed_hashes_file, "w") as f:
json.dump(list(self.processed_hashes), f)
def generate_pdf_hash(self, pdf_content: bytes):
"""Generate a hash for the given PDF content."""
hasher = hashlib.md5()
hasher.update(pdf_content)
return hasher.hexdigest()
def load_and_split_pdfs(self, chunk_from = 50, chunk_to = 100) -> List[Document]:
"""Load PDFs from S3, extract text, and split into chunks."""
print("***********")
# Initialize S3 client
s3_client = boto3.client(
's3',
aws_access_key_id=self.settings.AWS_ACCESS_KEY,
aws_secret_access_key=self.settings.AWS_SECRET_KEY,
region_name=self.settings.AWS_REGION
)
# List all PDF files in the S3 bucket and prefix
print(f"Listing files in S3 bucket: {self.settings.AWS_BUCKET_NAME}")
response = s3_client.list_objects_v2(Bucket=self.settings.AWS_BUCKET_NAME, Prefix="")
s3_keys = [obj['Key'] for obj in response.get('Contents', [])]
print(f"Found {len(s3_keys)} PDF files in S3")
documents = []
# Process each PDF file
for s3_key in s3_keys[chunk_from:chunk_to]:
print(f"Processing file: {s3_key}")
if not s3_key.lower().endswith(".pdf"):
print("Not a PDF file, skipping.")
continue
try:
# Read file from S3
obj = s3_client.get_object(Bucket=self.settings.AWS_BUCKET_NAME, Key=s3_key)
pdf_content = obj['Body'].read()
# Generate hash and check for duplicates
pdf_hash = self.generate_pdf_hash(pdf_content)
if pdf_hash in self.processed_hashes:
print(f"Duplicate PDF detected: {s3_key}, skipping.")
continue
# Extract text from PDF
pdf_file = BytesIO(pdf_content)
pdf_reader = PdfReader(pdf_file)
text = "".join(page.extract_text() for page in pdf_reader.pages)
# Add document with metadata
documents.append(Document(page_content=text, metadata={"source": s3_key}))
self.processed_hashes.add(pdf_hash)
except Exception as e:
print(f"Error processing {s3_key}: {e}")
print(f"Extracted text from {len(documents)} documents")
# Split documents into chunks
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=self.settings.CHUNK_SIZE,
chunk_overlap=self.settings.CHUNK_OVERLAP
)
chunks = text_splitter.split_documents(documents)
print(f"Created {len(chunks)} chunks")
# Save updated hashes
self.save_processed_hashes()
return chunks
def process_and_upload(self, total_batch=200):
"""Process PDFs and upload to Pinecone"""
# Create or connect to index
index = self.create_or_connect_index(self.index_name, self.settings.DIMENSIONS)
# Load and split documents
print(f'//////// chunking: ////////')
current_batch = 0
for i in range(0, total_batch, 50):
batch_size = 50 # Adjust based on your needs
chunks = self.load_and_split_pdfs(current_batch, current_batch+batch_size)
current_batch = current_batch + batch_size
# Prepare for batch processing
max_threads = 4 # Adjust based on your hardware
def process_batch(batch, batch_index):
"""Process a single batch of chunks"""
print(f"Processing batch {batch_index} on thread: {threading.current_thread().name}")
print(f"Active threads: {threading.active_count()}")
# Create ids for batch
ids = [f"chunk_{batch_index}_{j}" for j in range(len(batch))]
# Get texts and generate embeddings
texts = [doc.page_content for doc in batch]
embeddings = self.embeddings.embed_documents(texts)
# Create metadata
metadata = [
{
"text": doc.page_content,
"source": doc.metadata.get("source", "unknown"),
"page": doc.metadata.get("page", 0)
}
for doc in batch
]
# Create upsert batch
return list(zip(ids, embeddings, metadata))
with ThreadPoolExecutor(max_threads) as executor:
futures = []
print(f"Batch size being used: {batch_size}")
for i in range(0, len(chunks), batch_size):
batch = chunks[i:i + batch_size]
futures.append(executor.submit(process_batch, batch, i))
# Gather results and upsert to Pinecone
for future in tqdm(as_completed(futures), total=len(futures), desc="Uploading batches"):
try:
to_upsert = future.result()
index.upsert(vectors=to_upsert)
except Exception as e:
print(f"Error processing batch: {e}")
print(f"Successfully processed and uploaded {len(chunks)} chunks to Pinecone")
def cleanup_index(self) -> bool:
"""
Delete all vectors from the Pinecone index.
Returns:
bool: True if cleanup was successful, False otherwise
Raises:
Exception: Logs any unexpected errors during cleanup
"""
try:
# Try to get the index
if self.index_name in self.pc.list_indexes().names():
print(f'index name found in {self.pc.list_indexes().names()}')
# Attempt to delete all vectors
index = self.pc.Index(self.index_name)
index.delete(delete_all=True)
print(f"Successfully cleaned up index: {self.index_name}")
return True
print(f'Index doesn\'t exist.')
return True
except Exception as e:
print(f"Unexpected error during index cleanup: {str(e)}")
# You might want to log this error as well
import logging
logging.error(f"Failed to cleanup index {self.index_name}. Error: {str(e)}")
return False
finally:
# Any cleanup code that should run regardless of success/failure
print("Cleanup operation completed.")
# Example usage:
if __name__ == "__main__":
# Example .env file content:
rag_prep = RAGPrep()
rag_prep.process_and_upload()
# rag_prep.cleanup_index()