Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -14,6 +14,7 @@ from pathlib import Path
|
|
14 |
import tempfile
|
15 |
import spaces # for ZeroGPU
|
16 |
import requests # for IP geolocation
|
|
|
17 |
|
18 |
# Load environment variables and initialize clients
|
19 |
load_dotenv()
|
@@ -98,23 +99,51 @@ def initialize_system_sync():
|
|
98 |
# Use the same ChromaDB client that was loaded from HF
|
99 |
chroma_client = db # Use the global db instance we created
|
100 |
|
101 |
-
# Initialize the embedding function
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
|
107 |
# Get the collection
|
108 |
print("Getting collection...")
|
109 |
collection = chroma_client.get_collection(name="website_content", embedding_function=embedding_function)
|
110 |
print(f"Found {collection.count()} documents in collection")
|
111 |
|
112 |
-
# Initialize the reranker
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
|
119 |
@spaces.GPU(memory="40g")
|
120 |
def initialize_system():
|
|
|
14 |
import tempfile
|
15 |
import spaces # for ZeroGPU
|
16 |
import requests # for IP geolocation
|
17 |
+
import time
|
18 |
|
19 |
# Load environment variables and initialize clients
|
20 |
load_dotenv()
|
|
|
99 |
# Use the same ChromaDB client that was loaded from HF
|
100 |
chroma_client = db # Use the global db instance we created
|
101 |
|
102 |
+
# Initialize the embedding function with retries
|
103 |
+
max_retries = 3
|
104 |
+
retry_delay = 5 # seconds
|
105 |
+
|
106 |
+
for attempt in range(max_retries):
|
107 |
+
try:
|
108 |
+
print(f"\nAttempt {attempt + 1} of {max_retries} to initialize embedding function...")
|
109 |
+
embedding_function = chromadb.utils.embedding_functions.SentenceTransformerEmbeddingFunction(
|
110 |
+
model_name="sentence-transformers/all-mpnet-base-v2",
|
111 |
+
device=DEVICE
|
112 |
+
)
|
113 |
+
break
|
114 |
+
except Exception as e:
|
115 |
+
print(f"Error initializing embedding function: {str(e)}")
|
116 |
+
if attempt < max_retries - 1:
|
117 |
+
print(f"Retrying in {retry_delay} seconds...")
|
118 |
+
time.sleep(retry_delay)
|
119 |
+
else:
|
120 |
+
raise RuntimeError("Failed to initialize embedding function after multiple attempts")
|
121 |
|
122 |
# Get the collection
|
123 |
print("Getting collection...")
|
124 |
collection = chroma_client.get_collection(name="website_content", embedding_function=embedding_function)
|
125 |
print(f"Found {collection.count()} documents in collection")
|
126 |
|
127 |
+
# Initialize the reranker with retries
|
128 |
+
for attempt in range(max_retries):
|
129 |
+
try:
|
130 |
+
print(f"\nAttempt {attempt + 1} of {max_retries} to initialize reranker...")
|
131 |
+
reranker = CrossEncoder(
|
132 |
+
'cross-encoder/ms-marco-MiniLM-L-6-v2',
|
133 |
+
device=DEVICE,
|
134 |
+
max_length=512 # Add explicit max_length
|
135 |
+
)
|
136 |
+
if torch.cuda.is_available():
|
137 |
+
reranker.model.to('cuda')
|
138 |
+
print("Reranker moved to GPU")
|
139 |
+
break
|
140 |
+
except Exception as e:
|
141 |
+
print(f"Error initializing reranker: {str(e)}")
|
142 |
+
if attempt < max_retries - 1:
|
143 |
+
print(f"Retrying in {retry_delay} seconds...")
|
144 |
+
time.sleep(retry_delay)
|
145 |
+
else:
|
146 |
+
raise RuntimeError("Failed to initialize reranker after multiple attempts")
|
147 |
|
148 |
@spaces.GPU(memory="40g")
|
149 |
def initialize_system():
|