Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -123,17 +123,21 @@ def check_environment():
|
|
123 |
|
124 |
class SentenceTransformerRetriever:
|
125 |
def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2", cache_dir: str = "embeddings_cache"):
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
|
|
|
|
|
|
|
|
134 |
|
135 |
@st.cache_resource(show_spinner=False)
|
136 |
-
def _load_model(_self, _model_name: str):
|
137 |
"""Load and cache the sentence transformer model"""
|
138 |
try:
|
139 |
with warnings.catch_warnings():
|
@@ -144,11 +148,17 @@ class SentenceTransformerRetriever:
|
|
144 |
if not isinstance(test_embedding, torch.Tensor):
|
145 |
raise ValueError("Model initialization failed")
|
146 |
return model
|
|
|
|
|
|
|
|
|
147 |
def get_cache_path(self, data_folder: str = None) -> str:
|
|
|
148 |
return os.path.join(self.cache_dir, self.cache_file)
|
149 |
|
150 |
@log_function
|
151 |
def save_cache(self, data_folder: str, cache_data: dict):
|
|
|
152 |
try:
|
153 |
cache_path = self.get_cache_path()
|
154 |
if os.path.exists(cache_path):
|
@@ -162,7 +172,8 @@ class SentenceTransformerRetriever:
|
|
162 |
|
163 |
@log_function
|
164 |
@st.cache_data
|
165 |
-
def load_cache(_self, _data_folder: str = None) -> Optional[Dict]:
|
|
|
166 |
try:
|
167 |
cache_path = _self.get_cache_path()
|
168 |
if os.path.exists(cache_path):
|
@@ -179,6 +190,7 @@ class SentenceTransformerRetriever:
|
|
179 |
|
180 |
@log_function
|
181 |
def encode(self, texts: List[str], batch_size: int = 32) -> torch.Tensor:
|
|
|
182 |
try:
|
183 |
embeddings = self.model.encode(texts, batch_size=batch_size, convert_to_tensor=True, show_progress_bar=True)
|
184 |
return F.normalize(embeddings, p=2, dim=1)
|
@@ -188,23 +200,29 @@ class SentenceTransformerRetriever:
|
|
188 |
|
189 |
@log_function
|
190 |
def store_embeddings(self, embeddings: torch.Tensor):
|
|
|
191 |
self.doc_embeddings = embeddings
|
192 |
|
193 |
@log_function
|
194 |
def search(self, query_embedding: torch.Tensor, k: int, documents: List[str]):
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
|
|
|
|
|
|
|
|
|
|
208 |
|
209 |
class RAGPipeline:
|
210 |
def __init__(self, data_folder: str, k: int = 5):
|
|
|
123 |
|
124 |
class SentenceTransformerRetriever:
|
125 |
def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2", cache_dir: str = "embeddings_cache"):
|
126 |
+
try:
|
127 |
+
self.device = torch.device("cpu")
|
128 |
+
self.model_name = model_name
|
129 |
+
self.cache_dir = cache_dir
|
130 |
+
self.cache_file = "embeddings.pkl"
|
131 |
+
self.doc_embeddings = None
|
132 |
+
os.makedirs(cache_dir, exist_ok=True)
|
133 |
+
# Initialize model using cached method
|
134 |
+
self.model = self._load_model(model_name)
|
135 |
+
except Exception as e:
|
136 |
+
logging.error(f"Error initializing SentenceTransformerRetriever: {str(e)}")
|
137 |
+
raise
|
138 |
|
139 |
@st.cache_resource(show_spinner=False)
|
140 |
+
def _load_model(_self, _model_name: str):
|
141 |
"""Load and cache the sentence transformer model"""
|
142 |
try:
|
143 |
with warnings.catch_warnings():
|
|
|
148 |
if not isinstance(test_embedding, torch.Tensor):
|
149 |
raise ValueError("Model initialization failed")
|
150 |
return model
|
151 |
+
except Exception as e:
|
152 |
+
logging.error(f"Error loading model: {str(e)}")
|
153 |
+
raise
|
154 |
+
|
155 |
def get_cache_path(self, data_folder: str = None) -> str:
|
156 |
+
"""Get the path for cache file"""
|
157 |
return os.path.join(self.cache_dir, self.cache_file)
|
158 |
|
159 |
@log_function
|
160 |
def save_cache(self, data_folder: str, cache_data: dict):
|
161 |
+
"""Save embeddings to cache"""
|
162 |
try:
|
163 |
cache_path = self.get_cache_path()
|
164 |
if os.path.exists(cache_path):
|
|
|
172 |
|
173 |
@log_function
|
174 |
@st.cache_data
|
175 |
+
def load_cache(_self, _data_folder: str = None) -> Optional[Dict]:
|
176 |
+
"""Load embeddings from cache"""
|
177 |
try:
|
178 |
cache_path = _self.get_cache_path()
|
179 |
if os.path.exists(cache_path):
|
|
|
190 |
|
191 |
@log_function
|
192 |
def encode(self, texts: List[str], batch_size: int = 32) -> torch.Tensor:
|
193 |
+
"""Encode texts into embeddings"""
|
194 |
try:
|
195 |
embeddings = self.model.encode(texts, batch_size=batch_size, convert_to_tensor=True, show_progress_bar=True)
|
196 |
return F.normalize(embeddings, p=2, dim=1)
|
|
|
200 |
|
201 |
@log_function
|
202 |
def store_embeddings(self, embeddings: torch.Tensor):
|
203 |
+
"""Store embeddings in memory"""
|
204 |
self.doc_embeddings = embeddings
|
205 |
|
206 |
@log_function
|
207 |
def search(self, query_embedding: torch.Tensor, k: int, documents: List[str]):
|
208 |
+
"""Search for similar documents"""
|
209 |
+
try:
|
210 |
+
if self.doc_embeddings is None:
|
211 |
+
raise ValueError("No document embeddings stored!")
|
212 |
+
|
213 |
+
similarities = F.cosine_similarity(query_embedding, self.doc_embeddings)
|
214 |
+
k = min(k, len(documents))
|
215 |
+
scores, indices = torch.topk(similarities, k=k)
|
216 |
+
|
217 |
+
logging.info(f"\nSimilarity Stats:")
|
218 |
+
logging.info(f"Max similarity: {similarities.max().item():.4f}")
|
219 |
+
logging.info(f"Mean similarity: {similarities.mean().item():.4f}")
|
220 |
+
logging.info(f"Selected similarities: {scores.tolist()}")
|
221 |
+
|
222 |
+
return indices.cpu(), scores.cpu()
|
223 |
+
except Exception as e:
|
224 |
+
logging.error(f"Error in search: {str(e)}")
|
225 |
+
raise
|
226 |
|
227 |
class RAGPipeline:
|
228 |
def __init__(self, data_folder: str, k: int = 5):
|