Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -12,6 +12,8 @@ from datetime import datetime
|
|
12 |
from huggingface_hub import hf_hub_download, HfApi, CommitOperationAdd
|
13 |
from pathlib import Path
|
14 |
import tempfile
|
|
|
|
|
15 |
|
16 |
# Load environment variables and initialize clients
|
17 |
load_dotenv()
|
@@ -74,9 +76,19 @@ else:
|
|
74 |
# Load from Hugging Face dataset
|
75 |
db = load_chroma_db()
|
76 |
|
|
|
77 |
def initialize_system():
|
78 |
"""Initialize the system components"""
|
79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
# Use the same ChromaDB client that was loaded from HF
|
81 |
chroma_client = db # Use the global db instance we created
|
82 |
|
@@ -91,12 +103,16 @@ def initialize_system():
|
|
91 |
collection = chroma_client.get_collection(name="website_content", embedding_function=embedding_function)
|
92 |
print(f"Found {collection.count()} documents in collection")
|
93 |
|
94 |
-
# Initialize the reranker
|
95 |
print("\nInitialising Cross-Encoder...")
|
96 |
reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', device=DEVICE)
|
|
|
|
|
|
|
97 |
|
98 |
return chroma_client, collection, reranker
|
99 |
|
|
|
100 |
def get_context(message):
|
101 |
results = collection.query(
|
102 |
query_texts=[message],
|
@@ -165,8 +181,28 @@ def get_context(message):
|
|
165 |
print(f"\nFinal context length: {total_chars} characters")
|
166 |
return context
|
167 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
def log_conversation(timestamp, user_message, assistant_response, model_name, context, error=None, client_ip=None):
|
169 |
"""Log conversation details to JSON file - local directory or HuggingFace Dataset repository"""
|
|
|
|
|
|
|
170 |
# Create a log entry
|
171 |
log_entry = {
|
172 |
"timestamp": timestamp,
|
@@ -175,7 +211,8 @@ def log_conversation(timestamp, user_message, assistant_response, model_name, co
|
|
175 |
"assistant_response": assistant_response,
|
176 |
"context": context,
|
177 |
"error": str(error) if error else None,
|
178 |
-
"client_ip": client_ip
|
|
|
179 |
}
|
180 |
|
181 |
# Check if running on Hugging Face Spaces
|
@@ -285,9 +322,7 @@ def chat_response(message, history, model_name, request: gr.Request):
|
|
285 |
# Add history first without context
|
286 |
if history:
|
287 |
for h in history:
|
288 |
-
messages.append({"role": "
|
289 |
-
if h[1]: # If there's a response
|
290 |
-
messages.append({"role": "assistant", "content": str(h[1])})
|
291 |
|
292 |
# Add current message
|
293 |
messages.append({"role": "user", "content": str(message)})
|
@@ -408,7 +443,7 @@ if __name__ == "__main__":
|
|
408 |
info="Please try out the other AI models to use for responses (all LLMs are running on [GroqCloud](https://groq.com/groqrack/)) π"
|
409 |
)
|
410 |
|
411 |
-
chatbot = gr.Chatbot(height=600)
|
412 |
with gr.Row(equal_height=True):
|
413 |
msg = gr.Textbox(
|
414 |
placeholder="What would you like to know? Or choose an example question...β",
|
@@ -431,13 +466,17 @@ if __name__ == "__main__":
|
|
431 |
)
|
432 |
|
433 |
def user(user_message, history):
|
434 |
-
|
|
|
435 |
|
436 |
def bot(history, model_name, request: gr.Request):
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
|
|
|
|
|
|
|
441 |
|
442 |
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
|
443 |
bot, [chatbot, model_selector], chatbot
|
@@ -446,7 +485,7 @@ if __name__ == "__main__":
|
|
446 |
bot, [chatbot, model_selector], chatbot
|
447 |
)
|
448 |
|
449 |
-
clear.click(lambda:
|
450 |
gr.Markdown("<footer style='text-align: center; margin-top: 5px;'>π€ AI-generated content; while the Chat Assistant strives for accuracy, errors may occur; please thoroughly check critical information π€<br>β οΈ <strong><u>Disclaimer: This system was not produced by the British Antarctic Survey (BAS) and AI generated output does not reflect the views or opinions of BAS</u></strong> β οΈ <br>(just a bit of fun :D)</footer>")
|
451 |
demo.launch(
|
452 |
server_name="0.0.0.0",
|
|
|
12 |
from huggingface_hub import hf_hub_download, HfApi, CommitOperationAdd
|
13 |
from pathlib import Path
|
14 |
import tempfile
|
15 |
+
import spaces # for ZeroGPU
|
16 |
+
import requests # for IP geolocation
|
17 |
|
18 |
# Load environment variables and initialize clients
|
19 |
load_dotenv()
|
|
|
76 |
# Load from Hugging Face dataset
|
77 |
db = load_chroma_db()
|
78 |
|
79 |
+
@spaces.GPU(memory="40g") # Add GPU decorator for initialize_system
|
80 |
def initialize_system():
|
81 |
"""Initialize the system components"""
|
82 |
|
83 |
+
# Add GPU diagnostics
|
84 |
+
print("\n=== GPU Diagnostics ===")
|
85 |
+
print(f"CUDA available: {torch.cuda.is_available()}")
|
86 |
+
if torch.cuda.is_available():
|
87 |
+
print(f"Current CUDA device: {torch.cuda.current_device()}")
|
88 |
+
print(f"Device name: {torch.cuda.get_device_name()}")
|
89 |
+
print(f"Device memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
|
90 |
+
print("=====================\n")
|
91 |
+
|
92 |
# Use the same ChromaDB client that was loaded from HF
|
93 |
chroma_client = db # Use the global db instance we created
|
94 |
|
|
|
103 |
collection = chroma_client.get_collection(name="website_content", embedding_function=embedding_function)
|
104 |
print(f"Found {collection.count()} documents in collection")
|
105 |
|
106 |
+
# Initialize the reranker and explicitly move to GPU if available
|
107 |
print("\nInitialising Cross-Encoder...")
|
108 |
reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', device=DEVICE)
|
109 |
+
if torch.cuda.is_available():
|
110 |
+
reranker.model.to('cuda') # Ensure model is on GPU
|
111 |
+
print("Reranker moved to GPU")
|
112 |
|
113 |
return chroma_client, collection, reranker
|
114 |
|
115 |
+
@spaces.GPU(memory="40g") # Add GPU decorator for get_context
|
116 |
def get_context(message):
|
117 |
results = collection.query(
|
118 |
query_texts=[message],
|
|
|
181 |
print(f"\nFinal context length: {total_chars} characters")
|
182 |
return context
|
183 |
|
184 |
+
def get_ip_info(ip_address):
|
185 |
+
"""Get geolocation info for an IP address"""
|
186 |
+
if ip_address in ['127.0.0.1', 'localhost', '0.0.0.0']:
|
187 |
+
return {"country": "Local", "city": "Local"}
|
188 |
+
try:
|
189 |
+
response = requests.get(f'https://ipapi.co/{ip_address}/json/')
|
190 |
+
if response.status_code == 200:
|
191 |
+
data = response.json()
|
192 |
+
return {
|
193 |
+
"country": data.get("country_name", "Unknown"),
|
194 |
+
"city": data.get("city", "Unknown"),
|
195 |
+
"region": data.get("region", "Unknown")
|
196 |
+
}
|
197 |
+
except Exception as e:
|
198 |
+
print(f"Error getting IP info: {str(e)}")
|
199 |
+
return {"country": "Unknown", "city": "Unknown"}
|
200 |
+
|
201 |
def log_conversation(timestamp, user_message, assistant_response, model_name, context, error=None, client_ip=None):
|
202 |
"""Log conversation details to JSON file - local directory or HuggingFace Dataset repository"""
|
203 |
+
# Get IP geolocation
|
204 |
+
ip_info = get_ip_info(client_ip) if client_ip else {"country": "Unknown", "city": "Unknown"}
|
205 |
+
|
206 |
# Create a log entry
|
207 |
log_entry = {
|
208 |
"timestamp": timestamp,
|
|
|
211 |
"assistant_response": assistant_response,
|
212 |
"context": context,
|
213 |
"error": str(error) if error else None,
|
214 |
+
"client_ip": client_ip,
|
215 |
+
"location": ip_info
|
216 |
}
|
217 |
|
218 |
# Check if running on Hugging Face Spaces
|
|
|
322 |
# Add history first without context
|
323 |
if history:
|
324 |
for h in history:
|
325 |
+
messages.append({"role": h["role"], "content": str(h["content"])})
|
|
|
|
|
326 |
|
327 |
# Add current message
|
328 |
messages.append({"role": "user", "content": str(message)})
|
|
|
443 |
info="Please try out the other AI models to use for responses (all LLMs are running on [GroqCloud](https://groq.com/groqrack/)) π"
|
444 |
)
|
445 |
|
446 |
+
chatbot = gr.Chatbot(height=600, type="messages")
|
447 |
with gr.Row(equal_height=True):
|
448 |
msg = gr.Textbox(
|
449 |
placeholder="What would you like to know? Or choose an example question...β",
|
|
|
466 |
)
|
467 |
|
468 |
def user(user_message, history):
|
469 |
+
history = history or []
|
470 |
+
return "", history + [{"role": "user", "content": user_message}]
|
471 |
|
472 |
def bot(history, model_name, request: gr.Request):
|
473 |
+
history = history or []
|
474 |
+
if history and history[-1]["role"] == "user":
|
475 |
+
user_message = history[-1]["content"]
|
476 |
+
history_without_last = history[:-1]
|
477 |
+
for response in chat_response(user_message, history_without_last, model_name, request):
|
478 |
+
history_with_response = history + [{"role": "assistant", "content": response}]
|
479 |
+
yield history_with_response
|
480 |
|
481 |
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
|
482 |
bot, [chatbot, model_selector], chatbot
|
|
|
485 |
bot, [chatbot, model_selector], chatbot
|
486 |
)
|
487 |
|
488 |
+
clear.click(lambda: [], None, chatbot, queue=False) # Updated to return empty list
|
489 |
gr.Markdown("<footer style='text-align: center; margin-top: 5px;'>π€ AI-generated content; while the Chat Assistant strives for accuracy, errors may occur; please thoroughly check critical information π€<br>β οΈ <strong><u>Disclaimer: This system was not produced by the British Antarctic Survey (BAS) and AI generated output does not reflect the views or opinions of BAS</u></strong> β οΈ <br>(just a bit of fun :D)</footer>")
|
490 |
demo.launch(
|
491 |
server_name="0.0.0.0",
|