Mr-Geo commited on
Commit
f2e0937
·
verified ·
1 Parent(s): 7754c6e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -11
app.py CHANGED
@@ -14,6 +14,7 @@ from pathlib import Path
14
  import tempfile
15
  import spaces # for ZeroGPU
16
  import requests # for IP geolocation
 
17
 
18
  # Load environment variables and initialize clients
19
  load_dotenv()
@@ -98,23 +99,51 @@ def initialize_system_sync():
98
  # Use the same ChromaDB client that was loaded from HF
99
  chroma_client = db # Use the global db instance we created
100
 
101
- # Initialize the embedding function
102
- embedding_function = chromadb.utils.embedding_functions.SentenceTransformerEmbeddingFunction(
103
- model_name="sentence-transformers/all-mpnet-base-v2",
104
- device=DEVICE
105
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
  # Get the collection
108
  print("Getting collection...")
109
  collection = chroma_client.get_collection(name="website_content", embedding_function=embedding_function)
110
  print(f"Found {collection.count()} documents in collection")
111
 
112
- # Initialize the reranker and explicitly move to GPU if available
113
- print("\nInitialising Cross-Encoder...")
114
- reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', device=DEVICE)
115
- if torch.cuda.is_available():
116
- reranker.model.to('cuda') # Ensure model is on GPU
117
- print("Reranker moved to GPU")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  @spaces.GPU(memory="40g")
120
  def initialize_system():
 
14
  import tempfile
15
  import spaces # for ZeroGPU
16
  import requests # for IP geolocation
17
+ import time
18
 
19
  # Load environment variables and initialize clients
20
  load_dotenv()
 
99
  # Use the same ChromaDB client that was loaded from HF
100
  chroma_client = db # Use the global db instance we created
101
 
102
+ # Initialize the embedding function with retries
103
+ max_retries = 3
104
+ retry_delay = 5 # seconds
105
+
106
+ for attempt in range(max_retries):
107
+ try:
108
+ print(f"\nAttempt {attempt + 1} of {max_retries} to initialize embedding function...")
109
+ embedding_function = chromadb.utils.embedding_functions.SentenceTransformerEmbeddingFunction(
110
+ model_name="sentence-transformers/all-mpnet-base-v2",
111
+ device=DEVICE
112
+ )
113
+ break
114
+ except Exception as e:
115
+ print(f"Error initializing embedding function: {str(e)}")
116
+ if attempt < max_retries - 1:
117
+ print(f"Retrying in {retry_delay} seconds...")
118
+ time.sleep(retry_delay)
119
+ else:
120
+ raise RuntimeError("Failed to initialize embedding function after multiple attempts")
121
 
122
  # Get the collection
123
  print("Getting collection...")
124
  collection = chroma_client.get_collection(name="website_content", embedding_function=embedding_function)
125
  print(f"Found {collection.count()} documents in collection")
126
 
127
+ # Initialize the reranker with retries
128
+ for attempt in range(max_retries):
129
+ try:
130
+ print(f"\nAttempt {attempt + 1} of {max_retries} to initialize reranker...")
131
+ reranker = CrossEncoder(
132
+ 'cross-encoder/ms-marco-MiniLM-L-6-v2',
133
+ device=DEVICE,
134
+ max_length=512 # Add explicit max_length
135
+ )
136
+ if torch.cuda.is_available():
137
+ reranker.model.to('cuda')
138
+ print("Reranker moved to GPU")
139
+ break
140
+ except Exception as e:
141
+ print(f"Error initializing reranker: {str(e)}")
142
+ if attempt < max_retries - 1:
143
+ print(f"Retrying in {retry_delay} seconds...")
144
+ time.sleep(retry_delay)
145
+ else:
146
+ raise RuntimeError("Failed to initialize reranker after multiple attempts")
147
 
148
  @spaces.GPU(memory="40g")
149
  def initialize_system():