File size: 5,444 Bytes
b34efa5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
"""
Retriever module for Norwegian RAG chatbot.
Retrieves relevant document chunks based on query embeddings.
"""
import os
import json
import numpy as np
from typing import List, Dict, Any, Optional, Tuple, Union
from ..api.huggingface_api import HuggingFaceAPI
from ..api.config import MAX_CHUNKS_TO_RETRIEVE, SIMILARITY_THRESHOLD
class Retriever:
"""
Retrieves relevant document chunks based on query embeddings.
Uses cosine similarity to find the most relevant chunks.
"""
def __init__(
self,
api_client: Optional[HuggingFaceAPI] = None,
processed_dir: str = "/home/ubuntu/chatbot_project/data/processed",
max_chunks: int = MAX_CHUNKS_TO_RETRIEVE,
similarity_threshold: float = SIMILARITY_THRESHOLD
):
"""
Initialize the retriever.
Args:
api_client: HuggingFaceAPI client for generating embeddings
processed_dir: Directory containing processed documents
max_chunks: Maximum number of chunks to retrieve
similarity_threshold: Minimum similarity score for retrieval
"""
self.api_client = api_client or HuggingFaceAPI()
self.processed_dir = processed_dir
self.max_chunks = max_chunks
self.similarity_threshold = similarity_threshold
# Load document index
self.document_index_path = os.path.join(self.processed_dir, "document_index.json")
self.document_index = self._load_document_index()
def retrieve(self, query: str) -> List[Dict[str, Any]]:
"""
Retrieve relevant document chunks for a query.
Args:
query: User query
Returns:
List of retrieved chunks with metadata
"""
# Generate embedding for the query
query_embedding = self.api_client.generate_embeddings(query)[0]
# Find relevant chunks across all documents
all_results = []
for doc_id in self.document_index:
try:
# Load document data
doc_results = self._retrieve_from_document(doc_id, query_embedding)
all_results.extend(doc_results)
except Exception as e:
print(f"Error retrieving from document {doc_id}: {str(e)}")
# Sort all results by similarity score
all_results.sort(key=lambda x: x["similarity"], reverse=True)
# Return top results above threshold
return [
result for result in all_results[:self.max_chunks]
if result["similarity"] >= self.similarity_threshold
]
def _retrieve_from_document(
self,
document_id: str,
query_embedding: List[float]
) -> List[Dict[str, Any]]:
"""
Retrieve relevant chunks from a specific document.
Args:
document_id: Document ID
query_embedding: Query embedding vector
Returns:
List of retrieved chunks with metadata
"""
document_path = os.path.join(self.processed_dir, f"{document_id}.json")
if not os.path.exists(document_path):
return []
# Load document data
with open(document_path, 'r', encoding='utf-8') as f:
document_data = json.load(f)
chunks = document_data.get("chunks", [])
embeddings = document_data.get("embeddings", [])
metadata = document_data.get("metadata", {})
if not chunks or not embeddings or len(chunks) != len(embeddings):
return []
# Calculate similarity scores
results = []
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
similarity = self._cosine_similarity(query_embedding, embedding)
results.append({
"document_id": document_id,
"chunk_index": i,
"chunk_text": chunk,
"similarity": similarity,
"metadata": metadata
})
# Sort by similarity
results.sort(key=lambda x: x["similarity"], reverse=True)
return results[:self.max_chunks]
def _cosine_similarity(self, vec1: List[float], vec2: List[float]) -> float:
"""
Calculate cosine similarity between two vectors.
Args:
vec1: First vector
vec2: Second vector
Returns:
Cosine similarity score
"""
vec1 = np.array(vec1)
vec2 = np.array(vec2)
dot_product = np.dot(vec1, vec2)
norm1 = np.linalg.norm(vec1)
norm2 = np.linalg.norm(vec2)
if norm1 == 0 or norm2 == 0:
return 0.0
return dot_product / (norm1 * norm2)
def _load_document_index(self) -> Dict[str, Dict[str, Any]]:
"""
Load the document index from disk.
Returns:
Dictionary of document IDs to metadata
"""
if os.path.exists(self.document_index_path):
try:
with open(self.document_index_path, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
print(f"Error loading document index: {str(e)}")
return {}
|