File size: 11,952 Bytes
4279593
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
import os
import asyncio
from concurrent.futures import ThreadPoolExecutor
from model2vec import StaticModel
from transformers import AutoConfig
from sentence_transformers import SentenceTransformer
import torch
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from src.utils.api_key_manager import APIKeyManager
from src.helpers.helper import chunk_text

class LateChunker:
    def __init__(
            self, 
            model_name='minishlab/potion-base-8M', 
            max_workers=os.cpu_count() * 2,
            verbose=False
        ):
        self.verbose = verbose

        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.llm = APIKeyManager().get_llm()
        self.model_name = model_name
        
        # Initialize model using the fallback strategy
        self.model, self.context_length = self._initialize_model()
        
        # Initialize ThreadPoolExecutor
        self.executor = ThreadPoolExecutor(max_workers=max_workers)

    def _initialize_model(self):
        sentence_transformer_error = None
        model2vec_error = None

        # First attempt: Try SentenceTransformer
        try:
            # Get the model config to check max context length
            config = AutoConfig.from_pretrained(self.model_name)
            max_length = config.max_position_embeddings
            
            # Initialize SentenceTransformer model
            model = SentenceTransformer(self.model_name, trust_remote_code=True)
            model.max_seq_length = max_length  # Set the correct max length
            model.to(self.device).half()
            context_length = model.max_seq_length
            return model, context_length
        except Exception as e:
            sentence_transformer_error = str(e)

        # Second attempt: Try Model2Vec
        try:
            # Initialize Model2Vec model
            model = StaticModel.from_pretrained(
                self.model_name
            )
            # Get max sequence length from static model config
            context_length = model.config['seq_length']
            return model, context_length
        except Exception as e:
            model2vec_error = str(e)
            error_msg = (
                f"Failed to load model {self.model_name}.\n"
                f"SentenceTransformer error: {sentence_transformer_error}\n"
                f"Model2Vec error: {model2vec_error}"
            )
            raise Exception(error_msg) from e

    async def late_chunking(self, text, span_annotations, current_chunk_idx=None, total_chunks=None):
        print(f"Processing chunk {current_chunk_idx+1}/{total_chunks}...") \
        if self.verbose else None

        # Get the current running event loop
        loop = asyncio.get_running_loop()

        # Generate chunk embeddings
        chunk_embeddings = []
        for start, end in span_annotations:
            chunk_text = text[start:end]
            print("Generating chunk embeddings...") if self.verbose else None
            chunk_embedding = await loop.run_in_executor(
                self.executor,
                lambda: torch.tensor(
                    self.model.encode(
                        chunk_text, 
                        convert_to_tensor=True
                    )
                )
            )
            if isinstance(chunk_embedding, torch.Tensor):
                chunk_embedding = chunk_embedding.clone().detach().to(self.device)
                
            print(f"Chunk embedding shape: {chunk_embedding.shape}") if self.verbose else None
            chunk_embeddings.append(chunk_embedding)

        print("Late Chunking applied successfully!") if self.verbose else None
        return chunk_embeddings if chunk_embeddings else None

    def get_text_embedding(self, text):
        embeddings = self.model.encode(text, convert_to_tensor=True)
        if isinstance(embeddings, torch.Tensor):
            return embeddings.clone().detach().to(self.device)
        return torch.tensor(embeddings).to(self.device)
    
    def calculate_embedding_similarities(self, text1_embedding, text2_embedding):            
            text1_embedding = text1_embedding.cpu().numpy() 
            text2_embedding = text2_embedding.cpu().numpy()

            if text1_embedding.ndim == 1:
                text1_embedding = text1_embedding.reshape(1, -1)
            if text2_embedding.ndim == 1:
                text2_embedding = text2_embedding.reshape(1, -1)

            if text1_embedding.shape[1] != text2_embedding.shape[1]:
                text1_embedding = text1_embedding.T
            if text2_embedding.shape[1] != text1_embedding.shape[1]:
                text2_embedding = text2_embedding.T

            return cosine_similarity(text1_embedding, text2_embedding)[0]
    
    def select_relevant_chunks(self, similarities, chunks, max_tokens):
            sorted_indices = np.argsort(similarities)[::-1]
            selected_chunks = []
            total_tokens = 0
            
            for i, idx in enumerate(sorted_indices):
                print(f"Selected chunk {i+1}/{len(sorted_indices)} with similarity {similarities[idx]:.2f}") \
                if self.verbose else None
                chunk_tokens = self.llm.get_num_tokens(chunks[idx])
                print(f"Chunk tokens: {chunk_tokens}") if self.verbose else None

                if total_tokens + chunk_tokens > max_tokens:
                    print(f"Total tokens exceed max tokens allowed ({total_tokens} > {max_tokens}). \
Stopping chunk selection.") if self.verbose else None
                    break

                selected_chunks.append((idx, chunks[idx]))
                total_tokens += chunk_tokens
            
            print("Sorting selected chunks...") if self.verbose else None
            selected_chunks.sort(key=lambda x: x[0])
            print("Selected chunks sorted successfully!") if self.verbose else None
            return " ".join([chunk for _, chunk in selected_chunks])

    async def chunker(self, text, query, max_chunk_length=1000, max_tokens=2048, overlap=200):
        # Tokenize the entire text to check its length
        total_tokens = self.llm.get_num_tokens(text)

        # If the text is less than max tokens, return the text as is
        if total_tokens <= max_tokens:
            print(f"Text is less than the max tokens allowed ({total_tokens} <= {max_tokens}). \
Returning original text.") if self.verbose else None
            return text
        
        # Chunk the text if it exceeds max tokens
        print(f"Text is greater than the max tokens allowed ({total_tokens} > {max_tokens}). \
Chunking text...") if self.verbose else None
        chunks, span_annotations = chunk_text(
            text,
            max_chunk_length=max_chunk_length,
            overlap=overlap,
            # Use the smaller of either context length or max tokens
            context_length=min(self.context_length, max_tokens)  
        )
        print(f"Text chunked into {len(chunks)} macro chunks.") if self.verbose else None

        # Process each macro chunk individually
        chunk_embeddings = []
        tasks = []

        for i, macro_chunk in enumerate(chunks):
            # Adjust span annotations relative to the current macro chunk
            start_offset = span_annotations[i][0]
            adjusted_spans = [
                (start - start_offset, end - start_offset)
                for start, end in span_annotations
                if start >= start_offset and end <= start_offset + len(macro_chunk)
            ]

            # Apply late chunking for the current macro chunk
            tasks.append(self.late_chunking(macro_chunk, adjusted_spans, i, len(chunks)))

        # Aggregate embeddings asynchronously
        results = await asyncio.gather(*tasks)
        chunk_embeddings = torch.stack([result[0] for result in results])

        # Generate query embedding
        print("Generating query embedding...") if self.verbose else None
        query_embedding = self.get_text_embedding(query)
        print(f"Query embedding shape: {query_embedding.shape}") if self.verbose else None

        # Calculate similarities between query embedding and chunk embeddings
        print("Calculating embedding similarities...") if self.verbose else None
        similarities = self.calculate_embedding_similarities(query_embedding, chunk_embeddings)
        print(f"Similarities shape: {similarities.shape}") if self.verbose else None

        # Select relevant chunks based on similarity
        print("Selecting relevant chunks...") if self.verbose else None
        return self.select_relevant_chunks(similarities, chunks, max_tokens)
            
