nishantgaurav23 commited on
Commit
88e53d1
·
verified ·
1 Parent(s): f30497e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -23
app.py CHANGED
@@ -130,20 +130,20 @@ class SentenceTransformerRetriever:
130
  self.doc_embeddings = None
131
  os.makedirs(cache_dir, exist_ok=True)
132
  # Initialize model using cached method
133
- self.model = self._load_model()
134
 
135
  @st.cache_resource(show_spinner=False)
136
- def _load_model(self):
137
  """Load and cache the sentence transformer model"""
138
- with warnings.catch_warnings():
139
- warnings.simplefilter("ignore")
140
- model = SentenceTransformer(self.model_name, device="cpu")
141
- # Verify model is loaded correctly
142
- test_embedding = model.encode("test", convert_to_tensor=True)
143
- if not isinstance(test_embedding, torch.Tensor):
144
- raise ValueError("Model initialization failed")
145
- return model
146
-
147
  def get_cache_path(self, data_folder: str = None) -> str:
148
  return os.path.join(self.cache_dir, self.cache_file)
149
 
@@ -162,9 +162,9 @@ class SentenceTransformerRetriever:
162
 
163
  @log_function
164
  @st.cache_data
165
- def load_cache(self, data_folder: str = None) -> Optional[Dict]:
166
  try:
167
- cache_path = self.get_cache_path()
168
  if os.path.exists(cache_path):
169
  with open(cache_path, 'rb') as f:
170
  logging.info(f"Loading cache from: {cache_path}")
@@ -207,7 +207,7 @@ class SentenceTransformerRetriever:
207
  return indices.cpu(), scores.cpu()
208
 
209
  class RAGPipeline:
210
- def __init__(self, data_folder: str, k: int = 5):
211
  self.data_folder = data_folder
212
  self.k = k
213
  self.retriever = SentenceTransformerRetriever()
@@ -218,20 +218,20 @@ class RAGPipeline:
218
  self._initialize_model()
219
 
220
  @st.cache_resource(show_spinner=False)
221
- def _initialize_model(self):
222
  """Initialize the model with proper error handling and verification"""
223
  try:
224
- os.makedirs(os.path.dirname(self.model_path), exist_ok=True)
225
 
226
- if not os.path.exists(self.model_path):
227
  direct_url = "https://huggingface.co/TheBloke/Mistral-7B-v0.1-GGUF/resolve/main/mistral-7b-v0.1.Q4_K_M.gguf"
228
- download_file_with_progress(direct_url, self.model_path)
229
 
230
- if not os.path.exists(self.model_path):
231
- raise FileNotFoundError(f"Model file {self.model_path} not found after download attempts")
232
 
233
- if os.path.getsize(self.model_path) < 1000000: # Less than 1MB
234
- os.remove(self.model_path)
235
  raise ValueError("Downloaded model file is too small, likely corrupted")
236
 
237
  llm_config = {
@@ -242,7 +242,7 @@ class RAGPipeline:
242
  "verbose": False
243
  }
244
 
245
- self.llm = Llama(model_path=self.model_path, **llm_config)
246
  st.success("Model loaded successfully!")
247
 
248
  except Exception as e:
 
130
  self.doc_embeddings = None
131
  os.makedirs(cache_dir, exist_ok=True)
132
  # Initialize model using cached method
133
+ self.model = self._load_model(model_name) # Pass model_name as argument
134
 
135
  @st.cache_resource(show_spinner=False)
136
+ def _load_model(_self, _model_name: str): # Changed to _self and added _model_name
137
  """Load and cache the sentence transformer model"""
138
+ try:
139
+ with warnings.catch_warnings():
140
+ warnings.simplefilter("ignore")
141
+ model = SentenceTransformer(_model_name, device="cpu")
142
+ # Verify model is loaded correctly
143
+ test_embedding = model.encode("test", convert_to_tensor=True)
144
+ if not isinstance(test_embedding, torch.Tensor):
145
+ raise ValueError("Model initialization failed")
146
+ return model
147
  def get_cache_path(self, data_folder: str = None) -> str:
148
  return os.path.join(self.cache_dir, self.cache_file)
149
 
 
162
 
163
  @log_function
164
  @st.cache_data
165
+ def load_cache(_self, _data_folder: str = None) -> Optional[Dict]: # Changed to _self and _data_folder
166
  try:
167
+ cache_path = _self.get_cache_path()
168
  if os.path.exists(cache_path):
169
  with open(cache_path, 'rb') as f:
170
  logging.info(f"Loading cache from: {cache_path}")
 
207
  return indices.cpu(), scores.cpu()
208
 
209
  class RAGPipeline:
210
+ def __init__(self, data_folder: str, k: int = 5):
211
  self.data_folder = data_folder
212
  self.k = k
213
  self.retriever = SentenceTransformerRetriever()
 
218
  self._initialize_model()
219
 
220
  @st.cache_resource(show_spinner=False)
221
+ def _initialize_model(_self): # Changed to _self
222
  """Initialize the model with proper error handling and verification"""
223
  try:
224
+ os.makedirs(os.path.dirname(_self.model_path), exist_ok=True)
225
 
226
+ if not os.path.exists(_self.model_path):
227
  direct_url = "https://huggingface.co/TheBloke/Mistral-7B-v0.1-GGUF/resolve/main/mistral-7b-v0.1.Q4_K_M.gguf"
228
+ download_file_with_progress(direct_url, _self.model_path)
229
 
230
+ if not os.path.exists(_self.model_path):
231
+ raise FileNotFoundError(f"Model file {_self.model_path} not found after download attempts")
232
 
233
+ if os.path.getsize(_self.model_path) < 1000000: # Less than 1MB
234
+ os.remove(_self.model_path)
235
  raise ValueError("Downloaded model file is too small, likely corrupted")
236
 
237
  llm_config = {
 
242
  "verbose": False
243
  }
244
 
245
+ _self.llm = Llama(model_path=_self.model_path, **llm_config)
246
  st.success("Model loaded successfully!")
247
 
248
  except Exception as e: