Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import zipfile | |
import json | |
from dotenv import load_dotenv | |
from groq import Groq | |
import chromadb | |
from chromadb.config import Settings | |
import torch | |
from sentence_transformers import CrossEncoder | |
import gradio as gr | |
from datetime import datetime | |
from huggingface_hub import hf_hub_download, HfApi, CommitOperationAdd | |
from pathlib import Path | |
import tempfile | |
# Load environment variables and initialize clients | |
load_dotenv() | |
GROQ_API_KEY = os.getenv("GROQ_API_KEY") | |
client = Groq(api_key=GROQ_API_KEY) | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
# Get the token from environment variables | |
hf_token = os.getenv("HF_TOKEN") | |
def load_chroma_db(): | |
print("Using ChromaDB from Hugging Face dataset...") | |
# Download the zipped database from Hugging Face | |
zip_path = hf_hub_download( | |
repo_id="Mr-Geo/chroma_db", | |
filename="chroma_db.zip", | |
repo_type="dataset", | |
use_auth_token=hf_token | |
) | |
print(f"Downloaded database zip to: {zip_path}") | |
# Extract to a temporary directory | |
extract_dir = "/tmp" # This will create /tmp/chroma_db | |
with zipfile.ZipFile(zip_path, 'r') as zip_ref: | |
print("Zip contents:", zip_ref.namelist()) | |
zip_ref.extractall(extract_dir) | |
db_path = os.path.join(extract_dir, "chroma_db") | |
print(f"Using ChromaDB path: {db_path}") | |
print(f"Directory contents: {os.listdir(db_path)}") | |
db = chromadb.PersistentClient( | |
path=db_path, | |
settings=Settings( | |
anonymized_telemetry=False, | |
allow_reset=True, | |
is_persistent=True | |
) | |
) | |
# Debug: Print collections | |
collections = db.list_collections() | |
print("Available collections:", collections) | |
return db | |
# Check if running locally | |
if os.path.exists("./chroma_db/chroma.sqlite3"): | |
print("Using local ChromaDB setup...") | |
db = chromadb.PersistentClient( | |
path="./chroma_db", | |
settings=Settings( | |
anonymized_telemetry=False, | |
allow_reset=True, | |
is_persistent=True | |
) | |
) | |
else: | |
# Load from Hugging Face dataset | |
db = load_chroma_db() | |
def initialize_system(): | |
"""Initialize the system components""" | |
# Use the same ChromaDB client that was loaded from HF | |
chroma_client = db # Use the global db instance we created | |
# Initialize the embedding function | |
embedding_function = chromadb.utils.embedding_functions.SentenceTransformerEmbeddingFunction( | |
model_name="sentence-transformers/all-mpnet-base-v2", | |
device=DEVICE | |
) | |
# Get the collection | |
print("Getting collection...") | |
collection = chroma_client.get_collection(name="website_content", embedding_function=embedding_function) | |
print(f"Found {collection.count()} documents in collection") | |
# Initialize the reranker | |
print("\nInitialising Cross-Encoder...") | |
reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', device=DEVICE) | |
return chroma_client, collection, reranker | |
def get_context(message): | |
results = collection.query( | |
query_texts=[message], | |
n_results=500, | |
include=["metadatas", "documents", "distances"] | |
) | |
print(f"\n=== Search Results ===") | |
print(f"Initial ChromaDB results found: {len(results['documents'][0])}") | |
# Rerank all results | |
rerank_pairs = [(message, doc) for doc in results['documents'][0]] | |
rerank_scores = reranker.predict(rerank_pairs) | |
# Create list of results with scores | |
all_results = [] | |
url_chunks = {} # Group chunks by URL | |
# Group chunks by URL and store their scores | |
for score, doc, metadata in zip(rerank_scores, results['documents'][0], results['metadatas'][0]): | |
url = metadata['url'] | |
if url not in url_chunks: | |
url_chunks[url] = [] | |
url_chunks[url].append({'text': doc, 'metadata': metadata, 'score': score}) | |
# For each URL, select the best chunks while maintaining diversity | |
for url, chunks in url_chunks.items(): | |
# Sort chunks for this URL by score | |
chunks.sort(key=lambda x: x['score'], reverse=True) | |
# Take up to 5 chunks per URL, but only if their scores are good | |
selected_chunks = [] | |
for chunk in chunks[:5]: # 5 chunks per URL | |
# Only include if score is decent | |
if chunk['score'] > -10: # Increased threshold to ensure higher relevance | |
selected_chunks.append(chunk) | |
# Add selected chunks to final results | |
all_results.extend(selected_chunks) | |
# Sort all results by score for final ranking | |
all_results.sort(key=lambda x: x['score'], reverse=True) | |
# Take only top 20 results maximum | |
all_results = all_results[:20] | |
print(f"\nFinal results after reranking and filtering: {len(all_results)}") | |
if all_results: | |
print("\nTop Similarity Scores and URLs:") | |
for i, result in enumerate(all_results[:20], 1): # Show only top 20 in logs | |
print(f"{i}. Score: {result['score']:.4f} - URL: {result['metadata']['url']}") | |
print("=" * 50) | |
# Build context from filtered results | |
context = "\nRelevant Information:\n" | |
total_chars = 0 | |
max_chars = 30000 # To ensure we don't exceed token limits | |
for result in all_results: | |
chunk_text = f"\nSource: {result['metadata']['url']}\n{result['text']}\n" | |
if total_chars + len(chunk_text) > max_chars: | |
break | |
context += chunk_text | |
total_chars += len(chunk_text) | |
print(f"\nFinal context length: {total_chars} characters") | |
return context | |
def log_conversation(timestamp, user_message, assistant_response, model_name, context, error=None): | |
"""Log conversation details to JSON file - local directory or HuggingFace Dataset repository""" | |
# Create a log entry | |
log_entry = { | |
"timestamp": timestamp, | |
"model_name": model_name, | |
"user_message": user_message, | |
"assistant_response": assistant_response, | |
"context": context, | |
"error": str(error) if error else None | |
} | |
# Check if running on Hugging Face Spaces | |
is_hf_space = os.getenv('SPACE_ID') is not None | |
current_date = datetime.now().strftime("%Y-%m-%d") | |
if is_hf_space: | |
try: | |
# Initialize Hugging Face API | |
api = HfApi(token=hf_token) | |
filename = f"conversation_logs/daily_{current_date}.json" | |
try: | |
# Try to download existing file | |
existing_file = api.hf_hub_download( | |
repo_id="Mr-Geo/bas_chat_logs", | |
filename=filename, | |
repo_type="dataset", | |
token=hf_token | |
) | |
# Load existing logs | |
with open(existing_file, 'r', encoding='utf-8') as f: | |
logs = json.load(f) | |
except Exception: | |
# File doesn't exist yet, start with empty list | |
logs = [] | |
# Append new log entry | |
logs.append(log_entry) | |
# Create temporary file with updated logs | |
with tempfile.NamedTemporaryFile(mode='w', encoding='utf-8', delete=False, suffix='.json') as temp_file: | |
json.dump(logs, temp_file, ensure_ascii=False, indent=2) | |
temp_file_path = temp_file.name | |
# Push to the dataset repository | |
operations = [ | |
CommitOperationAdd( | |
path_in_repo=filename, | |
path_or_fileobj=temp_file_path | |
) | |
] | |
api.create_commit( | |
repo_id="Mr-Geo/bas_chat_logs", | |
repo_type="dataset", | |
operations=operations, | |
commit_message=f"Update conversation logs for {current_date}" | |
) | |
# Clean up temporary file | |
os.unlink(temp_file_path) | |
except Exception as e: | |
print(f"\nβ οΈ Error logging conversation to HuggingFace: {str(e)}") | |
else: | |
# Local environment - save to file | |
try: | |
log_dir = Path("logs") | |
log_dir.mkdir(exist_ok=True) | |
log_file = log_dir / f"conversation_log_{current_date}.json" | |
# Load existing logs if file exists | |
if log_file.exists(): | |
with open(log_file, 'r', encoding='utf-8') as f: | |
logs = json.load(f) | |
else: | |
logs = [] | |
# Append new log entry | |
logs.append(log_entry) | |
# Write updated logs | |
with open(log_file, 'w', encoding='utf-8') as f: | |
json.dump(logs, f, ensure_ascii=False, indent=2) | |
except Exception as e: | |
print(f"\nβ οΈ Error logging conversation locally: {str(e)}") | |
def chat_response(message, history, model_name): | |
"""Chat response function for Gradio interface""" | |
try: | |
# Get context and timestamp | |
context = get_context(message) | |
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
# Build messages list | |
messages = [{ | |
"role": "system", | |
"content": f"""You are an AI assistant for the British Antarctic Survey (BAS). Your responses should be based ONLY on the context provided below. | |
IMPORTANT INSTRUCTIONS: | |
1. ALWAYS thoroughly check the provided context before saying you don't have information | |
2. If you find ANY relevant information in the context, use it - even if it's not complete | |
3. If you find time-sensitive information in the context, share it - it's current as of when the context was retrieved | |
4. When citing sources, put them on a new line after the relevant information like this: | |
Here is some information about BAS. | |
Source: https://www.bas.ac.uk/example | |
5. Do not say things like: | |
- "I don't have access to real-time information" | |
- "I cannot browse the internet" | |
Instead, share what IS in the context, and only say "I don't have enough information" if you truly find nothing relevant to the users question. | |
6. Keep responses: | |
- With emojis where appropriate | |
- Without duplicate source citations | |
- Based strictly on the context below | |
Current Time: {timestamp} | |
Context: {context}""" | |
}] | |
print("\n\n==========START Contents of the message being sent to the LLM==========\n") | |
print(messages) | |
print("\n\n==========END Contents of the message being sent to the LLM==========\n") | |
# Add history and current message | |
if history: | |
for h in history: | |
messages.append({"role": "user", "content": f"{str(h[0])} at BAS"}) | |
if h[1]: # If there's a response | |
messages.append({"role": "assistant", "content": str(h[1])}) | |
messages.append({"role": "user", "content": str(message)}) | |
# Get response | |
response = "" | |
completion = client.chat.completions.create( | |
model=model_name, | |
messages=messages, | |
temperature=0.7, | |
max_tokens=2500, | |
top_p=0.95, | |
stream=True | |
) | |
print("\n=== LLM Response Start ===") | |
thinking_process = "" | |
final_response = "" | |
is_thinking = False | |
for chunk in completion: | |
if chunk.choices[0].delta.content: | |
content = chunk.choices[0].delta.content | |
print(content, end='', flush=True) | |
# Check for thinking tags | |
if "<think>" in content: | |
is_thinking = True | |
continue | |
elif "</think>" in content: | |
is_thinking = False | |
# Create collapsible thinking section | |
if thinking_process: | |
final_response = f"""<details> | |
<summary>π€ <u>Click to see 'thinking' process</u></summary> | |
<div style="font-size: 0.9em;"> | |
<i>π{thinking_process}</i> | |
</div> | |
<hr style="margin: 0; height: 2px;"> | |
</details> | |
{final_response}""" | |
continue | |
# Append content to appropriate section | |
if is_thinking: | |
thinking_process += content | |
else: | |
final_response += content | |
yield final_response | |
log_conversation(timestamp, message, final_response, model_name, context) | |
print("\n=== LLM Response End ===\n") | |
except Exception as e: | |
error_msg = f"An error occurred: {str(e)}" | |
print(f"\nERROR: {error_msg}") | |
log_conversation(datetime.now().strftime("%Y-%m-%d %H:%M:%S"), | |
message, error_msg, model_name, context, error=e) | |
yield error_msg | |
if __name__ == "__main__": | |
try: | |
print("\n=== Starting Application ===") | |
Path("logs").mkdir(exist_ok=True) | |
print("Initialising ChromaDB...") | |
chroma_client, collection, reranker = initialize_system() | |
print(f"Found {collection.count()} documents in collection") | |
print("\nCreating Gradio interface...") | |
demo = gr.Blocks() | |
with demo: | |
gr.Markdown("# πβοΈBritish Antarctic Survey Website Chat Assistant π§π€") | |
gr.Markdown("Accesses text data from 11,982 unique BAS URLs (6GB [Vector Database](https://huggingface.co/datasets/Mr-Geo/chroma_db/tree/main/) π extracted 02/02/2025) Created with open source technologies: [Gradio](https://gradio.app) for the interface π¨, [Groq](https://groq.com) for LLM processing β‘, and [Chroma](https://www.trychroma.com/) as the vector database π»") | |
model_selector = gr.Dropdown( | |
choices=[ | |
"llama-3.1-8b-instant", | |
"llama-3.3-70b-versatile", | |
"llama-3.3-70b-specdec", | |
"mixtral-8x7b-32768", | |
"deepseek-r1-distill-llama-70b" | |
], | |
value="llama-3.1-8b-instant", | |
label="Select AI Large Language Model π€", | |
info="Choose which AI model to use for responses (all models running on [GroqCloud](https://groq.com/groqrack/)" | |
) | |
chatbot = gr.Chatbot(height=600) | |
with gr.Row(equal_height=True): | |
msg = gr.Textbox( | |
placeholder="What would you like to know? Or choose an example question...β", | |
label="Your question", | |
show_label=True, | |
container=True, | |
scale=20 | |
) | |
send = gr.Button("Send β¬οΈ", scale=1, min_width=50) | |
clear = gr.Button("Clear chat history π§Ή (Click here if any errors are returned)") | |
gr.Examples( | |
examples=[ | |
"What research stations does BAS operate in Antarctica? ποΈ", | |
"Tell me about the RRS Sir David Attenborough π’", | |
"What kind of science and research does BAS do? π¬", | |
"What is BAS doing about climate change? π‘οΈ", | |
], | |
inputs=msg, | |
) | |
def user(user_message, history): | |
return "", history + [[user_message, None]] | |
def bot(history, model_name): | |
if history and history[-1][1] is None: | |
for response in chat_response(history[-1][0], history[:-1], model_name): | |
history[-1][1] = response | |
yield history | |
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( | |
bot, [chatbot, model_selector], chatbot | |
) | |
send.click(user, [msg, chatbot], [msg, chatbot], queue=False).then( | |
bot, [chatbot, model_selector], chatbot | |
) | |
clear.click(lambda: None, None, chatbot, queue=False) | |
gr.Markdown("<footer style='text-align: center; margin-top: 5px;'>π€ AI-generated content; while the Chat Assistant strives for accuracy, errors may occur; please thoroughly check critical information π€<br>β οΈ <strong><u>Disclaimer: This system was not produced by the British Antarctic Survey (BAS) and AI generated output does not reflect the views or opinions of BAS</u></strong> β οΈ <br>(just a bit of fun :D)</footer>") | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=False | |
) | |
except Exception as e: | |
print(f"\nERROR: {str(e)}") | |
raise |