if __name__ == "__main__":
    from sklearn.feature_extraction.text import TfidfVectorizer
    from sklearn.metrics.pairwise import cosine_similarity
    from src.reasoning.reasoner import Reasoner
    from src.search.search_engine import SearchEngine
    from src.crawl.crawler import CustomCrawler
    import time

    search_engine = SearchEngine()
    crawler = CustomCrawler()
    reasoner = Reasoner()
    chunking = LateChunker(verbose=True)

    loop = asyncio.new_event_loop()
    
    search1 = loop.run_until_complete(search_engine.search(
        "What is the history of climate change and pollution since the pre-indutrial revolution?",
        num_results=20,
        exclude_filetypes=["pdf"]
        ))
    urls = [result["link"] for result in search1]
    search2 = loop.run_until_complete(search_engine.search(
        "What is the impact of climate change on the Indian economy?",
        num_results=20,
        exclude_filetypes=["pdf"]
        ))
    urls.extend([result["link"] for result in search2])
    search3 = loop.run_until_complete(search_engine.search(
        "What are some of the latest, state of art techniques used to fight climate change?",
        num_results=20,
        exclude_filetypes=["pdf"]
        ))
    urls.extend([result["link"] for result in search3])
    search4 = loop.run_until_complete(search_engine.search(
        "What does the projection for climate change look like in the next 50 years?",
        num_results=20,
        exclude_filetypes=["pdf"]
        ))
    urls.extend([result["link"] for result in search4])
    search5 = loop.run_until_complete(search_engine.search(
        "What efforts are being made by governments all around the world to combat climate change?",
        num_results=20,
        exclude_filetypes=["pdf"]
        ))
    urls.extend([result["link"] for result in search5])

    results = loop.run_until_complete(crawler.fetch_page_contents(
        urls=urls,
        max_attempts=1,
        delay=0 
    ))
    text = "\n".join([f"Document {i}:\n{result}\n" for i, result in enumerate(results)])

    num_tokens_before_chunking = chunking.llm.get_num_tokens(text)
    start_time = time.perf_counter()
    response = loop.run_until_complete(chunking.chunker(
        text, 
        query="What is this text about? Give me a detailed answer",
        max_tokens=128000
    ))
    end_time = time.perf_counter()
    num_tokens_after_chunking = chunking.llm.get_num_tokens(response)
    print(f"\nResponse:\n{response}")
    print(f"\nNumber of URLs: {len(urls)}")
    print(f"\nNumber of tokens before late chunking: {num_tokens_before_chunking}")
    print(f"\nNumber of tokens after late chunking: {num_tokens_after_chunking}")
    print(f"\nTime taken: {end_time - start_time:.2f} seconds")

    # Calculate cosine similarity between original text and response
    def calculate_cosine_similarity(text1, text2):      
        vectorizer = TfidfVectorizer().fit_transform([text1, text2])
        vectors = vectorizer.toarray()
        return cosine_similarity(vectors)[0][1]
        
    similarity = calculate_cosine_similarity(text, response)
    print(f"\nCosine similarity between original text and late chunked text: {similarity * 100:.2f}%")