Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import zipfile | |
import json | |
from dotenv import load_dotenv | |
from groq import Groq | |
import chromadb | |
from chromadb.config import Settings | |
import torch | |
from sentence_transformers import CrossEncoder | |
import gradio as gr | |
from datetime import datetime | |
from huggingface_hub import hf_hub_download, HfApi, CommitOperationAdd | |
from pathlib import Path | |
import tempfile | |
import spaces # for ZeroGPU | |
import requests # for IP geolocation | |
import time | |
# Load environment variables and initialize clients | |
load_dotenv() | |
GROQ_API_KEY = os.getenv("GROQ_API_KEY") | |
client = Groq(api_key=GROQ_API_KEY) | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
# Get the token from environment variables | |
hf_token = os.getenv("HF_TOKEN") | |
# Initialize global variables | |
chroma_client = None | |
collection = None | |
reranker = None | |
embedding_function = None | |
def load_chroma_db(): | |
print("Using ChromaDB from Hugging Face dataset...") | |
# Download the zipped database from Hugging Face | |
zip_path = hf_hub_download( | |
repo_id="Mr-Geo/chroma_db", | |
filename="chroma_db.zip", | |
repo_type="dataset", | |
use_auth_token=hf_token | |
) | |
print(f"Downloaded database zip to: {zip_path}") | |
# Extract to a temporary directory | |
extract_dir = "/tmp" # This will create /tmp/chroma_db | |
with zipfile.ZipFile(zip_path, 'r') as zip_ref: | |
print("Zip contents:", zip_ref.namelist()) | |
zip_ref.extractall(extract_dir) | |
db_path = os.path.join(extract_dir, "chroma_db") | |
print(f"Using ChromaDB path: {db_path}") | |
print(f"Directory contents: {os.listdir(db_path)}") | |
db = chromadb.PersistentClient( | |
path=db_path, | |
settings=Settings( | |
anonymized_telemetry=False, | |
allow_reset=True, | |
is_persistent=True | |
) | |
) | |
# Debug: Print collections | |
collections = db.list_collections() | |
print("Available collections:", collections) | |
return db | |
# Check if running locally | |
if os.path.exists("./chroma_db/chroma.sqlite3"): | |
print("Using local ChromaDB setup...") | |
db = chromadb.PersistentClient( | |
path="./chroma_db", | |
settings=Settings( | |
anonymized_telemetry=False, | |
allow_reset=True, | |
is_persistent=True | |
) | |
) | |
else: | |
# Load from Hugging Face dataset | |
db = load_chroma_db() | |
def initialize_system_sync(): | |
"""Initialize the system components without GPU decoration""" | |
global chroma_client, collection, reranker, embedding_function | |
# Add GPU diagnostics | |
print("\n=== GPU Diagnostics ===") | |
print(f"CUDA available: {torch.cuda.is_available()}") | |
if torch.cuda.is_available(): | |
print(f"Current CUDA device: {torch.cuda.current_device()}") | |
print(f"Device name: {torch.cuda.get_device_name()}") | |
print(f"Device memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB") | |
print("=====================\n") | |
# Use the same ChromaDB client that was loaded from HF | |
chroma_client = db # Use the global db instance we created | |
# Initialize the embedding function with retries | |
max_retries = 3 | |
retry_delay = 5 # seconds | |
for attempt in range(max_retries): | |
try: | |
print(f"\nAttempt {attempt + 1} of {max_retries} to initialize embedding function...") | |
embedding_function = chromadb.utils.embedding_functions.SentenceTransformerEmbeddingFunction( | |
model_name="sentence-transformers/all-mpnet-base-v2", | |
device=DEVICE | |
) | |
break | |
except Exception as e: | |
print(f"Error initializing embedding function: {str(e)}") | |
if attempt < max_retries - 1: | |
print(f"Retrying in {retry_delay} seconds...") | |
time.sleep(retry_delay) | |
else: | |
raise RuntimeError("Failed to initialize embedding function after multiple attempts") | |
# Get the collection | |
print("Getting collection...") | |
collection = chroma_client.get_collection(name="website_content", embedding_function=embedding_function) | |
print(f"Found {collection.count()} documents in collection") | |
# Initialize the reranker with retries | |
for attempt in range(max_retries): | |
try: | |
print(f"\nAttempt {attempt + 1} of {max_retries} to initialize reranker...") | |
reranker = CrossEncoder( | |
'cross-encoder/ms-marco-MiniLM-L-6-v2', | |
device=DEVICE, | |
max_length=512 # Add explicit max_length | |
) | |
if torch.cuda.is_available(): | |
reranker.model.to('cuda') | |
print("Reranker moved to GPU") | |
break | |
except Exception as e: | |
print(f"Error initializing reranker: {str(e)}") | |
if attempt < max_retries - 1: | |
print(f"Retrying in {retry_delay} seconds...") | |
time.sleep(retry_delay) | |
else: | |
raise RuntimeError("Failed to initialize reranker after multiple attempts") | |
def initialize_system(): | |
"""GPU-decorated initialization for Gradio context""" | |
initialize_system_sync() | |
# Add GPU decorator for get_context | |
def get_context(message): | |
global collection, reranker # Access global variables | |
results = collection.query( | |
query_texts=[message], | |
n_results=500, | |
include=["metadatas", "documents", "distances"] | |
) | |
print(f"\n=== Search Results ===") | |
print(f"Initial ChromaDB results found: {len(results['documents'][0])}") | |
# Rerank all results | |
rerank_pairs = [(message, doc) for doc in results['documents'][0]] | |
rerank_scores = reranker.predict(rerank_pairs) | |
# Create list of results with scores | |
all_results = [] | |
url_chunks = {} # Group chunks by URL | |
# Group chunks by URL and store their scores | |
for score, doc, metadata in zip(rerank_scores, results['documents'][0], results['metadatas'][0]): | |
url = metadata['url'] | |
if url not in url_chunks: | |
url_chunks[url] = [] | |
url_chunks[url].append({'text': doc, 'metadata': metadata, 'score': score}) | |
# For each URL, select the best chunks while maintaining diversity | |
for url, chunks in url_chunks.items(): | |
# Sort chunks for this URL by score | |
chunks.sort(key=lambda x: x['score'], reverse=True) | |
# Take up to 5 chunks per URL, but only if their scores are good | |
selected_chunks = [] | |
for chunk in chunks[:5]: # 5 chunks per URL | |
# Only include if score is decent | |
if chunk['score'] > -10: # Increased threshold to ensure higher relevance | |
selected_chunks.append(chunk) | |
# Add selected chunks to final results | |
all_results.extend(selected_chunks) | |
# Sort all results by score for final ranking | |
all_results.sort(key=lambda x: x['score'], reverse=True) | |
# Take only top 20 results maximum | |
all_results = all_results[:20] | |
print(f"\nFinal results after reranking and filtering: {len(all_results)}") | |
if all_results: | |
print("\nTop Similarity Scores and URLs:") | |
for i, result in enumerate(all_results[:20], 1): # Show only top 20 in logs | |
print(f"{i}. Score: {result['score']:.4f} - URL: {result['metadata']['url']}") | |
print("=" * 50) | |
# Build context from filtered results | |
context = "\nRelevant Information:\n" | |
total_chars = 0 | |
max_chars = 30000 # To ensure we don't exceed token limits | |
for result in all_results: | |
chunk_text = f"\nSource: {result['metadata']['url']}\n{result['text']}\n" | |
if total_chars + len(chunk_text) > max_chars: | |
break | |
context += chunk_text | |
total_chars += len(chunk_text) | |
print(f"\nFinal context length: {total_chars} characters") | |
return context | |
def get_ip_info(ip_address): | |
"""Get geolocation info for an IP address""" | |
if not ip_address: | |
return {"country": "Unknown", "city": "Unknown", "region": "Unknown"} | |
# Handle local/private IPs | |
if ip_address in ['127.0.0.1', 'localhost', '0.0.0.0'] or ip_address.startswith(('10.', '172.', '192.168.')): | |
return {"country": "Local Network", "city": "Local", "region": "Local"} | |
try: | |
# Add user-agent to be a good API citizen | |
headers = { | |
'User-Agent': 'BAS-Website-Chat/1.0' | |
} | |
response = requests.get( | |
f'https://ipapi.co/{ip_address}/json/', | |
headers=headers, | |
timeout=5 # 5 second timeout | |
) | |
if response.status_code == 200: | |
data = response.json() | |
# Check for error responses | |
if 'error' in data: | |
print(f"IP API error: {data.get('reason', 'Unknown error')}") | |
return {"country": "Unknown", "city": "Unknown", "region": "Unknown"} | |
return { | |
"country": data.get("country_name", "Unknown"), | |
"city": data.get("city", "Unknown"), | |
"region": data.get("region", "Unknown"), | |
"latitude": data.get("latitude"), | |
"longitude": data.get("longitude"), | |
"timezone": data.get("timezone"), | |
"org": data.get("org") | |
} | |
else: | |
print(f"IP API returned status code: {response.status_code}") | |
except requests.exceptions.Timeout: | |
print(f"Timeout getting IP info for {ip_address}") | |
except requests.exceptions.RequestException as e: | |
print(f"Error getting IP info: {str(e)}") | |
except Exception as e: | |
print(f"Unexpected error getting IP info: {str(e)}") | |
return {"country": "Unknown", "city": "Unknown", "region": "Unknown"} | |
def log_conversation(timestamp, user_message, assistant_response, model_name, context, error=None, client_ip=None): | |
"""Log conversation details to JSON file - local directory or HuggingFace Dataset repository""" | |
# Get IP geolocation | |
ip_info = get_ip_info(client_ip) if client_ip else {"country": "Unknown", "city": "Unknown"} | |
# Create a log entry | |
log_entry = { | |
"timestamp": timestamp, | |
"model_name": model_name, | |
"user_message": user_message, | |
"assistant_response": assistant_response, | |
"context": context, | |
"error": str(error) if error else None, | |
"client_ip": client_ip, | |
"location": ip_info | |
} | |
# Check if running on Hugging Face Spaces | |
is_hf_space = os.getenv('SPACE_ID') is not None | |
current_date = datetime.now().strftime("%Y-%m-%d") | |
if is_hf_space: | |
try: | |
# Initialize Hugging Face API | |
api = HfApi(token=hf_token) | |
filename = f"conversation_logs/daily_{current_date}.json" | |
# Check if the dataset repository exists, if not create it | |
try: | |
api.repo_info(repo_id="Mr-Geo/chroma_db", repo_type="dataset") | |
except Exception: | |
api.create_repo( | |
repo_id="Mr-Geo/chroma_db", | |
repo_type="dataset", | |
private=True | |
) | |
try: | |
# Try to download existing file | |
existing_file = api.hf_hub_download( | |
repo_id="Mr-Geo/chroma_db", | |
filename=filename, | |
repo_type="dataset", | |
token=hf_token | |
) | |
# Load existing logs | |
with open(existing_file, 'r', encoding='utf-8') as f: | |
logs = json.load(f) | |
except Exception: | |
# File doesn't exist yet, start with empty list | |
logs = [] | |
# Append new log entry | |
logs.append(log_entry) | |
# Create temporary file with updated logs | |
with tempfile.NamedTemporaryFile(mode='w', encoding='utf-8', delete=False, suffix='.json') as temp_file: | |
json.dump(logs, temp_file, ensure_ascii=False, indent=2) | |
temp_file_path = temp_file.name | |
# Push to the dataset repository | |
operations = [ | |
CommitOperationAdd( | |
path_in_repo=filename, | |
path_or_fileobj=temp_file_path | |
) | |
] | |
api.create_commit( | |
repo_id="Mr-Geo/chroma_db", | |
repo_type="dataset", | |
operations=operations, | |
commit_message=f"Update conversation logs for {current_date}" | |
) | |
# Clean up temporary file | |
os.unlink(temp_file_path) | |
except Exception as e: | |
print(f"\nβ οΈ Error logging conversation to HuggingFace: {str(e)}") | |
else: | |
# Local environment - save to file | |
try: | |
log_dir = Path("logs") | |
log_dir.mkdir(exist_ok=True) | |
log_file = log_dir / f"conversation_log_{current_date}.json" | |
# Load existing logs if file exists | |
if log_file.exists(): | |
with open(log_file, 'r', encoding='utf-8') as f: | |
logs = json.load(f) | |
else: | |
logs = [] | |
# Append new log entry | |
logs.append(log_entry) | |
# Write updated logs | |
with open(log_file, 'w', encoding='utf-8') as f: | |
json.dump(logs, f, ensure_ascii=False, indent=2) | |
except Exception as e: | |
print(f"\nβ οΈ Error logging conversation locally: {str(e)}") | |
def chat_response(message, history, model_name, request: gr.Request): | |
"""Chat response function for Gradio interface""" | |
try: | |
# Get client IP address with better proxy handling | |
client_ip = None | |
if request: | |
# Try to get real IP from headers in order of reliability | |
client_ip = ( | |
request.headers.get('X-Forwarded-For', '').split(',')[0].strip() or | |
request.headers.get('X-Real-IP') or | |
request.headers.get('CF-Connecting-IP') or # Cloudflare | |
request.client.host | |
) | |
print(f"\nClient IP detected: {client_ip}") | |
print(f"Request headers: {request.headers}") | |
# Append 'at BAS' to the user's message | |
message += " at BAS" | |
# Get context and timestamp | |
context = get_context(message) | |
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
# Build messages list starting with a clean system message for history | |
messages = [] | |
# Add history first without context | |
if history: | |
for h in history: | |
messages.append({"role": h["role"], "content": str(h["content"])}) | |
# Add current message | |
messages.append({"role": "user", "content": str(message)}) | |
# Insert system message with context at the beginning | |
messages.insert(0, { | |
"role": "system", | |
"content": f"""You are an AI assistant for the British Antarctic Survey (BAS). Your responses should be based ONLY on the context provided below. | |
IMPORTANT INSTRUCTIONS: | |
1. ALWAYS thoroughly check the provided context before saying you don't have information. | |
2. If you find ANY relevant information in the context, use it - even if it's not complete. | |
3. If you find time-sensitive information in the context, share it - it's current as of when the context was retrieved. | |
4. When citing sources, you MUST always provide the URL source after the relevant information, like this: | |
Here is some information about BAS. | |
Source: https://www.bas.ac.uk/example | |
5. Do not say things like: | |
- "I don't have access to real-time information." | |
- "I cannot browse the internet." | |
Instead, share what IS in the context, and only say "I don't have enough information" if you truly find nothing relevant to the user's question. | |
6. Keep responses: | |
- With emojis where appropriate. | |
- Without duplicate source citations. | |
- Based on the context below. | |
Current Time: {timestamp} | |
Context: {context}""" | |
}) | |
print("\n\n==========START Contents of the message being sent to the LLM==========\n") | |
print(messages) | |
print("\n\n==========END Contents of the message being sent to the LLM==========\n") | |
# Get response | |
response = "" | |
completion = client.chat.completions.create( | |
model=model_name, | |
messages=messages, | |
temperature=0.7, | |
max_tokens=2500, | |
top_p=0.95, | |
stream=True | |
) | |
print("\n=== LLM Response Start ===") | |
thinking_process = "" | |
final_response = "" | |
is_thinking = False | |
for chunk in completion: | |
if chunk.choices[0].delta.content: | |
content = chunk.choices[0].delta.content | |
print(content, end='', flush=True) | |
# Check for thinking tags | |
if "<think>" in content: | |
is_thinking = True | |
continue | |
elif "</think>" in content: | |
is_thinking = False | |
# Create collapsible thinking section | |
if thinking_process: | |
final_response = f"""<details> | |
<summary>π€ <u>Click to see 'thinking' process</u></summary> | |
<div style="font-size: 0.9em;"> | |
<i>π{thinking_process}</i> | |
</div> | |
<hr style="margin: 0; height: 2px;"> | |
</details> | |
{final_response}""" | |
continue | |
# Append content to appropriate section | |
if is_thinking: | |
thinking_process += content | |
else: | |
final_response += content | |
yield final_response | |
log_conversation(timestamp, message, final_response, model_name, context, client_ip=client_ip) | |
print("\n=== LLM Response End ===\n") | |
except Exception as e: | |
error_msg = f"An error occurred: {str(e)}" | |
print(f"\nERROR: {error_msg}") | |
log_conversation(datetime.now().strftime("%Y-%m-%d %H:%M:%S"), | |
message, error_msg, model_name, context, error=e, client_ip=client_ip) | |
yield error_msg | |
if __name__ == "__main__": | |
try: | |
print("\n=== Starting Application ===") | |
Path("logs").mkdir(exist_ok=True) | |
print("Initialising ChromaDB...") | |
initialize_system_sync() # Use the synchronous version for initial setup | |
if collection is None: | |
raise RuntimeError("Failed to initialize collection") | |
print(f"Found {collection.count()} documents in collection") | |
print("\nCreating Gradio interface...") | |
demo = gr.Blocks() | |
with demo: | |
gr.Markdown("# πβοΈBritish Antarctic Survey Website Chat Assistant π§π€") | |
gr.Markdown("Accesses text data from 11,982 unique BAS URLs (6GB [Vector Database](https://huggingface.co/datasets/Mr-Geo/chroma_db/tree/main/) π extracted 02/02/2025) Created with open source technologies: [Gradio](https://gradio.app) for UI π¨, [Hugging Face](https://huggingface.co/) models for embeddings β‘, and [Chroma](https://www.trychroma.com/) as the vector database π»") | |
model_selector = gr.Dropdown( | |
choices=[ | |
("Llama 3.3 - Versatile π¦β¨", "llama-3.3-70b-versatile"), | |
("Llama 4 - Latest π", "meta-llama/llama-4-scout-17b-16e-instruct"), | |
("Mistral Saba - Balanced βοΈ", "mistral-saba-24b"), | |
("DeepSeek - Reasoning π§ π", "deepseek-r1-distill-llama-70b"), | |
("Compound Beta - Agentic & Live Search π οΈπ", "compound-beta") | |
], | |
value="llama-3.3-70b-versatile", | |
label="Select AI Large Language Model π€", | |
info="Please try out the other AI models to use for responses (all LLMs are running on [GroqCloud](https://groq.com/groqrack/)) - Compound Beta includes live internet searching! π" | |
) | |
chatbot = gr.Chatbot(height=600, type="messages") | |
with gr.Row(equal_height=True): | |
msg = gr.Textbox( | |
placeholder="What would you like to know about BAS? Or choose an example question...β", | |
label="Your question π€", | |
show_label=True, | |
container=True, | |
submit_btn=True, | |
scale=20, | |
) | |
clear = gr.Button("Clear Chat History π§Ή (Click here if any errors are returned and ask question again)") | |
gr.Examples( | |
examples=[ | |
"What research stations does BAS operate in Antarctica? ποΈ", | |
"Tell me about the RRS Sir David Attenborough π’", | |
"What are the latest climate research findings from BAS? π", | |
"What current projects is BAS working on in Antarctica? π¬", | |
"What's the latest news about BAS's Antarctic operations? π°", | |
"What's happening at Rothera Research Station right now? π‘οΈ" | |
], | |
inputs=msg, | |
) | |
def user(user_message, history): | |
history = history or [] | |
return "", history + [{"role": "user", "content": user_message}] | |
def bot(history, model_name, request: gr.Request): | |
history = history or [] | |
if history and history[-1]["role"] == "user": | |
user_message = history[-1]["content"] | |
history_without_last = history[:-1] | |
for response in chat_response(user_message, history_without_last, model_name, request): | |
history_with_response = history + [{"role": "assistant", "content": response}] | |
yield history_with_response | |
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( | |
bot, [chatbot, model_selector], chatbot | |
) | |
clear.click(lambda: [], None, chatbot, queue=False) # Updated to return empty list | |
gr.Markdown("<footer style='text-align: center; margin-top: 5px;'>π€ AI-generated content; while the Chat Assistant strives for accuracy, errors may occur; please thoroughly check critical information π€<br>β οΈ <strong><u>Disclaimer: This system was not produced by the British Antarctic Survey (BAS) and AI generated output does not reflect the views or opinions of BAS</u></strong> β οΈ <br>(just a bit of fun :D)</footer>") | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=False, | |
show_api=False | |
) | |
except Exception as e: | |
print(f"\nERROR: {str(e)}") | |
raise |