Spaces:
Paused
Paused
import os | |
import asyncio | |
from concurrent.futures import ThreadPoolExecutor | |
from model2vec import StaticModel | |
from transformers import AutoConfig | |
from sentence_transformers import SentenceTransformer | |
import torch | |
import numpy as np | |
from sklearn.metrics.pairwise import cosine_similarity | |
from src.utils.api_key_manager import APIKeyManager | |
from src.helpers.helper import chunk_text | |
class LateChunker: | |
def __init__( | |
self, | |
model_name='minishlab/potion-base-8M', | |
max_workers=os.cpu_count() * 2, | |
verbose=False | |
): | |
self.verbose = verbose | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.llm = APIKeyManager().get_llm() | |
self.model_name = model_name | |
# Initialize model using the fallback strategy | |
self.model, self.context_length = self._initialize_model() | |
# Initialize ThreadPoolExecutor | |
self.executor = ThreadPoolExecutor(max_workers=max_workers) | |
def _initialize_model(self): | |
sentence_transformer_error = None | |
model2vec_error = None | |
# First attempt: Try SentenceTransformer | |
try: | |
# Get the model config to check max context length | |
config = AutoConfig.from_pretrained(self.model_name) | |
max_length = config.max_position_embeddings | |
# Initialize SentenceTransformer model | |
model = SentenceTransformer(self.model_name, trust_remote_code=True) | |
model.max_seq_length = max_length # Set the correct max length | |
model.to(self.device).half() | |
context_length = model.max_seq_length | |
return model, context_length | |
except Exception as e: | |
sentence_transformer_error = str(e) | |
# Second attempt: Try Model2Vec | |
try: | |
# Initialize Model2Vec model | |
model = StaticModel.from_pretrained( | |
self.model_name | |
) | |
# Get max sequence length from static model config | |
context_length = model.config['seq_length'] | |
return model, context_length | |
except Exception as e: | |
model2vec_error = str(e) | |
error_msg = ( | |
f"Failed to load model {self.model_name}.\n" | |
f"SentenceTransformer error: {sentence_transformer_error}\n" | |
f"Model2Vec error: {model2vec_error}" | |
) | |
raise Exception(error_msg) from e | |
async def late_chunking(self, text, span_annotations, current_chunk_idx=None, total_chunks=None): | |
print(f"Processing chunk {current_chunk_idx+1}/{total_chunks}...") \ | |
if self.verbose else None | |
# Get the current running event loop | |
loop = asyncio.get_running_loop() | |
# Generate chunk embeddings | |
chunk_embeddings = [] | |
for start, end in span_annotations: | |
chunk_text = text[start:end] | |
print("Generating chunk embeddings...") if self.verbose else None | |
chunk_embedding = await loop.run_in_executor( | |
self.executor, | |
lambda: torch.tensor( | |
self.model.encode( | |
chunk_text, | |
convert_to_tensor=True | |
) | |
) | |
) | |
if isinstance(chunk_embedding, torch.Tensor): | |
chunk_embedding = chunk_embedding.clone().detach().to(self.device) | |
print(f"Chunk embedding shape: {chunk_embedding.shape}") if self.verbose else None | |
chunk_embeddings.append(chunk_embedding) | |
print("Late Chunking applied successfully!") if self.verbose else None | |
return chunk_embeddings if chunk_embeddings else None | |
def get_text_embedding(self, text): | |
embeddings = self.model.encode(text, convert_to_tensor=True) | |
if isinstance(embeddings, torch.Tensor): | |
return embeddings.clone().detach().to(self.device) | |
return torch.tensor(embeddings).to(self.device) | |
def calculate_embedding_similarities(self, text1_embedding, text2_embedding): | |
text1_embedding = text1_embedding.cpu().numpy() | |
text2_embedding = text2_embedding.cpu().numpy() | |
if text1_embedding.ndim == 1: | |
text1_embedding = text1_embedding.reshape(1, -1) | |
if text2_embedding.ndim == 1: | |
text2_embedding = text2_embedding.reshape(1, -1) | |
if text1_embedding.shape[1] != text2_embedding.shape[1]: | |
text1_embedding = text1_embedding.T | |
if text2_embedding.shape[1] != text1_embedding.shape[1]: | |
text2_embedding = text2_embedding.T | |
return cosine_similarity(text1_embedding, text2_embedding)[0] | |
def select_relevant_chunks(self, similarities, chunks, max_tokens): | |
sorted_indices = np.argsort(similarities)[::-1] | |
selected_chunks = [] | |
total_tokens = 0 | |
for i, idx in enumerate(sorted_indices): | |
print(f"Selected chunk {i+1}/{len(sorted_indices)} with similarity {similarities[idx]:.2f}") \ | |
if self.verbose else None | |
chunk_tokens = self.llm.get_num_tokens(chunks[idx]) | |
print(f"Chunk tokens: {chunk_tokens}") if self.verbose else None | |
if total_tokens + chunk_tokens > max_tokens: | |
print(f"Total tokens exceed max tokens allowed ({total_tokens} > {max_tokens}). \ | |
Stopping chunk selection.") if self.verbose else None | |
break | |
selected_chunks.append((idx, chunks[idx])) | |
total_tokens += chunk_tokens | |
print("Sorting selected chunks...") if self.verbose else None | |
selected_chunks.sort(key=lambda x: x[0]) | |
print("Selected chunks sorted successfully!") if self.verbose else None | |
return " ".join([chunk for _, chunk in selected_chunks]) | |
async def chunker(self, text, query, max_chunk_length=1000, max_tokens=2048, overlap=200): | |
# Tokenize the entire text to check its length | |
total_tokens = self.llm.get_num_tokens(text) | |
# If the text is less than max tokens, return the text as is | |
if total_tokens <= max_tokens: | |
print(f"Text is less than the max tokens allowed ({total_tokens} <= {max_tokens}). \ | |
Returning original text.") if self.verbose else None | |
return text | |
# Chunk the text if it exceeds max tokens | |
print(f"Text is greater than the max tokens allowed ({total_tokens} > {max_tokens}). \ | |
Chunking text...") if self.verbose else None | |
chunks, span_annotations = chunk_text( | |
text, | |
max_chunk_length=max_chunk_length, | |
overlap=overlap, | |
# Use the smaller of either context length or max tokens | |
context_length=min(self.context_length, max_tokens) | |
) | |
print(f"Text chunked into {len(chunks)} macro chunks.") if self.verbose else None | |
# Process each macro chunk individually | |
chunk_embeddings = [] | |
tasks = [] | |
for i, macro_chunk in enumerate(chunks): | |
# Adjust span annotations relative to the current macro chunk | |
start_offset = span_annotations[i][0] | |
adjusted_spans = [ | |
(start - start_offset, end - start_offset) | |
for start, end in span_annotations | |
if start >= start_offset and end <= start_offset + len(macro_chunk) | |
] | |
# Apply late chunking for the current macro chunk | |
tasks.append(self.late_chunking(macro_chunk, adjusted_spans, i, len(chunks))) | |
# Aggregate embeddings asynchronously | |
results = await asyncio.gather(*tasks) | |
chunk_embeddings = torch.stack([result[0] for result in results]) | |
# Generate query embedding | |
print("Generating query embedding...") if self.verbose else None | |
query_embedding = self.get_text_embedding(query) | |
print(f"Query embedding shape: {query_embedding.shape}") if self.verbose else None | |
# Calculate similarities between query embedding and chunk embeddings | |
print("Calculating embedding similarities...") if self.verbose else None | |
similarities = self.calculate_embedding_similarities(query_embedding, chunk_embeddings) | |
print(f"Similarities shape: {similarities.shape}") if self.verbose else None | |
# Select relevant chunks based on similarity | |
print("Selecting relevant chunks...") if self.verbose else None | |
return self.select_relevant_chunks(similarities, chunks, max_tokens) | |
if __name__ == "__main__": | |
from sklearn.feature_extraction.text import TfidfVectorizer | |
from sklearn.metrics.pairwise import cosine_similarity | |
from src.reasoning.reasoner import Reasoner | |
from src.search.search_engine import SearchEngine | |
from src.crawl.crawler import CustomCrawler | |
import time | |
search_engine = SearchEngine() | |
crawler = CustomCrawler() | |
reasoner = Reasoner() | |
chunking = LateChunker(verbose=True) | |
loop = asyncio.new_event_loop() | |
search1 = loop.run_until_complete(search_engine.search( | |
"What is the history of climate change and pollution since the pre-indutrial revolution?", | |
num_results=20, | |
exclude_filetypes=["pdf"] | |
)) | |
urls = [result["link"] for result in search1] | |
search2 = loop.run_until_complete(search_engine.search( | |
"What is the impact of climate change on the Indian economy?", | |
num_results=20, | |
exclude_filetypes=["pdf"] | |
)) | |
urls.extend([result["link"] for result in search2]) | |
search3 = loop.run_until_complete(search_engine.search( | |
"What are some of the latest, state of art techniques used to fight climate change?", | |
num_results=20, | |
exclude_filetypes=["pdf"] | |
)) | |
urls.extend([result["link"] for result in search3]) | |
search4 = loop.run_until_complete(search_engine.search( | |
"What does the projection for climate change look like in the next 50 years?", | |
num_results=20, | |
exclude_filetypes=["pdf"] | |
)) | |
urls.extend([result["link"] for result in search4]) | |
search5 = loop.run_until_complete(search_engine.search( | |
"What efforts are being made by governments all around the world to combat climate change?", | |
num_results=20, | |
exclude_filetypes=["pdf"] | |
)) | |
urls.extend([result["link"] for result in search5]) | |
results = loop.run_until_complete(crawler.fetch_page_contents( | |
urls=urls, | |
max_attempts=1, | |
delay=0 | |
)) | |
text = "\n".join([f"Document {i}:\n{result}\n" for i, result in enumerate(results)]) | |
num_tokens_before_chunking = chunking.llm.get_num_tokens(text) | |
start_time = time.perf_counter() | |
response = loop.run_until_complete(chunking.chunker( | |
text, | |
query="What is this text about? Give me a detailed answer", | |
max_tokens=128000 | |
)) | |
end_time = time.perf_counter() | |
num_tokens_after_chunking = chunking.llm.get_num_tokens(response) | |
print(f"\nResponse:\n{response}") | |
print(f"\nNumber of URLs: {len(urls)}") | |
print(f"\nNumber of tokens before late chunking: {num_tokens_before_chunking}") | |
print(f"\nNumber of tokens after late chunking: {num_tokens_after_chunking}") | |
print(f"\nTime taken: {end_time - start_time:.2f} seconds") | |
# Calculate cosine similarity between original text and response | |
def calculate_cosine_similarity(text1, text2): | |
vectorizer = TfidfVectorizer().fit_transform([text1, text2]) | |
vectors = vectorizer.toarray() | |
return cosine_similarity(vectors)[0][1] | |
similarity = calculate_cosine_similarity(text, response) | |
print(f"\nCosine similarity between original text and late chunked text: {similarity * 100:.2f}%") | |