Mr-Geo commited on
Commit
5fb219a
Β·
verified Β·
1 Parent(s): d8450ea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -12
app.py CHANGED
@@ -12,6 +12,8 @@ from datetime import datetime
12
  from huggingface_hub import hf_hub_download, HfApi, CommitOperationAdd
13
  from pathlib import Path
14
  import tempfile
 
 
15
 
16
  # Load environment variables and initialize clients
17
  load_dotenv()
@@ -74,9 +76,19 @@ else:
74
  # Load from Hugging Face dataset
75
  db = load_chroma_db()
76
 
 
77
  def initialize_system():
78
  """Initialize the system components"""
79
 
 
 
 
 
 
 
 
 
 
80
  # Use the same ChromaDB client that was loaded from HF
81
  chroma_client = db # Use the global db instance we created
82
 
@@ -91,12 +103,16 @@ def initialize_system():
91
  collection = chroma_client.get_collection(name="website_content", embedding_function=embedding_function)
92
  print(f"Found {collection.count()} documents in collection")
93
 
94
- # Initialize the reranker
95
  print("\nInitialising Cross-Encoder...")
96
  reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', device=DEVICE)
 
 
 
97
 
98
  return chroma_client, collection, reranker
99
 
 
100
  def get_context(message):
101
  results = collection.query(
102
  query_texts=[message],
@@ -165,8 +181,28 @@ def get_context(message):
165
  print(f"\nFinal context length: {total_chars} characters")
166
  return context
167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  def log_conversation(timestamp, user_message, assistant_response, model_name, context, error=None, client_ip=None):
169
  """Log conversation details to JSON file - local directory or HuggingFace Dataset repository"""
 
 
 
170
  # Create a log entry
171
  log_entry = {
172
  "timestamp": timestamp,
@@ -175,7 +211,8 @@ def log_conversation(timestamp, user_message, assistant_response, model_name, co
175
  "assistant_response": assistant_response,
176
  "context": context,
177
  "error": str(error) if error else None,
178
- "client_ip": client_ip
 
179
  }
180
 
181
  # Check if running on Hugging Face Spaces
@@ -285,9 +322,7 @@ def chat_response(message, history, model_name, request: gr.Request):
285
  # Add history first without context
286
  if history:
287
  for h in history:
288
- messages.append({"role": "user", "content": str(h[0])})
289
- if h[1]: # If there's a response
290
- messages.append({"role": "assistant", "content": str(h[1])})
291
 
292
  # Add current message
293
  messages.append({"role": "user", "content": str(message)})
@@ -408,7 +443,7 @@ if __name__ == "__main__":
408
  info="Please try out the other AI models to use for responses (all LLMs are running on [GroqCloud](https://groq.com/groqrack/)) 😊"
409
  )
410
 
411
- chatbot = gr.Chatbot(height=600)
412
  with gr.Row(equal_height=True):
413
  msg = gr.Textbox(
414
  placeholder="What would you like to know? Or choose an example question...❓",
@@ -431,13 +466,17 @@ if __name__ == "__main__":
431
  )
432
 
433
  def user(user_message, history):
434
- return "", history + [[user_message, None]]
 
435
 
436
  def bot(history, model_name, request: gr.Request):
437
- if history and history[-1][1] is None:
438
- for response in chat_response(history[-1][0], history[:-1], model_name, request):
439
- history[-1][1] = response
440
- yield history
 
 
 
441
 
442
  msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
443
  bot, [chatbot, model_selector], chatbot
@@ -446,7 +485,7 @@ if __name__ == "__main__":
446
  bot, [chatbot, model_selector], chatbot
447
  )
448
 
449
- clear.click(lambda: None, None, chatbot, queue=False)
450
  gr.Markdown("<footer style='text-align: center; margin-top: 5px;'>πŸ€– AI-generated content; while the Chat Assistant strives for accuracy, errors may occur; please thoroughly check critical information πŸ€–<br>⚠️ <strong><u>Disclaimer: This system was not produced by the British Antarctic Survey (BAS) and AI generated output does not reflect the views or opinions of BAS</u></strong> ⚠️ <br>(just a bit of fun :D)</footer>")
