import os import asyncio from concurrent.futures import ThreadPoolExecutor from model2vec import StaticModel from transformers import AutoConfig from sentence_transformers import SentenceTransformer import torch import numpy as np from sklearn.metrics.pairwise import cosine_similarity from src.utils.api_key_manager import APIKeyManager from src.helpers.helper import chunk_text class LateChunker: def __init__( self, model_name='minishlab/potion-base-8M', max_workers=os.cpu_count() * 2, verbose=False ): self.verbose = verbose self.device = "cuda" if torch.cuda.is_available() else "cpu" self.llm = APIKeyManager().get_llm() self.model_name = model_name # Initialize model using the fallback strategy self.model, self.context_length = self._initialize_model() # Initialize ThreadPoolExecutor self.executor = ThreadPoolExecutor(max_workers=max_workers) def _initialize_model(self): sentence_transformer_error = None model2vec_error = None # First attempt: Try SentenceTransformer try: # Get the model config to check max context length config = AutoConfig.from_pretrained(self.model_name) max_length = config.max_position_embeddings # Initialize SentenceTransformer model model = SentenceTransformer(self.model_name, trust_remote_code=True) model.max_seq_length = max_length # Set the correct max length model.to(self.device).half() context_length = model.max_seq_length return model, context_length except Exception as e: sentence_transformer_error = str(e) # Second attempt: Try Model2Vec try: # Initialize Model2Vec model model = StaticModel.from_pretrained( self.model_name ) # Get max sequence length from static model config context_length = model.config['seq_length'] return model, context_length except Exception as e: model2vec_error = str(e) error_msg = ( f"Failed to load model {self.model_name}.\n" f"SentenceTransformer error: {sentence_transformer_error}\n" f"Model2Vec error: {model2vec_error}" ) raise Exception(error_msg) from e async def late_chunking(self, text, span_annotations, current_chunk_idx=None, total_chunks=None): print(f"Processing chunk {current_chunk_idx+1}/{total_chunks}...") \ if self.verbose else None # Get the current running event loop loop = asyncio.get_running_loop() # Generate chunk embeddings chunk_embeddings = [] for start, end in span_annotations: chunk_text = text[start:end] print("Generating chunk embeddings...") if self.verbose else None chunk_embedding = await loop.run_in_executor( self.executor, lambda: torch.tensor( self.model.encode( chunk_text, convert_to_tensor=True ) ) ) if isinstance(chunk_embedding, torch.Tensor): chunk_embedding = chunk_embedding.clone().detach().to(self.device) print(f"Chunk embedding shape: {chunk_embedding.shape}") if self.verbose else None chunk_embeddings.append(chunk_embedding) print("Late Chunking applied successfully!") if self.verbose else None return chunk_embeddings if chunk_embeddings else None def get_text_embedding(self, text): embeddings = self.model.encode(text, convert_to_tensor=True) if isinstance(embeddings, torch.Tensor): return embeddings.clone().detach().to(self.device) return torch.tensor(embeddings).to(self.device) def calculate_embedding_similarities(self, text1_embedding, text2_embedding): text1_embedding = text1_embedding.cpu().numpy() text2_embedding = text2_embedding.cpu().numpy() if text1_embedding.ndim == 1: text1_embedding = text1_embedding.reshape(1, -1) if text2_embedding.ndim == 1: text2_embedding = text2_embedding.reshape(1, -1) if text1_embedding.shape[1] != text2_embedding.shape[1]: text1_embedding = text1_embedding.T if text2_embedding.shape[1] != text1_embedding.shape[1]: text2_embedding = text2_embedding.T return cosine_similarity(text1_embedding, text2_embedding)[0] def select_relevant_chunks(self, similarities, chunks, max_tokens): sorted_indices = np.argsort(similarities)[::-1] selected_chunks = [] total_tokens = 0 for i, idx in enumerate(sorted_indices): print(f"Selected chunk {i+1}/{len(sorted_indices)} with similarity {similarities[idx]:.2f}") \ if self.verbose else None chunk_tokens = self.llm.get_num_tokens(chunks[idx]) print(f"Chunk tokens: {chunk_tokens}") if self.verbose else None if total_tokens + chunk_tokens > max_tokens: print(f"Total tokens exceed max tokens allowed ({total_tokens} > {max_tokens}). \ Stopping chunk selection.") if self.verbose else None break selected_chunks.append((idx, chunks[idx])) total_tokens += chunk_tokens print("Sorting selected chunks...") if self.verbose else None selected_chunks.sort(key=lambda x: x[0]) print("Selected chunks sorted successfully!") if self.verbose else None return " ".join([chunk for _, chunk in selected_chunks]) async def chunker(self, text, query, max_chunk_length=1000, max_tokens=2048, overlap=200): # Tokenize the entire text to check its length total_tokens = self.llm.get_num_tokens(text) # If the text is less than max tokens, return the text as is if total_tokens <= max_tokens: print(f"Text is less than the max tokens allowed ({total_tokens} <= {max_tokens}). \ Returning original text.") if self.verbose else None return text # Chunk the text if it exceeds max tokens print(f"Text is greater than the max tokens allowed ({total_tokens} > {max_tokens}). \ Chunking text...") if self.verbose else None chunks, span_annotations = chunk_text( text, max_chunk_length=max_chunk_length, overlap=overlap, # Use the smaller of either context length or max tokens context_length=min(self.context_length, max_tokens) ) print(f"Text chunked into {len(chunks)} macro chunks.") if self.verbose else None # Process each macro chunk individually chunk_embeddings = [] tasks = [] for i, macro_chunk in enumerate(chunks): # Adjust span annotations relative to the current macro chunk start_offset = span_annotations[i][0] adjusted_spans = [ (start - start_offset, end - start_offset) for start, end in span_annotations if start >= start_offset and end <= start_offset + len(macro_chunk) ] # Apply late chunking for the current macro chunk tasks.append(self.late_chunking(macro_chunk, adjusted_spans, i, len(chunks))) # Aggregate embeddings asynchronously results = await asyncio.gather(*tasks) chunk_embeddings = torch.stack([result[0] for result in results]) # Generate query embedding print("Generating query embedding...") if self.verbose else None query_embedding = self.get_text_embedding(query) print(f"Query embedding shape: {query_embedding.shape}") if self.verbose else None # Calculate similarities between query embedding and chunk embeddings print("Calculating embedding similarities...") if self.verbose else None similarities = self.calculate_embedding_similarities(query_embedding, chunk_embeddings) print(f"Similarities shape: {similarities.shape}") if self.verbose else None # Select relevant chunks based on similarity print("Selecting relevant chunks...") if self.verbose else None return self.select_relevant_chunks(similarities, chunks, max_tokens) if __name__ == "__main__": from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics.pairwise import cosine_similarity from src.reasoning.reasoner import Reasoner from src.search.search_engine import SearchEngine from src.crawl.crawler import CustomCrawler import time search_engine = SearchEngine() crawler = CustomCrawler() reasoner = Reasoner() chunking = LateChunker(verbose=True) loop = asyncio.new_event_loop() search1 = loop.run_until_complete(search_engine.search( "What is the history of climate change and pollution since the pre-indutrial revolution?", num_results=20, exclude_filetypes=["pdf"] )) urls = [result["link"] for result in search1] search2 = loop.run_until_complete(search_engine.search( "What is the impact of climate change on the Indian economy?", num_results=20, exclude_filetypes=["pdf"] )) urls.extend([result["link"] for result in search2]) search3 = loop.run_until_complete(search_engine.search( "What are some of the latest, state of art techniques used to fight climate change?", num_results=20, exclude_filetypes=["pdf"] )) urls.extend([result["link"] for result in search3]) search4 = loop.run_until_complete(search_engine.search( "What does the projection for climate change look like in the next 50 years?", num_results=20, exclude_filetypes=["pdf"] )) urls.extend([result["link"] for result in search4]) search5 = loop.run_until_complete(search_engine.search( "What efforts are being made by governments all around the world to combat climate change?", num_results=20, exclude_filetypes=["pdf"] )) urls.extend([result["link"] for result in search5]) results = loop.run_until_complete(crawler.fetch_page_contents( urls=urls, max_attempts=1, delay=0 )) text = "\n".join([f"Document {i}:\n{result}\n" for i, result in enumerate(results)]) num_tokens_before_chunking = chunking.llm.get_num_tokens(text) start_time = time.perf_counter() response = loop.run_until_complete(chunking.chunker( text, query="What is this text about? Give me a detailed answer", max_tokens=128000 )) end_time = time.perf_counter() num_tokens_after_chunking = chunking.llm.get_num_tokens(response) print(f"\nResponse:\n{response}") print(f"\nNumber of URLs: {len(urls)}") print(f"\nNumber of tokens before late chunking: {num_tokens_before_chunking}") print(f"\nNumber of tokens after late chunking: {num_tokens_after_chunking}") print(f"\nTime taken: {end_time - start_time:.2f} seconds") # Calculate cosine similarity between original text and response def calculate_cosine_similarity(text1, text2): vectorizer = TfidfVectorizer().fit_transform([text1, text2]) vectors = vectorizer.toarray() return cosine_similarity(vectors)[0][1] similarity = calculate_cosine_similarity(text, response) print(f"\nCosine similarity between original text and late chunked text: {similarity * 100:.2f}%")