Spaces:
Build error
Build error
import sys | |
import os | |
import boto3 | |
import hashlib | |
import json | |
import threading | |
# Add the project root directory to Python path | |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
from typing import List | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from concurrent.futures import ThreadPoolExecutor, as_completed | |
from langchain_openai import OpenAIEmbeddings | |
import pinecone | |
from tqdm.auto import tqdm | |
from langchain.schema import Document | |
from config import get_settings | |
from dotenv import load_dotenv | |
from io import BytesIO | |
from PyPDF2 import PdfReader | |
load_dotenv() | |
class RAGPrep: | |
def __init__(self, processed_hashes_file="processed_hashes.json"): | |
self.settings = get_settings() | |
self.index_name = self.settings.INDEX_NAME | |
self.pc = self.init_pinecone() | |
self.embeddings = OpenAIEmbeddings(openai_api_key=self.settings.OPENAI_API_KEY) | |
self.processed_hashes_file = processed_hashes_file | |
self.processed_hashes = self.load_processed_hashes() | |
def init_pinecone(self): | |
"""Initialize Pinecone client""" | |
pc = pinecone.Pinecone(self.settings.PINECONE_API_KEY) | |
return pc | |
# Define function to create or connect to an existing index | |
def create_or_connect_index(self,index_name, dimension): | |
"""Create or connect to existing Pinecone index""" | |
spec = pinecone.ServerlessSpec( | |
cloud=self.settings.CLOUD, | |
region=self.settings.REGION | |
) | |
print(f'all indexes: {self.pc.list_indexes()}') | |
if index_name not in self.pc.list_indexes().names(): | |
self.pc.create_index( | |
name=index_name, | |
dimension=dimension, | |
metric='cosine', # You can use 'dotproduct' or other metrics if needed | |
spec=spec | |
) | |
return self.pc.Index(index_name) | |
def load_processed_hashes(self): | |
"""Load previously processed hashes from a file.""" | |
if os.path.exists(self.processed_hashes_file): | |
with open(self.processed_hashes_file, "r") as f: | |
return set(json.load(f)) | |
return set() | |
def save_processed_hashes(self): | |
"""Save processed hashes to a file.""" | |
with open(self.processed_hashes_file, "w") as f: | |
json.dump(list(self.processed_hashes), f) | |
def generate_pdf_hash(self, pdf_content: bytes): | |
"""Generate a hash for the given PDF content.""" | |
hasher = hashlib.md5() | |
hasher.update(pdf_content) | |
return hasher.hexdigest() | |
def load_and_split_pdfs(self, chunk_from = 50, chunk_to = 100) -> List[Document]: | |
"""Load PDFs from S3, extract text, and split into chunks.""" | |
print("***********") | |
# Initialize S3 client | |
s3_client = boto3.client( | |
's3', | |
aws_access_key_id=self.settings.AWS_ACCESS_KEY, | |
aws_secret_access_key=self.settings.AWS_SECRET_KEY, | |
region_name=self.settings.AWS_REGION | |
) | |
# List all PDF files in the S3 bucket and prefix | |
print(f"Listing files in S3 bucket: {self.settings.AWS_BUCKET_NAME}") | |
response = s3_client.list_objects_v2(Bucket=self.settings.AWS_BUCKET_NAME, Prefix="") | |
s3_keys = [obj['Key'] for obj in response.get('Contents', [])] | |
print(f"Found {len(s3_keys)} PDF files in S3") | |
documents = [] | |
# Process each PDF file | |
for s3_key in s3_keys[chunk_from:chunk_to]: | |
print(f"Processing file: {s3_key}") | |
if not s3_key.lower().endswith(".pdf"): | |
print("Not a PDF file, skipping.") | |
continue | |
try: | |
# Read file from S3 | |
obj = s3_client.get_object(Bucket=self.settings.AWS_BUCKET_NAME, Key=s3_key) | |
pdf_content = obj['Body'].read() | |
# Generate hash and check for duplicates | |
pdf_hash = self.generate_pdf_hash(pdf_content) | |
if pdf_hash in self.processed_hashes: | |
print(f"Duplicate PDF detected: {s3_key}, skipping.") | |
continue | |
# Extract text from PDF | |
pdf_file = BytesIO(pdf_content) | |
pdf_reader = PdfReader(pdf_file) | |
text = "".join(page.extract_text() for page in pdf_reader.pages) | |
# Add document with metadata | |
documents.append(Document(page_content=text, metadata={"source": s3_key})) | |
self.processed_hashes.add(pdf_hash) | |
except Exception as e: | |
print(f"Error processing {s3_key}: {e}") | |
print(f"Extracted text from {len(documents)} documents") | |
# Split documents into chunks | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=self.settings.CHUNK_SIZE, | |
chunk_overlap=self.settings.CHUNK_OVERLAP | |
) | |
chunks = text_splitter.split_documents(documents) | |
print(f"Created {len(chunks)} chunks") | |
# Save updated hashes | |
self.save_processed_hashes() | |
return chunks | |
def process_and_upload(self, total_batch=200): | |
"""Process PDFs and upload to Pinecone""" | |
# Create or connect to index | |
index = self.create_or_connect_index(self.index_name, self.settings.DIMENSIONS) | |
# Load and split documents | |
print(f'//////// chunking: ////////') | |
current_batch = 0 | |
for i in range(0, total_batch, 50): | |
batch_size = 50 # Adjust based on your needs | |
chunks = self.load_and_split_pdfs(current_batch, current_batch+batch_size) | |
current_batch = current_batch + batch_size | |
# Prepare for batch processing | |
max_threads = 4 # Adjust based on your hardware | |
def process_batch(batch, batch_index): | |
"""Process a single batch of chunks""" | |
print(f"Processing batch {batch_index} on thread: {threading.current_thread().name}") | |
print(f"Active threads: {threading.active_count()}") | |
# Create ids for batch | |
ids = [f"chunk_{batch_index}_{j}" for j in range(len(batch))] | |
# Get texts and generate embeddings | |
texts = [doc.page_content for doc in batch] | |
embeddings = self.embeddings.embed_documents(texts) | |
# Create metadata | |
metadata = [ | |
{ | |
"text": doc.page_content, | |
"source": doc.metadata.get("source", "unknown"), | |
"page": doc.metadata.get("page", 0) | |
} | |
for doc in batch | |
] | |
# Create upsert batch | |
return list(zip(ids, embeddings, metadata)) | |
with ThreadPoolExecutor(max_threads) as executor: | |
futures = [] | |
print(f"Batch size being used: {batch_size}") | |
for i in range(0, len(chunks), batch_size): | |
batch = chunks[i:i + batch_size] | |
futures.append(executor.submit(process_batch, batch, i)) | |
# Gather results and upsert to Pinecone | |
for future in tqdm(as_completed(futures), total=len(futures), desc="Uploading batches"): | |
try: | |
to_upsert = future.result() | |
index.upsert(vectors=to_upsert) | |
except Exception as e: | |
print(f"Error processing batch: {e}") | |
print(f"Successfully processed and uploaded {len(chunks)} chunks to Pinecone") | |
def cleanup_index(self) -> bool: | |
""" | |
Delete all vectors from the Pinecone index. | |
Returns: | |
bool: True if cleanup was successful, False otherwise | |
Raises: | |
Exception: Logs any unexpected errors during cleanup | |
""" | |
try: | |
# Try to get the index | |
if self.index_name in self.pc.list_indexes().names(): | |
print(f'index name found in {self.pc.list_indexes().names()}') | |
# Attempt to delete all vectors | |
index = self.pc.Index(self.index_name) | |
index.delete(delete_all=True) | |
print(f"Successfully cleaned up index: {self.index_name}") | |
return True | |
print(f'Index doesn\'t exist.') | |
return True | |
except Exception as e: | |
print(f"Unexpected error during index cleanup: {str(e)}") | |
# You might want to log this error as well | |
import logging | |
logging.error(f"Failed to cleanup index {self.index_name}. Error: {str(e)}") | |
return False | |
finally: | |
# Any cleanup code that should run regardless of success/failure | |
print("Cleanup operation completed.") | |
# Example usage: | |
if __name__ == "__main__": | |
# Example .env file content: | |
rag_prep = RAGPrep() | |
rag_prep.process_and_upload() | |
# rag_prep.cleanup_index() |