nishantgaurav23 commited on
Commit
526c0d9
·
verified ·
1 Parent(s): 88e53d1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -23
app.py CHANGED
@@ -123,17 +123,21 @@ def check_environment():
123
 
124
  class SentenceTransformerRetriever:
125
  def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2", cache_dir: str = "embeddings_cache"):
126
- self.device = torch.device("cpu")
127
- self.model_name = model_name
128
- self.cache_dir = cache_dir
129
- self.cache_file = "embeddings.pkl"
130
- self.doc_embeddings = None
131
- os.makedirs(cache_dir, exist_ok=True)
132
- # Initialize model using cached method
133
- self.model = self._load_model(model_name) # Pass model_name as argument
 
 
 
 
134
 
135
  @st.cache_resource(show_spinner=False)
136
- def _load_model(_self, _model_name: str): # Changed to _self and added _model_name
137
  """Load and cache the sentence transformer model"""
138
  try:
139
  with warnings.catch_warnings():
@@ -144,11 +148,17 @@ class SentenceTransformerRetriever:
144
  if not isinstance(test_embedding, torch.Tensor):
145
  raise ValueError("Model initialization failed")
146
  return model
 
 
 
 
147
  def get_cache_path(self, data_folder: str = None) -> str:
 
148
  return os.path.join(self.cache_dir, self.cache_file)
149
 
150
  @log_function
151
  def save_cache(self, data_folder: str, cache_data: dict):
 
152
  try:
153
  cache_path = self.get_cache_path()
154
  if os.path.exists(cache_path):
@@ -162,7 +172,8 @@ class SentenceTransformerRetriever:
162
 
163
  @log_function
164
  @st.cache_data
165
- def load_cache(_self, _data_folder: str = None) -> Optional[Dict]: # Changed to _self and _data_folder
 
166
  try:
167
  cache_path = _self.get_cache_path()
168
  if os.path.exists(cache_path):
@@ -179,6 +190,7 @@ class SentenceTransformerRetriever:
179
 
180
  @log_function
181
  def encode(self, texts: List[str], batch_size: int = 32) -> torch.Tensor:
 
182
  try:
183
  embeddings = self.model.encode(texts, batch_size=batch_size, convert_to_tensor=True, show_progress_bar=True)
184
  return F.normalize(embeddings, p=2, dim=1)
@@ -188,23 +200,29 @@ class SentenceTransformerRetriever:
188
 
189
  @log_function
190
  def store_embeddings(self, embeddings: torch.Tensor):
 
191
  self.doc_embeddings = embeddings
192
 
193
  @log_function
194
  def search(self, query_embedding: torch.Tensor, k: int, documents: List[str]):
195
- if self.doc_embeddings is None:
196
- raise ValueError("No document embeddings stored!")
197
-
198
- similarities = F.cosine_similarity(query_embedding, self.doc_embeddings)
199
- k = min(k, len(documents))
200
- scores, indices = torch.topk(similarities, k=k)
201
-
202
- logging.info(f"\nSimilarity Stats:")
203
- logging.info(f"Max similarity: {similarities.max().item():.4f}")
204
- logging.info(f"Mean similarity: {similarities.mean().item():.4f}")
205
- logging.info(f"Selected similarities: {scores.tolist()}")
206
-
207
- return indices.cpu(), scores.cpu()
 
 
 
 
 
208
 
209
  class RAGPipeline:
210
  def __init__(self, data_folder: str, k: int = 5):
 
123
 
124
  class SentenceTransformerRetriever:
125
  def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2", cache_dir: str = "embeddings_cache"):
126
+ try:
127
+ self.device = torch.device("cpu")
128
+ self.model_name = model_name
129
+ self.cache_dir = cache_dir
130
+ self.cache_file = "embeddings.pkl"
131
+ self.doc_embeddings = None
132
+ os.makedirs(cache_dir, exist_ok=True)
133
+ # Initialize model using cached method
134
+ self.model = self._load_model(model_name)
135
+ except Exception as e:
136
+ logging.error(f"Error initializing SentenceTransformerRetriever: {str(e)}")
137
+ raise
138
 
139
  @st.cache_resource(show_spinner=False)
140
+ def _load_model(_self, _model_name: str):
141
  """Load and cache the sentence transformer model"""
142
  try:
143
  with warnings.catch_warnings():
 
148
  if not isinstance(test_embedding, torch.Tensor):
149
  raise ValueError("Model initialization failed")
150
  return model
151
+ except Exception as e:
152
+ logging.error(f"Error loading model: {str(e)}")
153
+ raise
154
+
155
  def get_cache_path(self, data_folder: str = None) -> str:
156
+ """Get the path for cache file"""
157
  return os.path.join(self.cache_dir, self.cache_file)
158
 
159
  @log_function
160
  def save_cache(self, data_folder: str, cache_data: dict):
161
+ """Save embeddings to cache"""
162
  try:
163
  cache_path = self.get_cache_path()
164
  if os.path.exists(cache_path):
 
172
 
173
  @log_function
174
  @st.cache_data
175
+ def load_cache(_self, _data_folder: str = None) -> Optional[Dict]:
176
+ """Load embeddings from cache"""
177
  try:
178
  cache_path = _self.get_cache_path()
179
  if os.path.exists(cache_path):
 
190
 
191
  @log_function
192
  def encode(self, texts: List[str], batch_size: int = 32) -> torch.Tensor:
193
+ """Encode texts into embeddings"""
194
  try:
195
  embeddings = self.model.encode(texts, batch_size=batch_size, convert_to_tensor=True, show_progress_bar=True)
196
  return F.normalize(embeddings, p=2, dim=1)
 
200
 
201
  @log_function
202
  def store_embeddings(self, embeddings: torch.Tensor):
203
+ """Store embeddings in memory"""
204
  self.doc_embeddings = embeddings
205
 
206
  @log_function
207
  def search(self, query_embedding: torch.Tensor, k: int, documents: List[str]):
208
+ """Search for similar documents"""
209
+ try:
210
+ if self.doc_embeddings is None:
211
+ raise ValueError("No document embeddings stored!")
212
+
213
+ similarities = F.cosine_similarity(query_embedding, self.doc_embeddings)
214
+ k = min(k, len(documents))
215
+ scores, indices = torch.topk(similarities, k=k)
216
+
217
+ logging.info(f"\nSimilarity Stats:")
218
+ logging.info(f"Max similarity: {similarities.max().item():.4f}")
219
+ logging.info(f"Mean similarity: {similarities.mean().item():.4f}")
220
+ logging.info(f"Selected similarities: {scores.tolist()}")
221
+
222
+ return indices.cpu(), scores.cpu()
223
+ except Exception as e:
224
+ logging.error(f"Error in search: {str(e)}")
225
+ raise
226
 
227
  class RAGPipeline:
228
  def __init__(self, data_folder: str, k: int = 5):