451
  demo.launch(
452
  server_name="0.0.0.0",
 
12
  from huggingface_hub import hf_hub_download, HfApi, CommitOperationAdd
13
  from pathlib import Path
14
  import tempfile
15
+ import spaces # for ZeroGPU
16
+ import requests # for IP geolocation
17
 
18
  # Load environment variables and initialize clients
19
  load_dotenv()
 
76
  # Load from Hugging Face dataset
77
  db = load_chroma_db()
78
 
79
+ @spaces.GPU(memory="40g") # Add GPU decorator for initialize_system
80
  def initialize_system():
81
  """Initialize the system components"""
82
 
83
+ # Add GPU diagnostics
84
+ print("\n=== GPU Diagnostics ===")
85
+ print(f"CUDA available: {torch.cuda.is_available()}")
86
+ if torch.cuda.is_available():
87
+ print(f"Current CUDA device: {torch.cuda.current_device()}")
88
+ print(f"Device name: {torch.cuda.get_device_name()}")
89
+ print(f"Device memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
90
+ print("=====================\n")
91
+
92
  # Use the same ChromaDB client that was loaded from HF
93
  chroma_client = db # Use the global db instance we created
94
 
 
103
  collection = chroma_client.get_collection(name="website_content", embedding_function=embedding_function)
104
  print(f"Found {collection.count()} documents in collection")
105
 
106
+ # Initialize the reranker and explicitly move to GPU if available
107
  print("\nInitialising Cross-Encoder...")
108
  reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', device=DEVICE)
109
+ if torch.cuda.is_available():
110
+ reranker.model.to('cuda') # Ensure model is on GPU
111
+ print("Reranker moved to GPU")
112
 
113
  return chroma_client, collection, reranker
114
 
115
+ @spaces.GPU(memory="40g") # Add GPU decorator for get_context
116
  def get_context(message):
117
  results = collection.query(
118
  query_texts=[message],
 
181
  print(f"\nFinal context length: {total_chars} characters")
182
  return context
183
 
184
+ def get_ip_info(ip_address):
185
+ """Get geolocation info for an IP address"""
186
+ if ip_address in ['127.0.0.1', 'localhost', '0.0.0.0']:
187
+ return {"country": "Local", "city": "Local"}
188
+ try:
189
+ response = requests.get(f'https://ipapi.co/{ip_address}/json/')
190
+ if response.status_code == 200:
191
+ data = response.json()
192
+ return {
193
+ "country": data.get("country_name", "Unknown"),
194
+ "city": data.get("city", "Unknown"),
195
+ "region": data.get("region", "Unknown")
196
+ }
197
+ except Exception as e:
198
+ print(f"Error getting IP info: {str(e)}")
199
+ return {"country": "Unknown", "city": "Unknown"}
200
+
201
  def log_conversation(timestamp, user_message, assistant_response, model_name, context, error=None, client_ip=None):
202
  """Log conversation details to JSON file - local directory or HuggingFace Dataset repository"""
203
+ # Get IP geolocation
204
+ ip_info = get_ip_info(client_ip) if client_ip else {"country": "Unknown", "city": "Unknown"}
205
+
206
  # Create a log entry
207
  log_entry = {
208
  "timestamp": timestamp,
 
211
  "assistant_response": assistant_response,
212
  "context": context,
213
  "error": str(error) if error else None,
214
+ "client_ip": client_ip,
215
+ "location": ip_info
216
  }
217
 
218
  # Check if running on Hugging Face Spaces
 
322
  # Add history first without context
323
  if history:
324
  for h in history:
325
+ messages.append({"role": h["role"], "content": str(h["content"])})
 
 
326
 
327
  # Add current message
328
  messages.append({"role": "user", "content": str(message)})
 
443
  info="Please try out the other AI models to use for responses (all LLMs are running on [GroqCloud](https://groq.com/groqrack/)) 😊"
444
  )
445
 
446
+ chatbot = gr.Chatbot(height=600, type="messages")
447
  with gr.Row(equal_height=True):
448
  msg = gr.Textbox(
449
  placeholder="What would you like to know? Or choose an example question...❓",
 
466
  )
467
 
468
  def user(user_message, history):
469
+ history = history or []
470
+ return "", history + [{"role": "user", "content": user_message}]
471
 
472
  def bot(history, model_name, request: gr.Request):
473
+ history = history or []
474
+ if history and history[-1]["role"] == "user":
475
+ user_message = history[-1]["content"]
476
+ history_without_last = history[:-1]
477
+ for response in chat_response(user_message, history_without_last, model_name, request):
478
+ history_with_response = history + [{"role": "assistant", "content": response}]
479
+ yield history_with_response
480
 
481
  msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
482
  bot, [chatbot, model_selector], chatbot
 
485
  bot, [chatbot, model_selector], chatbot
486
  )
487
 
488
+ clear.click(lambda: [], None, chatbot, queue=False) # Updated to return empty list
489
  gr.Markdown("<footer style='text-align: center; margin-top: 5px;'>πŸ€– AI-generated content; while the Chat Assistant strives for accuracy, errors may occur; please thoroughly check critical information πŸ€–<br>⚠️ <strong><u>Disclaimer: This system was not produced by the British Antarctic Survey (BAS) and AI generated output does not reflect the views or opinions of BAS</u></strong> ⚠️ <br>(just a bit of fun :D)</footer>")
490
  demo.launch(
491
  server_name="0.0.0.0",