File size: 5,444 Bytes
b34efa5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
"""
Retriever module for Norwegian RAG chatbot.
Retrieves relevant document chunks based on query embeddings.
"""

import os
import json
import numpy as np
from typing import List, Dict, Any, Optional, Tuple, Union

from ..api.huggingface_api import HuggingFaceAPI
from ..api.config import MAX_CHUNKS_TO_RETRIEVE, SIMILARITY_THRESHOLD

class Retriever:
    """
    Retrieves relevant document chunks based on query embeddings.
    Uses cosine similarity to find the most relevant chunks.
    """
    
    def __init__(
        self,
        api_client: Optional[HuggingFaceAPI] = None,
        processed_dir: str = "/home/ubuntu/chatbot_project/data/processed",
        max_chunks: int = MAX_CHUNKS_TO_RETRIEVE,
        similarity_threshold: float = SIMILARITY_THRESHOLD
    ):
        """
        Initialize the retriever.
        
        Args:
            api_client: HuggingFaceAPI client for generating embeddings
            processed_dir: Directory containing processed documents
            max_chunks: Maximum number of chunks to retrieve
            similarity_threshold: Minimum similarity score for retrieval
        """
        self.api_client = api_client or HuggingFaceAPI()
        self.processed_dir = processed_dir
        self.max_chunks = max_chunks
        self.similarity_threshold = similarity_threshold
        
        # Load document index
        self.document_index_path = os.path.join(self.processed_dir, "document_index.json")
        self.document_index = self._load_document_index()
    
    def retrieve(self, query: str) -> List[Dict[str, Any]]:
        """
        Retrieve relevant document chunks for a query.
        
        Args:
            query: User query
            
        Returns:
            List of retrieved chunks with metadata
        """
        # Generate embedding for the query
        query_embedding = self.api_client.generate_embeddings(query)[0]
        
        # Find relevant chunks across all documents
        all_results = []
        
        for doc_id in self.document_index:
            try:
                # Load document data
                doc_results = self._retrieve_from_document(doc_id, query_embedding)
                all_results.extend(doc_results)
            except Exception as e:
                print(f"Error retrieving from document {doc_id}: {str(e)}")
        
        # Sort all results by similarity score
        all_results.sort(key=lambda x: x["similarity"], reverse=True)
        
        # Return top results above threshold
        return [
            result for result in all_results[:self.max_chunks]
            if result["similarity"] >= self.similarity_threshold
        ]
    
    def _retrieve_from_document(
        self,
        document_id: str,
        query_embedding: List[float]
    ) -> List[Dict[str, Any]]:
        """
        Retrieve relevant chunks from a specific document.
        
        Args:
            document_id: Document ID
            query_embedding: Query embedding vector
            
        Returns:
            List of retrieved chunks with metadata
        """
        document_path = os.path.join(self.processed_dir, f"{document_id}.json")
        if not os.path.exists(document_path):
            return []
        
        # Load document data
        with open(document_path, 'r', encoding='utf-8') as f:
            document_data = json.load(f)
        
        chunks = document_data.get("chunks", [])
        embeddings = document_data.get("embeddings", [])
        metadata = document_data.get("metadata", {})
        
        if not chunks or not embeddings or len(chunks) != len(embeddings):
            return []
        
        # Calculate similarity scores
        results = []
        for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
            similarity = self._cosine_similarity(query_embedding, embedding)
            
            results.append({
                "document_id": document_id,
                "chunk_index": i,
                "chunk_text": chunk,
                "similarity": similarity,
                "metadata": metadata
            })
        
        # Sort by similarity
        results.sort(key=lambda x: x["similarity"], reverse=True)
        
        return results[:self.max_chunks]
    
    def _cosine_similarity(self, vec1: List[float], vec2: List[float]) -> float:
        """
        Calculate cosine similarity between two vectors.
        
        Args:
            vec1: First vector
            vec2: Second vector
            
        Returns:
            Cosine similarity score
        """
        vec1 = np.array(vec1)
        vec2 = np.array(vec2)
        
        dot_product = np.dot(vec1, vec2)
        norm1 = np.linalg.norm(vec1)
        norm2 = np.linalg.norm(vec2)
        
        if norm1 == 0 or norm2 == 0:
            return 0.0
        
        return dot_product / (norm1 * norm2)
    
    def _load_document_index(self) -> Dict[str, Dict[str, Any]]:
        """
        Load the document index from disk.
        
        Returns:
            Dictionary of document IDs to metadata
        """
        if os.path.exists(self.document_index_path):
            try:
                with open(self.document_index_path, 'r', encoding='utf-8') as f:
                    return json.load(f)
            except Exception as e:
                print(f"Error loading document index: {str(e)}")
        
        return {}